Initial commit

This commit is contained in:
René Mathieu
2026-01-17 13:49:51 +01:00
commit 0fef8d96c5
1897 changed files with 396119 additions and 0 deletions

View File

@@ -0,0 +1 @@
5276d46021e0e0d7577e0c9155800cbf62932d60a50783fec42aefb63febedec /Users/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd

View File

@@ -0,0 +1 @@
d067f01423cddb3c442933b5fcc039b18ab651fcec1bc91c577693aafc25cf78 /Users/runner/work/aiohttp/aiohttp/aiohttp/_find_header.pxd

View File

@@ -0,0 +1 @@
f9823c608638b8a77fec1c2bd28dc5c043f08c775ec59e7e262dd74b4ebb7384 /Users/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx

View File

@@ -0,0 +1 @@
56514404ce87a15bfc6b400026d73a270165b2fdbe70313cfa007de29ddd7e14 /Users/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx

View File

@@ -0,0 +1 @@
dab8f933203eeb245d60f856e542a45b888d5a110094620e4811f90f816628d1 /Users/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py

View File

@@ -0,0 +1,278 @@
__version__ = "3.13.3"
from typing import TYPE_CHECKING, Tuple
from . import hdrs as hdrs
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientRequest,
ClientResponse,
ClientResponseError,
ClientSession,
ClientSSLError,
ClientTimeout,
ClientWebSocketResponse,
ClientWSTimeout,
ConnectionTimeoutError,
ContentTypeError,
Fingerprint,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NamedPipeConnector,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
RequestInfo,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TCPConnector,
TooManyRedirects,
UnixConnector,
WSMessageTypeError,
WSServerHandshakeError,
request,
)
from .client_middleware_digest_auth import DigestAuthMiddleware
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
from .compression_utils import set_zlib_backend
from .connector import (
AddrInfoType as AddrInfoType,
SocketFactoryType as SocketFactoryType,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth, ChainMapProxy, ETag
from .http import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
WebSocketError as WebSocketError,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
)
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
BadContentDispositionParam as BadContentDispositionParam,
BodyPartReader as BodyPartReader,
MultipartReader as MultipartReader,
MultipartWriter as MultipartWriter,
content_disposition_filename as content_disposition_filename,
parse_content_disposition as parse_content_disposition,
)
from .payload import (
PAYLOAD_REGISTRY as PAYLOAD_REGISTRY,
AsyncIterablePayload as AsyncIterablePayload,
BufferedReaderPayload as BufferedReaderPayload,
BytesIOPayload as BytesIOPayload,
BytesPayload as BytesPayload,
IOBasePayload as IOBasePayload,
JsonPayload as JsonPayload,
Payload as Payload,
StringIOPayload as StringIOPayload,
StringPayload as StringPayload,
TextIOPayload as TextIOPayload,
get_payload as get_payload,
payload_type as payload_type,
)
from .payload_streamer import streamer as streamer
from .resolver import (
AsyncResolver as AsyncResolver,
DefaultResolver as DefaultResolver,
ThreadedResolver as ThreadedResolver,
)
from .streams import (
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
DataQueue as DataQueue,
EofStream as EofStream,
FlowControlDataQueue as FlowControlDataQueue,
StreamReader as StreamReader,
)
from .tracing import (
TraceConfig as TraceConfig,
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
TraceDnsCacheHitParams as TraceDnsCacheHitParams,
TraceDnsCacheMissParams as TraceDnsCacheMissParams,
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
if TYPE_CHECKING:
# At runtime these are lazy-loaded at the bottom of the file.
from .worker import (
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
GunicornWebWorker as GunicornWebWorker,
)
__all__: Tuple[str, ...] = (
"hdrs",
# client
"AddrInfoType",
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
"ClientHttpProxyError",
"ClientOSError",
"ClientPayloadError",
"ClientProxyConnectionError",
"ClientResponse",
"ClientRequest",
"ClientResponseError",
"ClientSSLError",
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
"ClientWSTimeout",
"ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"FlowControlDataQueue",
"InvalidURL",
"InvalidUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlClientError",
"NonHttpUrlRedirectClientError",
"RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketFactoryType",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
"NamedPipeConnector",
"WSServerHandshakeError",
"request",
# client_middleware
"ClientMiddlewareType",
"ClientHandlerType",
# cookiejar
"CookieJar",
"DummyCookieJar",
# formdata
"FormData",
# helpers
"BasicAuth",
"ChainMapProxy",
"DigestAuthMiddleware",
"ETag",
"set_zlib_backend",
# http
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
"WSMsgType",
"WSCloseCode",
"WSMessage",
"WebSocketError",
# multipart
"BadContentDispositionHeader",
"BadContentDispositionParam",
"BodyPartReader",
"MultipartReader",
"MultipartWriter",
"content_disposition_filename",
"parse_content_disposition",
# payload
"AsyncIterablePayload",
"BufferedReaderPayload",
"BytesIOPayload",
"BytesPayload",
"IOBasePayload",
"JsonPayload",
"PAYLOAD_REGISTRY",
"Payload",
"StringIOPayload",
"StringPayload",
"TextIOPayload",
"get_payload",
"payload_type",
# payload_streamer
"streamer",
# resolver
"AsyncResolver",
"DefaultResolver",
"ThreadedResolver",
# streams
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"StreamReader",
# tracing
"TraceConfig",
"TraceConnectionCreateEndParams",
"TraceConnectionCreateStartParams",
"TraceConnectionQueuedEndParams",
"TraceConnectionQueuedStartParams",
"TraceConnectionReuseconnParams",
"TraceDnsCacheHitParams",
"TraceDnsCacheMissParams",
"TraceDnsResolveHostEndParams",
"TraceDnsResolveHostStartParams",
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
"WSMessageTypeError",
)
def __dir__() -> Tuple[str, ...]:
return __all__ + ("__doc__",)
def __getattr__(name: str) -> object:
global GunicornUVLoopWebWorker, GunicornWebWorker
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
try:
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
except ImportError:
return None
GunicornUVLoopWebWorker = guv # type: ignore[misc]
GunicornWebWorker = gw # type: ignore[misc]
return guv if name == "GunicornUVLoopWebWorker" else gw
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@@ -0,0 +1,338 @@
"""
Internal cookie handling helpers.
This module contains internal utilities for cookie parsing and manipulation.
These are not part of the public API and may change without notice.
"""
import re
from http.cookies import Morsel
from typing import List, Optional, Sequence, Tuple, cast
from .log import internal_logger
__all__ = (
"parse_set_cookie_headers",
"parse_cookie_header",
"preserve_morsel_with_coded_value",
)
# Cookie parsing constants
# Allow more characters in cookie names to handle real-world cookies
# that don't strictly follow RFC standards (fixes #2683)
# RFC 6265 defines cookie-name token as per RFC 2616 Section 2.2,
# but many servers send cookies with characters like {} [] () etc.
# This makes the cookie parser more tolerant of real-world cookies
# while still providing some validation to catch obviously malformed names.
_COOKIE_NAME_RE = re.compile(r"^[!#$%&\'()*+\-./0-9:<=>?@A-Z\[\]^_`a-z{|}~]+$")
_COOKIE_KNOWN_ATTRS = frozenset( # AKA Morsel._reserved
(
"path",
"domain",
"max-age",
"expires",
"secure",
"httponly",
"samesite",
"partitioned",
"version",
"comment",
)
)
_COOKIE_BOOL_ATTRS = frozenset( # AKA Morsel._flags
("secure", "httponly", "partitioned")
)
# SimpleCookie's pattern for parsing cookies with relaxed validation
# Based on http.cookies pattern but extended to allow more characters in cookie names
# to handle real-world cookies (fixes #2683)
_COOKIE_PATTERN = re.compile(
r"""
\s* # Optional whitespace at start of cookie
(?P<key> # Start of group 'key'
# aiohttp has extended to include [] for compatibility with real-world cookies
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\[\]]+ # Any word of at least one letter
) # End of group 'key'
( # Optional group: there may not be a value.
\s*=\s* # Equal Sign
(?P<val> # Start of group 'val'
"(?:[^\\"]|\\.)*" # Any double-quoted string (properly closed)
| # or
"[^";]* # Unmatched opening quote (differs from SimpleCookie - issue #7993)
| # or
# Special case for "expires" attr - RFC 822, RFC 850, RFC 1036, RFC 1123
(\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day (with comma)
[\w\d\s-]{9,11}\s[\d:]{8}\s # Date and time in specific format
(GMT|[+-]\d{4}) # Timezone: GMT or RFC 2822 offset like -0000, +0100
# NOTE: RFC 2822 timezone support is an aiohttp extension
# for issue #4493 - SimpleCookie does NOT support this
| # or
# ANSI C asctime() format: "Wed Jun 9 10:18:14 2021"
# NOTE: This is an aiohttp extension for issue #4327 - SimpleCookie does NOT support this format
\w{3}\s+\w{3}\s+[\s\d]\d\s+\d{2}:\d{2}:\d{2}\s+\d{4}
| # or
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]* # Any word or empty string
) # End of group 'val'
)? # End of optional value group
\s* # Any number of spaces.
(\s+|;|$) # Ending either at space, semicolon, or EOS.
""",
re.VERBOSE | re.ASCII,
)
def preserve_morsel_with_coded_value(cookie: Morsel[str]) -> Morsel[str]:
"""
Preserve a Morsel's coded_value exactly as received from the server.
This function ensures that cookie encoding is preserved exactly as sent by
the server, which is critical for compatibility with old servers that have
strict requirements about cookie formats.
This addresses the issue described in https://github.com/aio-libs/aiohttp/pull/1453
where Python's SimpleCookie would re-encode cookies, breaking authentication
with certain servers.
Args:
cookie: A Morsel object from SimpleCookie
Returns:
A Morsel object with preserved coded_value
"""
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
mrsl_val.__setstate__( # type: ignore[attr-defined]
{"key": cookie.key, "value": cookie.value, "coded_value": cookie.coded_value}
)
return mrsl_val
_unquote_sub = re.compile(r"\\(?:([0-3][0-7][0-7])|(.))").sub
def _unquote_replace(m: re.Match[str]) -> str:
"""
Replace function for _unquote_sub regex substitution.
Handles escaped characters in cookie values:
- Octal sequences are converted to their character representation
- Other escaped characters are unescaped by removing the backslash
"""
if m[1]:
return chr(int(m[1], 8))
return m[2]
def _unquote(value: str) -> str:
"""
Unquote a cookie value.
Vendored from http.cookies._unquote to ensure compatibility.
Note: The original implementation checked for None, but we've removed
that check since all callers already ensure the value is not None.
"""
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if len(value) < 2:
return value
if value[0] != '"' or value[-1] != '"':
return value
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
value = value[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
return _unquote_sub(_unquote_replace, value)
def parse_cookie_header(header: str) -> List[Tuple[str, Morsel[str]]]:
"""
Parse a Cookie header according to RFC 6265 Section 5.4.
Cookie headers contain only name-value pairs separated by semicolons.
There are no attributes in Cookie headers - even names that match
attribute names (like 'path' or 'secure') should be treated as cookies.
This parser uses the same regex-based approach as parse_set_cookie_headers
to properly handle quoted values that may contain semicolons. When the
regex fails to match a malformed cookie, it falls back to simple parsing
to ensure subsequent cookies are not lost
https://github.com/aio-libs/aiohttp/issues/11632
Args:
header: The Cookie header value to parse
Returns:
List of (name, Morsel) tuples for compatibility with SimpleCookie.update()
"""
if not header:
return []
cookies: List[Tuple[str, Morsel[str]]] = []
morsel: Morsel[str]
i = 0
n = len(header)
invalid_names = []
while i < n:
# Use the same pattern as parse_set_cookie_headers to find cookies
match = _COOKIE_PATTERN.match(header, i)
if not match:
# Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632
# Find next semicolon to skip or attempt simple key=value parsing
next_semi = header.find(";", i)
eq_pos = header.find("=", i)
# Try to extract key=value if '=' comes before ';'
if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi):
end_pos = next_semi if next_semi != -1 else n
key = header[i:eq_pos].strip()
value = header[eq_pos + 1 : end_pos].strip()
# Validate the name (same as regex path)
if not _COOKIE_NAME_RE.match(key):
invalid_names.append(key)
else:
morsel = Morsel()
morsel.__setstate__( # type: ignore[attr-defined]
{"key": key, "value": _unquote(value), "coded_value": value}
)
cookies.append((key, morsel))
# Move to next cookie or end
i = next_semi + 1 if next_semi != -1 else n
continue
key = match.group("key")
value = match.group("val") or ""
i = match.end(0)
# Validate the name
if not key or not _COOKIE_NAME_RE.match(key):
invalid_names.append(key)
continue
# Create new morsel
morsel = Morsel()
# Preserve the original value as coded_value (with quotes if present)
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
morsel.__setstate__( # type: ignore[attr-defined]
{"key": key, "value": _unquote(value), "coded_value": value}
)
cookies.append((key, morsel))
if invalid_names:
internal_logger.debug(
"Cannot load cookie. Illegal cookie names: %r", invalid_names
)
return cookies
def parse_set_cookie_headers(headers: Sequence[str]) -> List[Tuple[str, Morsel[str]]]:
"""
Parse cookie headers using a vendored version of SimpleCookie parsing.
This implementation is based on SimpleCookie.__parse_string to ensure
compatibility with how SimpleCookie parses cookies, including handling
of malformed cookies with missing semicolons.
This function is used for both Cookie and Set-Cookie headers in order to be
forgiving. Ideally we would have followed RFC 6265 Section 5.2 (for Cookie
headers) and RFC 6265 Section 4.2.1 (for Set-Cookie headers), but the
real world data makes it impossible since we need to be a bit more forgiving.
NOTE: This implementation differs from SimpleCookie in handling unmatched quotes.
SimpleCookie will stop parsing when it encounters a cookie value with an unmatched
quote (e.g., 'cookie="value'), causing subsequent cookies to be silently dropped.
This implementation handles unmatched quotes more gracefully to prevent cookie loss.
See https://github.com/aio-libs/aiohttp/issues/7993
"""
parsed_cookies: List[Tuple[str, Morsel[str]]] = []
for header in headers:
if not header:
continue
# Parse cookie string using SimpleCookie's algorithm
i = 0
n = len(header)
current_morsel: Optional[Morsel[str]] = None
morsel_seen = False
while 0 <= i < n:
# Start looking for a cookie
match = _COOKIE_PATTERN.match(header, i)
if not match:
# No more cookies
break
key, value = match.group("key"), match.group("val")
i = match.end(0)
lower_key = key.lower()
if key[0] == "$":
if not morsel_seen:
# We ignore attributes which pertain to the cookie
# mechanism as a whole, such as "$Version".
continue
# Process as attribute
if current_morsel is not None:
attr_lower_key = lower_key[1:]
if attr_lower_key in _COOKIE_KNOWN_ATTRS:
current_morsel[attr_lower_key] = value or ""
elif lower_key in _COOKIE_KNOWN_ATTRS:
if not morsel_seen:
# Invalid cookie string - attribute before cookie
break
if lower_key in _COOKIE_BOOL_ATTRS:
# Boolean attribute with any value should be True
if current_morsel is not None and current_morsel.isReservedKey(key):
current_morsel[lower_key] = True
elif value is None:
# Invalid cookie string - non-boolean attribute without value
break
elif current_morsel is not None:
# Regular attribute with value
current_morsel[lower_key] = _unquote(value)
elif value is not None:
# This is a cookie name=value pair
# Validate the name
if key in _COOKIE_KNOWN_ATTRS or not _COOKIE_NAME_RE.match(key):
internal_logger.warning(
"Can not load cookies: Illegal cookie name %r", key
)
current_morsel = None
else:
# Create new morsel
current_morsel = Morsel()
# Preserve the original value as coded_value (with quotes if present)
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
current_morsel.__setstate__( # type: ignore[attr-defined]
{"key": key, "value": _unquote(value), "coded_value": value}
)
parsed_cookies.append((key, current_morsel))
morsel_seen = True
else:
# Invalid cookie string - no value for non-attribute
break
return parsed_cookies

View File

@@ -0,0 +1,158 @@
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
cdef extern from "llhttp.h":
struct llhttp__internal_s:
int32_t _index
void* _span_pos0
void* _span_cb0
int32_t error
const char* reason
const char* error_pos
void* data
void* _current
uint64_t content_length
uint8_t type
uint8_t method
uint8_t http_major
uint8_t http_minor
uint8_t header_state
uint8_t lenient_flags
uint8_t upgrade
uint8_t finish
uint16_t flags
uint16_t status_code
void* settings
ctypedef llhttp__internal_s llhttp__internal_t
ctypedef llhttp__internal_t llhttp_t
ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1
ctypedef int (*llhttp_cb)(llhttp_t*) except -1
struct llhttp_settings_s:
llhttp_cb on_message_begin
llhttp_data_cb on_url
llhttp_data_cb on_status
llhttp_data_cb on_header_field
llhttp_data_cb on_header_value
llhttp_cb on_headers_complete
llhttp_data_cb on_body
llhttp_cb on_message_complete
llhttp_cb on_chunk_header
llhttp_cb on_chunk_complete
llhttp_cb on_url_complete
llhttp_cb on_status_complete
llhttp_cb on_header_field_complete
llhttp_cb on_header_value_complete
ctypedef llhttp_settings_s llhttp_settings_t
enum llhttp_errno:
HPE_OK,
HPE_INTERNAL,
HPE_STRICT,
HPE_LF_EXPECTED,
HPE_UNEXPECTED_CONTENT_LENGTH,
HPE_CLOSED_CONNECTION,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_CONSTANT,
HPE_INVALID_VERSION,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_STATUS,
HPE_INVALID_EOF_STATE,
HPE_INVALID_TRANSFER_ENCODING,
HPE_CB_MESSAGE_BEGIN,
HPE_CB_HEADERS_COMPLETE,
HPE_CB_MESSAGE_COMPLETE,
HPE_CB_CHUNK_HEADER,
HPE_CB_CHUNK_COMPLETE,
HPE_PAUSED,
HPE_PAUSED_UPGRADE,
HPE_USER
ctypedef llhttp_errno llhttp_errno_t
enum llhttp_flags:
F_CHUNKED,
F_CONTENT_LENGTH
enum llhttp_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum llhttp_method:
HTTP_DELETE,
HTTP_GET,
HTTP_HEAD,
HTTP_POST,
HTTP_PUT,
HTTP_CONNECT,
HTTP_OPTIONS,
HTTP_TRACE,
HTTP_COPY,
HTTP_LOCK,
HTTP_MKCOL,
HTTP_MOVE,
HTTP_PROPFIND,
HTTP_PROPPATCH,
HTTP_SEARCH,
HTTP_UNLOCK,
HTTP_BIND,
HTTP_REBIND,
HTTP_UNBIND,
HTTP_ACL,
HTTP_REPORT,
HTTP_MKACTIVITY,
HTTP_CHECKOUT,
HTTP_MERGE,
HTTP_MSEARCH,
HTTP_NOTIFY,
HTTP_SUBSCRIBE,
HTTP_UNSUBSCRIBE,
HTTP_PATCH,
HTTP_PURGE,
HTTP_MKCALENDAR,
HTTP_LINK,
HTTP_UNLINK,
HTTP_SOURCE,
HTTP_PRI,
HTTP_DESCRIBE,
HTTP_ANNOUNCE,
HTTP_SETUP,
HTTP_PLAY,
HTTP_PAUSE,
HTTP_TEARDOWN,
HTTP_GET_PARAMETER,
HTTP_SET_PARAMETER,
HTTP_REDIRECT,
HTTP_RECORD,
HTTP_FLUSH
ctypedef llhttp_method llhttp_method_t;
void llhttp_settings_init(llhttp_settings_t* settings)
void llhttp_init(llhttp_t* parser, llhttp_type type,
const llhttp_settings_t* settings)
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
int llhttp_should_keep_alive(const llhttp_t* parser)
void llhttp_resume_after_upgrade(llhttp_t* parser)
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
const char* llhttp_get_error_reason(const llhttp_t* parser)
const char* llhttp_get_error_pos(const llhttp_t* parser)
const char* llhttp_method_name(llhttp_method_t method)
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)

