]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Adapters
authorTom Christie <tom@tomchristie.com>
Fri, 26 Apr 2019 14:34:04 +0000 (15:34 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 26 Apr 2019 14:34:04 +0000 (15:34 +0100)
20 files changed:
httpcore/__init__.py
httpcore/adapters.py [new file with mode: 0644]
httpcore/auth.py [new file with mode: 0644]
httpcore/client.py [new file with mode: 0644]
httpcore/config.py
httpcore/connection.py
httpcore/connection_pool.py
httpcore/cookies.py [new file with mode: 0644]
httpcore/environment.py [new file with mode: 0644]
httpcore/exceptions.py
httpcore/http11.py
httpcore/http2.py
httpcore/models.py
httpcore/redirects.py [new file with mode: 0644]
httpcore/streams.py
httpcore/sync.py
tests/test_api.py
tests/test_config.py
tests/test_http2.py
tests/test_timeouts.py

index 00cb5fb80b54d53c25b47e8d44b79ac137df94fa..30ed38e6072eb39dadacd4fb39027d39661263c2 100644 (file)
@@ -1,3 +1,5 @@
+from .adapters import Adapter
+from .client import Client
 from .config import PoolLimits, SSLConfig, TimeoutConfig
 from .connection import HTTPConnection
 from .connection_pool import ConnectionPool
diff --git a/httpcore/adapters.py b/httpcore/adapters.py
new file mode 100644 (file)
index 0000000..0c14e89
--- /dev/null
@@ -0,0 +1,40 @@
+import typing
+from types import TracebackType
+
+from .models import URL, Request, Response
+
+
+class Adapter:
+    async def request(
+        self,
+        method: str,
+        url: typing.Union[str, URL],
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        **options: typing.Any,
+    ) -> Response:
+        request = Request(method, url, headers=headers, body=body)
+        self.prepare_request(request)
+        response = await self.send(request, **options)
+        return response
+
+    def prepare_request(self, request: Request) -> None:
+        raise NotImplementedError()  # pragma: nocover
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        raise NotImplementedError()  # pragma: nocover
+
+    async def close(self) -> None:
+        raise NotImplementedError()  # pragma: nocover
+
+    async def __aenter__(self) -> "Adapter":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.close()
diff --git a/httpcore/auth.py b/httpcore/auth.py
new file mode 100644 (file)
index 0000000..949d8a9
--- /dev/null
@@ -0,0 +1,18 @@
+import typing
+
+from .adapters import Adapter
+from .models import Request, Response
+
+
+class AuthAdapter(Adapter):
+    def __init__(self, dispatch: Adapter):
+        self.dispatch = dispatch
+
+    def prepare_request(self, request: Request) -> None:
+        self.dispatch.prepare_request(request)
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        return await self.dispatch.send(request, **options)
+
+    async def close(self) -> None:
+        self.dispatch.close()
diff --git a/httpcore/client.py b/httpcore/client.py
new file mode 100644 (file)
index 0000000..022ae7f
--- /dev/null
@@ -0,0 +1,124 @@
+import typing
+from types import TracebackType
+
+from .auth import AuthAdapter
+from .config import (
+    DEFAULT_MAX_REDIRECTS,
+    DEFAULT_POOL_LIMITS,
+    DEFAULT_SSL_CONFIG,
+    DEFAULT_TIMEOUT_CONFIG,
+    PoolLimits,
+    SSLConfig,
+    TimeoutConfig,
+)
+from .connection_pool import ConnectionPool
+from .cookies import CookieAdapter
+from .environment import EnvironmentAdapter
+from .models import URL, Request, Response
+from .redirects import RedirectAdapter
+
+
+class Client:
+    def __init__(
+        self,
+        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        limits: PoolLimits = DEFAULT_POOL_LIMITS,
+        max_redirects: int = DEFAULT_MAX_REDIRECTS,
+    ):
+        connection_pool = ConnectionPool(ssl=ssl, timeout=timeout, limits=limits)
+        cookie_adapter = CookieAdapter(dispatch=connection_pool)
+        auth_adapter = AuthAdapter(dispatch=cookie_adapter)
+        redirect_adapter = RedirectAdapter(
+            dispatch=auth_adapter, max_redirects=max_redirects
+        )
+        self.adapter = EnvironmentAdapter(dispatch=redirect_adapter)
+
+    async def request(
+        self,
+        method: str,
+        url: typing.Union[str, URL],
+        *,
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> Response:
+        request = Request(method, url, headers=headers, body=body)
+        self.prepare_request(request)
+        response = await self.send(
+            request,
+            stream=stream,
+            allow_redirects=allow_redirects,
+            ssl=ssl,
+            timeout=timeout,
+        )
+        return response
+
+    async def get(
+        self,
+        url: typing.Union[str, URL],
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        stream: bool = False,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> Response:
+        return await self.request(
+            "GET", url, headers=headers, stream=stream, ssl=ssl, timeout=timeout
+        )
+
+    async def post(
+        self,
+        url: typing.Union[str, URL],
+        *,
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        stream: bool = False,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> Response:
+        return await self.request(
+            "POST",
+            url,
+            body=body,
+            headers=headers,
+            stream=stream,
+            ssl=ssl,
+            timeout=timeout,
+        )
+
+    def prepare_request(self, request: Request) -> None:
+        self.adapter.prepare_request(request)
+
+    async def send(
+        self,
+        request: Request,
+        *,
+        stream: bool = False,
+        allow_redirects: bool = True,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> Response:
+        options = {"stream": stream}  # type: typing.Dict[str, typing.Any]
+        if ssl is not None:
+            options["ssl"] = ssl
+        if timeout is not None:
+            options["timeout"] = timeout
+        return await self.adapter.send(request, **options)
+
+    async def close(self) -> None:
+        await self.adapter.close()
+
+    async def __aenter__(self) -> "Client":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.close()
index 7166c05f75cf2e6fe922e422cc37a88a22e6365d..5ce24707c3bf8a0bcdd5e11733f7c0c51ec481b2 100644 (file)
@@ -112,24 +112,20 @@ class TimeoutConfig:
         connect_timeout: float = None,
         read_timeout: float = None,
         write_timeout: float = None,
-        pool_timeout: float = None,
     ):
         if timeout is not None:
             # Specified as a single timeout value
             assert connect_timeout is None
             assert read_timeout is None
             assert write_timeout is None
-            assert pool_timeout is None
             connect_timeout = timeout
             read_timeout = timeout
             write_timeout = timeout
-            pool_timeout = timeout
 
         self.timeout = timeout
         self.connect_timeout = connect_timeout
         self.read_timeout = read_timeout
         self.write_timeout = write_timeout
-        self.pool_timeout = pool_timeout
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
@@ -137,14 +133,13 @@ class TimeoutConfig:
             and self.connect_timeout == other.connect_timeout
             and self.read_timeout == other.read_timeout
             and self.write_timeout == other.write_timeout
-            and self.pool_timeout == other.pool_timeout
         )
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
         if self.timeout is not None:
             return f"{class_name}(timeout={self.timeout})"
-        return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})"
+        return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout})"
 
 
 class PoolLimits:
