From: Tom Christie Date: Fri, 26 Apr 2019 14:34:04 +0000 (+0100) Subject: Adapters X-Git-Tag: 0.3.0~66^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fab6fcd397e43e97286ba284816762cf6e7e55c9;p=thirdparty%2Fhttpx.git Adapters --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 00cb5fb8..30ed38e6 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,3 +1,5 @@ +from .adapters import Adapter +from .client import Client from .config import PoolLimits, SSLConfig, TimeoutConfig from .connection import HTTPConnection from .connection_pool import ConnectionPool diff --git a/httpcore/adapters.py b/httpcore/adapters.py new file mode 100644 index 00000000..0c14e89c --- /dev/null +++ b/httpcore/adapters.py @@ -0,0 +1,40 @@ +import typing +from types import TracebackType + +from .models import URL, Request, Response + + +class Adapter: + async def request( + self, + method: str, + url: typing.Union[str, URL], + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + **options: typing.Any, + ) -> Response: + request = Request(method, url, headers=headers, body=body) + self.prepare_request(request) + response = await self.send(request, **options) + return response + + def prepare_request(self, request: Request) -> None: + raise NotImplementedError() # pragma: nocover + + async def send(self, request: Request, **options: typing.Any) -> Response: + raise NotImplementedError() # pragma: nocover + + async def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def __aenter__(self) -> "Adapter": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/httpcore/auth.py b/httpcore/auth.py new file mode 100644 index 00000000..949d8a9f --- /dev/null +++ b/httpcore/auth.py @@ -0,0 +1,18 @@ +import typing + +from .adapters import Adapter +from .models import Request, Response + + +class AuthAdapter(Adapter): + def __init__(self, dispatch: Adapter): + self.dispatch = dispatch + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + self.dispatch.close() diff --git a/httpcore/client.py b/httpcore/client.py new file mode 100644 index 00000000..022ae7ff --- /dev/null +++ b/httpcore/client.py @@ -0,0 +1,124 @@ +import typing +from types import TracebackType + +from .auth import AuthAdapter +from .config import ( + DEFAULT_MAX_REDIRECTS, + DEFAULT_POOL_LIMITS, + DEFAULT_SSL_CONFIG, + DEFAULT_TIMEOUT_CONFIG, + PoolLimits, + SSLConfig, + TimeoutConfig, +) +from .connection_pool import ConnectionPool +from .cookies import CookieAdapter +from .environment import EnvironmentAdapter +from .models import URL, Request, Response +from .redirects import RedirectAdapter + + +class Client: + def __init__( + self, + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + limits: PoolLimits = DEFAULT_POOL_LIMITS, + max_redirects: int = DEFAULT_MAX_REDIRECTS, + ): + connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits) + cookie_adapter = CookieAdapter(dispatch=connection_pool) + auth_adapter = AuthAdapter(dispatch=cookie_adapter) + redirect_adapter = RedirectAdapter( + dispatch=auth_adapter, max_redirects=max_redirects + ) + self.adapter = EnvironmentAdapter(dispatch=redirect_adapter) + + async def request( + self, + method: str, + url: typing.Union[str, URL], + *, + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + stream: bool = False, + allow_redirects: bool = True, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + request = Request(method, url, headers=headers, body=body) + self.prepare_request(request) + response = await self.send( + request, + stream=stream, + allow_redirects=allow_redirects, + ssl=ssl, + timeout=timeout, + ) + return response + + async def get( + self, + url: typing.Union[str, URL], + *, + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + stream: bool = False, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + return await self.request( + "GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout + ) + + async def post( + self, + url: typing.Union[str, URL], + *, + body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", + headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), + stream: bool = False, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + return await self.request( + "POST", + url, + body=body, + headers=headers, + stream=stream, + ssl=ssl, + timeout=timeout, + ) + + def prepare_request(self, request: Request) -> None: + self.adapter.prepare_request(request) + + async def send( + self, + request: Request, + *, + stream: bool = False, + allow_redirects: bool = True, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + options = {"stream": stream} # type: typing.Dict[str, typing.Any] + if ssl is not None: + options["ssl"] = ssl + if timeout is not None: + options["timeout"] = timeout + return await self.adapter.send(request, **options) + + async def close(self) -> None: + await self.adapter.close() + + async def __aenter__(self) -> "Client": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.close() diff --git a/httpcore/config.py b/httpcore/config.py index 7166c05f..5ce24707 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -112,24 +112,20 @@ class TimeoutConfig: connect_timeout: float = None, read_timeout: float = None, write_timeout: float = None, - pool_timeout: float = None, ): if timeout is not None: # Specified as a single timeout value assert connect_timeout is None assert read_timeout is None assert write_timeout is None - assert pool_timeout is None connect_timeout = timeout read_timeout = timeout write_timeout = timeout - pool_timeout = timeout self.timeout = timeout self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.write_timeout = write_timeout - self.pool_timeout = pool_timeout def __eq__(self, other: typing.Any) -> bool: return ( @@ -137,14 +133,13 @@ class TimeoutConfig: and self.connect_timeout == other.connect_timeout and self.read_timeout == other.read_timeout and self.write_timeout == other.write_timeout - and self.pool_timeout == other.pool_timeout ) def __repr__(self) -> str: class_name = self.__class__.__name__ if self.timeout is not None: return f"{class_name}(timeout={self.timeout})" - return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})" + return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})" class PoolLimits: @@ -155,27 +150,29 @@ class PoolLimits: def __init__( self, *, - soft_limit: typing.Optional[int] = None, - hard_limit: typing.Optional[int] = None, + soft_limit: int = None, + hard_limit: int = None, + pool_timeout: float = None, ): self.soft_limit = soft_limit self.hard_limit = hard_limit + self.pool_timeout = pool_timeout def __eq__(self, other: typing.Any) -> bool: return ( isinstance(other, self.__class__) and self.soft_limit == other.soft_limit and self.hard_limit == other.hard_limit + and self.pool_timeout == other.pool_timeout ) def __repr__(self) -> str: class_name = self.__class__.__name__ - return ( - f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})" - ) + return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})" DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) -DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100) +DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0) DEFAULT_CA_BUNDLE_PATH = certifi.where() +DEFAULT_MAX_REDIRECTS = 30 diff --git a/httpcore/connection.py b/httpcore/connection.py index a9c9890e..1662018a 100644 --- a/httpcore/connection.py +++ b/httpcore/connection.py @@ -4,18 +4,19 @@ import typing import h2.connection import h11 +from .adapters import Adapter from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig from .exceptions import ConnectTimeout from .http2 import HTTP2Connection from .http11 import HTTP11Connection -from .models import Client, Origin, Request, Response +from .models import Origin, Request, Response from .streams import Protocol, connect # Callback signature: async def callback(conn: HTTPConnection) -> None ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]] -class HTTPConnection(Client): +class HTTPConnection(Adapter): def __init__( self, origin: typing.Union[str, Origin], @@ -30,33 +31,26 @@ class HTTPConnection(Client): self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: + def prepare_request(self, request: Request) -> None: + pass + + async def send(self, request: Request, **options: typing.Any) -> Response: if self.h11_connection is None and self.h2_connection is None: - await self.connect(ssl, timeout) + await self.connect(**options) if self.h2_connection is not None: - response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout) + response = await self.h2_connection.send(request, **options) else: assert self.h11_connection is not None - response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout) + response = await self.h11_connection.send(request, **options) return response - async def connect( - self, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> None: - if ssl is None: - ssl = self.ssl - if timeout is None: - timeout = self.timeout + async def connect(self, **options: typing.Any) -> None: + ssl = options.get("ssl", self.ssl) + timeout = options.get("timeout", self.timeout) + assert isinstance(ssl, SSLConfig) + assert isinstance(timeout, TimeoutConfig) hostname = self.origin.hostname port = self.origin.port @@ -69,20 +63,10 @@ class HTTPConnection(Client): reader, writer, protocol = await connect(hostname, port, ssl_context, timeout) if protocol == Protocol.HTTP_2: - self.h2_connection = HTTP2Connection( - reader, - writer, - origin=self.origin, - timeout=self.timeout, - on_release=on_release, - ) + self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release) else: self.h11_connection = HTTP11Connection( - reader, - writer, - origin=self.origin, - timeout=self.timeout, - on_release=on_release, + reader, writer, on_release=on_release ) async def close(self) -> None: diff --git a/httpcore/connection_pool.py b/httpcore/connection_pool.py index 6eaef9be..ab394a19 100644 --- a/httpcore/connection_pool.py +++ b/httpcore/connection_pool.py @@ -1,6 +1,7 @@ import collections.abc import typing +from .adapters import Adapter from .config import ( DEFAULT_CA_BUNDLE_PATH, DEFAULT_POOL_LIMITS, @@ -12,7 +13,7 @@ from .config import ( ) from .connection import HTTPConnection from .exceptions import PoolTimeout -from .models import Client, Origin, Request, Response +from .models import Origin, Request, Response from .streams import PoolSemaphore CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] @@ -81,7 +82,7 @@ class ConnectionStore(collections.abc.Sequence): return len(self.all) -class ConnectionPool(Client): +class ConnectionPool(Adapter): def __init__( self, *, @@ -94,7 +95,7 @@ class ConnectionPool(Client): self.limits = limits self.is_closed = False - self.max_connections = PoolSemaphore(limits, timeout) + self.max_connections = PoolSemaphore(limits) self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() @@ -102,31 +103,26 @@ class ConnectionPool(Client): def num_connections(self) -> int: return len(self.keepalive_connections) + len(self.active_connections) - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - connection = await self.acquire_connection(request.url.origin, timeout=timeout) + def prepare_request(self, request: Request) -> None: + pass + + async def send(self, request: Request, **options: typing.Any) -> Response: + connection = await self.acquire_connection(request.url.origin) try: - response = await connection.send(request, ssl=ssl, timeout=timeout) + response = await connection.send(request, **options) except BaseException as exc: self.active_connections.remove(connection) self.max_connections.release() raise exc return response - async def acquire_connection( - self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None - ) -> HTTPConnection: + async def acquire_connection(self, origin: Origin) -> HTTPConnection: connection = self.active_connections.pop_by_origin(origin, http2_only=True) if connection is None: connection = self.keepalive_connections.pop_by_origin(origin) if connection is None: - await self.max_connections.acquire(timeout) + await self.max_connections.acquire() connection = HTTPConnection( origin, ssl=self.ssl, diff --git a/httpcore/cookies.py b/httpcore/cookies.py new file mode 100644 index 00000000..f6fd2b03 --- /dev/null +++ b/httpcore/cookies.py @@ -0,0 +1,18 @@ +import typing + +from .adapters import Adapter +from .models import Request, Response + + +class CookieAdapter(Adapter): + def __init__(self, dispatch: Adapter): + self.dispatch = dispatch + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + self.dispatch.close() diff --git a/httpcore/environment.py b/httpcore/environment.py new file mode 100644 index 00000000..5065eed8 --- /dev/null +++ b/httpcore/environment.py @@ -0,0 +1,27 @@ +import typing + +from .adapters import Adapter +from .models import Request, Response + + +class EnvironmentAdapter(Adapter): + def __init__(self, dispatch: Adapter, trust_env: bool = True): + self.dispatch = dispatch + self.trust_env = trust_env + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + if self.trust_env: + self.merge_environment_options(options) + return await self.dispatch.send(request, **options) + + async def close(self) -> None: + await self.dispatch.close() + + def merge_environment_options(self, options: dict) -> None: + """ + Add environment options. + """ + #  TODO diff --git a/httpcore/exceptions.py b/httpcore/exceptions.py index 285b6403..94154b07 100644 --- a/httpcore/exceptions.py +++ b/httpcore/exceptions.py @@ -28,6 +28,12 @@ class PoolTimeout(Timeout): """ +class TooManyRedirects(Exception): + """ + Too many redirects. + """ + + class ProtocolError(Exception): """ Malformed HTTP. diff --git a/httpcore/http11.py b/httpcore/http11.py index 3fda6d1f..0075b524 100644 --- a/httpcore/http11.py +++ b/httpcore/http11.py @@ -2,9 +2,10 @@ import typing import h11 +from .adapters import Adapter from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig from .exceptions import ConnectTimeout, ReadTimeout -from .models import Client, Origin, Request, Response +from .models import Request, Response from .streams import BaseReader, BaseWriter H11Event = typing.Union[ @@ -25,35 +26,28 @@ OptionalTimeout = typing.Optional[TimeoutConfig] OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]] -class HTTP11Connection(Client): +class HTTP11Connection(Adapter): READ_NUM_BYTES = 4096 def __init__( self, reader: BaseReader, writer: BaseWriter, - origin: Origin, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, on_release: typing.Optional[OnReleaseCallback] = None, ): self.reader = reader self.writer = writer - self.origin = origin - self.timeout = timeout self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) - @property - def is_closed(self) -> bool: - return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) + def prepare_request(self, request: Request) -> None: + pass + + async def send(self, request: Request, **options: typing.Any) -> Response: + timeout = options.get("timeout") + stream = options.get("stream", False) + assert timeout is None or isinstance(timeout, TimeoutConfig) - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None - ) -> Response: #  Start sending the request. method = request.method.encode() target = request.url.full_path @@ -81,7 +75,7 @@ class HTTP11Connection(Client): headers = event.headers body = self._body_iter(timeout) - return Response( + response = Response( status_code=status_code, reason=reason, protocol="HTTP/1.1", @@ -90,6 +84,26 @@ class HTTP11Connection(Client): on_close=self.response_closed, ) + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + async def close(self) -> None: + event = h11.ConnectionClosed() + try: + # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. + self.h11_state.send(event) + except h11.ProtocolError: + # If we're in some other state then it's a premature close, + # and we'll end up in h11.ERROR. + pass + + await self.writer.close() + async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]: event = await self._receive_event(timeout) while isinstance(event, h11.Data): @@ -123,14 +137,6 @@ class HTTP11Connection(Client): if self.on_release is not None: await self.on_release() - async def close(self) -> None: - event = h11.ConnectionClosed() - try: - # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. - self.h11_state.send(event) - except h11.ProtocolError: - # If we're in some other state then it's a premature close, - # and we'll end up in h11.ERROR. - pass - - await self.writer.close() + @property + def is_closed(self) -> bool: + return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) diff --git a/httpcore/http2.py b/httpcore/http2.py index c60490ce..e89029cf 100644 --- a/httpcore/http2.py +++ b/httpcore/http2.py @@ -4,52 +4,39 @@ import typing import h2.connection import h2.events +from .adapters import Adapter from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig from .exceptions import ConnectTimeout, ReadTimeout -from .models import Client, Origin, Request, Response +from .models import Request, Response from .streams import BaseReader, BaseWriter OptionalTimeout = typing.Optional[TimeoutConfig] -class HTTP2Connection(Client): +class HTTP2Connection(Adapter): READ_NUM_BYTES = 4096 def __init__( - self, - reader: BaseReader, - writer: BaseWriter, - origin: Origin, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, - on_release: typing.Callable = None, + self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None ): self.reader = reader self.writer = writer - self.origin = origin - self.timeout = timeout self.on_release = on_release self.h2_state = h2.connection.H2Connection() self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] self.initialized = False - @property - def is_closed(self) -> bool: - return False + def prepare_request(self, request: Request) -> None: + pass - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None - ) -> Response: - if timeout is None: - timeout = self.timeout + async def send(self, request: Request, **options: typing.Any) -> Response: + timeout = options.get("timeout") + stream = options.get("stream", False) + assert timeout is None or isinstance(timeout, TimeoutConfig) + #  Start sending the request. if not self.initialized: self.initiate_connection() - - #  Start sending the request. stream_id = await self.send_headers(request, timeout) self.events[stream_id] = [] @@ -77,7 +64,7 @@ class HTTP2Connection(Client): body = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) - return Response( + response = Response( status_code=status_code, protocol="HTTP/2", headers=headers, @@ -85,6 +72,17 @@ class HTTP2Connection(Client): on_close=on_close, ) + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + async def close(self) -> None: + await self.writer.close() + def initiate_connection(self) -> None: self.h2_state.initiate_connection() data_to_send = self.h2_state.data_to_send() @@ -147,5 +145,6 @@ class HTTP2Connection(Client): if not self.events and self.on_release is not None: await self.on_release() - async def close(self) -> None: - await self.writer.close() + @property + def is_closed(self) -> bool: + return False diff --git a/httpcore/models.py b/httpcore/models.py index e4a809af..7229a76c 100644 --- a/httpcore/models.py +++ b/httpcore/models.py @@ -1,6 +1,5 @@ import http import typing -from types import TracebackType from urllib.parse import urlsplit from .config import SSLConfig, TimeoutConfig @@ -237,47 +236,6 @@ class Response: if self.on_close is not None: await self.on_close() - -class Client: - async def request( - self, - method: str, - url: typing.Union[str, URL], - *, - headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), - body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - stream: bool = False, - ) -> Response: - request = Request(method, url, headers=headers, body=body) - response = await self.send(request, ssl=ssl, timeout=timeout) - if not stream: - try: - await response.read() - finally: - await response.close() - return response - - async def send( - self, - request: Request, - *, - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - ) -> Response: - raise NotImplementedError() # pragma: nocover - - async def close(self) -> None: - raise NotImplementedError() # pragma: nocover - - async def __aenter__(self) -> "Client": - return self - - async def __aexit__( - self, - exc_type: typing.Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.close() + @property + def is_redirect(self) -> bool: + return self.status_code in (301, 302, 303, 307, 308) diff --git a/httpcore/redirects.py b/httpcore/redirects.py new file mode 100644 index 00000000..0657ebca --- /dev/null +++ b/httpcore/redirects.py @@ -0,0 +1,35 @@ +import typing + +from .adapters import Adapter +from .exceptions import TooManyRedirects +from .models import Request, Response + + +class RedirectAdapter(Adapter): + def __init__(self, dispatch: Adapter, max_redirects: int): + self.dispatch = dispatch + self.max_redirects = max_redirects + + def prepare_request(self, request: Request) -> None: + self.dispatch.prepare_request(request) + + async def send(self, request: Request, **options: typing.Any) -> Response: + allow_redirects = options.pop("allow_redirects", True) + history = [] + + while True: + response = await self.dispatch.send(request, **options) + if not allow_redirects or not response.is_redirect: + break + history.append(response) + if len(history) > self.max_redirects: + raise TooManyRedirects() + request = self.build_redirect_request(request, response) + + return response + + async def close(self) -> None: + self.dispatch.close() + + def build_redirect_request(self, request: Request, response: Response) -> Request: + raise NotImplementedError() diff --git a/httpcore/streams.py b/httpcore/streams.py index e46ffee4..03bba17c 100644 --- a/httpcore/streams.py +++ b/httpcore/streams.py @@ -41,10 +41,7 @@ class BaseWriter: class BasePoolSemaphore: - def __init__(self, limits: PoolLimits, timeout: TimeoutConfig): - raise NotImplementedError() # pragma: no cover - - async def acquire(self, timeout: OptionalTimeout = None) -> None: + async def acquire(self) -> None: raise NotImplementedError() # pragma: no cover def release(self) -> None: @@ -100,9 +97,8 @@ class Writer(BaseWriter): class PoolSemaphore(BasePoolSemaphore): - def __init__(self, limits: PoolLimits, timeout: TimeoutConfig): + def __init__(self, limits: PoolLimits): self.limits = limits - self.timeout = timeout @property def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: @@ -114,15 +110,13 @@ class PoolSemaphore(BasePoolSemaphore): self._semaphore = asyncio.BoundedSemaphore(value=max_connections) return self._semaphore - async def acquire(self, timeout: OptionalTimeout = None) -> None: + async def acquire(self) -> None: if self.semaphore is None: return - if timeout is None: - timeout = self.timeout - + timeout = self.limits.pool_timeout try: - await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout) + await asyncio.wait_for(self.semaphore.acquire(), timeout) except asyncio.TimeoutError: raise PoolTimeout() diff --git a/httpcore/sync.py b/httpcore/sync.py index b1f98f50..737d3fcf 100644 --- a/httpcore/sync.py +++ b/httpcore/sync.py @@ -2,9 +2,10 @@ import asyncio import typing from types import TracebackType +from .adapters import Adapter from .config import SSLConfig, TimeoutConfig from .connection_pool import ConnectionPool -from .models import URL, Client, Response +from .models import URL, Response class SyncResponse: @@ -44,8 +45,8 @@ class SyncResponse: class SyncClient: - def __init__(self, client: Client): - self._client = client + def __init__(self, adapter: Adapter): + self._client = adapter self._loop = asyncio.new_event_loop() def request( @@ -55,20 +56,10 @@ class SyncClient: *, headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", - ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, - stream: bool = False, + **options: typing.Any ) -> SyncResponse: response = self._loop.run_until_complete( - self._client.request( - method, - url, - headers=headers, - body=body, - ssl=ssl, - timeout=timeout, - stream=stream, - ) + self._client.request(method, url, headers=headers, body=body, **options) ) return SyncResponse(response, self._loop) diff --git a/tests/test_api.py b/tests/test_api.py index 6b80587d..4622849b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,16 +5,16 @@ import httpcore @pytest.mark.asyncio async def test_get(server): - async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/") + async with httpcore.Client() as client: + response = await client.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.body == b"Hello, world!" @pytest.mark.asyncio async def test_post(server): - async with httpcore.ConnectionPool() as http: - response = await http.request( + async with httpcore.Client() as client: + response = await client.request( "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" ) assert response.status_code == 200 @@ -22,8 +22,8 @@ async def test_post(server): @pytest.mark.asyncio async def test_stream_response(server): - async with httpcore.ConnectionPool() as http: - response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) + async with httpcore.Client() as client: + response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 assert not hasattr(response, "body") body = await response.read() @@ -36,8 +36,8 @@ async def test_stream_request(server): yield b"Hello, " yield b"world!" - async with httpcore.ConnectionPool() as http: - response = await http.request( + async with httpcore.Client() as client: + response = await client.request( "POST", "http://127.0.0.1:8000/", body=hello_world() ) assert response.status_code == 200 diff --git a/tests/test_config.py b/tests/test_config.py index 74c1267e..e4ce64a4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,13 +13,15 @@ def test_timeout_repr(): timeout = httpcore.TimeoutConfig(read_timeout=5.0) assert ( repr(timeout) - == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)" + == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)" ) def test_limits_repr(): limits = httpcore.PoolLimits(hard_limit=100) - assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)" + assert ( + repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)" + ) def test_ssl_eq(): diff --git a/tests/test_http2.py b/tests/test_http2.py index f17d0f98..dc287bce 100644 --- a/tests/test_http2.py +++ b/tests/test_http2.py @@ -79,11 +79,8 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter): @pytest.mark.asyncio async def test_http2_get_request(): server = MockServer() - origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection( - reader=server, writer=server, origin=origin - ) as client: - response = await client.request("GET", "http://example.org") + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response = await conn.request("GET", "http://example.org") assert response.status_code == 200 assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""} @@ -91,11 +88,8 @@ async def test_http2_get_request(): @pytest.mark.asyncio async def test_http2_post_request(): server = MockServer() - origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection( - reader=server, writer=server, origin=origin - ) as client: - response = await client.request("POST", "http://example.org", body=b"") + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response = await conn.request("POST", "http://example.org", body=b"") assert response.status_code == 200 assert json.loads(response.body) == { "method": "POST", @@ -107,13 +101,10 @@ async def test_http2_post_request(): @pytest.mark.asyncio async def test_http2_multiple_requests(): server = MockServer() - origin = httpcore.Origin("http://example.org") - async with httpcore.HTTP2Connection( - reader=server, writer=server, origin=origin - ) as client: - response_1 = await client.request("GET", "http://example.org/1") - response_2 = await client.request("GET", "http://example.org/2") - response_3 = await client.request("GET", "http://example.org/3") + async with httpcore.HTTP2Connection(reader=server, writer=server) as conn: + response_1 = await conn.request("GET", "http://example.org/1") + response_2 = await conn.request("GET", "http://example.org/2") + response_3 = await conn.request("GET", "http://example.org/3") assert response_1.status_code == 200 assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""} diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index b1ceef93..d91cb799 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -24,10 +24,9 @@ async def test_connect_timeout(server): @pytest.mark.asyncio async def test_pool_timeout(server): - timeout = httpcore.TimeoutConfig(pool_timeout=0.0001) - limits = httpcore.PoolLimits(hard_limit=1) + limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001) - async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http: + async with httpcore.ConnectionPool(limits=limits) as http: response = await http.request("GET", "http://127.0.0.1:8000/", stream=True) with pytest.raises(httpcore.PoolTimeout):