]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Stream refactoring and HTTP/2 test case
authorTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 11:05:23 +0000 (12:05 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 11:05:23 +0000 (12:05 +0100)
httpcore/__init__.py
httpcore/config.py
httpcore/connection.py
httpcore/exceptions.py
httpcore/http11.py
httpcore/http2.py
httpcore/streams.py [new file with mode: 0644]
tests/test_config.py
tests/test_http2.py [new file with mode: 0644]

index 738d7ecd76a002fa15de1319694d0fe99873fdd6..00cb5fb80b54d53c25b47e8d44b79ac137df94fa 100644 (file)
@@ -13,6 +13,7 @@ from .exceptions import (
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 from .models import URL, Origin, Request, Response
+from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect
 from .sync import SyncClient, SyncConnectionPool
 
 __version__ = "0.2.1"
index 5b7ab4e07bb8cc0bd23a9d2fa62a3efb1abe74d9..b0fadc4017561cf31660ccf5714d1373f58e78d3 100644 (file)
@@ -115,20 +115,24 @@ class TimeoutConfig:
         *,
         connect_timeout: float = None,
         read_timeout: float = None,
+        write_timeout: float = None,
         pool_timeout: float = None,
     ):
         if timeout is not None:
             # Specified as a single timeout value
             assert connect_timeout is None
             assert read_timeout is None
+            assert write_timeout is None
             assert pool_timeout is None
             connect_timeout = timeout
             read_timeout = timeout
+            write_timeout = timeout
             pool_timeout = timeout
 
         self.timeout = timeout
         self.connect_timeout = connect_timeout
         self.read_timeout = read_timeout
+        self.write_timeout = write_timeout
         self.pool_timeout = pool_timeout
 
     def __eq__(self, other: typing.Any) -> bool:
@@ -136,18 +140,24 @@ class TimeoutConfig:
             isinstance(other, self.__class__)
             and self.connect_timeout == other.connect_timeout
             and self.read_timeout == other.read_timeout
+            and self.write_timeout == other.write_timeout
             and self.pool_timeout == other.pool_timeout
         )
 
     def __hash__(self) -> int:
-        as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
+        as_tuple = (
+            self.connect_timeout,
+            self.read_timeout,
+            self.write_timeout,
+            self.pool_timeout,
+        )
         return hash(as_tuple)
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
         if self.timeout is not None:
             return f"{class_name}(timeout={self.timeout})"
-        return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
+        return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})"
 
 
 class PoolLimits:
index db3a17e6c3ba5bd0d9e61dfe8ccade7933971ba0..f164232ffe6cb893aadec1c02a66ae5ae1f10281 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import typing
 
 import h2.connection
@@ -9,6 +8,7 @@ from .exceptions import ConnectTimeout
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 from .models import Client, Origin, Request, Response
+from .streams import Protocol, connect
 
 
 class HTTPConnection(Client):
@@ -39,8 +39,14 @@ class HTTPConnection(Client):
             if timeout is None:
                 timeout = self.timeout
 
-            reader, writer, protocol = await self.connect(ssl, timeout)
-            if protocol == "h2":
+            hostname = self.origin.hostname
+            port = self.origin.port
+            ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
+
+            reader, writer, protocol = await connect(
+                hostname, port, ssl_context, timeout
+            )
+            if protocol == Protocol.HTTP_2:
                 self.h2_connection = HTTP2Connection(
                     reader,
                     writer,
@@ -68,8 +74,7 @@ class HTTPConnection(Client):
     async def close(self) -> None:
         if self.h2_connection is not None:
             await self.h2_connection.close()
-        else:
-            assert self.h11_connection is not None
+        elif self.h11_connection is not None:
             await self.h11_connection.close()
 
     @property
@@ -79,28 +84,3 @@ class HTTPConnection(Client):
         else:
             assert self.h11_connection is not None
             return self.h11_connection.is_closed
-
-    async def connect(
-        self, ssl: SSLConfig, timeout: TimeoutConfig
-    ) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]:
-        hostname = self.origin.hostname
-        port = self.origin.port
-        ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
-
-        try:
-            reader, writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_connection(hostname, port, ssl=ssl_context),
-                timeout.connect_timeout,
-            )
-        except asyncio.TimeoutError:
-            raise ConnectTimeout()
-
-        ssl_object = writer.get_extra_info("ssl_object")
-        if ssl_object is None:
-            protocol = "http/1.1"
-        else:
-            protocol = ssl_object.selected_alpn_protocol()
-        if protocol is None:
-            protocol = ssl_object.selected_npn_protocol()
-
-        return (reader, writer, protocol)
index 30814332c3732b6e2bddd55aaf25f28f71b5501d..285b64039b743b4be96d28389aa7363ab2c082ed 100644 (file)
@@ -16,6 +16,12 @@ class ReadTimeout(Timeout):
     """
 
 
+class WriteTimeout(Timeout):
+    """
+    Timeout while writing request data.
+    """
+
+
 class PoolTimeout(Timeout):
     """
     Timeout while waiting to acquire a connection from the pool.
