]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor to request = Client.send(request)
authorTom Christie <tom@tomchristie.com>
Mon, 22 Apr 2019 16:33:43 +0000 (17:33 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 22 Apr 2019 16:33:43 +0000 (17:33 +0100)
httpcore/__init__.py
httpcore/config.py
httpcore/connections.py
httpcore/datastructures.py
httpcore/pool.py
httpcore/sync.py [new file with mode: 0644]
tests/test_api.py
tests/test_connections.py [new file with mode: 0644]
tests/test_pool.py
tests/test_responses.py
tests/test_timeouts.py

index d4f6465f8f8d4f235e29c3150fb24ddf2d21045a..32ff66d6cb1f6cd447b440de22ad80baef357d0c 100644 (file)
@@ -1,5 +1,6 @@
 from .config import PoolLimits, SSLConfig, TimeoutConfig
-from .datastructures import URL, Request, Response
+from .connections import Connection
+from .datastructures import URL, Origin, Request, Response
 from .exceptions import (
     BadResponse,
     ConnectTimeout,
index f694fc9a5e289168260a928a7e4488e0cd95b5db..8cc784b3f78197dc6a8b6c806544f88b5a45e7e9 100644 (file)
@@ -1,3 +1,5 @@
+import asyncio
+import os
 import ssl
 import typing
 
index 16f91408a32fca23615eb0c4b0a29ad927f0a093..dadcfd90cb1e4f278042aea04a4bc8078aa00d41 100644 (file)
@@ -1,11 +1,10 @@
 import asyncio
-import ssl
 import typing
 
 import h11
 
-from .config import TimeoutConfig
-from .datastructures import Request, Response
+from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
+from .datastructures import Client, Origin, Request, Response
 from .exceptions import ConnectTimeout, ReadTimeout
 
 H11Event = typing.Union[
@@ -19,34 +18,47 @@ H11Event = typing.Union[
 
 
 class Connection:
-    def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None):
-        self.reader = None
-        self.writer = None
-        self.state = h11.Connection(our_role=h11.CLIENT)
+    def __init__(
+        self,
+        origin: typing.Union[str, Origin],
+        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        on_release: typing.Callable = None,
+    ):
+        self.origin = Origin(origin) if isinstance(origin, str) else origin
+        self.ssl = ssl
         self.timeout = timeout
         self.on_release = on_release
+        self._reader = None
+        self._writer = None
+        self._h11_state = h11.Connection(our_role=h11.CLIENT)
 
     @property
     def is_closed(self) -> bool:
-        return self.state.our_state in (h11.CLOSED, h11.ERROR)
+        return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
+
+    async def send(
+        self,
+        request: Request,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+        stream: bool = False,
+    ) -> Response:
+        assert request.url.origin == self.origin
+
+        if ssl is None:
+            ssl = self.ssl
+        if timeout is None:
+            timeout = self.timeout
+
+        # Make the connection
+        if self._reader is None:
+            await self._connect(ssl, timeout)
 
-    async def open(
-        self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None
-    ) -> None:
-        try:
-            self.reader, self.writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_connection(hostname, port, ssl=ssl),
-                self.timeout.connect_timeout,
-            )
-        except asyncio.TimeoutError:
-            raise ConnectTimeout()
-
-    async def send(self, request: Request) -> Response:
+        #  Start sending the request.
         method = request.method.encode()
         target = request.url.target
         headers = request.headers
-
-        #  Start sending the request.
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event)
 
@@ -64,69 +76,97 @@ class Connection:
         await self._send_event(event)
 
         # Start getting the response.
-        event = await self._receive_event()
+        event = await self._receive_event(timeout)
         if isinstance(event, h11.InformationalResponse):
-            event = await self._receive_event()
+            event = await self._receive_event(timeout)
         assert isinstance(event, h11.Response)
-        reason = event.reason.decode('latin1')
+        reason = event.reason.decode("latin1")
         status_code = event.status_code
         headers = event.headers
-        body = self._body_iter()
-        return Response(
-            status_code=status_code, reason=reason, headers=headers, body=body, on_close=self._release
+        body = self._body_iter(timeout)
+        response = Response(
+            status_code=status_code,
+            reason=reason,
+            headers=headers,
+            body=body,
+            on_close=self._release,
         )
 
-    async def _body_iter(self) -> typing.AsyncIterator[bytes]:
-        event = await self._receive_event()
+        if not stream:
+            # Read the response body.
+            try:
+                await response.read()
+            finally:
+                await response.close()
+
+        return response
+
+    async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
+        ssl_context = await ssl.load_ssl_context() if self.origin.is_secure else None
+
+        try:
+            self._reader, self._writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_connection(
+                    self.origin.hostname, self.origin.port, ssl=ssl_context
+                ),
+                timeout.connect_timeout,
+            )
+        except asyncio.TimeoutError:
+            raise ConnectTimeout()
+
+    async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+        event = await self._receive_event(timeout)
         while isinstance(event, h11.Data):
             yield event.data
-            event = await self._receive_event()
+            event = await self._receive_event(timeout)
         assert isinstance(event, h11.EndOfMessage)
 
     async def _send_event(self, event: H11Event) -> None:
-        assert self.writer is not None
+        assert self._writer is not None
 
-        data = self.state.send(event)
-        self.writer.write(data)
+        data = self._h11_state.send(event)
+        self._writer.write(data)
 
-    async def _receive_event(self) -> H11Event:
-        assert self.reader is not None
+    async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
+        assert self._reader is not None
 
-        event = self.state.next_event()
+        event = self._h11_state.next_event()
 
         while event is h11.NEED_DATA:
             try:
                 data = await asyncio.wait_for(
-                    self.reader.read(2048), self.timeout.read_timeout
+                    self._reader.read(2048), timeout.read_timeout
                 )
             except asyncio.TimeoutError:
                 raise ReadTimeout()
-            self.state.receive_data(data)
-            event = self.state.next_event()
+            self._h11_state.receive_data(data)
+            event = self._h11_state.next_event()
 
         return event
 
     async def _release(self) -> None:
-        assert self.writer is not None
+        assert self._writer is not None
 
-        if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
-            self.state.start_next_cycle()
+        if (
+            self._h11_state.our_state is h11.DONE
+            and self._h11_state.their_state is h11.DONE
+        ):
+            self._h11_state.start_next_cycle()
         else:
-            self.close()
+            await self.close()
 
         if self.on_release is not None:
             await self.on_release(self)
 
-    def close(self) -> None:
-        assert self.writer is not None
-
+    async def close(self) -> None:
         event = h11.ConnectionClosed()
         try:
             # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
-            self.state.send(event)
+            self._h11_state.send(event)
         except h11.ProtocolError:
             # If we're in some other state then it's a premature close,
             # and we'll end up in h11.ERROR.
             pass
 
-        self.writer.close()
+        if self._writer is not None:
+            self._writer.close()
index 49f85bdc96fc7ef546c060404bf9d10efd8a1c95..d6cadc422ad4c0dc81b8ff0f420a0831e442e76f 100644 (file)
@@ -61,6 +61,34 @@ class URL:
     def is_secure(self) -> bool:
         return self.components.scheme == "https"
 
+    @property
+    def origin(self) -> "Origin":
+        return Origin(self)
+
+
+class Origin:
+    def __init__(self, url: typing.Union[str, URL]) -> None:
+        if isinstance(url, str):
+            url = URL(url)
+        self.scheme = url.scheme
+        self.hostname = url.hostname
+        self.port = url.port
+
+    @property
+    def is_secure(self) -> bool:
+        return self.scheme == "https"
+
+    def __eq__(self, other: typing.Any) -> bool:
+        return (
+            isinstance(other, self.__class__)
+            and self.scheme == other.scheme
+            and self.hostname == other.hostname
+            and self.port == other.port
+        )
+
+    def __hash__(self) -> int:
+        return hash((self.scheme, self.hostname, self.port))
+
 
 class Request:
     def __init__(
@@ -207,3 +235,11 @@ class Response:
             self.is_closed = True
             if self.on_close is not None:
                 await self.on_close()
+
+
+class Client:
+    async def send(self, request: Request, **options: typing.Any) -> Response:
+        raise NotImplementedError()  # pragma: nocover
+
+    async def close(self) -> None:
+        raise NotImplementedError()  # pragma: nocover
index f09e01e04019a6cce3b9fa4be60f8a34f4f46e76..1a8466288667206665f0b3df2ce4fea6017fbf64 100644 (file)
@@ -1,7 +1,4 @@
 import asyncio
-import functools
-import os
-import ssl
 import typing
 from types import TracebackType
 
@@ -15,11 +12,9 @@ from .config import (
     TimeoutConfig,
 )
 from .connections import Connection
-from .datastructures import URL, Request, Response
+from .datastructures import Client, Origin, Request, Response
 from .exceptions import PoolTimeout
 
-ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig]
-
 
 class ConnectionSemaphore:
     def __init__(self, max_connections: int = None):
@@ -43,7 +38,7 @@ class ConnectionPool:
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         limits: PoolLimits = DEFAULT_POOL_LIMITS,
     ):