View File

@@ -0,0 +1,2 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

View File

@@ -0,0 +1,83 @@
# The file is autogenerated from aiohttp/hdrs.py
# Run ./tools/gen.py to update it after the origin changing.
from . import hdrs
cdef tuple headers = (
hdrs.ACCEPT,
hdrs.ACCEPT_CHARSET,
hdrs.ACCEPT_ENCODING,
hdrs.ACCEPT_LANGUAGE,
hdrs.ACCEPT_RANGES,
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
hdrs.ACCESS_CONTROL_MAX_AGE,
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
hdrs.AGE,
hdrs.ALLOW,
hdrs.AUTHORIZATION,
hdrs.CACHE_CONTROL,
hdrs.CONNECTION,
hdrs.CONTENT_DISPOSITION,
hdrs.CONTENT_ENCODING,
hdrs.CONTENT_LANGUAGE,
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_MD5,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TRANSFER_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.COOKIE,
hdrs.DATE,
hdrs.DESTINATION,
hdrs.DIGEST,
hdrs.ETAG,
hdrs.EXPECT,
hdrs.EXPIRES,
hdrs.FORWARDED,
hdrs.FROM,
hdrs.HOST,
hdrs.IF_MATCH,
hdrs.IF_MODIFIED_SINCE,
hdrs.IF_NONE_MATCH,
hdrs.IF_RANGE,
hdrs.IF_UNMODIFIED_SINCE,
hdrs.KEEP_ALIVE,
hdrs.LAST_EVENT_ID,
hdrs.LAST_MODIFIED,
hdrs.LINK,
hdrs.LOCATION,
hdrs.MAX_FORWARDS,
hdrs.ORIGIN,
hdrs.PRAGMA,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.RANGE,
hdrs.REFERER,
hdrs.RETRY_AFTER,
hdrs.SEC_WEBSOCKET_ACCEPT,
hdrs.SEC_WEBSOCKET_EXTENSIONS,
hdrs.SEC_WEBSOCKET_KEY,
hdrs.SEC_WEBSOCKET_KEY1,
hdrs.SEC_WEBSOCKET_PROTOCOL,
hdrs.SEC_WEBSOCKET_VERSION,
hdrs.SERVER,
hdrs.SET_COOKIE,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.URI,
hdrs.UPGRADE,
hdrs.USER_AGENT,
hdrs.VARY,
hdrs.VIA,
hdrs.WWW_AUTHENTICATE,
hdrs.WANT_DIGEST,
hdrs.WARNING,
hdrs.X_FORWARDED_FOR,
hdrs.X_FORWARDED_HOST,
hdrs.X_FORWARDED_PROTO,
)

View File