@@ -155,27 +150,29 @@ class PoolLimits:
     def __init__(
         self,
         *,
-        soft_limit: typing.Optional[int] = None,
-        hard_limit: typing.Optional[int] = None,
+        soft_limit: int = None,
+        hard_limit: int = None,
+        pool_timeout: float = None,
     ):
         self.soft_limit = soft_limit
         self.hard_limit = hard_limit
+        self.pool_timeout = pool_timeout
 
     def __eq__(self, other: typing.Any) -> bool:
         return (
             isinstance(other, self.__class__)
             and self.soft_limit == other.soft_limit
             and self.hard_limit == other.hard_limit
+            and self.pool_timeout == other.pool_timeout
         )
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
-        return (
-            f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
-        )
+        return f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit}, pool_timeout={self.pool_timeout})"
 
 
 DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
 DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
-DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100)
+DEFAULT_POOL_LIMITS = PoolLimits(soft_limit=10, hard_limit=100, pool_timeout=5.0)
 DEFAULT_CA_BUNDLE_PATH = certifi.where()
+DEFAULT_MAX_REDIRECTS = 30
index a9c9890e8005bd2d51b2c27133a33e32fe25d84e..1662018a59067f719c9b7f31162d78dfc3faba51 100644 (file)
@@ -4,18 +4,19 @@ import typing
 import h2.connection
 import h11
 
