)
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
-from .models import URL, Origin, Request, Response
+from .models import URL, Headers, Origin, Request, Response
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
from .sync import SyncClient, SyncConnectionPool
method: str,
url: typing.Union[str, URL],
*,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
**options: typing.Any,
) -> Response:
return await self.dispatch.send(request, **options)
async def close(self) -> None:
- self.dispatch.close()
+ await self.dispatch.close()
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
allow_redirects: bool = True,
ssl: typing.Optional[SSLConfig] = None,
self,
url: typing.Union[str, URL],
*,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
url: typing.Union[str, URL],
*,
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
stream: bool = False,
ssl: typing.Optional[SSLConfig] = None,
timeout: typing.Optional[TimeoutConfig] = None,
return await self.dispatch.send(request, **options)
async def close(self) -> None:
- self.dispatch.close()
+ await self.dispatch.close()
SUPPORTED_DECODERS = {
- b"identity": IdentityDecoder,
- b"deflate": DeflateDecoder,
- b"gzip": GZipDecoder,
- b"br": BrotliDecoder,
+ "identity": IdentityDecoder,
+ "deflate": DeflateDecoder,
+ "gzip": GZipDecoder,
+ "br": BrotliDecoder,
}
if brotli is None:
- SUPPORTED_DECODERS.pop(b"br") # pragma: nocover
+ SUPPORTED_DECODERS.pop("br") # pragma: nocover
-ACCEPT_ENCODING = b", ".join(
- [key for key in SUPPORTED_DECODERS.keys() if key != b"identity"]
+ACCEPT_ENCODING = ", ".join(
+ [key for key in SUPPORTED_DECODERS.keys() if key != "identity"]
)
Attempted to read or stream response content, but the request has been
closed without loading the body.
"""
+
+
+class InvalidURL(Exception):
+ """
+ """
# Start sending the request.
method = request.method.encode()
target = request.url.full_path
- headers = request.headers
+ headers = request.headers.raw
event = h11.Request(method=method, target=target, headers=headers)
await self._send_event(event, timeout)
(b":authority", request.url.hostname.encode()),
(b":scheme", request.url.scheme.encode()),
(b":path", request.url.full_path.encode()),
- ] + request.headers
+ ] + request.headers.raw
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
def query(self) -> str:
return self.components.query
+ @property
+ def fragment(self) -> str:
+ return self.components.fragment
+
@property
def hostname(self) -> str:
return self.components.hostname
return hash((self.is_ssl, self.hostname, self.port))
+class Headers(typing.MutableMapping[str, str]):
+ """
+ A case-insensitive multidict.
+ """
+
+ def __init__(self, headers: typing.List[typing.Tuple[bytes, bytes]]) -> None:
+ self._list = [(k.lower(), v) for k, v in headers]
+
+ @property
+ def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+ return self._list
+
+ def keys(self) -> typing.List[str]: # type: ignore
+ return [key.decode("latin-1") for key, value in self._list]
+
+ def values(self) -> typing.List[str]: # type: ignore
+ return [value.decode("latin-1") for key, value in self._list]
+
+ def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore
+ return [
+ (key.decode("latin-1"), value.decode("latin-1"))
+ for key, value in self._list
+ ]
+
+ def get(self, key: str, default: typing.Any = None) -> typing.Any:
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def getlist(self, key: str) -> typing.List[str]:
+ get_header_key = key.lower().encode("latin-1")
+ return [
+ item_value.decode("latin-1")
+ for item_key, item_value in self._list
+ if item_key == get_header_key
+ ]
+
+ def __getitem__(self, key: str) -> str:
+ get_header_key = key.lower().encode("latin-1")
+ for header_key, header_value in self._list:
+ if header_key == get_header_key:
+ return header_value.decode("latin-1")
+ raise KeyError(key)
+
+ def __setitem__(self, key: str, value: str) -> None:
+ """
+ Set the header `key` to `value`, removing any duplicate entries.
+ Retains insertion order.
+ """
+ set_key = key.lower().encode("latin-1")
+ set_value = value.encode("latin-1")
+
+ found_indexes = []
+ for idx, (item_key, item_value) in enumerate(self._list):
+ if item_key == set_key:
+ found_indexes.append(idx)
+
+ for idx in reversed(found_indexes[1:]):
+ del self._list[idx]
+
+ if found_indexes:
+ idx = found_indexes[0]
+ self._list[idx] = (set_key, set_value)
+ else:
+ self._list.append((set_key, set_value))
+
+ def __delitem__(self, key: str) -> None:
+ """
+ Remove the header `key`.
+ """
+ del_key = key.lower().encode("latin-1")
+
+ pop_indexes = []
+ for idx, (item_key, item_value) in enumerate(self._list):
+ if item_key == del_key:
+ pop_indexes.append(idx)
+
+ for idx in reversed(pop_indexes):
+ del self._list[idx]
+
+ def __contains__(self, key: typing.Any) -> bool:
+ get_header_key = key.lower().encode("latin-1")
+ for header_key, header_value in self._list:
+ if header_key == get_header_key:
+ return True
+ return False
+
+ def __iter__(self) -> typing.Iterator[typing.Any]:
+ return iter(self.keys())
+
+ def __len__(self) -> int:
+ return len(self._list)
+
+ def __eq__(self, other: typing.Any) -> bool:
+ if not isinstance(other, Headers):
+ return False
+ return sorted(self._list) == sorted(other._list)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ as_dict = dict(self.items())
+ if len(as_dict) == len(self):
+ return f"{class_name}({as_dict!r})"
+ return f"{class_name}(raw={self.raw!r})"
+
+
class Request:
def __init__(
self,
method: str,
url: typing.Union[str, URL],
*,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
):
self.method = method.upper()
self.url = URL(url) if isinstance(url, str) else url
- self.headers = list(headers)
if isinstance(body, bytes):
self.is_streaming = False
self.body = body
else:
self.is_streaming = True
self.body_aiter = body
- self.headers = self._auto_headers() + self.headers
+ self.headers = self.build_headers(headers)
- def _auto_headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+ def build_headers(
+ self, init_headers: typing.List[typing.Tuple[bytes, bytes]]
+ ) -> Headers:
has_host = False
has_content_length = False
has_accept_encoding = False
- for header, value in self.headers:
+ for header, value in init_headers:
header = header.strip().lower()
if header == b"host":
has_host = True
elif header == b"accept-encoding":
has_accept_encoding = True
- headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
+ auto_headers = [] # type: typing.List[typing.Tuple[bytes, bytes]]
if not has_host:
- headers.append((b"host", self.url.netloc.encode("ascii")))
+ auto_headers.append((b"host", self.url.netloc.encode("ascii")))
if not has_content_length:
if self.is_streaming:
- headers.append((b"transfer-encoding", b"chunked"))
+ auto_headers.append((b"transfer-encoding", b"chunked"))
elif self.body:
content_length = str(len(self.body)).encode()
- headers.append((b"content-length", content_length))
+ auto_headers.append((b"content-length", content_length))
if not has_accept_encoding:
- headers.append((b"accept-encoding", ACCEPT_ENCODING))
+ auto_headers.append((b"accept-encoding", ACCEPT_ENCODING.encode()))
- return headers
+ return Headers(auto_headers + init_headers)
async def stream(self) -> typing.AsyncIterator[bytes]:
if self.is_streaming:
*,
reason: typing.Optional[str] = None,
protocol: typing.Optional[str] = None,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
on_close: typing.Callable = None,
):
else:
self.reason = reason
self.protocol = protocol
- self.headers = list(headers)
+ self.headers = Headers(headers)
self.on_close = on_close
self.is_closed = False
self.is_streamed = False
decoders = [] # type: typing.List[Decoder]
- for header, value in self.headers:
- if header.strip().lower() == b"content-encoding":
- for part in value.split(b","):
- part = part.strip().lower()
- decoder_cls = SUPPORTED_DECODERS[part]
- decoders.append(decoder_cls())
+ value = self.headers.get("content-encoding", "identity")
+ for part in value.split(","):
+ part = part.strip().lower()
+ decoder_cls = SUPPORTED_DECODERS[part]
+ decoders.append(decoder_cls())
if len(decoders) == 0:
self.decoder = IdentityDecoder() # type: Decoder
import typing
+from urllib.parse import urljoin, urlparse
from .adapters import Adapter
from .exceptions import TooManyRedirects
-from .models import Request, Response
+from .models import URL, Request, Response
+from .status_codes import codes
+from .utils import requote_uri
class RedirectAdapter(Adapter):
return response
async def close(self) -> None:
- self.dispatch.close()
+ await self.dispatch.close()
def build_redirect_request(self, request: Request, response: Response) -> Request:
+ method = self.redirect_method(request, response)
+ url = self.redirect_url(request, response)
raise NotImplementedError()
+
+ def redirect_method(self, request: Request, response: Response) -> str:
+ """
+ When being redirected we may want to change the method of the request
+ based on certain specs or browser behavior.
+ """
+ method = request.method
+
+ # https://tools.ietf.org/html/rfc7231#section-6.4.4
+ if response.status_code == codes["see_other"] and method != "HEAD":
+ method = "GET"
+
+ # Do what the browsers do, despite standards...
+ # First, turn 302s into GETs.
+ if response.status_code == codes["found"] and method != "HEAD":
+ method = "GET"
+
+ # Second, if a POST is responded to with a 301, turn it into a GET.
+ # This bizarre behaviour is explained in Issue 1704.
+ if response.status_code == codes["moved"] and method == "POST":
+ method = "GET"
+
+ return method
+
+ def redirect_url(self, request: Request, response: Response) -> URL:
+ location = response.headers["Location"]
+
+ # Handle redirection without scheme (see: RFC 1808 Section 4)
+ if location.startswith("//"):
+ location = f"{request.url.scheme}:{location}"
+
+ # Normalize url case and attach previous fragment if needed (RFC 7231 7.1.2)
+ parsed = urlparse(location)
+ if parsed.fragment == "" and request.url.fragment:
+ parsed = parsed._replace(fragment=request.url.fragment)
+ url = parsed.geturl()
+
+ # Facilitate relative 'location' headers, as allowed by RFC 7231.
+ # (e.g. '/path/to/resource' instead of 'http://domain.tld/path/to/resource')
+ # Compliant with RFC3986, we percent encode the url.
+ if not parsed.netloc:
+ url = urljoin(str(request.url), requote_uri(url))
+ else:
+ url = requote_uri(url)
+
+ return URL(url)
--- /dev/null
+"""
+The ``codes`` object defines a mapping from common names for HTTP statuses
+to their numerical codes, accessible either as attributes or as dictionary
+items.
+>>> requests.codes['temporary_redirect']
+307
+>>> requests.codes.teapot
+418
+Some codes have multiple names, and both upper- and lower-case versions of
+the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
+``codes.okay`` all correspond to the HTTP status code 200.
+"""
+
+import typing
+
+from .structures import LookupDict
+
+_codes = {
+ # Informational.
+ 100: ("continue",),
+ 101: ("switching_protocols",),
+ 102: ("processing",),
+ 103: ("checkpoint",),
+ 122: ("uri_too_long", "request_uri_too_long"),
+ 200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", "✓"),
+ 201: ("created",),
+ 202: ("accepted",),
+ 203: ("non_authoritative_info", "non_authoritative_information"),
+ 204: ("no_content",),
+ 205: ("reset_content", "reset"),
+ 206: ("partial_content", "partial"),
+ 207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
+ 208: ("already_reported",),
+ 226: ("im_used",),
+ # Redirection.
+ 300: ("multiple_choices",),
+ 301: ("moved_permanently", "moved", "\\o-"),
+ 302: ("found",),
+ 303: ("see_other", "other"),
+ 304: ("not_modified",),
+ 305: ("use_proxy",),
+ 306: ("switch_proxy",),
+ 307: ("temporary_redirect", "temporary_moved", "temporary"),
+ 308: (
+ "permanent_redirect",
+ "resume_incomplete",
+ "resume",
+ ), # These 2 to be removed in 3.0
+ # Client Error.
+ 400: ("bad_request", "bad"),
+ 401: ("unauthorized",),
+ 402: ("payment_required", "payment"),
+ 403: ("forbidden",),
+ 404: ("not_found", "-o-"),
+ 405: ("method_not_allowed", "not_allowed"),
+ 406: ("not_acceptable",),
+ 407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
+ 408: ("request_timeout", "timeout"),
+ 409: ("conflict",),
+ 410: ("gone",),
+ 411: ("length_required",),
+ 412: ("precondition_failed", "precondition"),
+ 413: ("request_entity_too_large",),
+ 414: ("request_uri_too_large",),
+ 415: ("unsupported_media_type", "unsupported_media", "media_type"),
+ 416: (
+ "requested_range_not_satisfiable",
+ "requested_range",
+ "range_not_satisfiable",
+ ),
+ 417: ("expectation_failed",),
+ 418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
+ 421: ("misdirected_request",),
+ 422: ("unprocessable_entity", "unprocessable"),
+ 423: ("locked",),
+ 424: ("failed_dependency", "dependency"),
+ 425: ("unordered_collection", "unordered"),
+ 426: ("upgrade_required", "upgrade"),
+ 428: ("precondition_required", "precondition"),
+ 429: ("too_many_requests", "too_many"),
+ 431: ("header_fields_too_large", "fields_too_large"),
+ 444: ("no_response", "none"),
+ 449: ("retry_with", "retry"),
+ 450: ("blocked_by_windows_parental_controls", "parental_controls"),
+ 451: ("unavailable_for_legal_reasons", "legal_reasons"),
+ 499: ("client_closed_request",),
+ # Server Error.
+ 500: ("internal_server_error", "server_error", "/o\\", "✗"),
+ 501: ("not_implemented",),
+ 502: ("bad_gateway",),
+ 503: ("service_unavailable", "unavailable"),
+ 504: ("gateway_timeout",),
+ 505: ("http_version_not_supported", "http_version"),
+ 506: ("variant_also_negotiates",),
+ 507: ("insufficient_storage",),
+ 509: ("bandwidth_limit_exceeded", "bandwidth"),
+ 510: ("not_extended",),
+ 511: ("network_authentication_required", "network_auth", "network_authentication"),
+} # type: typing.Dict[int, typing.Sequence[str]]
+
+codes = LookupDict(name="status_codes")
+
+
+def _init() -> None:
+ for code, titles in _codes.items():
+ for title in titles:
+ setattr(codes, title, code)
+ if not title.startswith(("\\", "/")):
+ setattr(codes, title.upper(), code)
+
+ def doc(code: int) -> str:
+ names = ", ".join("``%s``" % n for n in _codes[code])
+ return "* %d: %s" % (code, names)
+
+ global __doc__
+ __doc__ = (
+ __doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes))
+ if __doc__ is not None
+ else None
+ )
+
+
+_init()
--- /dev/null
+import typing
+
+
+class LookupDict(dict):
+ """Dictionary lookup object."""
+
+ def __init__(self, name: str = None) -> None:
+ self.name = name
+ super(LookupDict, self).__init__()
+
+ def __repr__(self) -> str:
+ return "<lookup '%s'>" % (self.name)
+
+ def __getitem__(self, key: typing.Any) -> typing.Any:
+ # We allow fall-through here, so values default to None
+
+ return self.__dict__.get(key, None)
+
+ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
+ return self.__dict__.get(key, default)
from .adapters import Adapter
from .config import SSLConfig, TimeoutConfig
from .connection_pool import ConnectionPool
-from .models import URL, Response
+from .models import URL, Headers, Response
class SyncResponse:
return self._response.reason
@property
- def headers(self) -> typing.List[typing.Tuple[bytes, bytes]]:
+ def headers(self) -> Headers:
return self._response.headers
@property
method: str,
url: typing.Union[str, URL],
*,
- headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+ headers: typing.List[typing.Tuple[bytes, bytes]] = [],
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
**options: typing.Any
) -> SyncResponse:
--- /dev/null
+from urllib.parse import quote
+
+from .exceptions import InvalidURL
+
+# The unreserved URI characters (RFC 3986)
+UNRESERVED_SET = frozenset(
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
+)
+
+
+def unquote_unreserved(uri: str) -> str:
+ """Un-escape any percent-escape sequences in a URI that are unreserved
+ characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
+ :rtype: str
+ """
+ parts = uri.split("%")
+ for i in range(1, len(parts)):
+ h = parts[i][0:2]
+ if len(h) == 2 and h.isalnum():
+ try:
+ c = chr(int(h, 16))
+ except ValueError:
+ raise InvalidURL("Invalid percent-escape sequence: '%s'" % h)
+
+ if c in UNRESERVED_SET:
+ parts[i] = c + parts[i][2:]
+ else:
+ parts[i] = "%" + parts[i]
+ else:
+ parts[i] = "%" + parts[i]
+ return "".join(parts)
+
+
+def requote_uri(uri: str) -> str:
+ """
+ Re-quote the given URI.
+
+ This function passes the given URI through an unquote/quote cycle to
+ ensure that it is fully and consistently quoted.
+ """
+ safe_with_percent = "!#$%&'()*+,/:;=?@[]~"
+ safe_without_percent = "!#$&'()*+,/:;=?@[]~"
+ try:
+ # Unquote only the unreserved characters
+ # Then quote only illegal characters (do not quote reserved,
+ # unreserved, or '%')
+ return quote(unquote_unreserved(uri), safe=safe_with_percent)
+ except InvalidURL:
+ # We couldn't unquote the given URI, so let's try quoting it, but
+ # there may be unquoted '%'s in the URI. We need to make sure they're
+ # properly quoted so they do not cause issues elsewhere.
+ return quote(uri, safe=safe_without_percent)
def test_host_header():
request = httpcore.Request("GET", "http://example.org")
- assert request.headers == [
- (b"host", b"example.org"),
- (b"accept-encoding", b"deflate, gzip, br"),
- ]
+ assert request.headers == httpcore.Headers(
+ [(b"host", b"example.org"), (b"accept-encoding", b"deflate, gzip, br")]
+ )
def test_content_length_header():
request = httpcore.Request("POST", "http://example.org", body=b"test 123")
- assert request.headers == [
- (b"host", b"example.org"),
- (b"content-length", b"8"),
- (b"accept-encoding", b"deflate, gzip, br"),
- ]
+ assert request.headers == httpcore.Headers(
+ [
+ (b"host", b"example.org"),
+ (b"content-length", b"8"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ ]
+ )
def test_transfer_encoding_header():
body = streaming_body(b"test 123")
request = httpcore.Request("POST", "http://example.org", body=body)
- assert request.headers == [
- (b"host", b"example.org"),
- (b"transfer-encoding", b"chunked"),
- (b"accept-encoding", b"deflate, gzip, br"),
- ]
+ assert request.headers == httpcore.Headers(
+ [
+ (b"host", b"example.org"),
+ (b"transfer-encoding", b"chunked"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ ]
+ )
def test_override_host_header():
headers = [(b"host", b"1.2.3.4:80")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
- assert request.headers == [
- (b"accept-encoding", b"deflate, gzip, br"),
- (b"host", b"1.2.3.4:80"),
- ]
+ assert request.headers == httpcore.Headers(
+ [(b"accept-encoding", b"deflate, gzip, br"), (b"host", b"1.2.3.4:80")]
+ )
def test_override_accept_encoding_header():
headers = [(b"accept-encoding", b"identity")]
request = httpcore.Request("GET", "http://example.org", headers=headers)
- assert request.headers == [
- (b"host", b"example.org"),
- (b"accept-encoding", b"identity"),
- ]
+ assert request.headers == httpcore.Headers(
+ [(b"host", b"example.org"), (b"accept-encoding", b"identity")]
+ )
def test_override_content_length_header():
headers = [(b"content-length", b"8")]
request = httpcore.Request("POST", "http://example.org", body=body, headers=headers)
- assert request.headers == [
- (b"host", b"example.org"),
- (b"accept-encoding", b"deflate, gzip, br"),
- (b"content-length", b"8"),
- ]
+ assert request.headers == httpcore.Headers(
+ [
+ (b"host", b"example.org"),
+ (b"accept-encoding", b"deflate, gzip, br"),
+ (b"content-length", b"8"),
+ ]
+ )
def test_url():