]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Connections
authorTom Christie <tom@tomchristie.com>
Sat, 6 Apr 2019 12:18:39 +0000 (13:18 +0100)
committerTom Christie <tom@tomchristie.com>
Sat, 6 Apr 2019 12:18:39 +0000 (13:18 +0100)
13 files changed:
README.md
httpcore/__init__.py
httpcore/api.py [deleted file]
httpcore/config.py
httpcore/connections.py
httpcore/datastructures.py [new file with mode: 0644]
httpcore/decoders.py
httpcore/models.py [deleted file]
httpcore/pool.py [new file with mode: 0644]
requirements.txt
tests/conftest.py [new file with mode: 0644]
tests/test_api.py
tests/test_responses.py

index c6b8d257175e47c8018db26812506398b0e1cad4..63eb03ddb79b37c0d791f2524e7bc7e3349d38a8 100644 (file)
--- a/README.md
+++ b/README.md
@@ -63,7 +63,8 @@ of it, and exposes only plain datastructures that reflect the network response.
 ```python
 import httpcore
 
-response = await httpcore.request('GET', 'http://example.com')
+http = httpcore.ConnectionPool()
+response = await http.request('GET', 'http://example.com')
 assert response.status_code == 200
 assert response.body == b'Hello, world'
 ```
@@ -71,20 +72,22 @@ assert response.body == b'Hello, world'
 Top-level API...
 
 ```python
-response = await httpcore.request(method, url, [headers], [body], [stream])
+http = httpcore.ConnectionPool([ssl], [timeout], [limits])
+response = await http.request(method, url, [headers], [body], [stream])
 ```
 
-Explicit PoolManager...
+ConnectionPool as a context-manager...
 
 ```python
-async with httpcore.PoolManager([ssl], [timeout], [limits]) as pool:
-    response = await pool.request(method, url, [headers], [body], [stream])
+async with httpcore.ConnectionPool([ssl], [timeout], [limits]) as http:
+    response = await http.request(method, url, [headers], [body], [stream])
 ```
 
 Streaming...
 
 ```python
-response = await httpcore.request(method, url, stream=True)
+http = httpcore.ConnectionPool()
+response = await http.request(method, url, stream=True)
 async for part in response.stream():
     ...
 ```
@@ -100,7 +103,7 @@ import httpcore
 class GatewayServer:
     def __init__(self, base_url):
         self.base_url = base_url
-        self.pool = httpcore.PoolManager()
+        self.http = httpcore.ConnectionPool()
 
     async def __call__(self, scope, receive, send):
         assert scope['type'] == 'http'
@@ -122,7 +125,7 @@ class GatewayServer:
                 if not message.get('more_body', False):
                     break
 