-        self.ssl_config = ssl
+        self.ssl = ssl
         self.timeout = timeout
         self.limits = limits
         self.is_closed = False
@@ -51,36 +46,22 @@ class ConnectionPool:
         self.num_keepalive_connections = 0
         self._keepalive_connections = (
             {}
-        )  # type: typing.Dict[ConnectionKey, typing.List[Connection]]
+        )  # type: typing.Dict[Origin, typing.List[Connection]]
         self._max_connections = ConnectionSemaphore(
             max_connections=self.limits.hard_limit
         )
 
-    async def request(
+    async def send(
         self,
-        method: str,
-        url: str,
-        *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        stream: bool = False,
+        request: Request,
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None,
+        stream: bool = False,
     ) -> Response:
-        if ssl is None:
-            ssl = self.ssl_config
-        if timeout is None:
-            timeout = self.timeout
-
-        parsed_url = URL(url)
-        request = Request(method, parsed_url, headers=headers, body=body)
-        connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout)
-        response = await connection.send(request)
-        if not stream:
-            try:
-                await response.read()
-            finally:
-                await response.close()
+        connection = await self.acquire_connection(request.url.origin, timeout=timeout)
+        response = await connection.send(
+            request, ssl=ssl, timeout=timeout, stream=stream
+        )
         return response
 
     @property