index 45994164ffdda3a2fa13dc3a1f8d74533db5a4d0..253865fe92b8a488c670c6f395add7556eba82ad 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import typing
 
 import h11
@@ -6,6 +5,7 @@ import h11
 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, ReadTimeout
 from .models import Client, Origin, Request, Response
+from .streams import BaseReader, BaseWriter
 
 H11Event = typing.Union[
     h11.Request,
@@ -17,11 +17,16 @@ H11Event = typing.Union[
 ]
 
 
+OptionalTimeout = typing.Optional[TimeoutConfig]
+
+
 class HTTP11Connection(Client):
+    READ_NUM_BYTES = 4096
+
     def __init__(
         self,
-        reader: asyncio.StreamReader,
-        writer: asyncio.StreamWriter,
+        reader: BaseReader,
+        writer: BaseWriter,
         origin: Origin,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         on_release: typing.Callable = None,
@@ -44,24 +49,21 @@ class HTTP11Connection(Client):
         ssl: typing.Optional[SSLConfig] = None,
         timeout: typing.Optional[TimeoutConfig] = None
     ) -> Response:
-        if timeout is None:
-            timeout = self.timeout
-
         #  Start sending the request.
         method = request.method.encode()
         target = request.url.full_path
         headers = request.headers
         event = h11.Request(method=method, target=target, headers=headers)
-        await self._send_event(event)
+        await self._send_event(event, timeout)
 
         # Send the request body.
         async for data in request.stream():
             event = h11.Data(data=data)
-            await self._send_event(event)
+            await self._send_event(event, timeout)
 
         # Finalize sending the request.
         event = h11.EndOfMessage()
-        await self._send_event(event)
+        await self._send_event(event, timeout)
 
         # Start getting the response.
         event = await self._receive_event(timeout)
@@ -83,27 +85,22 @@ class HTTP11Connection(Client):
             on_close=self._release,
         )
 
-    async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+    async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
         event = await self._receive_event(timeout)
         while isinstance(event, h11.Data):
             yield event.data
             event = await self._receive_event(timeout)
         assert isinstance(event, h11.EndOfMessage)
 
-    async def _send_event(self, event: H11Event) -> None:
+    async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None:
         data = self.h11_state.send(event)
-        self.writer.write(data)
+        await self.writer.write(data, timeout)
 
-    async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
+    async def _receive_event(self, timeout: OptionalTimeout) -> H11Event:
         event = self.h11_state.next_event()
 
         while event is h11.NEED_DATA:
-            try:
-                data = await asyncio.wait_for(
-                    self.reader.read(2048), timeout.read_timeout
-                )
-            except asyncio.TimeoutError:
-                raise ReadTimeout()
+            data = await self.reader.read(self.READ_NUM_BYTES, timeout)
             self.h11_state.receive_data(data)
             event = self.h11_state.next_event()
 
@@ -131,5 +128,4 @@ class HTTP11Connection(Client):
             # and we'll end up in h11.ERROR.
             pass
 
-        if self.writer is not None:
-            self.writer.close()
+        await self.writer.close()
index 08904388da19f37012fe4f54f040fb5864c0ab52..41a0900cad70f690b42401c777e4d87cb6467a42 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import typing
 
 import h2.connection
@@ -7,13 +6,18 @@ import h2.events
 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
 from .exceptions import ConnectTimeout, ReadTimeout
 from .models import Client, Origin, Request, Response
+from .streams import BaseReader, BaseWriter
+
+OptionalTimeout = typing.Optional[TimeoutConfig]
 
 
 class HTTP2Connection(Client):
