-from .adapters import Adapter
+from .adapters.redirects import RedirectAdapter
from .client import Client
from .config import PoolLimits, SSLConfig, TimeoutConfig
-from .connection import HTTPConnection
-from .connection_pool import ConnectionPool
+from .dispatch.connection import HTTPConnection
+from .dispatch.connection_pool import ConnectionPool
+from .dispatch.http2 import HTTP2Connection
+from .dispatch.http11 import HTTP11Connection
from .exceptions import (
ConnectTimeout,
PoolTimeout,
StreamConsumed,
Timeout,
)
-from .http2 import HTTP2Connection
-from .http11 import HTTP11Connection
+from .interfaces import Adapter
from .models import URL, Headers, Origin, Request, Response
+from .status_codes import codes
from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
from .sync import SyncClient, SyncConnectionPool
--- /dev/null
+"""
+Adapter classes layer additional behavior over the raw dispatching of the
+HTTP request/response.
+"""
import typing
-from .adapters import Adapter
-from .models import Request, Response
+from ..interfaces import Adapter
+from ..models import Request, Response
-class AuthAdapter(Adapter):
+class AuthenticationAdapter(Adapter):
def __init__(self, dispatch: Adapter):
self.dispatch = dispatch
import typing
-from .adapters import Adapter
-from .models import Request, Response
+from ..interfaces import Adapter
+from ..models import Request, Response
class CookieAdapter(Adapter):
import typing
-from .adapters import Adapter
-from .models import Request, Response
+from ..interfaces import Adapter
+from ..models import Request, Response
class EnvironmentAdapter(Adapter):
import typing
from urllib.parse import urljoin, urlparse
-from .adapters import Adapter
-from .exceptions import TooManyRedirects
-from .models import URL, Request, Response
-from .status_codes import codes
-from .utils import requote_uri
+from ..config import DEFAULT_MAX_REDIRECTS
+from ..exceptions import TooManyRedirects
+from ..interfaces import Adapter
+from ..models import URL, Request, Response
+from ..status_codes import codes
+from ..utils import requote_uri
class RedirectAdapter(Adapter):
- def __init__(self, dispatch: Adapter, max_redirects: int):
+ def __init__(self, dispatch: Adapter, max_redirects: int = DEFAULT_MAX_REDIRECTS):
self.dispatch = dispatch
self.max_redirects = max_redirects
def build_redirect_request(self, request: Request, response: Response) -> Request:
method = self.redirect_method(request, response)
url = self.redirect_url(request, response)
- raise NotImplementedError()
+ return Request(method=method, url=url)
def redirect_method(self, request: Request, response: Response) -> str:
"""
method = request.method
# https://tools.ietf.org/html/rfc7231#section-6.4.4
- if response.status_code == codes["see_other"] and method != "HEAD":
+ if response.status_code == codes.see_other and method != "HEAD":
method = "GET"
# Do what the browsers do, despite standards...
- # First, turn 302s into GETs.
- if response.status_code == codes["found"] and method != "HEAD":
+ # Turn 302s into GETs.
+ if response.status_code == codes.found and method != "HEAD":
method = "GET"
- # Second, if a POST is responded to with a 301, turn it into a GET.
- # This bizarre behaviour is explained in Issue 1704.
- if response.status_code == codes["moved"] and method == "POST":
+ # If a POST is responded to with a 301, turn it into a GET.
+ # This bizarre behaviour is explained in 'requests' issue 1704.
+ if response.status_code == codes.moved_permanently and method == "POST":
method = "GET"
return method
import typing
from types import TracebackType
-from .auth import AuthAdapter
+from .adapters.authentication import AuthenticationAdapter
+from .adapters.cookies import CookieAdapter
+from .adapters.environment import EnvironmentAdapter
+from .adapters.redirects import RedirectAdapter
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
SSLConfig,
TimeoutConfig,
)
-from .connection_pool import ConnectionPool
-from .cookies import CookieAdapter
-from .environment import EnvironmentAdapter
+from .dispatch.connection_pool import ConnectionPool
from .models import URL, Request, Response
-from .redirects import RedirectAdapter
class Client:
):
connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits)
cookie_adapter = CookieAdapter(dispatch=connection_pool)
- auth_adapter = AuthAdapter(dispatch=cookie_adapter)
+ auth_adapter = AuthenticationAdapter(dispatch=cookie_adapter)
redirect_adapter = RedirectAdapter(
dispatch=auth_adapter, max_redirects=max_redirects
)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
DEFAULT_CA_BUNDLE_PATH = certifi.where()
-DEFAULT_MAX_REDIRECTS = 30
+DEFAULT_MAX_REDIRECTS = 20
--- /dev/null
+"""
+Dispatch classes handle the raw network connections and the implementation
+details of making the HTTP request and receiving the response.
+"""
import h2.connection
import h11
-from .adapters import Adapter
-from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
-from .exceptions import ConnectTimeout
+from ..config import (
+ DEFAULT_SSL_CONFIG,
+ DEFAULT_TIMEOUT_CONFIG,
+ SSLConfig,
+ TimeoutConfig,
+)
+from ..exceptions import ConnectTimeout
+from ..interfaces import Adapter
+from ..models import Origin, Request, Response
+from ..streams import Protocol, connect
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
-from .models import Origin, Request, Response
-from .streams import Protocol, connect
# Callback signature: async def callback(conn: HTTPConnection) -> None
ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
import collections.abc
import typing
-from .adapters import Adapter
-from .config import (
+from ..config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
DEFAULT_SSL_CONFIG,
SSLConfig,
TimeoutConfig,
)
+from ..exceptions import PoolTimeout
+from ..interfaces import Adapter
+from ..models import Origin, Request, Response
+from ..streams import PoolSemaphore
from .connection import HTTPConnection
-from .exceptions import PoolTimeout
-from .models import Origin, Request, Response
-from .streams import PoolSemaphore
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
import h11
-from .adapters import Adapter
-from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
-from .exceptions import ConnectTimeout, ReadTimeout
-from .models import Request, Response
-from .streams import BaseReader, BaseWriter
+from ..config import (
+ DEFAULT_SSL_CONFIG,
+ DEFAULT_TIMEOUT_CONFIG,
+ SSLConfig,
+ TimeoutConfig,
+)
+from ..exceptions import ConnectTimeout, ReadTimeout
+from ..interfaces import Adapter
+from ..models import Request, Response
+from ..streams import BaseReader, BaseWriter
H11Event = typing.Union[
h11.Request,
import h2.connection
import h2.events
-from .adapters import Adapter
-from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
-from .exceptions import ConnectTimeout, ReadTimeout
-from .models import Request, Response
-from .streams import BaseReader, BaseWriter
+from ..config import (
+ DEFAULT_SSL_CONFIG,
+ DEFAULT_TIMEOUT_CONFIG,
+ SSLConfig,
+ TimeoutConfig,
+)
+from ..exceptions import ConnectTimeout, ReadTimeout
+from ..interfaces import Adapter
+from ..models import Request, Response
+from ..streams import BaseReader, BaseWriter
OptionalTimeout = typing.Optional[TimeoutConfig]
import typing
from types import TracebackType
+from .config import TimeoutConfig
from .models import URL, Request, Response
+OptionalTimeout = typing.Optional[TimeoutConfig]
+
class Adapter:
async def request(
traceback: TracebackType = None,
) -> None:
await self.close()
+
+
+class BaseReader:
+ async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+ raise NotImplementedError() # pragma: no cover
+
+
+class BaseWriter:
+ def write_no_block(self, data: bytes) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ async def close(self) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+
+class BasePoolSemaphore:
+ async def acquire(self) -> None:
+ raise NotImplementedError() # pragma: no cover
+
+ def release(self) -> None:
+ raise NotImplementedError() # pragma: no cover
def origin(self) -> "Origin":
return Origin(self)
+ def __str__(self) -> str:
+ return self.components.geturl()
+
class Origin:
def __init__(self, url: typing.Union[str, URL]) -> None:
-"""
-The ``codes`` object defines a mapping from common names for HTTP statuses
-to their numerical codes, accessible either as attributes or as dictionary
-items.
->>> requests.codes['temporary_redirect']
-307
->>> requests.codes.teapot
-418
-Some codes have multiple names, and both upper- and lower-case versions of
-the names are allowed. For example, ``codes.ok``, ``codes.OK``, and
-``codes.okay`` all correspond to the HTTP status code 200.
-"""
-
-import typing
-
-from .structures import LookupDict
-
-_codes = {
- # Informational.
- 100: ("continue",),
- 101: ("switching_protocols",),
- 102: ("processing",),
- 103: ("checkpoint",),
- 122: ("uri_too_long", "request_uri_too_long"),
- 200: ("ok", "okay", "all_ok", "all_okay", "all_good", "\\o/", "✓"),
- 201: ("created",),
- 202: ("accepted",),
- 203: ("non_authoritative_info", "non_authoritative_information"),
- 204: ("no_content",),
- 205: ("reset_content", "reset"),
- 206: ("partial_content", "partial"),
- 207: ("multi_status", "multiple_status", "multi_stati", "multiple_stati"),
- 208: ("already_reported",),
- 226: ("im_used",),
- # Redirection.
- 300: ("multiple_choices",),
- 301: ("moved_permanently", "moved", "\\o-"),
- 302: ("found",),
- 303: ("see_other", "other"),
- 304: ("not_modified",),
- 305: ("use_proxy",),
- 306: ("switch_proxy",),
- 307: ("temporary_redirect", "temporary_moved", "temporary"),
- 308: (
- "permanent_redirect",
- "resume_incomplete",
- "resume",
- ), # These 2 to be removed in 3.0
- # Client Error.
- 400: ("bad_request", "bad"),
- 401: ("unauthorized",),
- 402: ("payment_required", "payment"),
- 403: ("forbidden",),
- 404: ("not_found", "-o-"),
- 405: ("method_not_allowed", "not_allowed"),
- 406: ("not_acceptable",),
- 407: ("proxy_authentication_required", "proxy_auth", "proxy_authentication"),
- 408: ("request_timeout", "timeout"),
- 409: ("conflict",),
- 410: ("gone",),
- 411: ("length_required",),
- 412: ("precondition_failed", "precondition"),
- 413: ("request_entity_too_large",),
- 414: ("request_uri_too_large",),
- 415: ("unsupported_media_type", "unsupported_media", "media_type"),
- 416: (
- "requested_range_not_satisfiable",
- "requested_range",
- "range_not_satisfiable",
- ),
- 417: ("expectation_failed",),
- 418: ("im_a_teapot", "teapot", "i_am_a_teapot"),
- 421: ("misdirected_request",),
- 422: ("unprocessable_entity", "unprocessable"),
- 423: ("locked",),
- 424: ("failed_dependency", "dependency"),
- 425: ("unordered_collection", "unordered"),
- 426: ("upgrade_required", "upgrade"),
- 428: ("precondition_required", "precondition"),
- 429: ("too_many_requests", "too_many"),
- 431: ("header_fields_too_large", "fields_too_large"),
- 444: ("no_response", "none"),
- 449: ("retry_with", "retry"),
- 450: ("blocked_by_windows_parental_controls", "parental_controls"),
- 451: ("unavailable_for_legal_reasons", "legal_reasons"),
- 499: ("client_closed_request",),
- # Server Error.
- 500: ("internal_server_error", "server_error", "/o\\", "✗"),
- 501: ("not_implemented",),
- 502: ("bad_gateway",),
- 503: ("service_unavailable", "unavailable"),
- 504: ("gateway_timeout",),
- 505: ("http_version_not_supported", "http_version"),
- 506: ("variant_also_negotiates",),
- 507: ("insufficient_storage",),
- 509: ("bandwidth_limit_exceeded", "bandwidth"),
- 510: ("not_extended",),
- 511: ("network_authentication_required", "network_auth", "network_authentication"),
-} # type: typing.Dict[int, typing.Sequence[str]]
-
-codes = LookupDict(name="status_codes")
-
-
-def _init() -> None:
- for code, titles in _codes.items():
- for title in titles:
- setattr(codes, title, code)
- if not title.startswith(("\\", "/")):
- setattr(codes, title.upper(), code)
-
- def doc(code: int) -> str:
- names = ", ".join("``%s``" % n for n in _codes[code])
- return "* %d: %s" % (code, names)
-
- global __doc__
- __doc__ = (
- __doc__ + "\n" + "\n".join(doc(code) for code in sorted(_codes))
- if __doc__ is not None
- else None
- )
-
-
-_init()
+import enum
+
+codes = enum.IntEnum(
+ "StatusCode",
+ [
+ ("continue", 100),
+ ("switching_protocols", 101),
+ ("ok", 200),
+ ("created", 201),
+ ("accepted", 202),
+ ("non_authoritative_information", 203),
+ ("no_content", 204),
+ ("reset_content", 205),
+ ("partial_content", 206),
+ ("multi_status", 207),
+ ("multiple_choices", 300),
+ ("moved_permanently", 301),
+ ("found", 302),
+ ("see_other", 303),
+ ("not_modified", 304),
+ ("use_proxy", 305),
+ ("reserved", 306),
+ ("temporary_redirect", 307),
+ ("bad_request", 400),
+ ("unauthorized", 401),
+ ("payment_required", 402),
+ ("forbidden", 403),
+ ("not_found", 404),
+ ("method_not_allowed", 405),
+ ("not_acceptable", 406),
+ ("proxy_authentication_required", 407),
+ ("request_timeout", 408),
+ ("conflict", 409),
+ ("gone", 410),
+ ("length_required", 411),
+ ("precondition_failed", 412),
+ ("request_entity_too_large", 413),
+ ("request_uri_too_long", 414),
+ ("unsupported_media_type", 415),
+ ("requested_range_not_satisfiable", 416),
+ ("expectation_failed", 417),
+ ("unprocessable_entity", 422),
+ ("locked", 423),
+ ("failed_dependency", 424),
+ ("precondition_required", 428),
+ ("too_many_requests", 429),
+ ("request_header_fields_too_large", 431),
+ ("unavailable_for_legal_reasons", 451),
+ ("internal_server_error", 500),
+ ("not_implemented", 501),
+ ("bad_gateway", 502),
+ ("service_unavailable", 503),
+ ("gateway_timeout", 504),
+ ("http_version_not_supported", 505),
+ ("insufficient_storage", 507),
+ ("network_authentication_required", 511),
+ ],
+)
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter
OptionalTimeout = typing.Optional[TimeoutConfig]
HTTP_2 = 2
-class BaseReader:
- async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
- raise NotImplementedError() # pragma: no cover
-
-
-class BaseWriter:
- def write_no_block(self, data: bytes) -> None:
- raise NotImplementedError() # pragma: no cover
-
- async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
- raise NotImplementedError() # pragma: no cover
-
- async def close(self) -> None:
- raise NotImplementedError() # pragma: no cover
-
-
-class BasePoolSemaphore:
- async def acquire(self) -> None:
- raise NotImplementedError() # pragma: no cover
-
- def release(self) -> None:
- raise NotImplementedError() # pragma: no cover
-
-
class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
+++ /dev/null
-import typing
-
-
-class LookupDict(dict):
- """Dictionary lookup object."""
-
- def __init__(self, name: str = None) -> None:
- self.name = name
- super(LookupDict, self).__init__()
-
- def __repr__(self) -> str:
- return "<lookup '%s'>" % (self.name)
-
- def __getitem__(self, key: typing.Any) -> typing.Any:
- # We allow fall-through here, so values default to None
-
- return self.__dict__.get(key, None)
-
- def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
- return self.__dict__.get(key, default)
import typing
from types import TracebackType
-from .adapters import Adapter
from .config import SSLConfig, TimeoutConfig
-from .connection_pool import ConnectionPool
+from .dispatch.connection_pool import ConnectionPool
+from .interfaces import Adapter
from .models import URL, Headers, Response
def unquote_unreserved(uri: str) -> str:
- """Un-escape any percent-escape sequences in a URI that are unreserved
+ """
+ Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
- :rtype: str
"""
parts = uri.split("%")
for i in range(1, len(parts)):
--- /dev/null
+import pytest
+
+from httpcore import Adapter, RedirectAdapter, Request, Response, codes
+
+
+class MockDispatch(Adapter):
+ def prepare_request(self, request: Request) -> None:
+ pass
+
+ async def send(self, request: Request, **options) -> Response:
+ if request.url.path == "/redirect_303":
+ return Response(
+ codes.see_other, headers=[(b"location", b"https://example.org/")]
+ )
+ elif request.url.path == "/relative_redirect":
+ return Response(codes.see_other, headers=[(b"location", b"/")])
+ return Response(codes.ok, body=b"Hello, world!")
+
+
+@pytest.mark.asyncio
+async def test_redirect_303():
+ client = RedirectAdapter(MockDispatch())
+ response = await client.request("GET", "https://example.org/redirect_303")
+ assert response.status_code == codes.ok
+
+
+@pytest.mark.asyncio
+async def test_relative_redirect():
+ client = RedirectAdapter(MockDispatch())
+ response = await client.request("GET", "https://example.org/relative_redirect")
+ assert response.status_code == codes.ok