From: Tom Christie Date: Thu, 25 Apr 2019 11:05:23 +0000 (+0100) Subject: Stream refactoring and HTTP/2 test case X-Git-Tag: 0.3.0~66^2~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a44190ff24558c2ee5efc9dd422f836af376150b;p=thirdparty%2Fhttpx.git Stream refactoring and HTTP/2 test case --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 738d7ecd..00cb5fb8 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -13,6 +13,7 @@ from .exceptions import ( from .http2 import HTTP2Connection from .http11 import HTTP11Connection from .models import URL, Origin, Request, Response +from .streams import BaseReader, BaseWriter, Protocol, Reader, Writer, connect from .sync import SyncClient, SyncConnectionPool __version__ = "0.2.1" diff --git a/httpcore/config.py b/httpcore/config.py index 5b7ab4e0..b0fadc40 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -115,20 +115,24 @@ class TimeoutConfig: *, connect_timeout: float = None, read_timeout: float = None, + write_timeout: float = None, pool_timeout: float = None, ): if timeout is not None: # Specified as a single timeout value assert connect_timeout is None assert read_timeout is None + assert write_timeout is None assert pool_timeout is None connect_timeout = timeout read_timeout = timeout + write_timeout = timeout pool_timeout = timeout self.timeout = timeout self.connect_timeout = connect_timeout self.read_timeout = read_timeout + self.write_timeout = write_timeout self.pool_timeout = pool_timeout def __eq__(self, other: typing.Any) -> bool: @@ -136,18 +140,24 @@ class TimeoutConfig: isinstance(other, self.__class__) and self.connect_timeout == other.connect_timeout and self.read_timeout == other.read_timeout + and self.write_timeout == other.write_timeout and self.pool_timeout == other.pool_timeout ) def __hash__(self) -> int: - as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout) + as_tuple = ( + self.connect_timeout, + self.read_timeout, + self.write_timeout, + self.pool_timeout, + ) return hash(as_tuple) def __repr__(self) -> str: class_name = self.__class__.__name__ if self.timeout is not None: return f"{class_name}(timeout={self.timeout})" - return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})" + return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, write_timeout={self.write_timeout}, pool_timeout={self.pool_timeout})" class PoolLimits: diff --git a/httpcore/connection.py b/httpcore/connection.py index db3a17e6..f164232f 100644 --- a/httpcore/connection.py +++ b/httpcore/connection.py @@ -1,4 +1,3 @@ -import asyncio import typing import h2.connection @@ -9,6 +8,7 @@ from .exceptions import ConnectTimeout from .http2 import HTTP2Connection from .http11 import HTTP11Connection from .models import Client, Origin, Request, Response +from .streams import Protocol, connect class HTTPConnection(Client): @@ -39,8 +39,14 @@ class HTTPConnection(Client): if timeout is None: timeout = self.timeout - reader, writer, protocol = await self.connect(ssl, timeout) - if protocol == "h2": + hostname = self.origin.hostname + port = self.origin.port + ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None + + reader, writer, protocol = await connect( + hostname, port, ssl_context, timeout + ) + if protocol == Protocol.HTTP_2: self.h2_connection = HTTP2Connection( reader, writer, @@ -68,8 +74,7 @@ class HTTPConnection(Client): async def close(self) -> None: if self.h2_connection is not None: await self.h2_connection.close() - else: - assert self.h11_connection is not None + elif self.h11_connection is not None: await self.h11_connection.close() @property @@ -79,28 +84,3 @@ class HTTPConnection(Client): else: assert self.h11_connection is not None return self.h11_connection.is_closed - - async def connect( - self, ssl: SSLConfig, timeout: TimeoutConfig - ) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]: - hostname = self.origin.hostname - port = self.origin.port - ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None - - try: - reader, writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, - ) - except asyncio.TimeoutError: - raise ConnectTimeout() - - ssl_object = writer.get_extra_info("ssl_object") - if ssl_object is None: - protocol = "http/1.1" - else: - protocol = ssl_object.selected_alpn_protocol() - if protocol is None: - protocol = ssl_object.selected_npn_protocol() - - return (reader, writer, protocol) diff --git a/httpcore/exceptions.py b/httpcore/exceptions.py index 30814332..285b6403 100644 --- a/httpcore/exceptions.py +++ b/httpcore/exceptions.py @@ -16,6 +16,12 @@ class ReadTimeout(Timeout): """ +class WriteTimeout(Timeout): + """ + Timeout while writing request data. + """ + + class PoolTimeout(Timeout): """ Timeout while waiting to acquire a connection from the pool. diff --git a/httpcore/http11.py b/httpcore/http11.py index 45994164..253865fe 100644 --- a/httpcore/http11.py +++ b/httpcore/http11.py @@ -1,4 +1,3 @@ -import asyncio import typing import h11 @@ -6,6 +5,7 @@ import h11 from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig from .exceptions import ConnectTimeout, ReadTimeout from .models import Client, Origin, Request, Response +from .streams import BaseReader, BaseWriter H11Event = typing.Union[ h11.Request, @@ -17,11 +17,16 @@ H11Event = typing.Union[ ] +OptionalTimeout = typing.Optional[TimeoutConfig] + + class HTTP11Connection(Client): + READ_NUM_BYTES = 4096 + def __init__( self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, + reader: BaseReader, + writer: BaseWriter, origin: Origin, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, on_release: typing.Callable = None, @@ -44,24 +49,21 @@ class HTTP11Connection(Client): ssl: typing.Optional[SSLConfig] = None, timeout: typing.Optional[TimeoutConfig] = None ) -> Response: - if timeout is None: - timeout = self.timeout - #  Start sending the request. method = request.method.encode() target = request.url.full_path headers = request.headers event = h11.Request(method=method, target=target, headers=headers) - await self._send_event(event) + await self._send_event(event, timeout) # Send the request body. async for data in request.stream(): event = h11.Data(data=data) - await self._send_event(event) + await self._send_event(event, timeout) # Finalize sending the request. event = h11.EndOfMessage() - await self._send_event(event) + await self._send_event(event, timeout) # Start getting the response. event = await self._receive_event(timeout) @@ -83,27 +85,22 @@ class HTTP11Connection(Client): on_close=self._release, ) - async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]: + async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]: event = await self._receive_event(timeout) while isinstance(event, h11.Data): yield event.data event = await self._receive_event(timeout) assert isinstance(event, h11.EndOfMessage) - async def _send_event(self, event: H11Event) -> None: + async def _send_event(self, event: H11Event, timeout: OptionalTimeout) -> None: data = self.h11_state.send(event) - self.writer.write(data) + await self.writer.write(data, timeout) - async def _receive_event(self, timeout: TimeoutConfig) -> H11Event: + async def _receive_event(self, timeout: OptionalTimeout) -> H11Event: event = self.h11_state.next_event() while event is h11.NEED_DATA: - try: - data = await asyncio.wait_for( - self.reader.read(2048), timeout.read_timeout - ) - except asyncio.TimeoutError: - raise ReadTimeout() + data = await self.reader.read(self.READ_NUM_BYTES, timeout) self.h11_state.receive_data(data) event = self.h11_state.next_event() @@ -131,5 +128,4 @@ class HTTP11Connection(Client): # and we'll end up in h11.ERROR. pass - if self.writer is not None: - self.writer.close() + await self.writer.close() diff --git a/httpcore/http2.py b/httpcore/http2.py index 08904388..41a0900c 100644 --- a/httpcore/http2.py +++ b/httpcore/http2.py @@ -1,4 +1,3 @@ -import asyncio import typing import h2.connection @@ -7,13 +6,18 @@ import h2.events from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig from .exceptions import ConnectTimeout, ReadTimeout from .models import Client, Origin, Request, Response +from .streams import BaseReader, BaseWriter + +OptionalTimeout = typing.Optional[TimeoutConfig] class HTTP2Connection(Client): + READ_NUM_BYTES = 4096 + def __init__( self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, + reader: BaseReader, + writer: BaseWriter, origin: Origin, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, on_release: typing.Callable = None, @@ -45,15 +49,15 @@ class HTTP2Connection(Client): self.initiate_connection() #  Start sending the request. - stream_id = await self.send_headers(stream_id, request) + stream_id = await self.send_headers(request, timeout) self.events[stream_id] = [] # Send the request body. async for data in request.stream(): - await self.send_data(stream_id, data) + await self.send_data(stream_id, data, timeout) # Finalize sending the request. - await self.end_stream(stream_id) + await self.end_stream(stream_id, timeout) # Start getting the response. while True: @@ -81,10 +85,10 @@ class HTTP2Connection(Client): def initiate_connection(self) -> None: self.h2_state.initiate_connection() data_to_send = self.h2_state.data_to_send() - self.writer.write(data_to_send) + self.writer.write_no_block(data_to_send) self.initialized = True - async def send_headers(self, stream_id: int, request: Request) -> int: + async def send_headers(self, request: Request, timeout: OptionalTimeout) -> int: stream_id = self.h2_state.get_next_available_stream_id() headers = [ (b":method", request.method.encode()), @@ -94,21 +98,23 @@ class HTTP2Connection(Client): ] + request.headers self.h2_state.send_headers(stream_id, headers) data_to_send = self.h2_state.data_to_send() - self.writer.write(data_to_send) + await self.writer.write(data_to_send, timeout) return stream_id - async def send_data(self, stream_id: int, data: bytes) -> None: + async def send_data( + self, stream_id: int, data: bytes, timeout: OptionalTimeout + ) -> None: self.h2_state.send_data(stream_id, data) data_to_send = self.h2_state.data_to_send() - self.writer.write(data_to_send) + await self.writer.write(data_to_send, timeout) - async def end_stream(self, stream_id: int) -> None: + async def end_stream(self, stream_id: int, timeout: OptionalTimeout) -> None: self.h2_state.end_stream(stream_id) data_to_send = self.h2_state.data_to_send() - self.writer.write(data_to_send) + await self.writer.write(data_to_send, timeout) async def body_iter( - self, stream_id: int, timeout: TimeoutConfig + self, stream_id: int, timeout: OptionalTimeout ) -> typing.AsyncIterator[bytes]: while True: event = await self.receive_event(stream_id, timeout) @@ -119,24 +125,17 @@ class HTTP2Connection(Client): break async def receive_event( - self, stream_id: int, timeout: TimeoutConfig + self, stream_id: int, timeout: OptionalTimeout ) -> h2.events.Event: while not self.events[stream_id]: - try: - data = await asyncio.wait_for( - self.reader.read(2048), timeout.read_timeout - ) - except asyncio.TimeoutError: - raise ReadTimeout() - + data = await self.reader.read(self.READ_NUM_BYTES, timeout) events = self.h2_state.receive_data(data) for event in events: if getattr(event, "stream_id", 0): self.events[event.stream_id].append(event) data_to_send = self.h2_state.data_to_send() - if data_to_send: - self.writer.write(data_to_send) + await self.writer.write(data_to_send, timeout) return self.events[stream_id].pop(0) diff --git a/httpcore/streams.py b/httpcore/streams.py new file mode 100644 index 00000000..5a9a0abb --- /dev/null +++ b/httpcore/streams.py @@ -0,0 +1,115 @@ +""" +The `Reader` and `Writer` classes here provide a lightweight layer over +`asyncio.StreamReader` and `asyncio.StreamWriter`. + +They help encapsulate the timeout logic, make it easier to unit-test +protocols, and help keep the rest of the package more `async`/`await` +based, and less strictly `asyncio`-specific. +""" +import asyncio +import enum +import ssl +import typing + +from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG +from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout + +OptionalTimeout = typing.Optional[TimeoutConfig] + + +class Protocol(enum.Enum): + HTTP_11 = 1 + HTTP_2 = 2 + + +class BaseReader: + async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + raise NotImplementedError() # pragma: no cover + + +class BaseWriter: + def write_no_block(self, data: bytes) -> None: + raise NotImplementedError() # pragma: no cover + + async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + raise NotImplementedError() # pragma: no cover + + async def close(self) -> None: + raise NotImplementedError() # pragma: no cover + + +class Reader(BaseReader): + def __init__( + self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig + ) -> None: + self.stream_reader = stream_reader + self.timeout = timeout + + async def read(self, n: int, timeout: OptionalTimeout = None) -> bytes: + if timeout is None: + timeout = self.timeout + + try: + data = await asyncio.wait_for( + self.stream_reader.read(n), timeout.read_timeout + ) + except asyncio.TimeoutError: + raise ReadTimeout() + + return data + + +class Writer(BaseWriter): + def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): + self.stream_writer = stream_writer + self.timeout = timeout + + def write_no_block(self, data: bytes) -> None: + self.stream_writer.write(data) + + async def write(self, data: bytes, timeout: OptionalTimeout = None) -> None: + if not data: + return + + if timeout is None: + timeout = self.timeout + + self.stream_writer.write(data) + try: + data = await asyncio.wait_for( # type: ignore + self.stream_writer.drain(), timeout.write_timeout + ) + except asyncio.TimeoutError: + raise WriteTimeout() + + async def close(self) -> None: + self.stream_writer.close() + + +async def connect( + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext] = None, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, +) -> typing.Tuple[Reader, Writer, Protocol]: + try: + stream_reader, stream_writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl_context), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + ssl_object = stream_writer.get_extra_info("ssl_object") + if ssl_object is None: + ident = "http/1.1" + else: + ident = ssl_object.selected_alpn_protocol() + if ident is None: + ident = ssl_object.selected_npn_protocol() + + reader = Reader(stream_reader=stream_reader, timeout=timeout) + writer = Writer(stream_writer=stream_writer, timeout=timeout) + protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 + + return (reader, writer, protocol) diff --git a/tests/test_config.py b/tests/test_config.py index daf0e1ec..8112d7c2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,7 +13,7 @@ def test_timeout_repr(): timeout = httpcore.TimeoutConfig(read_timeout=5.0) assert ( repr(timeout) - == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)" + == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, write_timeout=None, pool_timeout=None)" ) diff --git a/tests/test_http2.py b/tests/test_http2.py new file mode 100644 index 00000000..7bfe3013 --- /dev/null +++ b/tests/test_http2.py @@ -0,0 +1,76 @@ +import h2.config +import h2.connection +import h2.events +import pytest + +import httpcore + + +class MockServer(httpcore.BaseReader, httpcore.BaseWriter): + """ + This class exposes Reader and Writer style interfaces + """ + + def __init__(self): + config = h2.config.H2Configuration(client_side=False) + self.conn = h2.connection.H2Connection(config=config) + self.buffer = b"" + self.requests = {} + + # BaseReader interface + + async def read(self, n, timeout) -> bytes: + send, self.buffer = self.buffer[:n], self.buffer[n:] + return send + + # BaseWriter interface + + def write_no_block(self, data: bytes) -> None: + events = self.conn.receive_data(data) + self.buffer += self.conn.data_to_send() + for event in events: + if isinstance(event, h2.events.RequestReceived): + self.request_received(event.headers, event.stream_id) + elif isinstance(event, h2.events.DataReceived): + self.receive_data(event.data, event.stream_id) + elif isinstance(event, h2.events.StreamEnded): + self.stream_complete(event.stream_id) + + async def write(self, data: bytes, timeout) -> None: + self.write_no_block(data) + + async def close(self) -> None: + pass + + # Server implementation + + def request_received(self, headers, stream_id): + if stream_id not in self.requests: + self.requests[stream_id] = [] + self.requests[stream_id].append({"headers": headers, "data": b""}) + + def receive_data(self, data, stream_id): + self.requests[stream_id][-1]["data"] += data + + def stream_complete(self, stream_id): + requests = self.requests[stream_id].pop(0) + if not self.requests[stream_id]: + del self.requests[stream_id] + + response_headers = ( + (b":status", b"200"), + ) + response_body = b"Hello, world!" + self.conn.send_headers(stream_id, response_headers) + self.conn.send_data(stream_id, response_body, end_stream=True) + self.buffer += self.conn.data_to_send() + + +@pytest.mark.asyncio +async def test_http2(): + server = MockServer() + origin = httpcore.Origin("http://example.org") + client = httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) + response = await client.request("GET", "http://example.org") + assert response.status_code == 200 + assert response.body == b"Hello, world!"