+    READ_NUM_BYTES = 4096
+
     def __init__(
         self,
-        reader: asyncio.StreamReader,
-        writer: asyncio.StreamWriter,
+        reader: BaseReader,
+        writer: BaseWriter,
         origin: Origin,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         on_release: typing.Callable = None,
@@ -45,15 +49,15 @@ class HTTP2Connection(Client):
             self.initiate_connection()
 
         #  Start sending the request.
-        stream_id = await self.send_headers(stream_id, request)
+        stream_id = await self.send_headers(request, timeout)
         self.events[stream_id] = []
 
         # Send the request body.
         async for data in request.stream():
-            await self.send_data(stream_id, data)
+            await self.send_data(stream_id, data, timeout)
 
         # Finalize sending the request.
-        await self.end_stream(stream_id)
+        await self.end_stream(stream_id, timeout)
 
         # Start getting the response.
         while True:
@@ -81,10 +85,10 @@ class HTTP2Connection(Client):
     def initiate_connection(self) -> None:
         self.h2_state.initiate_connection()
         data_to_send = self.h2_state.data_to_send()
-        self.writer.write(data_to_send)
+        self.writer.write_no_block(data_to_send)
         self.initialized = True
 
-    async def send_headers(self, stream_id: int, request: Request) -> int:
+    async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int:
         stream_id = self.h2_state.get_next_available_stream_id()
         headers = [
             (b":method", request.method.encode()),
@@ -94,21 +98,23 @@ class HTTP2Connection(Client):
         ] + request.headers
         self.h2_state.send_headers(stream_id, headers)
         data_to_send = self.h2_state.data_to_send()
-        self.writer.write(data_to_send)
+        await self.writer.write(data_to_send, timeout)
         return stream_id
 
-    async def send_data(self, stream_id: int, data: bytes) -> None:
+    async def send_data(
+        self, stream_id: int, data: bytes, timeout: OptionalTimeout
+    ) -> None:
         self.h2_state.send_data(stream_id, data)
         data_to_send = self.h2_state.data_to_send()
-        self.writer.write(data_to_send)
+        await self.writer.write(data_to_send, timeout)
 
-    async def end_stream(self, stream_id: int) -> None:
+    async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None:
         self.h2_state.end_stream(stream_id)
         data_to_send = self.h2_state.data_to_send()
-        self.writer.write(data_to_send)
+        await self.writer.write(data_to_send, timeout)
 
     async def body_iter(
-        self, stream_id: int, timeout: TimeoutConfig
+        self, stream_id: int, timeout: OptionalTimeout
     ) -> typing.AsyncIterator[bytes]:
         while True:
             event = await self.receive_event(stream_id, timeout)
@@ -119,24 +125,17 @@ class HTTP2Connection(Client):
                 break
 
     async def receive_event(
-        self, stream_id: int, timeout: TimeoutConfig
+        self, stream_id: int, timeout: OptionalTimeout
     ) -> h2.events.Event:
         while not self.events[stream_id]:
-            try:
-                data = await asyncio.wait_for(
-                    self.reader.read(2048), timeout.read_timeout
-                )
-            except asyncio.TimeoutError:
-                raise ReadTimeout()
-
+            data = await self.reader.read(self.READ_NUM_BYTES, timeout)
             events = self.h2_state.receive_data(data)
             for event in events:
                 if getattr(event, "stream_id", 0):
                     self.events[event.stream_id].append(event)
 
             data_to_send = self.h2_state.data_to_send()
-            if data_to_send:
-                self.writer.write(data_to_send)
+            await self.writer.write(data_to_send, timeout)
 
         return self.events[stream_id].pop(0)
 
diff --git a/httpcore/streams.py b/httpcore/streams.py
new file mode 100644 (file)
index 0000000..5a9a0ab
--- /dev/null
@@ -0,0 +1,115 @@
+"""
+The `Reader` and `Writer` classes here provide a lightweight layer over
+`asyncio.StreamReader` and `asyncio.StreamWriter`.
+
+They help encapsulate the timeout logic, make it easier to unit-test
+protocols, and help keep the rest of the package more `async`/`await`
+based, and less strictly `asyncio`-specific.
+"""
+import asyncio
+import enum
+import ssl
+import typing
+
+from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
+from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+
+OptionalTimeout = typing.Optional[TimeoutConfig]
+
+
+class Protocol(enum.Enum):
+    HTTP_11 = 1
+    HTTP_2 = 2
+
+
+class BaseReader:
+    async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+        raise NotImplementedError()  # pragma: no cover
+
+
+class BaseWriter:
+    def write_no_block(self, data: bytes) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
+    async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
+    async def close(self) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
+
+class Reader(BaseReader):
+    def __init__(
+        self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
+    ) -> None:
+        self.stream_reader = stream_reader
+        self.timeout = timeout
+
+    async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes:
+        if timeout is None:
+            timeout = self.timeout
+
+        try:
+            data = await asyncio.wait_for(
+                self.stream_reader.read(n), timeout.read_timeout
+            )
+        except asyncio.TimeoutError:
+            raise ReadTimeout()
+
+        return data
+
+
+class Writer(BaseWriter):
+    def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
+        self.stream_writer = stream_writer
+        self.timeout = timeout
+
+    def write_no_block(self, data: bytes) -> None:
+        self.stream_writer.write(data)
+
+    async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None:
+        if not data:
+            return
+
+        if timeout is None:
+            timeout = self.timeout
+
+        self.stream_writer.write(data)
+        try:
+            data = await asyncio.wait_for(  # type: ignore
+                self.stream_writer.drain(), timeout.write_timeout
+            )
+        except asyncio.TimeoutError:
+            raise WriteTimeout()
+
+    async def close(self) -> None:
+        self.stream_writer.close()
+
+
+async def connect(
+    hostname: str,
+    port: int,
+    ssl_context: typing.Optional[ssl.SSLContext] = None,
+    timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+) -> typing.Tuple[Reader, Writer, Protocol]:
+    try:
+        stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+            asyncio.open_connection(hostname, port, ssl=ssl_context),
+            timeout.connect_timeout,
+        )
+    except asyncio.TimeoutError:
+        raise ConnectTimeout()
+
+    ssl_object = stream_writer.get_extra_info("ssl_object")
+    if ssl_object is None:
+        ident = "http/1.1"
+    else:
+        ident = ssl_object.selected_alpn_protocol()
+        if ident is None:
+            ident = ssl_object.selected_npn_protocol()
+
+    reader = Reader(stream_reader=stream_reader, timeout=timeout)
+    writer = Writer(stream_writer=stream_writer, timeout=timeout)
+    protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
+
+    return (reader, writer, protocol)
index daf0e1ec973e55863976eba7e474793834a1f80e..8112d7c24741dca101535d7e8c9dd5771e18731b 100644 (file)
@@ -13,7 +13,7 @@ def test_timeout_repr():
     timeout = httpcore.TimeoutConfig(read_timeout=5.0)
     assert (
         repr(timeout)
-        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
+        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)"
     )
 
 