@@ -88,38 +69,36 @@ class ConnectionPool:
         return self.num_active_connections + self.num_keepalive_connections
 
     async def acquire_connection(
-        self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig
+        self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
     ) -> Connection:
-        key = (url.scheme, url.hostname, url.port, ssl, timeout)
         try:
-            connection = self._keepalive_connections[key].pop()
-            if not self._keepalive_connections[key]:
-                del self._keepalive_connections[key]
+            connection = self._keepalive_connections[origin].pop()
+            if not self._keepalive_connections[origin]:
+                del self._keepalive_connections[origin]
             self.num_keepalive_connections -= 1
             self.num_active_connections += 1
 
         except (KeyError, IndexError):
-            if url.is_secure:
-                ssl_context = await ssl.load_ssl_context()
+            if timeout is None:
+                pool_timeout = self.timeout.pool_timeout
             else:
-                ssl_context = None
+                pool_timeout = timeout.pool_timeout
 
             try:
-                await asyncio.wait_for(
-                    self._max_connections.acquire(), timeout.pool_timeout
-                )
+                await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
             except asyncio.TimeoutError:
                 raise PoolTimeout()
-            release = functools.partial(self.release_connection, key=key)
-            connection = Connection(timeout=timeout, on_release=release)
+            connection = Connection(
+                origin,
+                ssl=self.ssl,
+                timeout=self.timeout,
+                on_release=self.release_connection,
+            )
             self.num_active_connections += 1
-            await connection.open(url.hostname, url.port, ssl=ssl_context)
 
         return connection
 
-    async def release_connection(
-        self, connection: Connection, key: ConnectionKey
-    ) -> None:
+    async def release_connection(self, connection: Connection) -> None:
         if connection.is_closed:
             self._max_connections.release()
             self.num_active_connections -= 1
@@ -129,14 +108,14 @@ class ConnectionPool:
         ):
             self._max_connections.release()
             self.num_active_connections -= 1
-            connection.close()
+            await connection.close()
         else:
             self.num_active_connections -= 1
             self.num_keepalive_connections += 1
             try:
-                self._keepalive_connections[key].append(connection)
+                self._keepalive_connections[connection.origin].append(connection)
             except KeyError:
-                self._keepalive_connections[key] = [connection]
+                self._keepalive_connections[connection.origin] = [connection]
 
     async def close(self) -> None:
         self.is_closed = True