+from .adapters import Adapter
 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
 from .exceptions import ConnectTimeout
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
-from .models import Client, Origin, Request, Response
+from .models import Origin, Request, Response
 from .streams import Protocol, connect
 
 # Callback signature: async def callback(conn: HTTPConnection) -> None
 ReleaseCallback = typing.Callable[["HTTPConnection"], typing.Awaitable[None]]
 
 
-class HTTPConnection(Client):
+class HTTPConnection(Adapter):
     def __init__(
         self,
         origin: typing.Union[str, Origin],
@@ -30,33 +31,26 @@ class HTTPConnection(Client):
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
         self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
 
-    async def send(
-        self,
-        request: Request,
-        *,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-    ) -> Response:
+    def prepare_request(self, request: Request) -> None:
+        pass
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
         if self.h11_connection is None and self.h2_connection is None:
-            await self.connect(ssl, timeout)
+            await self.connect(**options)
 
         if self.h2_connection is not None:
-            response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
+            response = await self.h2_connection.send(request, **options)
         else:
             assert self.h11_connection is not None
-            response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout)
+            response = await self.h11_connection.send(request, **options)
 
         return response
 
-    async def connect(
-        self,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-    ) -> None:
-        if ssl is None:
-            ssl = self.ssl
-        if timeout is None:
-            timeout = self.timeout
+    async def connect(self, **options: typing.Any) -> None:
+        ssl = options.get("ssl", self.ssl)
+        timeout = options.get("timeout", self.timeout)
+        assert isinstance(ssl, SSLConfig)
+        assert isinstance(timeout, TimeoutConfig)
 
         hostname = self.origin.hostname
         port = self.origin.port
@@ -69,20 +63,10 @@ class HTTPConnection(Client):
 
         reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
         if protocol == Protocol.HTTP_2:
-            self.h2_connection = HTTP2Connection(
-                reader,
-                writer,
-                origin=self.origin,
-                timeout=self.timeout,
-                on_release=on_release,
-            )
+            self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
         else:
             self.h11_connection = HTTP11Connection(
-                reader,
-                writer,
-                origin=self.origin,
-                timeout=self.timeout,
-                on_release=on_release,
+                reader, writer, on_release=on_release
             )
 
     async def close(self) -> None:
index 6eaef9bef2beb1768c2d7baca5767e64d2ec46c0..ab394a19630cfd5168c507c2160cc242db547ded 100644 (file)
@@ -1,6 +1,7 @@
 import collections.abc
 import typing
 