diff --git a/tests/test_http2.py b/tests/test_http2.py
new file mode 100644 (file)
index 0000000..7bfe301
--- /dev/null
@@ -0,0 +1,76 @@
+import h2.config
+import h2.connection
+import h2.events
+import pytest
+
+import httpcore
+
+
+class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
+    """
+    This class exposes Reader and Writer style interfaces
+    """
+
+    def __init__(self):
+        config = h2.config.H2Configuration(client_side=False)
+        self.conn = h2.connection.H2Connection(config=config)
+        self.buffer = b""
+        self.requests = {}
+
+    # BaseReader interface
+
+    async def read(self, n, timeout) -> bytes:
+        send, self.buffer = self.buffer[:n], self.buffer[n:]
+        return send
+
+    # BaseWriter interface
+
+    def write_no_block(self, data: bytes) -> None:
+        events = self.conn.receive_data(data)
+        self.buffer += self.conn.data_to_send()
+        for event in events:
+            if isinstance(event, h2.events.RequestReceived):
+                self.request_received(event.headers, event.stream_id)
+            elif isinstance(event, h2.events.DataReceived):
+                self.receive_data(event.data, event.stream_id)
+            elif isinstance(event, h2.events.StreamEnded):
+                self.stream_complete(event.stream_id)
+
+    async def write(self, data: bytes, timeout) -> None:
+        self.write_no_block(data)
+
+    async def close(self) -> None:
+        pass
+
+    # Server implementation
+
+    def request_received(self, headers, stream_id):
+        if stream_id not in self.requests:
+            self.requests[stream_id] = []
+        self.requests[stream_id].append({"headers": headers, "data": b""})
+
+    def receive_data(self, data, stream_id):
+        self.requests[stream_id][-1]["data"] += data
+
+    def stream_complete(self, stream_id):
+        requests = self.requests[stream_id].pop(0)
+        if not self.requests[stream_id]:
+            del self.requests[stream_id]
+
+        response_headers = (
+            (b":status", b"200"),
+        )
+        response_body = b"Hello, world!"
+        self.conn.send_headers(stream_id, response_headers)
+        self.conn.send_data(stream_id, response_body, end_stream=True)
+        self.buffer += self.conn.data_to_send()
+
+
+@pytest.mark.asyncio
+async def test_http2():
+    server = MockServer()
+    origin = httpcore.Origin("http://example.org")
+    client = httpcore.HTTP2Connection(reader=server, writer=server, origin=origin)
+    response = await client.request("GET", "http://example.org")
+    assert response.status_code == 200
+    assert response.body == b"Hello, world!"