diff --git a/httpcore/sync.py b/httpcore/sync.py
new file mode 100644 (file)
index 0000000..e69de29
index 6b80587dbc4f475f27175ba9175093c06b257431..30199c93de4641264540bd028e5a4d714cec9044 100644 (file)
@@ -5,25 +5,28 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/")
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request)
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
 
 
 @pytest.mark.asyncio
 async def test_post(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request(
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request(
             "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
         )
+        response = await client.send(request)
     assert response.status_code == 200
 
 
 @pytest.mark.asyncio
 async def test_stream_response(server):
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request, stream=True)
     assert response.status_code == 200
     assert not hasattr(response, "body")
     body = await response.read()
@@ -36,8 +39,7 @@ async def test_stream_request(server):
         yield b"Hello, "
         yield b"world!"
 
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request(
-            "POST", "http://127.0.0.1:8000/", body=hello_world()
-        )
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("POST", "http://127.0.0.1:8000/", body=hello_world())
+        response = await client.send(request)
     assert response.status_code == 200
diff --git a/tests/test_connections.py b/tests/test_connections.py
new file mode 100644 (file)
index 0000000..958c1b5
--- /dev/null
@@ -0,0 +1,22 @@
+import pytest
+
+import httpcore
+
+
+@pytest.mark.asyncio
+async def test_get(server):
+    client = httpcore.Connection(origin="http://127.0.0.1:8000/")
+    request = httpcore.Request(method="GET", url="http://127.0.0.1:8000/")
+    response = await client.send(request)
+    assert response.status_code == 200
+    assert response.body == b"Hello, world!"
+
+
+@pytest.mark.asyncio
+async def test_post(server):
+    client = httpcore.Connection(origin="http://127.0.0.1:8000/")
+    request = httpcore.Request(
+        method="POST", url="http://127.0.0.1:8000/", body=b"Hello, world!"
+    )
+    response = await client.send(request)
+    assert response.status_code == 200
index 77a221575441fda934bd4799ed8a4a33b00726f0..de25fe65a792e4d206865459618754dc407ef38b 100644 (file)
@@ -8,14 +8,16 @@ async def test_keepalive_connections(server):
     """
     Connections should default to staying in a keep-alive state.
     """
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
-        response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -23,14 +25,16 @@ async def test_differing_connection_keys(server):
     """
     Connnections to differing connection keys should result in multiple connections.
     """
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
-        response = await http.request("GET", "http://localhost:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 2
+        request = httpcore.Request("GET", "http://localhost:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 2
 
 
 @pytest.mark.asyncio
@@ -40,14 +44,16 @@ async def test_soft_limit(server):
     """
     limits = httpcore.PoolLimits(soft_limit=1)
 
-    async with httpcore.ConnectionPool(limits=limits) as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+    async with httpcore.ConnectionPool(limits=limits) as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
-        response = await http.request("GET", "http://localhost:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        request = httpcore.Request("GET", "http://localhost:8000/")
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -55,15 +61,16 @@ async def test_streaming_response_holds_connection(server):
     """
     A streaming request should hold the connection open until the response is read.
     """
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request, stream=True)
+        assert client.num_active_connections == 1
+        assert client.num_keepalive_connections == 0
 
         await response.read()
 
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -71,22 +78,24 @@ async def test_multiple_concurrent_connections(server):
     """
     Multiple conncurrent requests should open multiple conncurrent connections.
     """
-    async with httpcore.ConnectionPool() as http:
-        response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response_a = await client.send(request, stream=True)
+        assert client.num_active_connections == 1
+        assert client.num_keepalive_connections == 0
 
-        response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 2
-        assert http.num_keepalive_connections == 0
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response_b = await client.send(request, stream=True)
+        assert client.num_active_connections == 2
+        assert client.num_keepalive_connections == 0
 
         await response_b.read()
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 1
+        assert client.num_active_connections == 1
+        assert client.num_keepalive_connections == 1
 
         await response_a.read()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 2
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 2
 
 
 @pytest.mark.asyncio
@@ -95,10 +104,11 @@ async def test_close_connections(server):
     Using a `Connection: close` header should close the connection.
     """
     headers = [(b"connection", b"close")]
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 0
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/", headers=headers)
+        response = await client.send(request)
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 0
 
 
 @pytest.mark.asyncio
@@ -106,12 +116,13 @@ async def test_standard_response_close(server):
     """
     A standard close should keep the connection open.
     """
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request, stream=True)
         await response.read()
         await response.close()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 1
 
 
 @pytest.mark.asyncio
@@ -119,8 +130,9 @@ async def test_premature_response_close(server):
     """
     A premature close should close the connection.
     """
-    async with httpcore.ConnectionPool() as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+    async with httpcore.ConnectionPool() as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request, stream=True)
         await response.close()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 0
+        assert client.num_active_connections == 0
+        assert client.num_keepalive_connections == 0
index 3efef890b7a88ec97b593953ad9019d4a8a3363f..bb930bdb204c21ff75259821024045740696d4f8 100644 (file)
@@ -3,26 +3,13 @@ import pytest
 import httpcore
 
 
-class MockHTTP(httpcore.ConnectionPool):
-    async def request(
-        self, method, url, *, headers=(), body=b"", stream=False
-    ) -> httpcore.Response:
-        if stream:
+async def streaming_body():
+    yield b"Hello, "
+    yield b"world!"
 
-            async def streaming_body():
-                yield b"Hello, "
-                yield b"world!"
 
-            return httpcore.Response(200, body=streaming_body())
-        return httpcore.Response(200, body=b"Hello, world!")
-
-
-http = MockHTTP()
-
-
-@pytest.mark.asyncio
-async def test_request():
-    response = await http.request("GET", "http://example.com")
+def test_response():
+    response = httpcore.Response(200, body=b"Hello, world!")
     assert response.status_code == 200
     assert response.reason == "OK"
     assert response.body == b"Hello, world!"
@@ -31,7 +18,7 @@ async def test_request():
 
 @pytest.mark.asyncio
 async def test_read_response():
-    response = await http.request("GET", "http://example.com")
+    response = httpcore.Response(200, body=b"Hello, world!")
 
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
@@ -45,25 +32,8 @@ async def test_read_response():
 
 
 @pytest.mark.asyncio
-async def test_stream_response():
-    response = await http.request("GET", "http://example.com")
-
-    assert response.status_code == 200
-    assert response.body == b"Hello, world!"
-    assert response.is_closed
-
-    body = b""
-    async for part in response.stream():
-        body += part
-
-    assert body == b"Hello, world!"
-    assert response.body == b"Hello, world!"
-    assert response.is_closed
-
-
-@pytest.mark.asyncio
-async def test_read_streaming_response():
-    response = await http.request("GET", "http://example.com", stream=True)
+async def test_streaming_response():
+    response = httpcore.Response(200, body=streaming_body())
 
     assert response.status_code == 200
     assert not hasattr(response, "body")
@@ -76,26 +46,9 @@ async def test_read_streaming_response():
     assert response.is_closed
 
 
-@pytest.mark.asyncio
-async def test_stream_streaming_response():
-    response = await http.request("GET", "http://example.com", stream=True)
-
-    assert response.status_code == 200
-    assert not hasattr(response, "body")
-    assert not response.is_closed
-
-    body = b""
-    async for part in response.stream():
-        body += part
-
-    assert body == b"Hello, world!"
-    assert not hasattr(response, "body")
-    assert response.is_closed
-
-
 @pytest.mark.asyncio
 async def test_cannot_read_after_stream_consumed():
-    response = await http.request("GET", "http://example.com", stream=True)
+    response = httpcore.Response(200, body=streaming_body())
 
     body = b""
     async for part in response.stream():
@@ -107,7 +60,7 @@ async def test_cannot_read_after_stream_consumed():
 
 @pytest.mark.asyncio
 async def test_cannot_read_after_response_closed():
-    response = await http.request("GET", "http://example.com", stream=True)
+    response = httpcore.Response(200, body=streaming_body())
 
     await response.close()
 
index b1ceef93d455cc0fd0f1a5a73bcb4432c7fae006..5b61aee21c0403f43a6aaddb1c039fa9ad7dffef 100644 (file)
@@ -7,19 +7,21 @@ import httpcore
 async def test_read_timeout(server):
     timeout = httpcore.TimeoutConfig(read_timeout=0.0001)
 
-    async with httpcore.ConnectionPool(timeout=timeout) as http:
+    async with httpcore.ConnectionPool(timeout=timeout) as client:
         with pytest.raises(httpcore.ReadTimeout):
-            await http.request("GET", "http://127.0.0.1:8000/slow_response")
+            request = httpcore.Request("GET", "http://127.0.0.1:8000/slow_response")
+            await client.send(request)
 
 
 @pytest.mark.asyncio
 async def test_connect_timeout(server):
     timeout = httpcore.TimeoutConfig(connect_timeout=0.0001)
 
-    async with httpcore.ConnectionPool(timeout=timeout) as http:
+    async with httpcore.ConnectionPool(timeout=timeout) as client:
         with pytest.raises(httpcore.ConnectTimeout):
             # See https://stackoverflow.com/questions/100841/
-            await http.request("GET", "http://10.255.255.1/")
+            request = httpcore.Request("GET", "http://10.255.255.1/")
+            await client.send(request)
 
 
 @pytest.mark.asyncio
@@ -27,10 +29,12 @@ async def test_pool_timeout(server):
     timeout = httpcore.TimeoutConfig(pool_timeout=0.0001)
     limits = httpcore.PoolLimits(hard_limit=1)
 
-    async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as http:
-        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+    async with httpcore.ConnectionPool(timeout=timeout, limits=limits) as client:
+        request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+        response = await client.send(request, stream=True)
 
         with pytest.raises(httpcore.PoolTimeout):
-            await http.request("GET", "http://localhost:8000/")
+            request = httpcore.Request("GET", "http://127.0.0.1:8000/")
+            await client.send(request)
 
         await response.read()