-        response = await self.pool.request(
+        response = await self.http.request(
             method, url, headers=headers, body=body, stream=True
         )
 
index 69894f361a383a753b3c70c024915ddf274db194..24a6bbfb26cff3a4e603dbdbd278f1ada8d896d3 100644 (file)
@@ -1,5 +1,6 @@
-from .api import PoolManager, Response, request
 from .config import PoolLimits, SSLConfig, TimeoutConfig
+from .datastructures import URL, Request, Response
 from .exceptions import ResponseClosed, StreamConsumed
+from .pool import ConnectionPool
 
 __version__ = "0.0.2"
diff --git a/httpcore/api.py b/httpcore/api.py
deleted file mode 100644 (file)
index 24c4fec..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-import typing
-from types import TracebackType
-
-from .config import (
-    DEFAULT_POOL_LIMITS,
-    DEFAULT_SSL_CONFIG,
-    DEFAULT_TIMEOUT_CONFIG,
-    PoolLimits,
-    SSLConfig,
-    TimeoutConfig,
-)
-from .models import Response
-
-
-async def request(
-    method: str,
-    url: str,
-    *,
-    headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
-    body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-    stream: bool = False,
-    ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-    timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-) -> Response:
-    async with PoolManager(ssl=ssl, timeout=timeout) as pool:
-       return await pool.request(
-           method=method, url=url, headers=headers, body=body, stream=stream
-       )
-
-
-class PoolManager:
-    def __init__(
-        self,
-        *,
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
-        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-        limits: PoolLimits = DEFAULT_POOL_LIMITS,
-    ):
-        self.ssl = ssl
-        self.timeout = timeout
-        self.limits = limits
-        self.is_closed = False
-
-    async def request(
-        self,
-        method: str,
-        url: str,
-        *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        stream: bool = False,
-    ) -> Response:
-        raise NotImplementedError()
-
-    async def close(self) -> None:
-        self.is_closed = True
-
-    async def __aenter__(self) -> "PoolManager":
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        await self.close()
index aa0c0718737dc3d03d4dc0e2c5b73b432f188eca..d169e0afb7ba22368aacec235a18e78f955e1b41 100644 (file)
@@ -1,5 +1,7 @@
 import typing
 
+import certifi
+
 
 class SSLConfig:
     """
@@ -52,3 +54,4 @@ class PoolLimits:
 DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
 DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
 DEFAULT_POOL_LIMITS = PoolLimits(max_hosts=10, conns_per_host=10, hard_limit=False)
+DEFAULT_CA_BUNDLE_PATH = certifi.where()
index dc9c4ff0636d3e602ebd805bd712f22377c20773..db27b03c153da7cac17bfe2728841a3d5bee00fa 100644 (file)
-from config import TimeoutConfig
-
 import asyncio
-import h11
 import ssl
+import typing
+
+import h11
+
+from .config import TimeoutConfig
+from .datastructures import Request, Response
+from .exceptions import ConnectTimeout, ReadTimeout
+
+H11Event = typing.Union[
+    h11.Request,
+    h11.Response,
+    h11.InformationalResponse,
+    h11.Data,
+    h11.EndOfMessage,
+    h11.ConnectionClosed,
+]
 
 
 class Connection:
-    def __init__(self):
+    def __init__(self, timeout: TimeoutConfig):
         self.reader = None
         self.writer = None
         self.state = h11.Connection(our_role=h11.CLIENT)
+        self.timeout = timeout
 
-    async def open(self, host: str, port: int, ssl: ssl.SSLContext):
+    async def open(
+        self,
+        hostname: str,
+        port: int,
+        *,
+        ssl: typing.Union[bool, ssl.SSLContext] = False
+    ) -> None:
         try:
-            self.reader, self.writer = await asyncio.wait_for(
-                asyncio.open_connection(host, port, ssl=ssl), timeout
+            self.reader, self.writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_connection(hostname, port, ssl=ssl),
+                self.timeout.connect_timeout,
             )
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-    async def send(self, request: Request) -> Response:
-        method = request.method
-
-        target = request.url.path
-        if request.url.query:
-            target += "?" + request.url.query
+    async def send(self, request: Request, stream: bool=False) -> Response:
+        method = request.method.encode()
+        target = request.url.target
+        host_header = (b"host", request.url.netloc.encode("ascii"))
+        if request.is_streaming:
+            content_length = (b"transfer-encoding", b"chunked")
+        else:
+            content_length = (b"content-length", str(len(request.body)).encode())
 
-        headers = [
-            ("host", request.url.netloc)
-        ] += request.headers
+        headers = [host_header, content_length] + request.headers
 
-        # Send the request method, path/query, and headers.
+        #  Start sending the request.
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event)
 
         # Send the request body.
         if request.is_streaming:
-            async for data in request.raw():
+            async for data in request.stream():
                 event = h11.Data(data=data)
                 await self._send_event(event)
-        else:
+        elif request.body:
             event = h11.Data(data=request.body)
             await self._send_event(event)
 
         # Finalize sending the request.
         event = h11.EndOfMessage()
-        await connection.send_event(event)
+        await self._send_event(event)
+
+        # Start getting the response.
+        event = await self._receive_event()
+        if isinstance(event, h11.InformationalResponse):
+            event = await self._receive_event()
+        assert isinstance(event, h11.Response)
+        status_code = event.status_code
+        headers = event.headers
+
+        if stream:
+            return Response(status_code=status_code, headers=headers, body=self.body_iter())
+
+        #  Get the response body.
+        body = b""
+        event = await self._receive_event()
+        while isinstance(event, h11.Data):
+            body += event.data
+            event = await self._receive_event()
+        assert isinstance(event, h11.EndOfMessage)
+        await self.close()
 
-    async def _send_event(self, message):
-        data = self.state.send(message)
+        return Response(status_code=status_code, headers=headers, body=body)
+
+    async def body_iter(self) -> typing.Iterable[bytes]:
+        event = await self._receive_event()
+        while isinstance(event, h11.Data):
+            yield event.data
+            event = await self._receive_event()
+        assert isinstance(event, h11.EndOfMessage)
+        await self.close()
+
+    async def _send_event(self, event: H11Event) -> None:
+        assert self.writer is not None
+
+        data = self.state.send(event)
         self.writer.write(data)
 
-    async def _receive_event(self, timeout):
+    async def _receive_event(self) -> H11Event:
+        assert self.reader is not None
+
         event = self.state.next_event()
 
-        while type(event) is h11.NEED_DATA:
+        while event is h11.NEED_DATA:
             try:
-                data = await asyncio.wait_for(self.reader.read(2048), timeout)
+                data = await asyncio.wait_for(
+                    self.reader.read(2048), self.timeout.read_timeout
+                )
             except asyncio.TimeoutError:
                 raise ReadTimeout()
             self.state.receive_data(data)
@@ -64,7 +121,8 @@ class Connection:
 
         return event
 
-    async def close(self):
-        self.writer.close()
-        if hasattr(self.writer, "wait_closed"):
-            await self.writer.wait_closed()
+    async def close(self) -> None:
+        if self.writer is not None:
+            self.writer.close()
+            if hasattr(self.writer, "wait_closed"):
+                await self.writer.wait_closed()
diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py
new file mode 100644 (file)
index 0000000..d60e18a
--- /dev/null
@@ -0,0 +1,145 @@
+import typing
+from urllib.parse import urlsplit
+
+from .decoders import IdentityDecoder
+from .exceptions import ResponseClosed, StreamConsumed
+
+
+class URL:
+    def __init__(self, url: str = "") -> None:
+        self.components = urlsplit(url)
+        if not self.components.scheme:
+            raise ValueError("No scheme included in URL.")
+        if self.components.scheme not in ("http", "https"):
+            raise ValueError('URL scheme must be "http" or "https".')
+        if not self.components.hostname:
+            raise ValueError("No hostname included in URL.")
+
+    @property
+    def scheme(self) -> str:
+        return self.components.scheme
+
+    @property
+    def netloc(self) -> str:
+        return self.components.netloc
+
+    @property
+    def path(self) -> str:
+        return self.components.path
+
+    @property
+    def query(self) -> str:
+        return self.components.query
+
+    @property
+    def hostname(self) -> str:
+        return self.components.hostname
+
+    @property
+    def port(self) -> int:
+        port = self.components.port
+        if port is None:
+            return {"https": 443, "http": 80}[self.scheme]
+        return port
+
+    @property
+    def target(self) -> str:
+        path = self.path or "/"
+        query = self.query
+        if query:
+            return path + "?" + query
+        return path
+
+    @property
+    def is_secure(self) -> bool:
+        return self.components.scheme == "https"
+
+
+class Request:
+    def __init__(
+        self,
+        method: str,
+        url: URL,
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+    ):
+        self.method = method
+        self.url = url
+        self.headers = list(headers)
+        if isinstance(body, bytes):
+            self.is_streaming = False
+            self.body = body
+        else:
+            self.is_streaming = True
+            self.body_aiter = body
+
+    async def stream(self) -> typing.AsyncIterator[bytes]:
+        assert self.is_streaming
+
+        async for part in self.body_aiter:
+            yield part
+
+
+class Response:
+    def __init__(
+        self,
+        status_code: int,
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        on_close: typing.Callable = None,
+    ):
+        self.status_code = status_code
+        self.headers = list(headers)
+        self.on_close = on_close
+        self.is_closed = False
+        self.is_streamed = False
+        self.decoder = IdentityDecoder()
+        if isinstance(body, bytes):
+            self.is_closed = True
+            self.body = body
+        else:
+            self.body_aiter = body
+
+    async def read(self) -> bytes:
+        """
+        Read and return the response content.
+        """
+        if not hasattr(self, "body"):
+            body = b""
+            async for part in self.stream():
+                body += part
+            self.body = body
+        return self.body
+
+    async def stream(self) -> typing.AsyncIterator[bytes]:
+        """
+        A byte-iterator over the decoded response content.
+        This will allow us to handle gzip, deflate, and brotli encoded responses.
+        """
+        if hasattr(self, "body"):
+            yield self.body
+        else:
+            async for chunk in self.raw():
+                yield self.decoder.decode(chunk)
+            yield self.decoder.flush()
+
+    async def raw(self) -> typing.AsyncIterator[bytes]:
+        """
+        A byte-iterator over the raw response content.
+        """
+        if self.is_streamed:
+            raise StreamConsumed()
+        if self.is_closed:
+            raise ResponseClosed()
+        self.is_streamed = True
+        async for part in self.body_aiter:
+            yield part
+        await self.close()
+
+    async def close(self) -> None:
+        if not self.is_closed:
+            self.is_closed = True
+            if self.on_close is not None:
+                await self.on_close()
index 09b9336bbe690043d70ad61aae1b81ae74f08983..2d35a44f53d2e625ef4ec47afeafaccea277e271 100644 (file)
@@ -8,7 +8,7 @@ class IdentityDecoder:
         return chunk
 
     def flush(self) -> bytes:
-        return b''
+        return b""
 
 
 # class DeflateDecoder:
diff --git a/httpcore/models.py b/httpcore/models.py
deleted file mode 100644 (file)
index edf174b..0000000
+++ /dev/null
@@ -1,68 +0,0 @@
-import typing
-
-from .decoders import IdentityDecoder
-from .exceptions import ResponseClosed, StreamConsumed
-
-
-class Response:
-    def __init__(
-        self,
-        status_code: int,
-        *,
-        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
-        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
-        on_close: typing.Callable = None,
-    ):
-        self.status_code = status_code
-        self.headers = list(headers)
-        self.on_close = on_close
-        self.is_closed = False
-        self.is_streamed = False
-        self.decoder = IdentityDecoder()
-        if isinstance(body, bytes):
-            self.is_closed = True
-            self.body = body
-        else:
-            self.body_aiter = body
-
-    async def read(self) -> bytes:
-        """
-        Read and return the response content.
-        """
-        if not hasattr(self, "body"):
-            body = b""
-            async for part in self.stream():
-                body += part
-            self.body = body
-        return self.body
-
-    async def stream(self):
-        """
-        A byte-iterator over the decoded response content.
-        This will allow us to handle gzip, deflate, and brotli encoded responses.
-        """
-        if hasattr(self, "body"):
-            yield self.body
-        else:
-            async for chunk in self.raw():
-                yield self.decoder.decode(chunk)
-            yield self.decoder.flush()
-
-    async def raw(self) -> typing.AsyncIterator[bytes]:
-        """
-        A byte-iterator over the raw response content.
-        """
-        if self.is_streamed:
-            raise StreamConsumed()
-        if self.is_closed:
-            raise ResponseClosed()
-        self.is_streamed = True
-        async for part in self.body_aiter():
-            yield part
-        await self.close()
-
-    async def close(self) -> None:
-        if not self.is_closed:
-            self.is_closed = True
-            if self.on_close is not None:
-                await self.on_close()
diff --git a/httpcore/pool.py b/httpcore/pool.py
new file mode 100644 (file)
index 0000000..7594847
--- /dev/null
@@ -0,0 +1,126 @@
+import asyncio
+import os
+import ssl
+import typing
+from types import TracebackType
+
+from .config import (
+    DEFAULT_CA_BUNDLE_PATH,
+    DEFAULT_POOL_LIMITS,
+    DEFAULT_SSL_CONFIG,
+    DEFAULT_TIMEOUT_CONFIG,
+    PoolLimits,
+    SSLConfig,
+    TimeoutConfig,
+)
+from .connections import Connection
+from .datastructures import URL, Request, Response
+
+
+class ConnectionPool:
+    def __init__(
+        self,
+        *,
+        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        limits: PoolLimits = DEFAULT_POOL_LIMITS,
+    ):
+        self.ssl_config = ssl
+        self.timeout = timeout
+        self.limits = limits
+        self.is_closed = False
+
+    async def request(
+        self,
+        method: str,
+        url: str,
+        *,
+        headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
+        body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
+        stream: bool = False,
+    ) -> Response:
+        parsed_url = URL(url)
+        request = Request(method, parsed_url, headers=headers, body=body)
+        ssl_context = await self.get_ssl_context(parsed_url)
+        connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
+        response = await connection.send(request, stream=stream)
+        return response
+
+    async def acquire_connection(
+        self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
+    ) -> Connection:
+        connection = Connection(timeout=self.timeout)
+        await connection.open(url.hostname, url.port, ssl=ssl)
+        return connection
+
+    async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
+        if not url.is_secure:
+            return False
+
+        if not hasattr(self, "ssl_context"):
+            if not self.ssl_config.verify:
+                self.ssl_context = self.get_ssl_context_no_verify()
+            else:
+                # Run the SSL loading in a threadpool, since it makes disk accesses.
+                loop = asyncio.get_event_loop()
+                self.ssl_context = await loop.run_in_executor(
+                    None, self.get_ssl_context_verify
+                )
+
+        return self.ssl_context
+
+    def get_ssl_context_no_verify(self) -> ssl.SSLContext:
+        """
+        Return an SSL context for unverified connections.
+        """
+        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        context.options |= ssl.OP_NO_SSLv2
+        context.options |= ssl.OP_NO_SSLv3
+        context.options |= ssl.OP_NO_COMPRESSION
+        context.set_default_verify_paths()
+        return context
+
+    def get_ssl_context_verify(self) -> ssl.SSLContext:
+        """
+        Return an SSL context for verified connections.
+        """
+        cert = self.ssl_config.cert
+        verify = self.ssl_config.verify
+
+        if isinstance(verify, bool):
+            ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
+        elif os.path.exists(verify):
+            ca_bundle_path = verify
+        else:
+            raise IOError(
+                "Could not find a suitable TLS CA certificate bundle, "
+                "invalid path: {}".format(verify)
+            )
+
+        context = ssl.create_default_context()
+        if os.path.isfile(ca_bundle_path):
+            context.load_verify_locations(cafile=ca_bundle_path)
+        elif os.path.isdir(ca_bundle_path):
+            context.load_verify_locations(capath=ca_bundle_path)
+
+        if cert is not None:
+            if isinstance(cert, str):
+                context.load_cert_chain(certfile=cert)
+            else:
+                context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
+
+        return context
+
+    async def close(self) -> None:
+        self.is_closed = True
+
+    async def __aenter__(self) -> "ConnectionPool":
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.close()
index 16c770639363bf4d94994631259c7b0af0bf8da1..1baef34140f65d0d1b3378c8044f2b618d7165f4 100644 (file)
@@ -1,3 +1,4 @@
+certifi
 h11
 
 # Testing
@@ -9,3 +10,4 @@ mypy
 pytest
 pytest-asyncio
 pytest-cov
+uvicorn
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644 (file)
index 0000000..08e3670
--- /dev/null
@@ -0,0 +1,34 @@
+import asyncio
+import json
+
+import pytest
+from uvicorn.config import Config
+from uvicorn.main import Server
+
+
+async def app(scope, receive, send):
+    assert scope['type'] == 'http'
+    await send({
+        'type': 'http.response.start',
+        'status': 200,
+        'headers': [
+            [b'content-type', b'text/plain'],
+        ]
+    })
+    await send({
+        'type': 'http.response.body',
+        'body': b'Hello, world!',
+    })
+
+
+@pytest.fixture
+async def server():
+    config = Config(app=app, lifespan="off")
+    server = Server(config=config)
+    task = asyncio.ensure_future(server.serve())
+    try:
+        while not server.started:
+            await asyncio.sleep(0.0001)
+        yield server
+    finally:
+        task.cancel()
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..c24e7373316e7b24d7a6d56e9828c38690035de5 100644 (file)
@@ -0,0 +1,38 @@
+import pytest
+import httpcore
+
+
+@pytest.mark.asyncio
+async def test_get(server):
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request('GET', "http://127.0.0.1:8000/")
+    assert response.status_code == 200
+    assert response.body == b'Hello, world!'
+
+
+@pytest.mark.asyncio
+async def test_post(server):
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request('POST', "http://127.0.0.1:8000/", body=b"Hello, world!")
+    assert response.status_code == 200
+
+
+@pytest.mark.asyncio
+async def test_stream_response(server):
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request('GET', "http://127.0.0.1:8000/", stream=True)
+    assert response.status_code == 200
+    assert not hasattr(response, 'body')
+    body = await response.read()
+    assert body == b'Hello, world!'
+
+
+@pytest.mark.asyncio
+async def test_stream_request(server):
+    async def hello_world():
+        yield b"Hello, "
+        yield b"world!"
+
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request('POST', "http://127.0.0.1:8000/", body=hello_world())
+    assert response.status_code == 200
index 19a387bab91e617fe61b8d1cc6e5b2d980b2e791..ae754b404a1126690ce74720b2053ca87a93867c 100644 (file)
@@ -3,17 +3,21 @@ import pytest
 import httpcore
 
 
-class MockRequests(httpcore.PoolManager):
-    async def request(self, method, url, *, headers = (), body = b'', stream = False) -> httpcore.Response:
+class MockHTTP(httpcore.ConnectionPool):
+    async def request(
+        self, method, url, *, headers=(), body=b"", stream=False
+    ) -> httpcore.Response:
         if stream:
+
             async def streaming_body():
                 yield b"Hello, "
                 yield b"world!"
-            return httpcore.Response(200, body=streaming_body)
+
+            return httpcore.Response(200, body=streaming_body())
         return httpcore.Response(200, body=b"Hello, world!")
 
 
-http = MockRequests()
+http = MockHTTP()
 
 
 @pytest.mark.asyncio
@@ -47,7 +51,7 @@ async def test_stream_response():
     assert response.body == b"Hello, world!"
     assert response.is_closed
 
-    body = b''
+    body = b""
     async for part in response.stream():
         body += part
 
@@ -61,7 +65,7 @@ async def test_read_streaming_response():
     response = await http.request("GET", "http://example.com", stream=True)
 
     assert response.status_code == 200
-    assert not hasattr(response, 'body')
+    assert not hasattr(response, "body")
     assert not response.is_closed
 
     body = await response.read()
@@ -76,15 +80,15 @@ async def test_stream_streaming_response():
     response = await http.request("GET", "http://example.com", stream=True)
 
     assert response.status_code == 200
-    assert not hasattr(response, 'body')
+    assert not hasattr(response, "body")
     assert not response.is_closed
 
-    body = b''
+    body = b""
     async for part in response.stream():
         body += part
 
     assert body == b"Hello, world!"
-    assert not hasattr(response, 'body')
+    assert not hasattr(response, "body")
     assert response.is_closed
 
 
@@ -92,13 +96,14 @@ async def test_stream_streaming_response():
 async def test_cannot_read_after_stream_consumed():
     response = await http.request("GET", "http://example.com", stream=True)
 
-    body = b''
+    body = b""
     async for part in response.stream():
         body += part
 
     with pytest.raises(httpcore.StreamConsumed):
         await response.read()
 
+
 @pytest.mark.asyncio
 async def test_cannot_read_after_response_closed():
     response = await http.request("GET", "http://example.com", stream=True)