+from .adapters import Adapter
 from .config import (
     DEFAULT_CA_BUNDLE_PATH,
     DEFAULT_POOL_LIMITS,
@@ -12,7 +13,7 @@ from .config import (
 )
 from .connection import HTTPConnection
 from .exceptions import PoolTimeout
-from .models import Client, Origin, Request, Response
+from .models import Origin, Request, Response
 from .streams import PoolSemaphore
 
 CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
@@ -81,7 +82,7 @@ class ConnectionStore(collections.abc.Sequence):
         return len(self.all)
 
 
-class ConnectionPool(Client):
+class ConnectionPool(Adapter):
     def __init__(
         self,
         *,
@@ -94,7 +95,7 @@ class ConnectionPool(Client):
         self.limits = limits
         self.is_closed = False
 
-        self.max_connections = PoolSemaphore(limits, timeout)
+        self.max_connections = PoolSemaphore(limits)
         self.keepalive_connections = ConnectionStore()
         self.active_connections = ConnectionStore()
 
@@ -102,31 +103,26 @@ class ConnectionPool(Client):
     def num_connections(self) -> int:
         return len(self.keepalive_connections) + len(self.active_connections)
 
-    async def send(
-        self,
-        request: Request,
-        *,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-    ) -> Response:
-        connection = await self.acquire_connection(request.url.origin, timeout=timeout)
+    def prepare_request(self, request: Request) -> None:
+        pass
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        connection = await self.acquire_connection(request.url.origin)
         try:
-            response = await connection.send(request, ssl=ssl, timeout=timeout)
+            response = await connection.send(request, **options)
         except BaseException as exc:
             self.active_connections.remove(connection)
             self.max_connections.release()
             raise exc
         return response
 
-    async def acquire_connection(
-        self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
-    ) -> HTTPConnection:
+    async def acquire_connection(self, origin: Origin) -> HTTPConnection:
         connection = self.active_connections.pop_by_origin(origin, http2_only=True)
         if connection is None:
             connection = self.keepalive_connections.pop_by_origin(origin)
 
         if connection is None:
-            await self.max_connections.acquire(timeout)
+            await self.max_connections.acquire()
             connection = HTTPConnection(
                 origin,
                 ssl=self.ssl,
diff --git a/httpcore/cookies.py b/httpcore/cookies.py
new file mode 100644 (file)
index 0000000..f6fd2b0
--- /dev/null
@@ -0,0 +1,18 @@
+import typing
+
+from .adapters import Adapter
+from .models import Request, Response
+
+
+class CookieAdapter(Adapter):
+    def __init__(self, dispatch: Adapter):
+        self.dispatch = dispatch
+
+    def prepare_request(self, request: Request) -> None:
+        self.dispatch.prepare_request(request)
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        return await self.dispatch.send(request, **options)
+
+    async def close(self) -> None:
+        self.dispatch.close()
diff --git a/httpcore/environment.py b/httpcore/environment.py
new file mode 100644 (file)
index 0000000..5065eed
--- /dev/null
@@ -0,0 +1,27 @@
+import typing
+
+from .adapters import Adapter
+from .models import Request, Response
+
+
+class EnvironmentAdapter(Adapter):
+    def __init__(self, dispatch: Adapter, trust_env: bool = True):
+        self.dispatch = dispatch
+        self.trust_env = trust_env
+
+    def prepare_request(self, request: Request) -> None:
+        self.dispatch.prepare_request(request)
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        if self.trust_env:
+            self.merge_environment_options(options)
+        return await self.dispatch.send(request, **options)
+
+    async def close(self) -> None:
+        await self.dispatch.close()
+
+    def merge_environment_options(self, options: dict) -> None:
+        """
+        Add environment options.
+        """
+        #  TODO
index 285b64039b743b4be96d28389aa7363ab2c082ed..94154b073e29b9e644e3595ba4514237f84177a5 100644 (file)
@@ -28,6 +28,12 @@ class PoolTimeout(Timeout):
     """
 
 
+class TooManyRedirects(Exception):
+    """
+    Too many redirects.
+    """
+
+
 class ProtocolError(Exception):
     """
     Malformed HTTP.
index 3fda6d1f0d30e0a41c4d6cbb65257c865a69c5c5..0075b524cead5ffb4c0addebd3dc8e0ec25d0034 100644 (file)
@@ -2,9 +2,10 @@ import typing
 
 import h11
 
+from .adapters import Adapter
 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, ReadTimeout
-from .models import Client, Origin, Request, Response
+from .models import Request, Response
 from .streams import BaseReader, BaseWriter
 
 H11Event = typing.Union[
@@ -25,35 +26,28 @@ OptionalTimeout = typing.Optional[TimeoutConfig]
 OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
 
 
-class HTTP11Connection(Client):
+class HTTP11Connection(Adapter):
     READ_NUM_BYTES = 4096
 
     def __init__(
         self,
         reader: BaseReader,
         writer: BaseWriter,
-        origin: Origin,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         on_release: typing.Optional[OnReleaseCallback] = None,
     ):
         self.reader = reader
         self.writer = writer
-        self.origin = origin
-        self.timeout = timeout
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
-    @property
-    def is_closed(self) -> bool:
-        return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
+    def prepare_request(self, request: Request) -> None:
+        pass
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        timeout = options.get("timeout")
+        stream = options.get("stream", False)
+        assert timeout is None or isinstance(timeout, TimeoutConfig)
 
-    async def send(
-        self,
-        request: Request,
-        *,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None
-    ) -> Response:
         #  Start sending the request.
         method = request.method.encode()
         target = request.url.full_path
@@ -81,7 +75,7 @@ class HTTP11Connection(Client):
         headers = event.headers
         body = self._body_iter(timeout)
 
-        return Response(
+        response = Response(
             status_code=status_code,
             reason=reason,
             protocol="HTTP/1.1",
@@ -90,6 +84,26 @@ class HTTP11Connection(Client):
             on_close=self.response_closed,
         )
 
+        if not stream:
+            try:
+                await response.read()
+            finally:
+                await response.close()
+
+        return response
+
+    async def close(self) -> None:
+        event = h11.ConnectionClosed()
+        try:
+            # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+            self.h11_state.send(event)
+        except h11.ProtocolError:
+            # If we're in some other state then it's a premature close,
+            # and we'll end up in h11.ERROR.
+            pass
+
+        await self.writer.close()
+
     async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
         event = await self._receive_event(timeout)
         while isinstance(event, h11.Data):
@@ -123,14 +137,6 @@ class HTTP11Connection(Client):
         if self.on_release is not None:
             await self.on_release()
 
-    async def close(self) -> None:
-        event = h11.ConnectionClosed()
-        try:
-            # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
-            self.h11_state.send(event)
-        except h11.ProtocolError:
-            # If we're in some other state then it's a premature close,
-            # and we'll end up in h11.ERROR.
-            pass
-
-        await self.writer.close()
+    @property
+    def is_closed(self) -> bool:
+        return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
index c60490ce1203b6526e9fb0b9bbb6c317bb439d22..e89029cf9467cb8554c98d791a3eabb58a28fba0 100644 (file)
@@ -4,52 +4,39 @@ import typing
 import h2.connection
 import h2.events
 
+from .adapters import Adapter
 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, ReadTimeout
-from .models import Client, Origin, Request, Response
+from .models import Request, Response
 from .streams import BaseReader, BaseWriter
 
 OptionalTimeout = typing.Optional[TimeoutConfig]
 
 
-class HTTP2Connection(Client):
+class HTTP2Connection(Adapter):
     READ_NUM_BYTES = 4096
 
     def __init__(
-        self,
-        reader: BaseReader,
-        writer: BaseWriter,
-        origin: Origin,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-        on_release: typing.Callable = None,
+        self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
     ):
         self.reader = reader
         self.writer = writer
-        self.origin = origin
-        self.timeout = timeout
         self.on_release = on_release
         self.h2_state = h2.connection.H2Connection()
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
         self.initialized = False
 
-    @property
-    def is_closed(self) -> bool:
-        return False
+    def prepare_request(self, request: Request) -> None:
+        pass
 
-    async def send(
-        self,
-        request: Request,
-        *,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None
-    ) -> Response:
-        if timeout is None:
-            timeout = self.timeout
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        timeout = options.get("timeout")
+        stream = options.get("stream", False)
+        assert timeout is None or isinstance(timeout, TimeoutConfig)
 
+        #  Start sending the request.
         if not self.initialized:
             self.initiate_connection()
-
-        #  Start sending the request.
         stream_id = await self.send_headers(request, timeout)
         self.events[stream_id] = []
 
@@ -77,7 +64,7 @@ class HTTP2Connection(Client):
         body = self.body_iter(stream_id, timeout)
         on_close = functools.partial(self.response_closed, stream_id=stream_id)
 
-        return Response(
+        response = Response(
             status_code=status_code,
             protocol="HTTP/2",
             headers=headers,
@@ -85,6 +72,17 @@ class HTTP2Connection(Client):
             on_close=on_close,
         )
 
+        if not stream:
+            try:
+                await response.read()
+            finally:
+                await response.close()
+
+        return response
+
+    async def close(self) -> None:
+        await self.writer.close()
+
     def initiate_connection(self) -> None:
         self.h2_state.initiate_connection()
         data_to_send = self.h2_state.data_to_send()
@@ -147,5 +145,6 @@ class HTTP2Connection(Client):
         if not self.events and self.on_release is not None:
             await self.on_release()
 
-    async def close(self) -> None:
-        await self.writer.close()
+    @property
+    def is_closed(self) -> bool:
+        return False
index e4a809af5b97c0e4c54eb522365262d4cc6fe14b..7229a76c18c8c0db594e407a61a2ec7623452829 100644 (file)
@@ -1,6 +1,5 @@
 import http
 import typing
-from types import TracebackType
 from urllib.parse import urlsplit
 
 from .config import SSLConfig, TimeoutConfig
@@ -237,47 +236,6 @@ class Response:
             if self.on_close is not None:
                 await self.on_close()
 
-
-class Client:
-    async def request(
-        self,
-        method: str,
-        url: typing.Union[str, URL],
-        *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-        stream: bool = False,
-    ) -> Response:
-        request = Request(method, url, headers=headers, body=body)
-        response = await self.send(request, ssl=ssl, timeout=timeout)
-        if not stream:
-            try:
-                await response.read()
-            finally:
-                await response.close()
-        return response
-
-    async def send(
-        self,
-        request: Request,
-        *,
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-    ) -> Response:
-        raise NotImplementedError()  # pragma: nocover
-
-    async def close(self) -> None:
-        raise NotImplementedError()  # pragma: nocover
-
-    async def __aenter__(self) -> "Client":
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        await self.close()
+    @property
+    def is_redirect(self) -> bool:
+        return self.status_code in (301, 302, 303, 307, 308)
diff --git a/httpcore/redirects.py b/httpcore/redirects.py
new file mode 100644 (file)
index 0000000..0657ebc
--- /dev/null
@@ -0,0 +1,35 @@
+import typing
+
+from .adapters import Adapter
+from .exceptions import TooManyRedirects
+from .models import Request, Response
+
+
+class RedirectAdapter(Adapter):
+    def __init__(self, dispatch: Adapter, max_redirects: int):
+        self.dispatch = dispatch
+        self.max_redirects = max_redirects
+
+    def prepare_request(self, request: Request) -> None:
+        self.dispatch.prepare_request(request)
+
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        allow_redirects = options.pop("allow_redirects", True)
+        history = []
+
+        while True:
+            response = await self.dispatch.send(request, **options)
+            if not allow_redirects or not response.is_redirect:
+                break
+            history.append(response)
+            if len(history) > self.max_redirects:
+                raise TooManyRedirects()
+            request = self.build_redirect_request(request, response)
+
+        return response
+
+    async def close(self) -> None:
+        self.dispatch.close()
+
+    def build_redirect_request(self, request: Request, response: Response) -> Request:
+        raise NotImplementedError()
index e46ffee4856174b8011b136f77d25325e6c6a724..03bba17ca6719ac74a9b2d0e8b47129b16bb0b3a 100644 (file)
@@ -41,10 +41,7 @@ class BaseWriter:
 
 
 class BasePoolSemaphore:
-    def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
-        raise NotImplementedError()  # pragma: no cover
-
-    async def acquire(self, timeout: OptionalTimeout = None) -> None:
+    async def acquire(self) -> None:
         raise NotImplementedError()  # pragma: no cover
 
     def release(self) -> None:
@@ -100,9 +97,8 @@ class Writer(BaseWriter):
 
 
 class PoolSemaphore(BasePoolSemaphore):
-    def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
+    def __init__(self, limits: PoolLimits):
         self.limits = limits
-        self.timeout = timeout
 
     @property
     def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
@@ -114,15 +110,13 @@ class PoolSemaphore(BasePoolSemaphore):
                 self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
         return self._semaphore
 
-    async def acquire(self, timeout: OptionalTimeout = None) -> None:
+    async def acquire(self) -> None:
         if self.semaphore is None:
             return
 
-        if timeout is None:
-            timeout = self.timeout
-
+        timeout = self.limits.pool_timeout
         try:
-            await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
+            await asyncio.wait_for(self.semaphore.acquire(), timeout)
         except asyncio.TimeoutError:
             raise PoolTimeout()
 
index b1f98f50e128ff83cd69b1fb3bc2dc1b48a552c7..737d3fcfe2e8c61557d18895ce60f55067a18c38 100644 (file)
@@ -2,9 +2,10 @@ import asyncio
 import typing
 from types import TracebackType
 
+from .adapters import Adapter
 from .config import SSLConfig, TimeoutConfig
 from .connection_pool import ConnectionPool
-from .models import URL, Client, Response
+from .models import URL, Response
 
 
 class SyncResponse:
@@ -44,8 +45,8 @@ class SyncResponse:
 
 
 class SyncClient:
-    def __init__(self, client: Client):
-        self._client = client
+    def __init__(self, adapter: Adapter):
+        self._client = adapter
         self._loop = asyncio.new_event_loop()
 
     def request(
@@ -55,20 +56,10 @@ class SyncClient:
         *,
         headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
-        stream: bool = False,
+        **options: typing.Any
     ) -> SyncResponse:
         response = self._loop.run_until_complete(
-            self._client.request(
-                method,
-                url,
-                headers=headers,
-                body=body,
-                ssl=ssl,
-                timeout=timeout,
-                stream=stream,
-            )
+            self._client.request(method, url, headers=headers, body=body, **options)
         )
         return SyncResponse(response, self._loop)
 
index 6b80587dbc4f475f27175ba9175093c06b257431..4622849b561fe225a0c4ad54c273e52f7089bf15 100644 (file)
@@ -5,16 +5,16 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/")
+    async with httpcore.Client() as client:
+        response = await client.request("GET", "http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_post(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request(
+    async with httpcore.Client() as client:
+        response = await client.request(
             "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
         )
     assert response.status_code == 200
@@ -22,8 +22,8 @@ async def test_post(server):
 
 @pytest.mark.asyncio
 async def test_stream_response(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+    async with httpcore.Client() as client:
+        response = await client.request("GET", "http://127.0.0.1:8000/", stream=True)
     assert response.status_code == 200
     assert not hasattr(response, "body")
     body = await response.read()
@@ -36,8 +36,8 @@ async def test_stream_request(server):
         yield b"Hello, "
         yield b"world!"
 
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request(
+    async with httpcore.Client() as client:
+        response = await client.request(
             "POST", "http://127.0.0.1:8000/", body=hello_world()
         )
     assert response.status_code == 200
index 74c1267ea179586cc0fe4dfa4f1f1bbb1e815677..e4ce64a46c3ebc9bbe0292e73cdd418b164aed56 100644 (file)
@@ -13,13 +13,15 @@ def test_timeout_repr():
     timeout = httpcore.TimeoutConfig(read_timeout=5.0)
     assert (
         repr(timeout)
-        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)"
+        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None)"
     )
 
 
 def test_limits_repr():
     limits = httpcore.PoolLimits(hard_limit=100)
-    assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
+    assert (
+        repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100, pool_timeout=None)"
+    )
 
 
 def test_ssl_eq():
index f17d0f98225ab3f63f142678bc16288e1820d23e..dc287bce333784bbb152ae80203657b1dc85dc71 100644 (file)
@@ -79,11 +79,8 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
 @pytest.mark.asyncio
 async def test_http2_get_request():
     server = MockServer()
-    origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(
-        reader=server, writer=server, origin=origin
-    ) as client:
-        response = await client.request("GET", "http://example.org")
+    async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
+        response = await conn.request("GET", "http://example.org")
     assert response.status_code == 200
     assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
 
@@ -91,11 +88,8 @@ async def test_http2_get_request():
 @pytest.mark.asyncio
 async def test_http2_post_request():
     server = MockServer()
-    origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(
-        reader=server, writer=server, origin=origin
-    ) as client:
-        response = await client.request("POST", "http://example.org", body=b"<data>")
+    async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
+        response = await conn.request("POST", "http://example.org", body=b"<data>")
     assert response.status_code == 200
     assert json.loads(response.body) == {
         "method": "POST",
@@ -107,13 +101,10 @@ async def test_http2_post_request():
 @pytest.mark.asyncio
 async def test_http2_multiple_requests():
     server = MockServer()
-    origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(
-        reader=server, writer=server, origin=origin
-    ) as client:
-        response_1 = await client.request("GET", "http://example.org/1")
-        response_2 = await client.request("GET", "http://example.org/2")
-        response_3 = await client.request("GET", "http://example.org/3")
+    async with httpcore.HTTP2Connection(reader=server, writer=server) as conn:
+        response_1 = await conn.request("GET", "http://example.org/1")
+        response_2 = await conn.request("GET", "http://example.org/2")
+        response_3 = await conn.request("GET", "http://example.org/3")
 
     assert response_1.status_code == 200
     assert json.loads(response_1.body) == {"method": "GET", "path": "/1", "body": ""}
index b1ceef93d455cc0fd0f1a5a73bcb4432c7fae006..d91cb799494462da180d47ffb7422ce6d3fc63e7 100644 (file)
@@ -24,10 +24,9 @@ async def test_connect_timeout(server):
 
 @pytest.mark.asyncio
 async def test_pool_timeout(server):
-    timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
-    limits = httpcore.PoolLimits(hard_limit=1)
+    limits = httpcore.PoolLimits(hard_limit=1, pool_timeout=0.0001)
 
-    async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
+    async with httpcore.ConnectionPool(limits=limits) as http:
         response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
 
         with pytest.raises(httpcore.PoolTimeout):