@@ -0,0 +1,835 @@
# Based on https://github.com/MagicStack/httptools
#
from cpython cimport (
Py_buffer,
PyBUF_SIMPLE,
PyBuffer_Release,
PyBytes_AsString,
PyBytes_AsStringAndSize,
PyObject_GetBuffer,
)
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.limits cimport ULLONG_MAX
from libc.string cimport memcpy
from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
from yarl import URL as _URL
from aiohttp import hdrs
from aiohttp.helpers import DEBUG, set_exception
from .http_exceptions import (
BadHttpMessage,
BadHttpMethod,
BadStatusLine,
ContentLengthError,
InvalidHeader,
InvalidURLError,
LineTooLong,
PayloadEncodingError,
TransferEncodingError,
)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .http_writer import (
HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
HttpVersion11 as _HttpVersion11,
)
from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader
cimport cython
from aiohttp cimport _cparser as cparser
include "_headers.pxi"
from aiohttp cimport _find_header
ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
int PyByteArray_Resize(object, Py_ssize_t) except -1
Py_ssize_t PyByteArray_Size(object) except -1
char* PyByteArray_AsString(object)
__all__ = ('HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
cdef object StreamReader = _StreamReader
cdef object DeflateBuffer = _DeflateBuffer
cdef bytes EMPTY_BYTES = b""
cdef inline object extend(object buf, const char* at, size_t length):
cdef Py_ssize_t s
cdef char* ptr
s = PyByteArray_Size(buf)
PyByteArray_Resize(buf, s + length)
ptr = PyByteArray_AsString(buf)
memcpy(ptr + s, at, length)
DEF METHODS_COUNT = 46;
cdef list _http_method = []
for i in range(METHODS_COUNT):
_http_method.append(
cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii'))
cdef inline str http_method_str(int i):
if i < METHODS_COUNT:
return <str>_http_method[i]
else:
return "<unknown>"
cdef inline object find_header(bytes raw_header):
cdef Py_ssize_t size
cdef char *buf
cdef int idx
PyBytes_AsStringAndSize(raw_header, &buf, &size)
idx = _find_header.find_header(buf, size)
if idx == -1:
return raw_header.decode('utf-8', 'surrogateescape')
return headers[idx]
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
cdef readonly object url # yarl.URL
def __init__(self, method, path, version, headers, raw_headers,
should_close, compression, upgrade, chunked, url):
self.method = method
self.path = path
self.version = version
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
self.url = url
def __repr__(self):
info = []
info.append(("method", self.method))
info.append(("path", self.path))
info.append(("version", self.version))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
info.append(("url", self.url))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawRequestMessage(' + sinfo + ')>'
def _replace(self, **dct):
cdef RawRequestMessage ret
ret = _new_request_message(self.method,
self.path,
self.version,
self.headers,
self.raw_headers,
self.should_close,
self.compression,
self.upgrade,
self.chunked,
self.url)
if "method" in dct:
ret.method = dct["method"]
if "path" in dct:
ret.path = dct["path"]
if "version" in dct:
ret.version = dct["version"]
if "headers" in dct:
ret.headers = dct["headers"]
if "raw_headers" in dct:
ret.raw_headers = dct["raw_headers"]
if "should_close" in dct:
ret.should_close = dct["should_close"]
if "compression" in dct:
ret.compression = dct["compression"]
if "upgrade" in dct:
ret.upgrade = dct["upgrade"]
if "chunked" in dct:
ret.chunked = dct["chunked"]
if "url" in dct:
ret.url = dct["url"]
return ret
cdef _new_request_message(str method,
str path,
object version,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked,
object url):
cdef RawRequestMessage ret
ret = RawRequestMessage.__new__(RawRequestMessage)
ret.method = method
ret.path = path
ret.version = version
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
ret.url = url
return ret
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
def __init__(self, version, code, reason, headers, raw_headers,
should_close, compression, upgrade, chunked):
self.version = version
self.code = code
self.reason = reason
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
def __repr__(self):
info = []
info.append(("version", self.version))
info.append(("code", self.code))
info.append(("reason", self.reason))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawResponseMessage(' + sinfo + ')>'
cdef _new_response_message(object version,
int code,
str reason,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked):
cdef RawResponseMessage ret
ret = RawResponseMessage.__new__(RawResponseMessage)
ret.version = version
ret.code = code
ret.reason = reason
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
return ret
@cython.internal
cdef class HttpParser:
cdef:
cparser.llhttp_t* _cparser
cparser.llhttp_settings_t* _csettings
bytes _raw_name
object _name
bytes _raw_value
bint _has_value
object _protocol
object _loop
object _timer
size_t _max_line_size
size_t _max_field_size
size_t _max_headers
bint _response_with_body
bint _read_until_eof
bint _started
object _url
bytearray _buf
str _path
str _reason
list _headers
list _raw_headers
bint _upgraded
list _messages
object _payload
bint _payload_error
object _payload_exception
object _last_error
bint _auto_decompress
int _limit
str _content_encoding
Py_buffer py_buf
def __cinit__(self):
self._cparser = <cparser.llhttp_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_t))
if self._cparser is NULL:
raise MemoryError()
self._csettings = <cparser.llhttp_settings_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_settings_t))
if self._csettings is NULL:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self._cparser)
PyMem_Free(self._csettings)
cdef _init(
self, cparser.llhttp_type mode,
object protocol, object loop, int limit,
object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
cparser.llhttp_settings_init(self._csettings)
cparser.llhttp_init(self._cparser, mode, self._csettings)
self._cparser.data = <void*>self
self._cparser.content_length = 0
self._protocol = protocol
self._loop = loop
self._timer = timer
self._buf = bytearray()
self._payload = None
self._payload_error = 0
self._payload_exception = payload_exception
self._messages = []
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
self._has_value = False
self._max_line_size = max_line_size
self._max_headers = max_headers
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
self._csettings.on_header_field = cb_on_header_field
self._csettings.on_header_value = cb_on_header_value
self._csettings.on_headers_complete = cb_on_headers_complete
self._csettings.on_body = cb_on_body
self._csettings.on_message_begin = cb_on_message_begin
self._csettings.on_message_complete = cb_on_message_complete
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete
self._last_error = None
self._limit = limit
cdef _process_header(self):
cdef str value
if self._raw_name is not EMPTY_BYTES:
name = find_header(self._raw_name)
value = self._raw_value.decode('utf-8', 'surrogateescape')
self._headers.append((name, value))
if name is CONTENT_ENCODING:
self._content_encoding = value
self._has_value = False
self._raw_headers.append((self._raw_name, self._raw_value))
self._raw_name = EMPTY_BYTES
self._raw_value = EMPTY_BYTES
cdef _on_header_field(self, char* at, size_t length):
if self._has_value:
self._process_header()
if self._raw_name is EMPTY_BYTES:
self._raw_name = at[:length]
else:
self._raw_name += at[:length]
cdef _on_header_value(self, char* at, size_t length):
if self._raw_value is EMPTY_BYTES:
self._raw_value = at[:length]
else:
self._raw_value += at[:length]
self._has_value = True
cdef _on_headers_complete(self):
self._process_header()
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(CIMultiDict(self._headers))
if self._cparser.type == cparser.HTTP_REQUEST:
h_upg = headers.get("upgrade", "")
allowed = upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES
if allowed or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
encoding = None
enc = self._content_encoding
if enc is not None:
self._content_encoding = None
if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}:
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
self.http_version(), headers, raw_headers,
should_close, encoding, upgrade, chunked, self._url)
else:
msg = _new_response_message(
self.http_version(), self._cparser.status_code, self._reason,
headers, raw_headers, should_close, encoding,
upgrade, chunked)
if (
ULLONG_MAX > self._cparser.content_length > 0 or chunked or
self._cparser.method == cparser.HTTP_CONNECT or
(self._cparser.status_code >= 199 and
self._cparser.content_length == 0 and
self._read_until_eof)
):
payload = StreamReader(
self._protocol, timer=self._timer, loop=self._loop,
limit=self._limit)
else:
payload = EMPTY_PAYLOAD
self._payload = payload
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding)
if not self._response_with_body:
payload = EMPTY_PAYLOAD
self._messages.append((msg, payload))
cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None
cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()
cdef _on_chunk_complete(self):
self._payload.end_http_chunk_receiving()
cdef object _on_status_complete(self):
pass
cdef inline http_version(self):
cdef cparser.llhttp_t* parser = self._cparser
if parser.http_major == 1:
if parser.http_minor == 0:
return HttpVersion10
elif parser.http_minor == 1:
return HttpVersion11
return HttpVersion(parser.http_major, parser.http_minor)
### Public API ###
def feed_eof(self):
cdef bytes desc
if self._payload is not None:
if self._cparser.flags & cparser.F_CHUNKED:
raise TransferEncodingError(
"Not enough data to satisfy transfer length header.")
elif self._cparser.flags & cparser.F_CONTENT_LENGTH:
raise ContentLengthError(
"Not enough data to satisfy content length header.")
elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK:
desc = cparser.llhttp_get_error_reason(self._cparser)
raise PayloadEncodingError(desc.decode('latin-1'))
else:
self._payload.feed_eof()
elif self._started:
self._on_headers_complete()
if self._messages:
return self._messages[-1][0]
def feed_data(self, data):
cdef:
size_t data_len
size_t nb
cdef cparser.llhttp_errno_t errno
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
data_len = <size_t>self.py_buf.len
errno = cparser.llhttp_execute(
self._cparser,
<char*>self.py_buf.buf,
data_len)
if errno is cparser.HPE_PAUSED_UPGRADE:
cparser.llhttp_resume_after_upgrade(self._cparser)
nb = cparser.llhttp_get_error_pos(self._cparser) - <char*>self.py_buf.buf
PyBuffer_Release(&self.py_buf)
if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE):
if self._payload_error == 0:
if self._last_error is not None:
ex = self._last_error
self._last_error = None
else:
after = cparser.llhttp_get_error_pos(self._cparser)
before = data[:after - <char*>self.py_buf.buf]
after_b = after.split(b"\r\n", 1)[0]
before = before.rsplit(b"\r\n", 1)[-1]
data = before + after_b
pointer = " " * (len(repr(before))-1) + "^"
ex = parser_error_from_errno(self._cparser, data, pointer)
self._payload = None
raise ex
if self._messages:
messages = self._messages
self._messages = []
else:
messages = ()
if self._upgraded:
return messages, True, data[nb:]
else:
return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
cdef class HttpRequestParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
):
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
cdef object _on_status_complete(self):
cdef int idx1, idx2
if not self._buf:
return
self._path = self._buf.decode('utf-8', 'surrogateescape')
try:
idx3 = len(self._path)
if self._cparser.method == cparser.HTTP_CONNECT:
# authority-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
self._url = URL.build(authority=self._path, encoded=True)
elif idx3 > 1 and self._path[0] == '/':
# origin-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
idx1 = self._path.find("?")
if idx1 == -1:
query = ""
idx2 = self._path.find("#")
if idx2 == -1:
path = self._path
fragment = ""
else:
path = self._path[0: idx2]
fragment = self._path[idx2+1:]
else:
path = self._path[0:idx1]
idx1 += 1
idx2 = self._path.find("#", idx1+1)
if idx2 == -1:
query = self._path[idx1:]
fragment = ""
else:
query = self._path[idx1: idx2]
fragment = self._path[idx2+1:]
self._url = URL.build(
path=path,
query_string=query,
fragment=fragment,
encoded=True,
)
else:
# absolute-form for proxy maybe,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
self._url = URL(self._path, encoded=True)
finally:
PyByteArray_Resize(self._buf, 0)
cdef class HttpResponseParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True
):
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
# Use strict parsing on dev mode, so users are warned about broken servers.
if not DEBUG:
cparser.llhttp_set_lenient_headers(self._cparser, 1)
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
cdef object _on_status_complete(self):
if self._buf:
self._reason = self._buf.decode('utf-8', 'surrogateescape')
PyByteArray_Resize(self._buf, 0)
else:
self._reason = self._reason or ''
cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
pyparser._started = True
pyparser._headers = []
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
pyparser._reason = None
return 0
cdef int cb_on_url(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_status(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef str reason
try:
if length > pyparser._max_line_size:
raise LineTooLong(
'Status line is too long', pyparser._max_line_size, length)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_field(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
pyparser._on_status_complete()
size = len(pyparser._raw_name) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header name is too long', pyparser._max_field_size, size)
pyparser._on_header_field(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_value(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
size = len(pyparser._raw_value) + length
if size > pyparser._max_field_size:
raise LineTooLong(
'Header value is too long', pyparser._max_field_size, size)
pyparser._on_header_value(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_status_complete()
pyparser._on_headers_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
else:
return 0
cdef int cb_on_body(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
reraised_exc = pyparser._payload_exception(str(underlying_exc))
set_exception(pyparser._payload, reraised_exc, underlying_exc)
pyparser._payload_error = 1
return -1
else:
return 0
cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._started = False
pyparser._on_message_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_header()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
cparser.HPE_CB_HEADERS_COMPLETE,
cparser.HPE_CB_MESSAGE_COMPLETE,
cparser.HPE_CB_CHUNK_HEADER,
cparser.HPE_CB_CHUNK_COMPLETE,
cparser.HPE_INVALID_CONSTANT,
cparser.HPE_INVALID_HEADER_TOKEN,
cparser.HPE_INVALID_CONTENT_LENGTH,
cparser.HPE_INVALID_CHUNK_SIZE,
cparser.HPE_INVALID_EOF_STATE,
cparser.HPE_INVALID_TRANSFER_ENCODING}:
return BadHttpMessage(err_msg)
elif errno == cparser.HPE_INVALID_METHOD:
return BadHttpMethod(error=err_msg)
elif errno in {cparser.HPE_INVALID_STATUS,
cparser.HPE_INVALID_VERSION}:
return BadStatusLine(error=err_msg)
elif errno == cparser.HPE_INVALID_URL:
return InvalidURLError(err_msg)
return BadHttpMessage(err_msg)

View File

@@ -0,0 +1,162 @@
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc
from cpython.object cimport PyObject_Str
from libc.stdint cimport uint8_t, uint64_t
from libc.string cimport memcpy
from multidict import istr
DEF BUF_SIZE = 16 * 1024 # 16KiB
cdef object _istr = istr
# ----------------- writer ---------------------------
cdef struct Writer:
char *buf
Py_ssize_t size
Py_ssize_t pos
bint heap_allocated
cdef inline void _init_writer(Writer* writer, char *buf):
writer.buf = buf
writer.size = BUF_SIZE
writer.pos = 0
writer.heap_allocated = 0
cdef inline void _release_writer(Writer* writer):
if writer.heap_allocated:
PyMem_Free(writer.buf)
cdef inline int _write_byte(Writer* writer, uint8_t ch):
cdef char * buf
cdef Py_ssize_t size
if writer.pos == writer.size:
# reallocate
size = writer.size + BUF_SIZE
if not writer.heap_allocated:
buf = <char*>PyMem_Malloc(size)
if buf == NULL:
PyErr_NoMemory()
return -1
memcpy(buf, writer.buf, writer.size)
else:
buf = <char*>PyMem_Realloc(writer.buf, size)
if buf == NULL:
PyErr_NoMemory()
return -1
writer.buf = buf
writer.size = size
writer.heap_allocated = 1
writer.buf[writer.pos] = <char>ch
writer.pos += 1
return 0
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
cdef uint64_t utf = <uint64_t> symbol
if utf < 0x80:
return _write_byte(writer, <uint8_t>utf)
elif utf < 0x800:
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif 0xD800 <= utf <= 0xDFFF:
# surogate pair, ignored
return 0
elif utf < 0x10000:
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
return -1
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif utf > 0x10FFFF:
# symbol is too large
return 0
else:
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
cdef inline int _write_str(Writer* writer, str s):
cdef Py_UCS4 ch
for ch in s:
if _write_utf8(writer, ch) < 0:
return -1
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
cdef Py_UCS4 ch
cdef str out_str
if type(s) is str:
out_str = <str>s
elif type(s) is _istr:
out_str = PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
out_str = str(s)
for ch in out_str:
if ch == 0x0D or ch == 0x0A:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
if _write_utf8(writer, ch) < 0:
return -1
# --------------- _serialize_headers ----------------------
def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
cdef char buf[BUF_SIZE]
_init_writer(&writer, buf)
try:
if _write_str(&writer, status_line) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
for key, val in headers.items():
if _write_str_raise_on_nlcr(&writer, key) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str_raise_on_nlcr(&writer, val) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
finally:
_release_writer(&writer)

View File

@@ -0,0 +1 @@
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /Users/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd

View File

@@ -0,0 +1 @@
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /Users/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx

View File

@@ -0,0 +1 @@
9e5fe78ed0ebce5414d2b8e01868d90c1facc20b84d2d5ff6c23e86e44a155ae /Users/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd

View File

@@ -0,0 +1 @@
"""WebSocket protocol versions 13 and 8."""

View File

@@ -0,0 +1,147 @@
"""Helpers for WebSocket protocol versions 13 and 8."""
import functools
import re
from struct import Struct
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
from ..helpers import NO_EXTENSIONS
from .models import WSHandshakeError
UNPACK_LEN3 = Struct("!Q").unpack_from
UNPACK_CLOSE_CODE = Struct("!H").unpack
PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
MASK_LEN: Final[int] = 4
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
# Used by _websocket_mask_python
@functools.lru_cache
def _xor_table() -> List[bytes]:
return [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
_XOR_TABLE = _xor_table()
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
websocket_mask = _websocket_mask_python
else:
try:
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
websocket_mask = _websocket_mask_python
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
r"^(?:;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
)
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError("Invalid window size")
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
return compress, notakeover
def ws_ext_gen(
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError(
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
)
enabledext = ["permessage-deflate"]
if not isserver:
enabledext.append("client_max_window_bits")
if compress < 15:
enabledext.append("server_max_window_bits=" + str(compress))
if server_notakeover:
enabledext.append("server_no_context_takeover")
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return "; ".join(enabledext)

View File

@@ -0,0 +1,3 @@
"""Cython declarations for websocket masking."""
cpdef void _websocket_mask_cython(bytes mask, bytearray data)

View File

@@ -0,0 +1,48 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates its `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]

View File

@@ -0,0 +1,84 @@
"""Models for WebSocket protocol versions 13 and 8."""
import json
from enum import IntEnum
from typing import Any, Callable, Final, NamedTuple, Optional, cast
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
ABNORMAL_CLOSURE = 1006
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xA
CLOSE = 0x8
# aiohttp specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return cast(str, self.args[1])
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""

View File

@@ -0,0 +1,31 @@
"""Reader for WebSocket protocol versions 13 and 8."""
from typing import TYPE_CHECKING
from ..helpers import NO_EXTENSIONS
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)
WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython

View File

@@ -0,0 +1,110 @@
import cython
from .mask cimport _websocket_mask_cython as websocket_mask
cdef unsigned int READ_HEADER
cdef unsigned int READ_PAYLOAD_LENGTH
cdef unsigned int READ_PAYLOAD_MASK
cdef unsigned int READ_PAYLOAD
cdef int OP_CODE_NOT_SET
cdef int OP_CODE_CONTINUATION
cdef int OP_CODE_TEXT
cdef int OP_CODE_BINARY
cdef int OP_CODE_CLOSE
cdef int OP_CODE_PING
cdef int OP_CODE_PONG
cdef int COMPRESSED_NOT_SET
cdef int COMPRESSED_FALSE
cdef int COMPRESSED_TRUE
cdef object UNPACK_LEN3
cdef object UNPACK_CLOSE_CODE
cdef object TUPLE_NEW
cdef object WSMsgType
cdef object WSMessage
cdef object WS_MSG_TYPE_TEXT
cdef object WS_MSG_TYPE_BINARY
cdef set ALLOWED_CLOSE_CODES
cdef set MESSAGE_TYPES_WITH_CONTENT
cdef tuple EMPTY_FRAME
cdef tuple EMPTY_FRAME_ERROR
cdef class WebSocketDataQueue:
cdef unsigned int _size
cdef public object _protocol
cdef unsigned int _limit
cdef object _loop
cdef bint _eof
cdef object _waiter
cdef object _exception
cdef public object _buffer
cdef object _get_buffer
cdef object _put_buffer
cdef void _release_waiter(self)
cpdef void feed_data(self, object data, unsigned int size)
@cython.locals(size="unsigned int")
cdef _read_from_buffer(self)
cdef class WebSocketReader:
cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef Exception _exc
cdef bytearray _partial
cdef unsigned int _state
cdef int _opcode
cdef bint _frame_fin
cdef int _frame_opcode
cdef list _payload_fragments
cdef Py_ssize_t _frame_payload_len
cdef bytes _tail
cdef bint _has_mask
cdef bytes _frame_mask
cdef Py_ssize_t _payload_bytes_to_read
cdef unsigned int _payload_len_flag
cdef int _compressed
cdef object _decompressobj
cdef bint _compress
cpdef tuple feed_data(self, object data)
@cython.locals(
is_continuation=bint,
fin=bint,
has_partial=bint,
payload_merged=bytes,
)
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
@cython.locals(
start_pos=Py_ssize_t,
data_len=Py_ssize_t,
length=Py_ssize_t,
chunk_size=Py_ssize_t,
chunk_len=Py_ssize_t,
data_len=Py_ssize_t,
data_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
f_start_pos=Py_ssize_t,
f_end_pos=Py_ssize_t,
has_mask=bint,
fin=bint,
had_fragments=Py_ssize_t,
payload_bytearray=bytearray,
)
cpdef void _feed_data(self, bytes data) except *

View File

@@ -0,0 +1,478 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# Validate continuation frames before processing
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
# XXX: It's possible that the zlib backend (isal is known to
# do this, maybe others too?) will return max_length bytes,
# but internally buffer more data such that the payload is
# >max_length, so we return one extra byte and if we're able
# to do that, then the message is too big.
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING,
(
self._max_msg_size + 1
if self._max_msg_size
else self._max_msg_size
),
)
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message exceeds size limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,478 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Deque, Final, Optional, Set, Tuple, Union
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: Optional[asyncio.Future[None]] = None
self._exception: Union[BaseException, None] = None
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> Optional[BaseException]:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._exc: Optional[Exception] = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: Optional[bytes] = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(
self, data: Union[bytes, bytearray, memoryview]
) -> Tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
payload: Union[bytes, bytearray],
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# Validate continuation frames before processing
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(self._partial)} "
f"exceeds limit {self._max_msg_size}",
)
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: Union[bytes, bytearray]
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {len(assembled_payload)} "
f"exceeds limit {self._max_msg_size}",
)
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
# XXX: It's possible that the zlib backend (isal is known to
# do this, maybe others too?) will return max_length bytes,
# but internally buffer more data such that the payload is
# >max_length, so we return one extra byte and if we're able
# to do that, then the message is too big.
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING,
(
self._max_msg_size + 1
if self._max_msg_size
else self._max_msg_size
),
)
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message exceeds size limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: Union[bytes, bytearray]
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@@ -0,0 +1,262 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import random
import sys
from functools import partial
from typing import Final, Optional, Set, Union
from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
from ..compression_utils import ZLibBackend, ZLibCompressor
from .helpers import (
MASK_LEN,
MSG_SIZE,
PACK_CLOSE_CODE,
PACK_LEN1,
PACK_LEN2,
PACK_LEN3,
PACK_RANDBITS,
websocket_mask,
)
from .models import WS_DEFLATE_TRAILING, WSMsgType
DEFAULT_LIMIT: Final[int] = 2**16
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
# Control frames (ping, pong, close) are never compressed
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size to reduce the number of executor calls and avoid task
# creation overhead, since both are significant sources of latency when chunks
# are small. A size of 16KiB was chosen as a balance between avoiding task
# overhead and not blocking the event loop too long with synchronous compression.
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
class WebSocketWriter:
"""WebSocket writer.
The writer is responsible for sending messages to the client. It is
created by the protocol when a connection is established. The writer
should avoid implementing any application logic and should only be
concerned with the low-level details of the WebSocket protocol.
"""
def __init__(
self,
protocol: BaseProtocol,
transport: asyncio.Transport,
*,
use_mask: bool = False,
limit: int = DEFAULT_LIMIT,
random: random.Random = random.Random(),
compress: int = 0,
notakeover: bool = False,
) -> None:
"""Initialize a WebSocket writer."""
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj: Optional[ZLibCompressor] = None
self._send_lock = asyncio.Lock()
self._background_tasks: Set[asyncio.Task[None]] = set()
async def send_frame(
self, message: bytes, opcode: int, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
# Non-compressed frames don't need lock or shield
self._write_websocket_frame(message, opcode, 0)
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
# Small compressed payloads - compress synchronously in event loop
# We need the lock even though sync compression has no await points.
# This prevents small frames from interleaving with large frames that
# compress in the executor, avoiding compressor state corruption.
async with self._send_lock:
self._send_compressed_frame_sync(message, opcode, compress)
else:
# Large compressed frames need shield to prevent corruption
# For large compressed frames, the entire compress+send
# operation must be atomic. If cancelled after compression but
# before send, the compressor state would be advanced but data
# not sent, corrupting subsequent frames.
# Create a task to shield from cancellation
# The lock is acquired inside the shielded task so the entire
# operation (lock + compress + send) completes atomically.
# Use eager_start on Python 3.12+ to avoid scheduling overhead
loop = asyncio.get_running_loop()
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
if sys.version_info >= (3, 12):
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
send_task = loop.create_task(coro)
# Keep a strong reference to prevent garbage collection
self._background_tasks.add(send_task)
send_task.add_done_callback(self._background_tasks.discard)
await asyncio.shield(send_task)
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
"""
Write a websocket frame to the transport.
This method handles frame header construction, masking, and writing to transport.
It does not handle compression or flow control - those are the responsibility
of the caller.
"""
msg_length = len(message)
use_mask = self.use_mask
mask_bit = 0x80 if use_mask else 0
# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < 65536:
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10
if self.transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += header_len + msg_length
def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
"""Get or create a compressor object for the given compression level."""
if compress:
# Do not set self._compress if compressing is for this frame
return ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
return self._compressobj
def _send_compressed_frame_sync(
self, message: bytes, opcode: int, compress: Optional[int]
) -> None:
"""
Synchronous send for small compressed frames.
This is used for small compressed payloads that compress synchronously in the event loop.
Since there are no await points, this is inherently cancellation-safe.
"""
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
compressobj.compress_sync(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)
async def _send_compressed_frame_async_locked(
self, message: bytes, opcode: int, compress: Optional[int]
) -> None:
"""
Async send for large compressed frames with lock.
Acquires the lock and compresses large payloads asynchronously in
the executor. The lock is held for the entire operation to ensure
the compressor state is not corrupted by concurrent sends.
MUST be run shielded from cancellation. If cancelled after
compression but before sending, the compressor state would be
advanced but data not sent, corrupting subsequent frames.
"""
async with self._send_lock:
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
await compressobj.compress(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
self._closing = True

View File

@@ -0,0 +1,268 @@
import asyncio
import logging
import socket
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)
from multidict import CIMultiDict
from yarl import URL
from ._cookie_helpers import parse_set_cookie_headers
from .typedefs import LooseCookies
if TYPE_CHECKING:
from .web_app import Application
from .web_exceptions import HTTPException
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
else:
BaseRequest = Request = Application = StreamResponse = None
HTTPException = None
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
@abstractmethod
async def resolve(self, request: Request) -> "AbstractMatchInfo":
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
__slots__ = ()
@property # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Execute matched request handler"""
@property
@abstractmethod
def expect_handler(
self,
) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> Optional[HTTPException]:
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> Dict[str, Any]:
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> Tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
return self._request
@abstractmethod
def __await__(self) -> Generator[None, None, StreamResponse]:
"""Execute the view handler."""
class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""
hostname: str
host: str
port: int
family: int
proto: int
flags: int
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
async def close(self) -> None:
"""Release resolver"""
if TYPE_CHECKING:
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self._loop = loop or asyncio.get_running_loop()
@property
@abstractmethod
def quote_cookie(self) -> bool:
"""Return True if cookies should be quoted."""
@abstractmethod
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
"""Clear all cookies if no predicate is passed."""
@abstractmethod
def clear_domain(self, domain: str) -> None:
"""Clear all cookies for domain and all subdomains."""
@abstractmethod
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
def update_cookies_from_headers(
self, headers: Sequence[str], response_url: URL
) -> None:
"""Update cookies from raw Set-Cookie headers."""
if headers and (cookies_to_update := parse_set_cookie_headers(headers)):
self.update_cookies(cookies_to_update, response_url)
@abstractmethod
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size: int = 0
output_size: int = 0
length: Optional[int] = 0
@abstractmethod
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes = b"") -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(
self, encoding: str = "deflate", strategy: Optional[int] = None
) -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write HTTP headers"""
def send_headers(self) -> None:
"""Force sending buffered headers if not already sent.
Required only if write_headers() buffers headers instead of sending immediately.
For backwards compatibility, this method does nothing by default.
"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
__slots__ = ("logger", "log_format")
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
"""Emit log to logger."""
@property
def enabled(self) -> bool:
"""Check if logger is enabled."""
return True

View File

@@ -0,0 +1,100 @@
import asyncio
from typing import Optional, cast
from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
class BaseProtocol(asyncio.Protocol):
__slots__ = (
"_loop",
"_paused",
"_drain_waiter",
"_connection_lost",
"_reading_paused",
"transport",
)
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._reading_paused = False
self.transport: Optional[asyncio.Transport] = None
@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None
@property
def writing_paused(self) -> bool:
return self._paused
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
if not self._reading_paused and self.transport is not None:
try:
self.transport.pause_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = True
def resume_reading(self) -> None:
if self._reading_paused and self.transport is not None:
try:
self.transport.resume_reading()
except (AttributeError, NotImplementedError, RuntimeError):
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: Optional[BaseException]) -> None:
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)
async def _drain_helper(self) -> None:
if self.transport is None:
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
waiter = self._loop.create_future()
self._drain_waiter = waiter
await asyncio.shield(waiter)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,421 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, Union
from multidict import MultiMapping
from .typedefs import StrOrURL
if TYPE_CHECKING:
import ssl
SSLContext = ssl.SSLContext
else:
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore[assignment]
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorDNSError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
"ConnectionTimeoutError",
"SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ClientResponseError",
"ClientHttpProxyError",
"WSServerHandshakeError",
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"WSMessageTypeError",
)
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Base class for exceptions that occur after getting a response.
request_info: An instance of RequestInfo.
history: A sequence of responses, if redirects occurred.
status: HTTP status code.
message: Error message.
headers: Response headers.
"""
def __init__(
self,
request_info: RequestInfo,
history: Tuple[ClientResponse, ...],
*,
code: Optional[int] = None,
status: Optional[int] = None,
message: str = "",
headers: Optional[MultiMapping[str]] = None,
) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead"
)
warnings.warn(
"code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return "{}, message={!r}, url={!r}".format(
self.status,
self.message,
str(self.request_info.real_url),
)
def __repr__(self) -> str:
args = f"{self.request_info!r}, {self.history!r}"
if self.status != 0:
args += f", status={self.status!r}"
if self.message != "":
args += f", message={self.message!r}"
if self.headers is not None:
args += f", headers={self.headers!r}"
return f"{type(self).__name__}({args})"
@property
def code(self) -> int:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`aiohttp.connector.TCPConnector` if
a connection can not be established.
"""
def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl
def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientConnectorDNSError(ClientConnectorError):
"""DNS resolution failed during client connection.
Raised in :class:`aiohttp.connector.TCPConnector` if
DNS resolution fails.
"""
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
class UnixClientConnectorError(ClientConnectorError):
"""Unix connector error.
Raised in :py:class:`aiohttp.connector.UnixConnector`
if connection to unix socket can not be established.
"""
def __init__(
self, path: str, connection_key: ConnectionKey, os_error: OSError
) -> None:
self._path = path
super().__init__(connection_key, os_error)
@property
def path(self) -> str:
return self._path
def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
if message is None:
message = "Server disconnected"
self.args = (message,)
self.message = message
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ConnectionTimeoutError(ServerTimeoutError):
"""Connection timeout error."""
class SocketTimeoutError(ServerTimeoutError):
"""Socket timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return "<{} expected={!r} got={!r} host={!r} port={!r}>".format(
self.__class__.__name__, self.expected, self.got, self.host, self.port
)
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part.
"""
# Derive from ValueError for backward compatibility
def __init__(self, url: StrOrURL, description: Union[str, None] = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
self._url = url
self._description = description
if description:
super().__init__(url, description)
else:
super().__init__(url)
@property
def url(self) -> StrOrURL:
return self._url
@property
def description(self) -> "str | None":
return self._description
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self}>"
def __str__(self) -> str:
if self._description:
return f"{self._url} - {self._description}"
return str(self._url)
class InvalidUrlClientError(InvalidURL):
"""Invalid URL client error."""
class RedirectClientError(ClientError):
"""Client redirect error."""
class NonHttpUrlClientError(ClientError):
"""Non http URL client error."""
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
"""Invalid URL redirect client error."""
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
"""Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (
ClientSSLError,
ssl.CertificateError,
)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (
ClientSSLError,
ValueError,
)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc]
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc]
"""Response certificate error."""
def __init__(
self, connection_key: ConnectionKey, certificate_error: Exception
) -> None:
self._conn_key = connection_key
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> Optional[int]:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return (
"Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} "
"[{0.certificate_error.__class__.__name__}: "
"{0.certificate_error.args}]".format(self)
)
class WSMessageTypeError(TypeError):
"""WebSocket message type is not valid."""

View File

@@ -0,0 +1,480 @@
"""
Digest authentication middleware for aiohttp client.
This middleware implements HTTP Digest Authentication according to RFC 7616,
providing a more secure alternative to Basic Authentication. It supports all
standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session
variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options.
"""
import hashlib
import os
import re
import sys
import time
from typing import (
Callable,
Dict,
Final,
FrozenSet,
List,
Literal,
Tuple,
TypedDict,
Union,
)
from yarl import URL
from . import hdrs
from .client_exceptions import ClientError
from .client_middlewares import ClientHandlerType
from .client_reqrep import ClientRequest, ClientResponse
from .payload import Payload
class DigestAuthChallenge(TypedDict, total=False):
realm: str
nonce: str
qop: str
algorithm: str
opaque: str
domain: str
stale: str
DigestFunctions: Dict[str, Callable[[bytes], "hashlib._Hash"]] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
"SHA-SESS": hashlib.sha1,
"SHA256": hashlib.sha256,
"SHA256-SESS": hashlib.sha256,
"SHA-256": hashlib.sha256,
"SHA-256-SESS": hashlib.sha256,
"SHA512": hashlib.sha512,
"SHA512-SESS": hashlib.sha512,
"SHA-512": hashlib.sha512,
"SHA-512-SESS": hashlib.sha512,
}
# Compile the regex pattern once at module level for performance
_HEADER_PAIRS_PATTERN = re.compile(
r'(?:^|\s|,\s*)(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
if sys.version_info < (3, 11)
else r'(?:^|\s|,\s*)((?>\w+))\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
# +------------|--------|--|-|-|--|----|------|----|--||-----|-> Match valid start/sep
# +--------|--|-|-|--|----|------|----|--||-----|-> alphanumeric key (atomic
# | | | | | | | | || | group reduces backtracking)
# +--|-|-|--|----|------|----|--||-----|-> maybe whitespace
# | | | | | | | || |
# +-|-|--|----|------|----|--||-----|-> = (delimiter)
# +-|--|----|------|----|--||-----|-> maybe whitespace
# | | | | | || |
# +--|----|------|----|--||-----|-> group quoted or unquoted
# | | | | || |
# +----|------|----|--||-----|-> if quoted...
# +------|----|--||-----|-> anything but " or \
# +----|--||-----|-> escaped characters allowed
# +--||-----|-> or can be empty string
# || |
# +|-----|-> if unquoted...
# +-----|-> anything but , or <space>
# +-> at least one char req'd
)
# RFC 7616: Challenge parameters to extract
CHALLENGE_FIELDS: Final[
Tuple[
Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ...
]
] = (
"realm",
"nonce",
"qop",
"algorithm",
"opaque",
"domain",
"stale",
)
# Supported digest authentication algorithms
# Use a tuple of sorted keys for predictable documentation and error messages
SUPPORTED_ALGORITHMS: Final[Tuple[str, ...]] = tuple(sorted(DigestFunctions.keys()))
# RFC 7616: Fields that require quoting in the Digest auth header
# These fields must be enclosed in double quotes in the Authorization header.
# Algorithm, qop, and nc are never quoted per RFC specifications.
# This frozen set is used by the template-based header construction to
# automatically determine which fields need quotes.
QUOTED_AUTH_FIELDS: Final[FrozenSet[str]] = frozenset(
{"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"}
)
def escape_quotes(value: str) -> str:
"""Escape double quotes for HTTP header values."""
return value.replace('"', '\\"')
def unescape_quotes(value: str) -> str:
"""Unescape double quotes in HTTP header values."""
return value.replace('\\"', '"')
def parse_header_pairs(header: str) -> Dict[str, str]:
"""
Parse key-value pairs from WWW-Authenticate or similar HTTP headers.
This function handles the complex format of WWW-Authenticate header values,
supporting both quoted and unquoted values, proper handling of commas in
quoted values, and whitespace variations per RFC 7616.
Examples of supported formats:
- key1="value1", key2=value2
- key1 = "value1" , key2="value, with, commas"
- key1=value1,key2="value2"
- realm="example.com", nonce="12345", qop="auth"
Args:
header: The header value string to parse
Returns:
Dictionary mapping parameter names to their values
"""
return {
stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val
for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header)
if (stripped_key := key.strip())
}
class DigestAuthMiddleware:
"""
HTTP digest authentication middleware for aiohttp client.
This middleware intercepts 401 Unauthorized responses containing a Digest
authentication challenge, calculates the appropriate digest credentials,
and automatically retries the request with the proper Authorization header.
Features:
- Handles all aspects of Digest authentication handshake automatically
- Supports all standard hash algorithms:
- MD5, MD5-SESS
- SHA, SHA-SESS
- SHA256, SHA256-SESS, SHA-256, SHA-256-SESS
- SHA512, SHA512-SESS, SHA-512, SHA-512-SESS
- Supports 'auth' and 'auth-int' quality of protection modes
- Properly handles quoted strings and parameter parsing
- Includes replay attack protection with client nonce count tracking
- Supports preemptive authentication per RFC 7616 Section 3.6
Standards compliance:
- RFC 7616: HTTP Digest Access Authentication (primary reference)
- RFC 2617: HTTP Authentication (deprecated by RFC 7616)
- RFC 1945: Section 11.1 (username restrictions)
Implementation notes:
The core digest calculation is inspired by the implementation in
https://github.com/requests/requests/blob/v2.18.4/requests/auth.py
with added support for modern digest auth features and error handling.
"""
def __init__(
self,
login: str,
password: str,
preemptive: bool = True,
) -> None:
if login is None:
raise ValueError("None is not allowed as login value")
if password is None:
raise ValueError("None is not allowed as password value")
if ":" in login:
raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)')
self._login_str: Final[str] = login
self._login_bytes: Final[bytes] = login.encode("utf-8")
self._password_bytes: Final[bytes] = password.encode("utf-8")
self._last_nonce_bytes = b""
self._nonce_count = 0
self._challenge: DigestAuthChallenge = {}
self._preemptive: bool = preemptive
# Set of URLs defining the protection space
self._protection_space: List[str] = []
async def _encode(
self, method: str, url: URL, body: Union[Payload, Literal[b""]]
) -> str:
"""
Build digest authorization header for the current challenge.
Args:
method: The HTTP method (GET, POST, etc.)
url: The request URL
body: The request body (used for qop=auth-int)
Returns:
A fully formatted Digest authorization header string
Raises:
ClientError: If the challenge is missing required parameters or
contains unsupported values
"""
challenge = self._challenge
if "realm" not in challenge:
raise ClientError(
"Malformed Digest auth challenge: Missing 'realm' parameter"
)
if "nonce" not in challenge:
raise ClientError(
"Malformed Digest auth challenge: Missing 'nonce' parameter"
)
# Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name)
realm = challenge["realm"]
nonce = challenge["nonce"]
# Empty nonce values are not allowed as they are security-critical for replay protection
if not nonce:
raise ClientError(
"Security issue: Digest auth challenge contains empty 'nonce' value"
)
qop_raw = challenge.get("qop", "")
# Preserve original algorithm case for response while using uppercase for processing
algorithm_original = challenge.get("algorithm", "MD5")
algorithm = algorithm_original.upper()
opaque = challenge.get("opaque", "")
# Convert string values to bytes once
nonce_bytes = nonce.encode("utf-8")
realm_bytes = realm.encode("utf-8")
path = URL(url).path_qs
# Process QoP
qop = ""
qop_bytes = b""
if qop_raw:
valid_qops = {"auth", "auth-int"}.intersection(
{q.strip() for q in qop_raw.split(",") if q.strip()}
)
if not valid_qops:
raise ClientError(
f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}"
)
qop = "auth-int" if "auth-int" in valid_qops else "auth"
qop_bytes = qop.encode("utf-8")
if algorithm not in DigestFunctions:
raise ClientError(
f"Digest auth error: Unsupported hash algorithm: {algorithm}. "
f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}"
)
hash_fn: Final = DigestFunctions[algorithm]
def H(x: bytes) -> bytes:
"""RFC 7616 Section 3: Hash function H(data) = hex(hash(data))."""
return hash_fn(x).hexdigest().encode()
def KD(s: bytes, d: bytes) -> bytes:
"""RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data))."""
return H(b":".join((s, d)))
# Calculate A1 and A2
A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes))
A2 = f"{method.upper()}:{path}".encode()
if qop == "auth-int":
if isinstance(body, Payload): # will always be empty bytes unless Payload
entity_bytes = await body.as_bytes() # Get bytes from Payload
else:
entity_bytes = body
entity_hash = H(entity_bytes)
A2 = b":".join((A2, entity_hash))
HA1 = H(A1)
HA2 = H(A2)
# Nonce count handling
if nonce_bytes == self._last_nonce_bytes:
self._nonce_count += 1
else:
self._nonce_count = 1
self._last_nonce_bytes = nonce_bytes
ncvalue = f"{self._nonce_count:08x}"
ncvalue_bytes = ncvalue.encode("utf-8")
# Generate client nonce
cnonce = hashlib.sha1(
b"".join(
[
str(self._nonce_count).encode("utf-8"),
nonce_bytes,
time.ctime().encode("utf-8"),
os.urandom(8),
]
)
).hexdigest()[:16]
cnonce_bytes = cnonce.encode("utf-8")
# Special handling for session-based algorithms
if algorithm.upper().endswith("-SESS"):
HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes)))
# Calculate the response digest
if qop:
noncebit = b":".join(
(nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2)
)
response_digest = KD(HA1, noncebit)
else:
response_digest = KD(HA1, b":".join((nonce_bytes, HA2)))
# Define a dict mapping of header fields to their values
# Group fields into always-present, optional, and qop-dependent
header_fields = {
# Always present fields
"username": escape_quotes(self._login_str),
"realm": escape_quotes(realm),
"nonce": escape_quotes(nonce),
"uri": path,
"response": response_digest.decode(),
"algorithm": algorithm_original,
}
# Optional fields
if opaque:
header_fields["opaque"] = escape_quotes(opaque)
# QoP-dependent fields
if qop:
header_fields["qop"] = qop
header_fields["nc"] = ncvalue
header_fields["cnonce"] = cnonce
# Build header using templates for each field type
pairs: List[str] = []
for field, value in header_fields.items():
if field in QUOTED_AUTH_FIELDS:
pairs.append(f'{field}="{value}"')
else:
pairs.append(f"{field}={value}")
return f"Digest {', '.join(pairs)}"
def _in_protection_space(self, url: URL) -> bool:
"""
Check if the given URL is within the current protection space.
According to RFC 7616, a URI is in the protection space if any URI
in the protection space is a prefix of it (after both have been made absolute).
"""
request_str = str(url)
for space_str in self._protection_space:
# Check if request starts with space URL
if not request_str.startswith(space_str):
continue
# Exact match or space ends with / (proper directory prefix)
if len(request_str) == len(space_str) or space_str[-1] == "/":
return True
# Check next char is / to ensure proper path boundary
if request_str[len(space_str)] == "/":
return True
return False
def _authenticate(self, response: ClientResponse) -> bool:
"""
Takes the given response and tries digest-auth, if needed.
Returns true if the original request must be resent.
"""
if response.status != 401:
return False
auth_header = response.headers.get("www-authenticate", "")
if not auth_header:
return False # No authentication header present
method, sep, headers = auth_header.partition(" ")
if not sep:
# No space found in www-authenticate header
return False # Malformed auth header, missing scheme separator
if method.lower() != "digest":
# Not a digest auth challenge (could be Basic, Bearer, etc.)
return False
if not headers:
# We have a digest scheme but no parameters
return False # Malformed digest header, missing parameters
# We have a digest auth header with content
if not (header_pairs := parse_header_pairs(headers)):
# Failed to parse any key-value pairs
return False # Malformed digest header, no valid parameters
# Extract challenge parameters
self._challenge = {}
for field in CHALLENGE_FIELDS:
if value := header_pairs.get(field):
self._challenge[field] = value
# Update protection space based on domain parameter or default to origin
origin = response.url.origin()
if domain := self._challenge.get("domain"):
# Parse space-separated list of URIs
self._protection_space = []
for uri in domain.split():
# Remove quotes if present
uri = uri.strip('"')
if uri.startswith("/"):
# Path-absolute, relative to origin
self._protection_space.append(str(origin.join(URL(uri))))
else:
# Absolute URI
self._protection_space.append(str(URL(uri)))
else:
# No domain specified, protection space is entire origin
self._protection_space = [str(origin)]
# Return True only if we found at least one challenge parameter
return bool(self._challenge)
async def __call__(
self, request: ClientRequest, handler: ClientHandlerType
) -> ClientResponse:
"""Run the digest auth middleware."""
response = None
for retry_count in range(2):
# Apply authorization header if:
# 1. This is a retry after 401 (retry_count > 0), OR
# 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space
if retry_count > 0 or (
self._preemptive
and self._challenge
and self._in_protection_space(request.url)
):
request.headers[hdrs.AUTHORIZATION] = await self._encode(
request.method, request.url, request.body
)
# Send the request
response = await handler(request)
# Check if we need to authenticate
if not self._authenticate(response):
break
# At this point, response is guaranteed to be defined
assert response is not None
return response

View File

@@ -0,0 +1,55 @@
"""Client middleware support."""
from collections.abc import Awaitable, Callable, Sequence
from .client_reqrep import ClientRequest, ClientResponse
__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")
# Type alias for client request handlers - functions that process requests and return responses
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]
# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
ClientMiddlewareType = Callable[
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
]
def build_client_middlewares(
handler: ClientHandlerType,
middlewares: Sequence[ClientMiddlewareType],
) -> ClientHandlerType:
"""
Apply middlewares to request handler.
The middlewares are applied in reverse order, so the first middleware
in the list wraps all subsequent middlewares and the handler.
This implementation avoids using partial/update_wrapper to minimize overhead
and doesn't cache to avoid holding references to stateful middleware.
"""
# Optimize for single middleware case
if len(middlewares) == 1:
middleware = middlewares[0]
async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
return await middleware(req, handler)
return single_middleware_handler
# Build the chain for multiple middlewares
current_handler = handler
for middleware in reversed(middlewares):
# Create a new closure that captures the current state
def make_wrapper(
mw: ClientMiddlewareType, next_h: ClientHandlerType
) -> ClientHandlerType:
async def wrapped(req: ClientRequest) -> ClientResponse:
return await mw(req, next_h)
return wrapped
current_handler = make_wrapper(middleware, current_handler)
return current_handler

View File

@@ -0,0 +1,359 @@
import asyncio
from contextlib import suppress
from typing import Any, Optional, Tuple, Union
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
EMPTY_BODY_STATUS_CODES,
BaseTimerContext,
set_exception,
set_result,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload: Optional[StreamReader] = None
self._skip_payload = False
self._payload_parser = None
self._timer = None
self._tail = b""
self._upgraded = False
self._parser: Optional[HttpResponseParser] = None
self._read_timeout: Optional[float] = None
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None
self._timeout_ceil_threshold: Optional[float] = 5
self._closed: Union[None, asyncio.Future[None]] = None
self._connection_lost_called = False
@property
def closed(self) -> Union[None, asyncio.Future[None]]:
"""Future that is set when the connection is closed.
This property returns a Future that will be completed when the connection
is closed. The Future is created lazily on first access to avoid creating
futures that will never be awaited.
Returns:
- A Future[None] if the connection is still open or was closed after
this property was accessed
- None if connection_lost() was already called before this property
was ever accessed (indicating no one is waiting for the closure)
"""
if self._closed is None and not self._connection_lost_called:
self._closed = self._loop.create_future()
return self._closed
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
return bool(
self._should_close
or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
or self._exception is not None
or self._payload_parser is not None
or self._buffer
or self._tail
)
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def abort(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.abort()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None and not self.transport.is_closing()
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost_called = True
self._drop_timeout()
original_connection_error = exc
reraised_exc = original_connection_error
connection_closed_cleanly = original_connection_error is None
if self._closed is not None:
# If someone is waiting for the closed future,
# we should set it to None or an exception. If
# self._closed is None, it means that
# connection_lost() was called already
# or nobody is waiting for it.
if connection_closed_cleanly:
set_result(self._closed, None)
else:
assert original_connection_error is not None
set_exception(
self._closed,
ClientConnectionError(
f"Connection lost: {original_connection_error !s}",
),
original_connection_error,
)
if self._payload_parser is not None:
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as underlying_exc:
if self._payload is not None:
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)
if not self.is_eof():
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(reraised_exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc, exc_cause)
def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def set_response_params(
self,
*,
timer: Optional[BaseTimerContext] = None,
skip_payload: bool = False,
read_until_eof: bool = False,
auto_decompress: bool = True,
read_timeout: Optional[float] = None,
read_bufsize: int = 2**16,
timeout_ceil_threshold: float = 5,
max_line_size: int = 8190,
max_field_size: int = 8190,
) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._timeout_ceil_threshold = timeout_ceil_threshold
self._parser = HttpResponseParser(
self,
self._loop,
read_bufsize,
timer=timer,
payload_exception=ClientPayloadError,
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
max_line_size=max_line_size,
max_field_size=max_field_size,
)
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout
)
else:
self._read_timeout_handle = None
def start_timeout(self) -> None:
self._reschedule_timeout()
@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout
@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout
def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
set_exception(self._payload, exc)
def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
if not data:
return
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if upgraded and tail:
self.data_received(tail)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,428 @@
"""WebSocket client for asyncio."""
import asyncio
import sys
from types import TracebackType
from typing import Any, Optional, Type, cast
import attr
from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSMessage,
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONDecoder,
JSONEncoder,
)
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
@attr.s(frozen=True, slots=True)
class ClientWSTimeout:
ws_receive = attr.ib(type=Optional[float], default=None)
ws_close = attr.ib(type=Optional[float], default=None)
DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
class ClientWebSocketResponse:
def __init__(
self,
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: Optional[str],
response: ClientResponse,
timeout: ClientWSTimeout,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
heartbeat: Optional[float] = None,
compress: int = 0,
client_notakeover: bool = False,
) -> None:
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code: Optional[int] = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: Optional[asyncio.TimerHandle] = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: Optional[asyncio.TimerHandle] = None
self._loop = loop
self._waiting: bool = False
self._close_wait: Optional[asyncio.Future[None]] = None
self._exception: Optional[BaseException] = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: Optional[asyncio.Task[None]] = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
self._heartbeat_cb = None
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(coro)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
self._handle_ping_pong_exception(
ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
)
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> Optional[int]:
return self._close_code
@property
def protocol(self) -> Optional[str]:
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> Optional[BaseException]:
return self._exception
async def ping(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PING)
async def pong(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PONG)
async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: Optional[int] = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)
async def send_str(self, data: str, compress: Optional[int] = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)
async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
async def send_json(
self,
data: Any,
compress: Optional[int] = None,
*,
dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
) -> None:
await self.send_str(dumps(data), compress=compress)
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._close_wait
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if self._close_code:
self._response.close()
return True
while True:
try:
async with async_timeout.timeout(self._timeout.ws_close):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
async def receive(self, timeout: Optional[float] = None) -> WSMessage:
receive_timeout = timeout or self._timeout.ws_receive
while True:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
try:
self._waiting = True
try:
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
self._reset_heartbeat()
finally:
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type not in _INTERNAL_RECEIVE_TYPES:
# If its not a close/closing/ping/pong message
# we can return it immediately
return msg
if msg.type is WSMsgType.CLOSE:
self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type is WSMsgType.CLOSING:
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
async def receive_str(self, *, timeout: Optional[float] = None) -> str:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return cast(bytes, msg.data)
async def receive_json(
self,
*,
loads: JSONDecoder = DEFAULT_JSON_DECODER,
timeout: Optional[float] = None,
) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data)
def __aiter__(self) -> "ClientWebSocketResponse":
return self
async def __anext__(self) -> WSMessage:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
async def __aenter__(self) -> "ClientWebSocketResponse":
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await self.close()

View File

@@ -0,0 +1,348 @@
import asyncio
import sys
import zlib
from abc import ABC, abstractmethod
from concurrent.futures import Executor
from typing import Any, Final, Optional, Protocol, TypedDict, cast
if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
from typing import Union
Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
try:
try:
import brotlicffi as brotli
except ImportError:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
try:
if sys.version_info >= (3, 14):
from compression.zstd import ZstdDecompressor # noqa: I900
else: # TODO(PY314): Remove mentions of backports.zstd across codebase
from backports.zstd import ZstdDecompressor
HAS_ZSTD = True
except ImportError:
HAS_ZSTD = False
MAX_SYNC_CHUNK_SIZE = 4096
DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB
# Unlimited decompression constants - different libraries use different conventions
ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
class ZLibCompressObjProtocol(Protocol):
def compress(self, data: Buffer) -> bytes: ...
def flush(self, mode: int = ..., /) -> bytes: ...
class ZLibDecompressObjProtocol(Protocol):
def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
def flush(self, length: int = ..., /) -> bytes: ...
@property
def eof(self) -> bool: ...
class ZLibBackendProtocol(Protocol):
MAX_WBITS: int
Z_FULL_FLUSH: int
Z_SYNC_FLUSH: int
Z_BEST_SPEED: int
Z_FINISH: int
def compressobj(
self,
level: int = ...,
method: int = ...,
wbits: int = ...,
memLevel: int = ...,
strategy: int = ...,
zdict: Optional[Buffer] = ...,
) -> ZLibCompressObjProtocol: ...
def decompressobj(
self, wbits: int = ..., zdict: Buffer = ...
) -> ZLibDecompressObjProtocol: ...
def compress(
self, data: Buffer, /, level: int = ..., wbits: int = ...
) -> bytes: ...
def decompress(
self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
) -> bytes: ...
class CompressObjArgs(TypedDict, total=False):
wbits: int
strategy: int
level: int
class ZLibBackendWrapper:
def __init__(self, _zlib_backend: ZLibBackendProtocol):
self._zlib_backend: ZLibBackendProtocol = _zlib_backend
@property
def name(self) -> str:
return getattr(self._zlib_backend, "__name__", "undefined")
@property
def MAX_WBITS(self) -> int:
return self._zlib_backend.MAX_WBITS
@property
def Z_FULL_FLUSH(self) -> int:
return self._zlib_backend.Z_FULL_FLUSH
@property
def Z_SYNC_FLUSH(self) -> int:
return self._zlib_backend.Z_SYNC_FLUSH
@property
def Z_BEST_SPEED(self) -> int:
return self._zlib_backend.Z_BEST_SPEED
@property
def Z_FINISH(self) -> int:
return self._zlib_backend.Z_FINISH
def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
return self._zlib_backend.compressobj(*args, **kwargs)
def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
return self._zlib_backend.decompressobj(*args, **kwargs)
def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.compress(data, *args, **kwargs)
def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.decompress(data, *args, **kwargs)
# Everything not explicitly listed in the Protocol we just pass through
def __getattr__(self, attrname: str) -> Any:
return getattr(self._zlib_backend, attrname)
ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
ZLibBackend._zlib_backend = new_zlib_backend
def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + ZLibBackend.MAX_WBITS
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
class DecompressionBaseHandler(ABC):
def __init__(
self,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
"""Base class for decompression handlers."""
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
@abstractmethod
def decompress_sync(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
async def decompress(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.decompress_sync, data, max_length
)
return self.decompress_sync(data, max_length)
class ZLibCompressor:
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: Optional[int] = None,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
self._mode = (
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
kwargs: CompressObjArgs = {}
kwargs["wbits"] = self._mode
if strategy is not None:
kwargs["strategy"] = strategy
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)
async def compress(self, data: bytes) -> bytes:
"""Compress the data and returned the compressed bytes.
Note that flush() must be called after the last call to compress()
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.
**WARNING: This method is NOT cancellation-safe when used with flush().**
If this operation is cancelled, the compressor state may be corrupted.
The connection MUST be closed after cancellation to avoid data corruption
in subsequent compress operations.
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
# For large payloads, offload compression to executor to avoid blocking event loop
should_use_executor = (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
)
if should_use_executor:
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
def flush(self, mode: Optional[int] = None) -> bytes:
"""Flush the compressor synchronously.
**WARNING: This method is NOT cancellation-safe when called after compress().**
The flush() operation accesses shared compressor state. If compress() was
cancelled, calling flush() may result in corrupted data. The connection MUST
be closed after compress() cancellation.
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)
class ZLibDecompressor(DecompressionBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
self._mode = encoding_to_mode(encoding, suppress_deflate_header)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
return self._decompressor.decompress(data, max_length)
def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)
@property
def eof(self) -> bool:
return self._decompressor.eof
class BrotliDecompressor(DecompressionBaseHandler):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(
self,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
) -> None:
"""Decompress data using the Brotli library."""
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data, max_length))
return cast(bytes, self._obj.process(data, max_length))
def flush(self) -> bytes:
"""Flush the decompressor."""
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""
class ZSTDDecompressor(DecompressionBaseHandler):
def __init__(
self,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
) -> None:
if not HAS_ZSTD:
raise RuntimeError(
"The zstd decompression is not available. "
"Please install `backports.zstd` module"
)
self._obj = ZstdDecompressor()
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
def decompress_sync(
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
# zstd uses -1 for unlimited, while zlib uses 0 for unlimited
# Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
zstd_max_length = (
ZSTD_MAX_LENGTH_UNLIMITED
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
else max_length
)
return self._obj.decompress(data, zstd_max_length)
def flush(self) -> bytes:
return b""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,522 @@
import asyncio
import calendar
import contextlib
import datetime
import heapq
import itertools
import os # noqa
import pathlib
import pickle
import re
import time
import warnings
from collections import defaultdict
from collections.abc import Mapping
from http.cookies import BaseCookie, Morsel, SimpleCookie
from typing import (
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
from yarl import URL
from ._cookie_helpers import preserve_morsel_with_coded_value
from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL
__all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
# We cache these string methods here as their use is in performance critical code.
_FORMAT_PATH = "{}/{}".format
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
# The minimum number of scheduled cookie expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
_SIMPLE_COOKIE = SimpleCookie()
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
)
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_MONTH_RE = re.compile(
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
re.I,
)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except (OSError, ValueError):
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
# Throws ValueError on PyPy 3.9, OSError elsewhere
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
def __init__(
self,
*,
unsafe: bool = False,
quote_cookie: bool = True,
treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
super().__init__(loop=loop)
self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
self._morsel_cache: DefaultDict[Tuple[str, str], Dict[str, Morsel[str]]] = (
defaultdict(dict)
)
self._host_only_cookies: Set[Tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
if treat_as_secure_origin is None:
treat_as_secure_origin = []
elif isinstance(treat_as_secure_origin, URL):
treat_as_secure_origin = [treat_as_secure_origin.origin()]
elif isinstance(treat_as_secure_origin, str):
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
else:
treat_as_secure_origin = [
URL(url).origin() if isinstance(url, str) else url.origin()
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._expire_heap: List[Tuple[float, Tuple[str, str, str]]] = []
self._expirations: Dict[Tuple[str, str, str], float] = {}
@property
def quote_cookie(self) -> bool:
return self._quote_cookie
def save(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="wb") as f:
pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
def load(self, file_path: PathLike) -> None:
file_path = pathlib.Path(file_path)
with file_path.open(mode="rb") as f:
self._cookies = pickle.load(f)
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
if predicate is None:
self._expire_heap.clear()
self._cookies.clear()
self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
now = time.time()
to_del = [
key
for (domain, path), cookie in self._cookies.items()
for name, morsel in cookie.items()
if (
(key := (domain, path, name)) in self._expirations
and self._expirations[key] <= now
)
or predicate(morsel)
]
if to_del:
self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
def __iter__(self) -> "Iterator[Morsel[str]]":
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
"""Return number of cookies.
This function does not iterate self to avoid unnecessary expiration
checks.
"""
return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
"""Remove expired cookies."""
if not (expire_heap_len := len(self._expire_heap)):
return
# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the cookie has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry
for entry in self._expire_heap
if self._expirations.get(entry[1]) == entry[0]
]
heapq.heapify(self._expire_heap)
now = time.time()
to_del: List[Tuple[str, str, str]] = []
# Find any expired cookies and add them to the to-delete list
while self._expire_heap:
when, cookie_key = self._expire_heap[0]
if when > now:
break
heapq.heappop(self._expire_heap)
# Check if the cookie hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(cookie_key) == when:
to_del.append(cookie_key)
if to_del:
self._delete_cookies(to_del)
def _delete_cookies(self, to_del: List[Tuple[str, str, str]]) -> None:
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
self._cookies[(domain, path)].pop(name, None)
self._morsel_cache[(domain, path)].pop(name, None)
self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
cookie_key = (domain, path, name)
if self._expirations.get(cookie_key) == when:
# Avoid adding duplicates to the heap
return
heapq.heappush(self._expire_heap, (when, cookie_key))
self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
hostname = response_url.raw_host
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
path = path.rstrip("/")
if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
elif expires := cookie["expires"]:
if expire_time := self._parse_date(expires):
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
key = (domain, path)
if self._cookies[key].get(name) != cookie:
# Don't blow away the cache if the same
# cookie gets set again
self._cookies[key][name] = cookie
self._morsel_cache[key].pop(name, None)
self._do_expiration()
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
"""Returns this jar's cookies filtered by their attributes."""
# We always use BaseCookie now since all
# cookies set on on filtered are fully constructed
# Morsels, not just names and values.
filtered: BaseCookie[str] = BaseCookie()
if not self._cookies:
# Skip do_expiration() if there are no cookies.
return filtered
self._do_expiration()
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
if type(request_url) is not URL:
warnings.warn(
"filter_cookies expects yarl.URL instances only,"
f"and will stop working in 4.x, got {type(request_url)}",
DeprecationWarning,
stacklevel=2,
)
request_url = URL(request_url)
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
# Send shared cookie
key = ("", "")
for c in self._cookies[key].values():
# Check cache first
if c.key in self._morsel_cache[key]:
filtered[c.key] = self._morsel_cache[key][c.key]
continue
# Build and cache the morsel
mrsl_val = self._build_morsel(c)
self._morsel_cache[key][c.key] = mrsl_val
filtered[c.key] = mrsl_val
if is_ip_address(hostname):
if not self._unsafe:
return filtered
domains: Iterable[str] = (hostname,)
else:
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
domains = itertools.accumulate(
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
)
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
# Create every combination of (domain, path) pairs.
pairs = itertools.product(domains, paths)
path_len = len(request_url.path)
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for p in pairs:
if p not in self._cookies:
continue
for name, cookie in self._cookies[p].items():
domain = cookie["domain"]
if (domain, name) in self._host_only_cookies and domain != hostname:
continue
# Skip edge case when the cookie has a trailing slash but request doesn't.
if len(cookie["path"]) > path_len:
continue
if is_not_secure and cookie["secure"]:
continue
# We already built the Morsel so reuse it here
if name in self._morsel_cache[p]:
filtered[name] = self._morsel_cache[p][name]
continue
# Build and cache the morsel
mrsl_val = self._build_morsel(cookie)
self._morsel_cache[p][name] = mrsl_val
filtered[name] = mrsl_val
return filtered
def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]:
"""Build a morsel for sending, respecting quote_cookie setting."""
if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"':
return preserve_morsel_with_coded_value(cookie)
morsel: Morsel[str] = Morsel()
if self._quote_cookie:
value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value)
else:
coded_value = value = cookie.value
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined]
return morsel
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[: -len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@classmethod
def _parse_date(cls, date_str: str) -> Optional[int]:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = (int(s) for s in time_match.groups())
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
if not 1 <= day <= 31:
return None
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> "Iterator[Morsel[str]]":
while False:
yield None
def __len__(self) -> int:
return 0
@property
def quote_cookie(self) -> bool:
return True
def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None:
pass
def clear_domain(self, domain: str) -> None:
pass
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
return SimpleCookie()

View File

@@ -0,0 +1,179 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .payload import Payload
__all__ = ("FormData",)
class FormData:
"""Helper class for form body generation.
Supports multipart/form-data and application/x-www-form-urlencoded.
"""
def __init__(
self,
fields: Iterable[Any] = (),
quote_fields: bool = True,
charset: Optional[str] = None,
*,
default_to_multipart: bool = False,
) -> None:
self._writer = multipart.MultipartWriter("form-data")
self._fields: List[Any] = []
self._is_multipart = default_to_multipart
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(
self,
name: str,
value: Any,
*,
content_type: Optional[str] = None,
filename: Optional[str] = None,
content_transfer_encoding: Optional[str] = None,
) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name
type_options: MultiDict[str] = MultiDict({"name": name})
if filename is not None and not isinstance(filename, str):
raise TypeError("filename must be an instance of str. Got: %s" % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
type_options["filename"] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError(
"content_type must be an instance of str. Got: %s" % content_type
)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError(
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, "unknown")
self.add_field(k, rec) # type: ignore[arg-type]
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp)
else:
raise TypeError(
"Only io.IOBase, multidict and (name, file) "
"pairs allowed, use .add_field() for passing "
"more complex parameters, got {!r}".format(rec)
)
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options["name"], value))
charset = self._charset if self._charset is not None else "utf-8"
if charset == "utf-8":
content_type = "application/x-www-form-urlencoded"
else:
content_type = "application/x-www-form-urlencoded; charset=%s" % charset
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type,
)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value,
content_type=headers[hdrs.CONTENT_TYPE],
headers=headers,
encoding=self._charset,
)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset
)
except Exception as exc:
raise TypeError(
"Can not serialize value type: %r\n "
"headers: %r\n value: %r" % (type(value), headers, value)
) from exc
if dispparams:
part.set_content_disposition(
"form-data", quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
self._fields.clear()
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

View File

@@ -0,0 +1,121 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
import itertools
from typing import Final, Set
from multidict import istr
METH_ANY: Final[str] = "*"
METH_CONNECT: Final[str] = "CONNECT"
METH_HEAD: Final[str] = "HEAD"
METH_GET: Final[str] = "GET"
METH_DELETE: Final[str] = "DELETE"
METH_OPTIONS: Final[str] = "OPTIONS"
METH_PATCH: Final[str] = "PATCH"
METH_POST: Final[str] = "POST"
METH_PUT: Final[str] = "PUT"
METH_TRACE: Final[str] = "TRACE"
METH_ALL: Final[Set[str]] = {
METH_CONNECT,
METH_HEAD,
METH_GET,
METH_DELETE,
METH_OPTIONS,
METH_PATCH,
METH_POST,
METH_PUT,
METH_TRACE,
}
ACCEPT: Final[istr] = istr("Accept")
ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset")
ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding")
ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language")
ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges")
ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age")
ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials")
ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers")
ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods")
ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin")
ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers")
ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers")
ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method")
AGE: Final[istr] = istr("Age")
ALLOW: Final[istr] = istr("Allow")
AUTHORIZATION: Final[istr] = istr("Authorization")
CACHE_CONTROL: Final[istr] = istr("Cache-Control")
CONNECTION: Final[istr] = istr("Connection")
CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition")
CONTENT_ENCODING: Final[istr] = istr("Content-Encoding")
CONTENT_LANGUAGE: Final[istr] = istr("Content-Language")
CONTENT_LENGTH: Final[istr] = istr("Content-Length")
CONTENT_LOCATION: Final[istr] = istr("Content-Location")
CONTENT_MD5: Final[istr] = istr("Content-MD5")
CONTENT_RANGE: Final[istr] = istr("Content-Range")
CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding")
CONTENT_TYPE: Final[istr] = istr("Content-Type")
COOKIE: Final[istr] = istr("Cookie")
DATE: Final[istr] = istr("Date")
DESTINATION: Final[istr] = istr("Destination")
DIGEST: Final[istr] = istr("Digest")
ETAG: Final[istr] = istr("Etag")
EXPECT: Final[istr] = istr("Expect")
EXPIRES: Final[istr] = istr("Expires")
FORWARDED: Final[istr] = istr("Forwarded")
FROM: Final[istr] = istr("From")
HOST: Final[istr] = istr("Host")
IF_MATCH: Final[istr] = istr("If-Match")
IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since")
IF_NONE_MATCH: Final[istr] = istr("If-None-Match")
IF_RANGE: Final[istr] = istr("If-Range")
IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since")
KEEP_ALIVE: Final[istr] = istr("Keep-Alive")
LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID")
LAST_MODIFIED: Final[istr] = istr("Last-Modified")
LINK: Final[istr] = istr("Link")
LOCATION: Final[istr] = istr("Location")
MAX_FORWARDS: Final[istr] = istr("Max-Forwards")
ORIGIN: Final[istr] = istr("Origin")
PRAGMA: Final[istr] = istr("Pragma")
PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate")
PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization")
RANGE: Final[istr] = istr("Range")
REFERER: Final[istr] = istr("Referer")
RETRY_AFTER: Final[istr] = istr("Retry-After")
SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept")
SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version")
SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol")
SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions")
SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key")
SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1")
SERVER: Final[istr] = istr("Server")
SET_COOKIE: Final[istr] = istr("Set-Cookie")
TE: Final[istr] = istr("TE")
TRAILER: Final[istr] = istr("Trailer")
TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding")
UPGRADE: Final[istr] = istr("Upgrade")
URI: Final[istr] = istr("URI")
USER_AGENT: Final[istr] = istr("User-Agent")
VARY: Final[istr] = istr("Vary")
VIA: Final[istr] = istr("Via")
WANT_DIGEST: Final[istr] = istr("Want-Digest")
WARNING: Final[istr] = istr("Warning")
WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate")
X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For")
X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host")
X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto")
# These are the upper/lower case variants of the headers/methods
# Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...}
METH_HEAD_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower())))
)
METH_CONNECT_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower())))
)
HOST_ALL: Final = frozenset(
map("".join, itertools.product(*zip(HOST.upper(), HOST.lower())))
)

View File

@@ -0,0 +1,986 @@
"""Various helper functions"""
import asyncio
import base64
import binascii
import contextlib
import datetime
import enum
import functools
import inspect
import netrc
import os
import platform
import re
import sys
import time
import weakref
from collections import namedtuple
from contextlib import suppress
from email.message import EmailMessage
from email.parser import HeaderParser
from email.policy import HTTP
from email.utils import parsedate
from math import ceil
from pathlib import Path
from types import MappingProxyType, TracebackType
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Tuple,
Type,
TypeVar,
Union,
get_args,
overload,
)
from urllib.parse import quote
from urllib.request import getproxies, proxy_bypass
import attr
from multidict import MultiDict, MultiDictProxy, MultiMapping
from propcache.api import under_cached_property as reify
from yarl import URL
from . import hdrs
from .log import client_logger
if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout
__all__ = ("BasicAuth", "ChainMapProxy", "ETag", "reify")
IS_MACOS = platform.system() == "Darwin"
IS_WINDOWS = platform.system() == "Windows"
PY_310 = sys.version_info >= (3, 10)
PY_311 = sys.version_info >= (3, 11)
_T = TypeVar("_T")
_S = TypeVar("_S")
_SENTINEL = enum.Enum("_SENTINEL", "sentinel")
sentinel = _SENTINEL.sentinel
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
EMPTY_BODY_STATUS_CODES = frozenset((204, 304, *range(100, 200)))
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2
EMPTY_BODY_METHODS = hdrs.METH_HEAD_ALL
DEBUG = sys.flags.dev_mode or (
not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
)
CHAR = {chr(i) for i in range(0, 128)}
CTL = {chr(i) for i in range(0, 32)} | {
chr(127),
}
SEPARATORS = {
"(",
")",
"<",
">",
"@",
",",
";",
":",
"\\",
'"',
"/",
"[",
"]",
"?",
"=",
"{",
"}",
" ",
chr(9),
}
TOKEN = CHAR ^ CTL ^ SEPARATORS
class noop:
def __await__(self) -> Generator[None, None, None]:
yield
class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
"""Http basic authentication helper."""
def __new__(
cls, login: str, password: str = "", encoding: str = "latin1"
) -> "BasicAuth":
if login is None:
raise ValueError("None is not allowed as login value")
if password is None:
raise ValueError("None is not allowed as password value")
if ":" in login:
raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
return super().__new__(cls, login, password, encoding)
@classmethod
def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
"""Create a BasicAuth object from an Authorization HTTP header."""
try:
auth_type, encoded_credentials = auth_header.split(" ", 1)
except ValueError:
raise ValueError("Could not parse authorization header.")
if auth_type.lower() != "basic":
raise ValueError("Unknown authorization method %s" % auth_type)
try:
decoded = base64.b64decode(
encoded_credentials.encode("ascii"), validate=True
).decode(encoding)
except binascii.Error:
raise ValueError("Invalid base64 encoding.")
try:
# RFC 2617 HTTP Authentication
# https://www.ietf.org/rfc/rfc2617.txt
# the colon must be present, but the username and password may be
# otherwise blank.
username, password = decoded.split(":", 1)
except ValueError:
raise ValueError("Invalid credentials.")
return cls(username, password, encoding=encoding)
@classmethod
def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
"""Create BasicAuth from url."""
if not isinstance(url, URL):
raise TypeError("url should be yarl.URL instance")
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return None
return cls(url.user or "", url.password or "", encoding=encoding)
def encode(self) -> str:
"""Encode credentials."""
creds = (f"{self.login}:{self.password}").encode(self.encoding)
return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Remove user and password from URL if present and return BasicAuth object."""
# Check raw_user and raw_password first as yarl is likely
# to already have these values parsed from the netloc in the cache.
if url.raw_user is None and url.raw_password is None:
return url, None
return url.with_user(None), BasicAuth(url.user or "", url.password or "")
def netrc_from_env() -> Optional[netrc.netrc]:
"""Load netrc from file.
Attempt to load it from the path specified by the env-var
NETRC or in the default location in the user's home directory.
Returns None if it couldn't be found or fails to parse.
"""
netrc_env = os.environ.get("NETRC")
if netrc_env is not None:
netrc_path = Path(netrc_env)
else:
try:
home_dir = Path.home()
except RuntimeError as e: # pragma: no cover
# if pathlib can't resolve home, it may raise a RuntimeError
client_logger.debug(
"Could not resolve home directory when "
"trying to look for .netrc file: %s",
e,
)
return None
netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc")
try:
return netrc.netrc(str(netrc_path))
except netrc.NetrcParseError as e:
client_logger.warning("Could not parse .netrc file: %s", e)
except OSError as e:
netrc_exists = False
with contextlib.suppress(OSError):
netrc_exists = netrc_path.is_file()
# we couldn't read the file (doesn't exist, permissions, etc.)
if netrc_env or netrc_exists:
# only warn if the environment wanted us to load it,
# or it appears like the default file does actually exist
client_logger.warning("Could not read .netrc file: %s", e)
return None
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ProxyInfo:
proxy: URL
proxy_auth: Optional[BasicAuth]
def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth:
"""
Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``.
:raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no
entry is found for the ``host``.
"""
if netrc_obj is None:
raise LookupError("No .netrc file found")
auth_from_netrc = netrc_obj.authenticators(host)
if auth_from_netrc is None:
raise LookupError(f"No entry for {host!s} found in the `.netrc` file.")
login, account, password = auth_from_netrc
# TODO(PY311): username = login or account
# Up to python 3.10, account could be None if not specified,
# and login will be empty string if not specified. From 3.11,
# login and account will be empty string if not specified.
username = login if (login or account is None) else account
# TODO(PY311): Remove this, as password will be empty string
# if not specified
if password is None:
password = ""
return BasicAuth(username, password)
def proxies_from_env() -> Dict[str, ProxyInfo]:
proxy_urls = {
k: URL(v)
for k, v in getproxies().items()
if k in ("http", "https", "ws", "wss")
}
netrc_obj = netrc_from_env()
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
ret = {}
for proto, val in stripped.items():
proxy, auth = val
if proxy.scheme in ("https", "wss"):
client_logger.warning(
"%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
)
continue
if netrc_obj and auth is None:
if proxy.host is not None:
try:
auth = basicauth_from_netrc(netrc_obj, proxy.host)
except LookupError:
auth = None
ret[proto] = ProxyInfo(proxy, auth)
return ret
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
"""Get a permitted proxy for the given URL from the env."""
if url.host is not None and proxy_bypass(url.host):
raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
proxies_in_env = proxies_from_env()
try:
proxy_info = proxies_in_env[url.scheme]
except KeyError:
raise LookupError(f"No proxies found for `{url!s}` in the env")
else:
return proxy_info.proxy, proxy_info.proxy_auth
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MimeType:
type: str
subtype: str
suffix: str
parameters: "MultiDictProxy[str]"
@functools.lru_cache(maxsize=56)
def parse_mimetype(mimetype: str) -> MimeType:
"""Parses a MIME type into its components.
mimetype is a MIME type string.
Returns a MimeType object.
Example:
>>> parse_mimetype('text/html; charset=utf-8')
MimeType(type='text', subtype='html', suffix='',
parameters={'charset': 'utf-8'})
"""
if not mimetype:
return MimeType(
type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
)
parts = mimetype.split(";")
params: MultiDict[str] = MultiDict()
for item in parts[1:]:
if not item:
continue
key, _, value = item.partition("=")
params.add(key.lower().strip(), value.strip(' "'))
fulltype = parts[0].strip().lower()
if fulltype == "*":
fulltype = "*/*"
mtype, _, stype = fulltype.partition("/")
stype, _, suffix = stype.partition("+")
return MimeType(
type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
)
class EnsureOctetStream(EmailMessage):
def __init__(self) -> None:
super().__init__()
# https://www.rfc-editor.org/rfc/rfc9110#section-8.3-5
self.set_default_type("application/octet-stream")
def get_content_type(self) -> str:
"""Re-implementation from Message
Returns application/octet-stream in place of plain/text when
value is wrong.
The way this class is used guarantees that content-type will
be present so simplify the checks wrt to the base implementation.
"""
value = self.get("content-type", "").lower()
# Based on the implementation of _splitparam in the standard library
ctype, _, _ = value.partition(";")
ctype = ctype.strip()
if ctype.count("/") != 1:
return self.get_default_type()
return ctype
@functools.lru_cache(maxsize=56)
def parse_content_type(raw: str) -> Tuple[str, MappingProxyType[str, str]]:
"""Parse Content-Type header.
Returns a tuple of the parsed content type and a
MappingProxyType of parameters. The default returned value
is `application/octet-stream`
"""
msg = HeaderParser(EnsureOctetStream, policy=HTTP).parsestr(f"Content-Type: {raw}")
content_type = msg.get_content_type()
params = msg.get_params(())
content_dict = dict(params[1:]) # First element is content type again
return content_type, MappingProxyType(content_dict)
def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
name = getattr(obj, "name", None)
if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
return Path(name).name
return default
not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
def quoted_string(content: str) -> str:
"""Return 7-bit content as quoted-string.
Format content into a quoted-string as defined in RFC5322 for
Internet Message Format. Notice that this is not the 8-bit HTTP
format, but the 7-bit email format. Content must be in usascii or
a ValueError is raised.
"""
if not (QCONTENT > set(content)):
raise ValueError(f"bad content for quoted-string {content!r}")
return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
def content_disposition_header(
disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
) -> str:
"""Sets ``Content-Disposition`` header for MIME.
This is the MIME payload Content-Disposition header from RFC 2183
and RFC 7579 section 4.2, not the HTTP Content-Disposition from
RFC 6266.
disptype is a disposition type: inline, attachment, form-data.
Should be valid extension token (see RFC 2183)
quote_fields performs value quoting to 7-bit MIME headers
according to RFC 7578. Set to quote_fields to False if recipient
can take 8-bit file names and field values.
_charset specifies the charset to use when quote_fields is True.
params is a dict with disposition params.
"""
if not disptype or not (TOKEN > set(disptype)):
raise ValueError(f"bad content disposition type {disptype!r}")
value = disptype
if params:
lparams = []
for key, val in params.items():
if not key or not (TOKEN > set(key)):
raise ValueError(f"bad content disposition parameter {key!r}={val!r}")
if quote_fields:
if key.lower() == "filename":
qval = quote(val, "", encoding=_charset)
lparams.append((key, '"%s"' % qval))
else:
try:
qval = quoted_string(val)
except ValueError:
qval = "".join(
(_charset, "''", quote(val, "", encoding=_charset))
)
lparams.append((key + "*", qval))
else:
lparams.append((key, '"%s"' % qval))
else:
qval = val.replace("\\", "\\\\").replace('"', '\\"')
lparams.append((key, '"%s"' % qval))
sparams = "; ".join("=".join(pair) for pair in lparams)
value = "; ".join((value, sparams))
return value
def is_ip_address(host: Optional[str]) -> bool:
"""Check if host looks like an IP Address.
This check is only meant as a heuristic to ensure that
a host is not a domain name.
"""
if not host:
return False
# For a host to be an ipv4 address, it must be all numeric.
# The host must contain a colon to be an IPv6 address.
return ":" in host or host.replace(".", "").isdigit()
_cached_current_datetime: Optional[int] = None
_cached_formatted_datetime = ""
def rfc822_formatted_time() -> str:
global _cached_current_datetime
global _cached_formatted_datetime
now = int(time.time())
if now != _cached_current_datetime:
# Weekday and month names for HTTP date/time formatting;
# always English!
# Tuples are constants stored in codeobject!
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
_monthname = (
"", # Dummy so we can use 1-based month numbers
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
)
year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
_weekdayname[wd],
day,
_monthname[month],
year,
hh,
mm,
ss,
)
_cached_current_datetime = now
return _cached_formatted_datetime
def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
ref, name = info
ob = ref()
if ob is not None:
with suppress(Exception):
getattr(ob, name)()
def weakref_handle(
ob: object,
name: str,
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is not None and timeout > 0:
when = loop.time() + timeout
if timeout >= timeout_ceil_threshold:
when = ceil(when)
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
return None
def call_later(
cb: Callable[[], Any],
timeout: float,
loop: asyncio.AbstractEventLoop,
timeout_ceil_threshold: float = 5,
) -> Optional[asyncio.TimerHandle]:
if timeout is None or timeout <= 0:
return None
now = loop.time()
when = calculate_timeout_when(now, timeout, timeout_ceil_threshold)
return loop.call_at(when, cb)
def calculate_timeout_when(
loop_time: float,
timeout: float,
timeout_ceiling_threshold: float,
) -> float:
"""Calculate when to execute a timeout."""
when = loop_time + timeout
if timeout > timeout_ceiling_threshold:
return ceil(when)
return when
class TimeoutHandle:
"""Timeout handle"""
__slots__ = ("_timeout", "_loop", "_ceil_threshold", "_callbacks")
def __init__(
self,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float],
ceil_threshold: float = 5,
) -> None:
self._timeout = timeout
self._loop = loop
self._ceil_threshold = ceil_threshold
self._callbacks: List[
Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]
] = []
def register(
self, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> None:
self._callbacks.append((callback, args, kwargs))
def close(self) -> None:
self._callbacks.clear()
def start(self) -> Optional[asyncio.TimerHandle]:
timeout = self._timeout
if timeout is not None and timeout > 0:
when = self._loop.time() + timeout
if timeout >= self._ceil_threshold:
when = ceil(when)
return self._loop.call_at(when, self.__call__)
else:
return None
def timer(self) -> "BaseTimerContext":
if self._timeout is not None and self._timeout > 0:
timer = TimerContext(self._loop)
self.register(timer.timeout)
return timer
else:
return TimerNoop()
def __call__(self) -> None:
for cb, args, kwargs in self._callbacks:
with suppress(Exception):
cb(*args, **kwargs)
self._callbacks.clear()
class BaseTimerContext(ContextManager["BaseTimerContext"]):
__slots__ = ()
def assert_timeout(self) -> None:
"""Raise TimeoutError if timeout has been exceeded."""
class TimerNoop(BaseTimerContext):
__slots__ = ()
def __enter__(self) -> BaseTimerContext:
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
return
class TimerContext(BaseTimerContext):
"""Low resolution timeout context manager"""
__slots__ = ("_loop", "_tasks", "_cancelled", "_cancelling")
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._tasks: List[asyncio.Task[Any]] = []
self._cancelled = False
self._cancelling = 0
def assert_timeout(self) -> None:
"""Raise TimeoutError if timer has already been cancelled."""
if self._cancelled:
raise asyncio.TimeoutError from None
def __enter__(self) -> BaseTimerContext:
task = asyncio.current_task(loop=self._loop)
if task is None:
raise RuntimeError("Timeout context manager should be used inside a task")
if sys.version_info >= (3, 11):
# Remember if the task was already cancelling
# so when we __exit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = task.cancelling()
if self._cancelled:
raise asyncio.TimeoutError from None
self._tasks.append(task)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
enter_task: Optional[asyncio.Task[Any]] = None
if self._tasks:
enter_task = self._tasks.pop()
if exc_type is asyncio.CancelledError and self._cancelled:
assert enter_task is not None
# The timeout was hit, and the task was cancelled
# so we need to uncancel the last task that entered the context manager
# since the cancellation should not leak out of the context manager
if sys.version_info >= (3, 11):
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
if enter_task.uncancel() > self._cancelling:
return None
raise asyncio.TimeoutError from exc_val
return None
def timeout(self) -> None:
if not self._cancelled:
for task in set(self._tasks):
task.cancel()
self._cancelled = True
def ceil_timeout(
delay: Optional[float], ceil_threshold: float = 5
) -> async_timeout.Timeout:
if delay is None or delay <= 0:
return async_timeout.timeout(None)
loop = asyncio.get_running_loop()
now = loop.time()
when = now + delay
if delay > ceil_threshold:
when = ceil(when)
return async_timeout.timeout_at(when)
class HeadersMixin:
"""Mixin for handling headers."""
ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"])
_headers: MultiMapping[str]
_content_type: Optional[str] = None
_content_dict: Optional[Dict[str, str]] = None
_stored_content_type: Union[str, None, _SENTINEL] = sentinel
def _parse_content_type(self, raw: Optional[str]) -> None:
self._stored_content_type = raw
if raw is None:
# default value according to RFC 2616
self._content_type = "application/octet-stream"
self._content_dict = {}
else:
content_type, content_mapping_proxy = parse_content_type(raw)
self._content_type = content_type
# _content_dict needs to be mutable so we can update it
self._content_dict = content_mapping_proxy.copy()
@property
def content_type(self) -> str:
"""The value of content part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_type is not None
return self._content_type
@property
def charset(self) -> Optional[str]:
"""The value of charset part for Content-Type HTTP header."""
raw = self._headers.get(hdrs.CONTENT_TYPE)
if self._stored_content_type != raw:
self._parse_content_type(raw)
assert self._content_dict is not None
return self._content_dict.get("charset")
@property
def content_length(self) -> Optional[int]:
"""The value of Content-Length HTTP header."""
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
return None if content_length is None else int(content_length)
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
if not fut.done():
fut.set_result(result)
_EXC_SENTINEL = BaseException()
class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None: ... # pragma: no cover
def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return
exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause
fut.set_exception(exc)
@functools.total_ordering
class AppKey(Generic[_T]):
"""Keys for static typing support in Application."""
__slots__ = ("_name", "_t", "__orig_class__")
# This may be set by Python when instantiating with a generic type. We need to
# support this, in order to support types that are not concrete classes,
# like Iterable, which can't be passed as the second parameter to __init__.
__orig_class__: Type[object]
def __init__(self, name: str, t: Optional[Type[_T]] = None):
# Prefix with module name to help deduplicate key names.
frame = inspect.currentframe()
while frame:
if frame.f_code.co_name == "<module>":
module: str = frame.f_globals["__name__"]
break
frame = frame.f_back
self._name = module + "." + name
self._t = t
def __lt__(self, other: object) -> bool:
if isinstance(other, AppKey):
return self._name < other._name
return True # Order AppKey above other types.
def __repr__(self) -> str:
t = self._t
if t is None:
with suppress(AttributeError):
# Set to type arg.
t = get_args(self.__orig_class__)[0]
if t is None:
t_repr = "<<Unknown>>"
elif isinstance(t, type):
if t.__module__ == "builtins":
t_repr = t.__qualname__
else:
t_repr = f"{t.__module__}.{t.__qualname__}"
else:
t_repr = repr(t)
return f"<AppKey({self._name}, type={t_repr})>"
class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]):
__slots__ = ("_maps",)
def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None:
self._maps = tuple(maps)
def __init_subclass__(cls) -> None:
raise TypeError(
"Inheritance class {} from ChainMapProxy "
"is forbidden".format(cls.__name__)
)
@overload # type: ignore[override]
def __getitem__(self, key: AppKey[_T]) -> _T: ...
@overload
def __getitem__(self, key: str) -> Any: ...
def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any:
for mapping in self._maps:
try:
return mapping[key]
except KeyError:
pass
raise KeyError(key)
@overload # type: ignore[override]
def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: ...
@overload
def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: ...
@overload
def get(self, key: str, default: Any = ...) -> Any: ...
def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any:
try:
return self[key]
except KeyError:
return default
def __len__(self) -> int:
# reuses stored hash values if possible
return len(set().union(*self._maps))
def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]:
d: Dict[Union[str, AppKey[Any]], Any] = {}
for mapping in reversed(self._maps):
# reuses stored hash values if possible
d.update(mapping)
return iter(d)
def __contains__(self, key: object) -> bool:
return any(key in m for m in self._maps)
def __bool__(self) -> bool:
return any(self._maps)
def __repr__(self) -> str:
content = ", ".join(map(repr, self._maps))
return f"ChainMapProxy({content})"
# https://tools.ietf.org/html/rfc7232#section-2.3
_ETAGC = r"[!\x23-\x7E\x80-\xff]+"
_ETAGC_RE = re.compile(_ETAGC)
_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"'
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
ETAG_ANY = "*"
@attr.s(auto_attribs=True, frozen=True, slots=True)
class ETag:
value: str
is_weak: bool = False
def validate_etag_value(value: str) -> None:
if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
raise ValueError(
f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
)
def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
"""Process a date string, return a datetime object"""
if date_str is not None:
timetuple = parsedate(date_str)
if timetuple is not None:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None
@functools.lru_cache
def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
return (
code in EMPTY_BODY_STATUS_CODES
or method in EMPTY_BODY_METHODS
or (200 <= code < 300 and method in hdrs.METH_CONNECT_ALL)
)
def should_remove_content_length(method: str, code: int) -> bool:
"""Check if a Content-Length header should be removed.
This should always be a subset of must_be_empty_body
"""
# https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8
# https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4
return code in EMPTY_BODY_STATUS_CODES or (
200 <= code < 300 and method in hdrs.METH_CONNECT_ALL
)

View File

@@ -0,0 +1,72 @@
import sys
from http import HTTPStatus
from typing import Mapping, Tuple
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import (
HeadersParser as HeadersParser,
HttpParser as HttpParser,
HttpRequestParser as HttpRequestParser,
HttpResponseParser as HttpResponseParser,
RawRequestMessage as RawRequestMessage,
RawResponseMessage as RawResponseMessage,
)
from .http_websocket import (
WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE,
WS_KEY as WS_KEY,
WebSocketError as WebSocketError,
WebSocketReader as WebSocketReader,
WebSocketWriter as WebSocketWriter,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
ws_ext_gen as ws_ext_gen,
ws_ext_parse as ws_ext_parse,
)
from .http_writer import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
StreamWriter as StreamWriter,
)
__all__ = (
"HttpProcessingError",
"RESPONSES",
"SERVER_SOFTWARE",
# .http_writer
"StreamWriter",
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
# .http_parser
"HeadersParser",
"HttpParser",
"HttpRequestParser",
"HttpResponseParser",
"RawRequestMessage",
"RawResponseMessage",
# .http_websocket
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"ws_ext_gen",
"ws_ext_parse",
"WSMessage",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
)
SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format(
sys.version_info, __version__
)
RESPONSES: Mapping[int, Tuple[str, str]] = {
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
}

Some files were not shown because too many files have changed in this diff Show More