From: Tom Christie Date: Wed, 24 Apr 2019 14:48:18 +0000 (+0100) Subject: First pass at HTTP/2 support X-Git-Tag: 0.3.0~66^2~22 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=53f3dc4a6628dabe2732fe7dfec334dacb235da9;p=thirdparty%2Fhttpx.git First pass at HTTP/2 support --- diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 9824e854..48e3426a 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,4 +1,5 @@ from .config import PoolLimits, SSLConfig, TimeoutConfig +from .connection import HTTPConnection from .connectionpool import ConnectionPool from .datastructures import URL, Origin, Request, Response from .exceptions import ( @@ -10,6 +11,7 @@ from .exceptions import ( StreamConsumed, Timeout, ) +from .http2 import HTTP2Connection from .http11 import HTTP11Connection from .sync import SyncClient, SyncConnectionPool diff --git a/httpcore/config.py b/httpcore/config.py index 8cc784b3..5b7ab4e0 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -73,7 +73,23 @@ class SSLConfig: "invalid path: {}".format(self.verify) ) - context = ssl.create_default_context() + context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) + + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + + # RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST + # support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the + # blacklist defined in this section allows only the AES GCM and ChaCha20 + # cipher suites with ephemeral key negotiation. + context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20") + + if ssl.HAS_ALPN: + context.set_alpn_protocols(["h2", "http/1.1"]) + if ssl.HAS_NPN: + context.set_npn_protocols(["h2", "http/1.1"]) + if os.path.isfile(ca_bundle_path): context.load_verify_locations(cafile=ca_bundle_path) elif os.path.isdir(ca_bundle_path): diff --git a/httpcore/connection.py b/httpcore/connection.py new file mode 100644 index 00000000..33f34594 --- /dev/null +++ b/httpcore/connection.py @@ -0,0 +1,106 @@ +import asyncio +import typing + +import h2.connection +import h11 + +from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig +from .datastructures import Client, Origin, Request, Response +from .exceptions import ConnectTimeout +from .http2 import HTTP2Connection +from .http11 import HTTP11Connection + + +class HTTPConnection(Client): + def __init__( + self, + origin: typing.Union[str, Origin], + ssl: SSLConfig = DEFAULT_SSL_CONFIG, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + on_release: typing.Callable = None, + ): + self.origin = Origin(origin) if isinstance(origin, str) else origin + self.ssl = ssl + self.timeout = timeout + self.on_release = on_release + self.h11_connection = None # type: typing.Optional[HTTP11Connection] + self.h2_connection = None # type: typing.Optional[HTTP2Connection] + + async def send( + self, + request: Request, + *, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, + ) -> Response: + if self.h11_connection is None and self.h2_connection is None: + if ssl is None: + ssl = self.ssl + if timeout is None: + timeout = self.timeout + + reader, writer, protocol = await self.connect(ssl, timeout) + if protocol == "h2": + self.h2_connection = HTTP2Connection( + reader, + writer, + origin=self.origin, + timeout=self.timeout, + on_release=self.on_release, + ) + else: + self.h11_connection = HTTP11Connection( + reader, + writer, + origin=self.origin, + timeout=self.timeout, + on_release=self.on_release, + ) + + if self.h2_connection is not None: + response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout) + else: + assert self.h11_connection is not None + response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout) + + return response + + async def close(self) -> None: + if self.h2_connection is not None: + await self.h2_connection.close() + else: + assert self.h11_connection is not None + await self.h11_connection.close() + + @property + def is_closed(self) -> bool: + if self.h2_connection is not None: + return self.h2_connection.is_closed + else: + assert self.h11_connection is not None + return self.h11_connection.is_closed + + async def connect( + self, ssl: SSLConfig, timeout: TimeoutConfig + ) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]: + hostname = self.origin.hostname + port = self.origin.port + ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None + + try: + reader, writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl_context), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + ssl_object = writer.get_extra_info("ssl_object") + if ssl_object is None: + protocol = "http/1.1" + else: + protocol = ssl_object.selected_alpn_protocol() + if protocol is None: + protocol = ssl_object.selected_npn_protocol() + + return (reader, writer, protocol) diff --git a/httpcore/connectionpool.py b/httpcore/connectionpool.py index a2040dbc..54bb32f5 100644 --- a/httpcore/connectionpool.py +++ b/httpcore/connectionpool.py @@ -10,9 +10,9 @@ from .config import ( SSLConfig, TimeoutConfig, ) +from .connection import HTTPConnection from .datastructures import Client, Origin, Request, Response from .exceptions import PoolTimeout -from .http11 import HTTP11Connection class ConnectionPool(Client): @@ -31,7 +31,7 @@ class ConnectionPool(Client): self.num_keepalive_connections = 0 self._keepalive_connections = ( {} - ) # type: typing.Dict[Origin, typing.List[HTTP11Connection]] + ) # type: typing.Dict[Origin, typing.List[HTTPConnection]] self._max_connections = ConnectionSemaphore( max_connections=self.limits.hard_limit ) @@ -53,7 +53,7 @@ class ConnectionPool(Client): async def acquire_connection( self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None - ) -> HTTP11Connection: + ) -> HTTPConnection: try: connection = self._keepalive_connections[origin].pop() if not self._keepalive_connections[origin]: @@ -71,7 +71,7 @@ class ConnectionPool(Client): await asyncio.wait_for(self._max_connections.acquire(), pool_timeout) except asyncio.TimeoutError: raise PoolTimeout() - connection = HTTP11Connection( + connection = HTTPConnection( origin, ssl=self.ssl, timeout=self.timeout, @@ -81,7 +81,7 @@ class ConnectionPool(Client): return connection - async def release_connection(self, connection: HTTP11Connection) -> None: + async def release_connection(self, connection: HTTPConnection) -> None: if connection.is_closed: self._max_connections.release() self.num_active_connections -= 1 diff --git a/httpcore/datastructures.py b/httpcore/datastructures.py index de5ee2f5..e4a809af 100644 --- a/httpcore/datastructures.py +++ b/httpcore/datastructures.py @@ -52,7 +52,7 @@ class URL: return port @property - def target(self) -> str: + def full_path(self) -> str: path = self.path or "/" query = self.query if query: @@ -138,10 +138,11 @@ class Request: return headers async def stream(self) -> typing.AsyncIterator[bytes]: - assert self.is_streaming - - async for part in self.body_aiter: - yield part + if self.is_streaming: + async for part in self.body_aiter: + yield part + elif self.body: + yield self.body class Response: @@ -150,6 +151,7 @@ class Response: status_code: int, *, reason: typing.Optional[str] = None, + protocol: typing.Optional[str] = None, headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", on_close: typing.Callable = None, @@ -162,6 +164,7 @@ class Response: self.reason = "" else: self.reason = reason + self.protocol = protocol self.headers = list(headers) self.on_close = on_close self.is_closed = False diff --git a/httpcore/http11.py b/httpcore/http11.py index 23cc27ce..f660867a 100644 --- a/httpcore/http11.py +++ b/httpcore/http11.py @@ -20,55 +20,43 @@ H11Event = typing.Union[ class HTTP11Connection(Client): def __init__( self, - origin: typing.Union[str, Origin], - ssl: SSLConfig = DEFAULT_SSL_CONFIG, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + origin: Origin, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, on_release: typing.Callable = None, ): - self.origin = Origin(origin) if isinstance(origin, str) else origin - self.ssl = ssl + self.origin = origin + self.reader = reader + self.writer = writer self.timeout = timeout self.on_release = on_release - self._reader = None - self._writer = None - self._h11_state = h11.Connection(our_role=h11.CLIENT) + self.h11_state = h11.Connection(our_role=h11.CLIENT) @property def is_closed(self) -> bool: - return self._h11_state.our_state in (h11.CLOSED, h11.ERROR) + return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) async def send( self, request: Request, *, ssl: typing.Optional[SSLConfig] = None, - timeout: typing.Optional[TimeoutConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None ) -> Response: - assert request.url.origin == self.origin - - if ssl is None: - ssl = self.ssl if timeout is None: timeout = self.timeout - # Make the connection - if self._reader is None: - await self._connect(ssl, timeout) - #  Start sending the request. method = request.method.encode() - target = request.url.target + target = request.url.full_path headers = request.headers event = h11.Request(method=method, target=target, headers=headers) await self._send_event(event) # Send the request body. - if request.is_streaming: - async for data in request.stream(): - event = h11.Data(data=data) - await self._send_event(event) - elif request.body: - event = h11.Data(data=request.body) + async for data in request.stream(): + event = h11.Data(data=data) await self._send_event(event) # Finalize sending the request. @@ -79,32 +67,22 @@ class HTTP11Connection(Client): event = await self._receive_event(timeout) if isinstance(event, h11.InformationalResponse): event = await self._receive_event(timeout) + assert isinstance(event, h11.Response) reason = event.reason.decode("latin1") status_code = event.status_code headers = event.headers body = self._body_iter(timeout) + return Response( status_code=status_code, reason=reason, + protocol="HTTP/1.1", headers=headers, body=body, on_close=self._release, ) - async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None: - hostname = self.origin.hostname - port = self.origin.port - ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None - - try: - self._reader, self._writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, - ) - except asyncio.TimeoutError: - raise ConnectTimeout() - async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]: event = await self._receive_event(timeout) while isinstance(event, h11.Data): @@ -113,36 +91,30 @@ class HTTP11Connection(Client): assert isinstance(event, h11.EndOfMessage) async def _send_event(self, event: H11Event) -> None: - assert self._writer is not None - - data = self._h11_state.send(event) - self._writer.write(data) + data = self.h11_state.send(event) + self.writer.write(data) async def _receive_event(self, timeout: TimeoutConfig) -> H11Event: - assert self._reader is not None - - event = self._h11_state.next_event() + event = self.h11_state.next_event() while event is h11.NEED_DATA: try: data = await asyncio.wait_for( - self._reader.read(2048), timeout.read_timeout + self.reader.read(2048), timeout.read_timeout ) except asyncio.TimeoutError: raise ReadTimeout() - self._h11_state.receive_data(data) - event = self._h11_state.next_event() + self.h11_state.receive_data(data) + event = self.h11_state.next_event() return event async def _release(self) -> None: - assert self._writer is not None - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE + self.h11_state.our_state is h11.DONE + and self.h11_state.their_state is h11.DONE ): - self._h11_state.start_next_cycle() + self.h11_state.start_next_cycle() else: await self.close() @@ -153,11 +125,11 @@ class HTTP11Connection(Client): event = h11.ConnectionClosed() try: # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. - self._h11_state.send(event) + self.h11_state.send(event) except h11.ProtocolError: # If we're in some other state then it's a premature close, # and we'll end up in h11.ERROR. pass - if self._writer is not None: - self._writer.close() + if self.writer is not None: + self.writer.close() diff --git a/httpcore/http2.py b/httpcore/http2.py new file mode 100644 index 00000000..084a87ed --- /dev/null +++ b/httpcore/http2.py @@ -0,0 +1,152 @@ +import asyncio +import typing + +import h2.connection +import h2.events + +from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig +from .datastructures import Client, Origin, Request, Response +from .exceptions import ConnectTimeout, ReadTimeout + + +class HTTP2Connection(Client): + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + origin: Origin, + timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + on_release: typing.Callable = None, + ): + self.origin = origin + self.reader = reader + self.writer = writer + self.timeout = timeout + self.on_release = on_release + self.h2_state = h2.connection.H2Connection() + self.events = [] # type: typing.List[h2.events.Event] + + @property + def is_closed(self) -> bool: + return False + + async def send( + self, + request: Request, + *, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None + ) -> Response: + if timeout is None: + timeout = self.timeout + + #  Start sending the request. + await self._initiate_connection() + await self._send_headers(request) + + # Send the request body. + if request.body: + await self._send_data(request.body) + + # Finalize sending the request. + await self._end_stream() + + # Start getting the response. + while True: + event = await self._receive_event(timeout) + if isinstance(event, h2.events.ResponseReceived): + break + + status_code = 200 + headers = [] + for k, v in event.headers: + if k == b":status": + status_code = int(v.decode()) + elif not k.startswith(b":"): + headers.append((k, v)) + + body = self._body_iter(timeout) + return Response( + status_code=status_code, + protocol="HTTP/2", + headers=headers, + body=body, + on_close=self._release, + ) + + async def _initiate_connection(self) -> None: + self.h2_state.initiate_connection() + data_to_send = self.h2_state.data_to_send() + self.writer.write(data_to_send) + + async def _send_headers(self, request: Request) -> None: + headers = [ + (b":method", request.method.encode()), + (b":authority", request.url.hostname.encode()), + (b":scheme", request.url.scheme.encode()), + (b":path", request.url.full_path.encode()), + ] + request.headers + self.h2_state.send_headers(1, headers) + data_to_send = self.h2_state.data_to_send() + self.writer.write(data_to_send) + + async def _send_data(self, data: bytes) -> None: + self.h2_state.send_data(1, data) + data_to_send = self.h2_state.data_to_send() + self.writer.write(data_to_send) + + async def _end_stream(self) -> None: + self.h2_state.end_stream(1) + data_to_send = self.h2_state.data_to_send() + self.writer.write(data_to_send) + + async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]: + while True: + event = await self._receive_event(timeout) + if isinstance(event, h2.events.DataReceived): + yield event.data + elif isinstance(event, h2.events.StreamEnded): + break + + async def _receive_event(self, timeout: TimeoutConfig) -> h2.events.Event: + while not self.events: + try: + data = await asyncio.wait_for( + self.reader.read(2048), timeout.read_timeout + ) + except asyncio.TimeoutError: + raise ReadTimeout() + + events = self.h2_state.receive_data(data) + self.events.extend(events) + + data_to_send = self.h2_state.data_to_send() + if data_to_send: + self.writer.write(data_to_send) + + return self.events.pop(0) + + async def _release(self) -> None: + # if ( + # self.h11_state.our_state is h11.DONE + # and self.h11_state.their_state is h11.DONE + # ): + # self.h11_state.start_next_cycle() + # else: + # await self.close() + + if self.on_release is not None: + await self.on_release(self) + + async def close(self) -> None: + # event = h11.ConnectionClosed() + # try: + # # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED. + # self.h11_state.send(event) + # except h11.ProtocolError: + # # If we're in some other state then it's a premature close, + # # and we'll end up in h11.ERROR. + # pass + + if self.writer is not None: + self.writer.close() diff --git a/requirements.txt b/requirements.txt index 5108a8d6..18f9c5fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ certifi h11 +h2 # Optional brotlipy diff --git a/tests/test_connections.py b/tests/test_connections.py index f1590140..11031106 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -5,7 +5,7 @@ import httpcore @pytest.mark.asyncio async def test_get(server): - http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/") + http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/") response = await http.request("GET", "http://127.0.0.1:8000/") assert response.status_code == 200 assert response.body == b"Hello, world!" @@ -13,7 +13,7 @@ async def test_get(server): @pytest.mark.asyncio async def test_post(server): - http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/") + http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/") response = await http.request( "POST", "http://127.0.0.1:8000/", body=b"Hello, world!" ) diff --git a/tests/test_requests.py b/tests/test_requests.py index bdbf2caa..c88b70a0 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -73,12 +73,12 @@ def test_url(): request = httpcore.Request("GET", "http://example.org") assert request.url.scheme == "http" assert request.url.port == 80 - assert request.url.target == "/" + assert request.url.full_path == "/" request = httpcore.Request("GET", "https://example.org/abc?foo=bar") assert request.url.scheme == "https" assert request.url.port == 443 - assert request.url.target == "/abc?foo=bar" + assert request.url.full_path == "/abc?foo=bar" def test_invalid_urls():