]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
First pass at HTTP/2 support
authorTom Christie <tom@tomchristie.com>
Wed, 24 Apr 2019 14:48:18 +0000 (15:48 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 24 Apr 2019 14:48:18 +0000 (15:48 +0100)
httpcore/__init__.py
httpcore/config.py
httpcore/connection.py [new file with mode: 0644]
httpcore/connectionpool.py
httpcore/datastructures.py
httpcore/http11.py
httpcore/http2.py [new file with mode: 0644]
requirements.txt
tests/test_connections.py
tests/test_requests.py

index 9824e854e54223571979e8ac7ffcbf3606be02e4..48e3426ac46faa2cff68c6f26657136821adcb79 100644 (file)
@@ -1,4 +1,5 @@
 from .config import PoolLimits, SSLConfig, TimeoutConfig
+from .connection import HTTPConnection
 from .connectionpool import ConnectionPool
 from .datastructures import URL, Origin, Request, Response
 from .exceptions import (
@@ -10,6 +11,7 @@ from .exceptions import (
     StreamConsumed,
     Timeout,
 )
+from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 from .sync import SyncClient, SyncConnectionPool
 
index 8cc784b3f78197dc6a8b6c806544f88b5a45e7e9..5b7ab4e07bb8cc0bd23a9d2fa62a3efb1abe74d9 100644 (file)
@@ -73,7 +73,23 @@ class SSLConfig:
                 "invalid path: {}".format(self.verify)
             )
 
-        context = ssl.create_default_context()
+        context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)
+
+        context.options |= ssl.OP_NO_SSLv2
+        context.options |= ssl.OP_NO_SSLv3
+        context.options |= ssl.OP_NO_COMPRESSION
+
+        # RFC 7540 Section 9.2.2: "deployments of HTTP/2 that use TLS 1.2 MUST
+        # support TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256". In practice, the
+        # blacklist defined in this section allows only the AES GCM and ChaCha20
+        # cipher suites with ephemeral key negotiation.
+        context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20")
+
+        if ssl.HAS_ALPN:
+            context.set_alpn_protocols(["h2", "http/1.1"])
+        if ssl.HAS_NPN:
+            context.set_npn_protocols(["h2", "http/1.1"])
+
         if os.path.isfile(ca_bundle_path):
             context.load_verify_locations(cafile=ca_bundle_path)
         elif os.path.isdir(ca_bundle_path):
diff --git a/httpcore/connection.py b/httpcore/connection.py
new file mode 100644 (file)
index 0000000..33f3459
--- /dev/null
@@ -0,0 +1,106 @@
+import asyncio
+import typing
+
+import h2.connection
+import h11
+
+from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
+from .datastructures import Client, Origin, Request, Response
+from .exceptions import ConnectTimeout
+from .http2 import HTTP2Connection
+from .http11 import HTTP11Connection
+
+
+class HTTPConnection(Client):
+    def __init__(
+        self,
+        origin: typing.Union[str, Origin],
+        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        on_release: typing.Callable = None,
+    ):
+        self.origin = Origin(origin) if isinstance(origin, str) else origin
+        self.ssl = ssl
+        self.timeout = timeout
+        self.on_release = on_release
+        self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
+        self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
+
+    async def send(
+        self,
+        request: Request,
+        *,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> Response:
+        if self.h11_connection is None and self.h2_connection is None:
+            if ssl is None:
+                ssl = self.ssl
+            if timeout is None:
+                timeout = self.timeout
+
+            reader, writer, protocol = await self.connect(ssl, timeout)
+            if protocol == "h2":
+                self.h2_connection = HTTP2Connection(
+                    reader,
+                    writer,
+                    origin=self.origin,
+                    timeout=self.timeout,
+                    on_release=self.on_release,
+                )
+            else:
+                self.h11_connection = HTTP11Connection(
+                    reader,
+                    writer,
+                    origin=self.origin,
+                    timeout=self.timeout,
+                    on_release=self.on_release,
+                )
+
+        if self.h2_connection is not None:
+            response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
+        else:
+            assert self.h11_connection is not None
+            response = await self.h11_connection.send(request, ssl=ssl, timeout=timeout)
+
+        return response
+
+    async def close(self) -> None:
+        if self.h2_connection is not None:
+            await self.h2_connection.close()
+        else:
+            assert self.h11_connection is not None
+            await self.h11_connection.close()
+
+    @property
+    def is_closed(self) -> bool:
+        if self.h2_connection is not None:
+            return self.h2_connection.is_closed
+        else:
+            assert self.h11_connection is not None
+            return self.h11_connection.is_closed
+
+    async def connect(
+        self, ssl: SSLConfig, timeout: TimeoutConfig
+    ) -> typing.Tuple[asyncio.StreamReader, asyncio.StreamWriter, str]:
+        hostname = self.origin.hostname
+        port = self.origin.port
+        ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
+
+        try:
+            reader, writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_connection(hostname, port, ssl=ssl_context),
+                timeout.connect_timeout,
+            )
+        except asyncio.TimeoutError:
+            raise ConnectTimeout()
+
+        ssl_object = writer.get_extra_info("ssl_object")
+        if ssl_object is None:
+            protocol = "http/1.1"
+        else:
+            protocol = ssl_object.selected_alpn_protocol()
+        if protocol is None:
+            protocol = ssl_object.selected_npn_protocol()
+
+        return (reader, writer, protocol)
index a2040dbc410ca6671bfdcec645d16fefa13d0101..54bb32f5916f709935eb28f7791378aead5ca4dc 100644 (file)
@@ -10,9 +10,9 @@ from .config import (
     SSLConfig,
     TimeoutConfig,
 )
+from .connection import HTTPConnection
 from .datastructures import Client, Origin, Request, Response
 from .exceptions import PoolTimeout
-from .http11 import HTTP11Connection
 
 
 class ConnectionPool(Client):
@@ -31,7 +31,7 @@ class ConnectionPool(Client):
         self.num_keepalive_connections = 0
         self._keepalive_connections = (
             {}
-        )  # type: typing.Dict[Origin, typing.List[HTTP11Connection]]
+        )  # type: typing.Dict[Origin, typing.List[HTTPConnection]]
         self._max_connections = ConnectionSemaphore(
             max_connections=self.limits.hard_limit
         )
@@ -53,7 +53,7 @@ class ConnectionPool(Client):
 
     async def acquire_connection(
         self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
-    ) -> HTTP11Connection:
+    ) -> HTTPConnection:
         try:
             connection = self._keepalive_connections[origin].pop()
             if not self._keepalive_connections[origin]:
@@ -71,7 +71,7 @@ class ConnectionPool(Client):
                 await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
             except asyncio.TimeoutError:
                 raise PoolTimeout()
-            connection = HTTP11Connection(
+            connection = HTTPConnection(
                 origin,
                 ssl=self.ssl,
                 timeout=self.timeout,
@@ -81,7 +81,7 @@ class ConnectionPool(Client):
 
         return connection
 
-    async def release_connection(self, connection: HTTP11Connection) -> None:
+    async def release_connection(self, connection: HTTPConnection) -> None:
         if connection.is_closed:
             self._max_connections.release()
             self.num_active_connections -= 1
index de5ee2f5a3a8451f5f1314124c19ee70f57329b0..e4a809af5b97c0e4c54eb522365262d4cc6fe14b 100644 (file)
@@ -52,7 +52,7 @@ class URL:
         return port
 
     @property
-    def target(self) -> str:
+    def full_path(self) -> str:
         path = self.path or "/"
         query = self.query
         if query:
@@ -138,10 +138,11 @@ class Request:
         return headers
 
     async def stream(self) -> typing.AsyncIterator[bytes]:
-        assert self.is_streaming
-
-        async for part in self.body_aiter:
-            yield part
+        if self.is_streaming:
+            async for part in self.body_aiter:
+                yield part
+        elif self.body:
+            yield self.body
 
 
 class Response:
@@ -150,6 +151,7 @@ class Response:
         status_code: int,
         *,
         reason: typing.Optional[str] = None,
+        protocol: typing.Optional[str] = None,
         headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         on_close: typing.Callable = None,
@@ -162,6 +164,7 @@ class Response:
                 self.reason = ""
         else:
             self.reason = reason
+        self.protocol = protocol
         self.headers = list(headers)
         self.on_close = on_close
         self.is_closed = False
index 23cc27ceca14e3130148a8aa9fa6fd17b89101c2..f660867abc3275a1341062c9b8d20a8179cf2bc6 100644 (file)
@@ -20,55 +20,43 @@ H11Event = typing.Union[
 class HTTP11Connection(Client):
     def __init__(
         self,
-        origin: typing.Union[str, Origin],
-        ssl: SSLConfig = DEFAULT_SSL_CONFIG,
+        reader: asyncio.StreamReader,
+        writer: asyncio.StreamWriter,
+        origin: Origin,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         on_release: typing.Callable = None,
     ):
-        self.origin = Origin(origin) if isinstance(origin, str) else origin
-        self.ssl = ssl
+        self.origin = origin
+        self.reader = reader
+        self.writer = writer
         self.timeout = timeout
         self.on_release = on_release
-        self._reader = None
-        self._writer = None
-        self._h11_state = h11.Connection(our_role=h11.CLIENT)
+        self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
     @property
     def is_closed(self) -> bool:
-        return self._h11_state.our_state in (h11.CLOSED, h11.ERROR)
+        return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
 
     async def send(
         self,
         request: Request,
         *,
         ssl: typing.Optional[SSLConfig] = None,
-        timeout: typing.Optional[TimeoutConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None
     ) -> Response:
-        assert request.url.origin == self.origin
-
-        if ssl is None:
-            ssl = self.ssl
         if timeout is None:
             timeout = self.timeout
 
-        # Make the connection
-        if self._reader is None:
-            await self._connect(ssl, timeout)
-
         #  Start sending the request.
         method = request.method.encode()
-        target = request.url.target
+        target = request.url.full_path
         headers = request.headers
         event = h11.Request(method=method, target=target, headers=headers)
         await self._send_event(event)
 
         # Send the request body.
-        if request.is_streaming:
-            async for data in request.stream():
-                event = h11.Data(data=data)
-                await self._send_event(event)
-        elif request.body:
-            event = h11.Data(data=request.body)
+        async for data in request.stream():
+            event = h11.Data(data=data)
             await self._send_event(event)
 
         # Finalize sending the request.
@@ -79,32 +67,22 @@ class HTTP11Connection(Client):
         event = await self._receive_event(timeout)
         if isinstance(event, h11.InformationalResponse):
             event = await self._receive_event(timeout)
+
         assert isinstance(event, h11.Response)
         reason = event.reason.decode("latin1")
         status_code = event.status_code
         headers = event.headers
         body = self._body_iter(timeout)
+
         return Response(
             status_code=status_code,
             reason=reason,
+            protocol="HTTP/1.1",
             headers=headers,
             body=body,
             on_close=self._release,
         )
 
-    async def _connect(self, ssl: SSLConfig, timeout: TimeoutConfig) -> None:
-        hostname = self.origin.hostname
-        port = self.origin.port
-        ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
-
-        try:
-            self._reader, self._writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_connection(hostname, port, ssl=ssl_context),
-                timeout.connect_timeout,
-            )
-        except asyncio.TimeoutError:
-            raise ConnectTimeout()
-
     async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
         event = await self._receive_event(timeout)
         while isinstance(event, h11.Data):
@@ -113,36 +91,30 @@ class HTTP11Connection(Client):
         assert isinstance(event, h11.EndOfMessage)
 
     async def _send_event(self, event: H11Event) -> None:
-        assert self._writer is not None
-
-        data = self._h11_state.send(event)
-        self._writer.write(data)
+        data = self.h11_state.send(event)
+        self.writer.write(data)
 
     async def _receive_event(self, timeout: TimeoutConfig) -> H11Event:
-        assert self._reader is not None
-
-        event = self._h11_state.next_event()
+        event = self.h11_state.next_event()
 
         while event is h11.NEED_DATA:
             try:
                 data = await asyncio.wait_for(
-                    self._reader.read(2048), timeout.read_timeout
+                    self.reader.read(2048), timeout.read_timeout
                 )
             except asyncio.TimeoutError:
                 raise ReadTimeout()
-            self._h11_state.receive_data(data)
-            event = self._h11_state.next_event()
+            self.h11_state.receive_data(data)
+            event = self.h11_state.next_event()
 
         return event
 
     async def _release(self) -> None:
-        assert self._writer is not None
-
         if (
-            self._h11_state.our_state is h11.DONE
-            and self._h11_state.their_state is h11.DONE
+            self.h11_state.our_state is h11.DONE
+            and self.h11_state.their_state is h11.DONE
         ):
-            self._h11_state.start_next_cycle()
+            self.h11_state.start_next_cycle()
         else:
             await self.close()
 
@@ -153,11 +125,11 @@ class HTTP11Connection(Client):
         event = h11.ConnectionClosed()
         try:
             # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
-            self._h11_state.send(event)
+            self.h11_state.send(event)
         except h11.ProtocolError:
             # If we're in some other state then it's a premature close,
             # and we'll end up in h11.ERROR.
             pass
 
-        if self._writer is not None:
-            self._writer.close()
+        if self.writer is not None:
+            self.writer.close()
diff --git a/httpcore/http2.py b/httpcore/http2.py
new file mode 100644 (file)
index 0000000..084a87e
--- /dev/null
@@ -0,0 +1,152 @@
+import asyncio
+import typing
+
+import h2.connection
+import h2.events
+
+from .config import DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, SSLConfig, TimeoutConfig
+from .datastructures import Client, Origin, Request, Response
+from .exceptions import ConnectTimeout, ReadTimeout
+
+
+class HTTP2Connection(Client):
+    def __init__(
+        self,
+        reader: asyncio.StreamReader,
+        writer: asyncio.StreamWriter,
+        origin: Origin,
+        timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        on_release: typing.Callable = None,
+    ):
+        self.origin = origin
+        self.reader = reader
+        self.writer = writer
+        self.timeout = timeout
+        self.on_release = on_release
+        self.h2_state = h2.connection.H2Connection()
+        self.events = []  # type: typing.List[h2.events.Event]
+
+    @property
+    def is_closed(self) -> bool:
+        return False
+
+    async def send(
+        self,
+        request: Request,
+        *,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None
+    ) -> Response:
+        if timeout is None:
+            timeout = self.timeout
+
+        #  Start sending the request.
+        await self._initiate_connection()
+        await self._send_headers(request)
+
+        # Send the request body.
+        if request.body:
+            await self._send_data(request.body)
+
+        # Finalize sending the request.
+        await self._end_stream()
+
+        # Start getting the response.
+        while True:
+            event = await self._receive_event(timeout)
+            if isinstance(event, h2.events.ResponseReceived):
+                break
+
+        status_code = 200
+        headers = []
+        for k, v in event.headers:
+            if k == b":status":
+                status_code = int(v.decode())
+            elif not k.startswith(b":"):
+                headers.append((k, v))
+
+        body = self._body_iter(timeout)
+        return Response(
+            status_code=status_code,
+            protocol="HTTP/2",
+            headers=headers,
+            body=body,
+            on_close=self._release,
+        )
+
+    async def _initiate_connection(self) -> None:
+        self.h2_state.initiate_connection()
+        data_to_send = self.h2_state.data_to_send()
+        self.writer.write(data_to_send)
+
+    async def _send_headers(self, request: Request) -> None:
+        headers = [
+            (b":method", request.method.encode()),
+            (b":authority", request.url.hostname.encode()),
+            (b":scheme", request.url.scheme.encode()),
+            (b":path", request.url.full_path.encode()),
+        ] + request.headers
+        self.h2_state.send_headers(1, headers)
+        data_to_send = self.h2_state.data_to_send()
+        self.writer.write(data_to_send)
+
+    async def _send_data(self, data: bytes) -> None:
+        self.h2_state.send_data(1, data)
+        data_to_send = self.h2_state.data_to_send()
+        self.writer.write(data_to_send)
+
+    async def _end_stream(self) -> None:
+        self.h2_state.end_stream(1)
+        data_to_send = self.h2_state.data_to_send()
+        self.writer.write(data_to_send)
+
+    async def _body_iter(self, timeout: TimeoutConfig) -> typing.AsyncIterator[bytes]:
+        while True:
+            event = await self._receive_event(timeout)
+            if isinstance(event, h2.events.DataReceived):
+                yield event.data
+            elif isinstance(event, h2.events.StreamEnded):
+                break
+
+    async def _receive_event(self, timeout: TimeoutConfig) -> h2.events.Event:
+        while not self.events:
+            try:
+                data = await asyncio.wait_for(
+                    self.reader.read(2048), timeout.read_timeout
+                )
+            except asyncio.TimeoutError:
+                raise ReadTimeout()
+
+            events = self.h2_state.receive_data(data)
+            self.events.extend(events)
+
+            data_to_send = self.h2_state.data_to_send()
+            if data_to_send:
+                self.writer.write(data_to_send)
+
+        return self.events.pop(0)
+
+    async def _release(self) -> None:
+        # if (
+        #     self.h11_state.our_state is h11.DONE
+        #     and self.h11_state.their_state is h11.DONE
+        # ):
+        #     self.h11_state.start_next_cycle()
+        # else:
+        #     await self.close()
+
+        if self.on_release is not None:
+            await self.on_release(self)
+
+    async def close(self) -> None:
+        # event = h11.ConnectionClosed()
+        # try:
+        #     # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+        #     self.h11_state.send(event)
+        # except h11.ProtocolError:
+        #     # If we're in some other state then it's a premature close,
+        #     # and we'll end up in h11.ERROR.
+        #     pass
+
+        if self.writer is not None:
+            self.writer.close()
index 5108a8d6d38d4f9c2116fd33c917c95c5c3f74ba..18f9c5fe4c5970c0fa84775876b6df9352fb1c22 100644 (file)
@@ -1,5 +1,6 @@
 certifi
 h11
+h2
 
 # Optional
 brotlipy
index f15901404c535e21279914995ebbc541b71a1265..110311067740321848ad8d84b2573fb75f51d62b 100644 (file)
@@ -5,7 +5,7 @@ import httpcore
 
 @pytest.mark.asyncio
 async def test_get(server):
-    http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
+    http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
     response = await http.request("GET", "http://127.0.0.1:8000/")
     assert response.status_code == 200
     assert response.body == b"Hello, world!"
@@ -13,7 +13,7 @@ async def test_get(server):
 
 @pytest.mark.asyncio
 async def test_post(server):
-    http = httpcore.HTTP11Connection(origin="http://127.0.0.1:8000/")
+    http = httpcore.HTTPConnection(origin="http://127.0.0.1:8000/")
     response = await http.request(
         "POST", "http://127.0.0.1:8000/", body=b"Hello, world!"
     )
index bdbf2caa08c27078c6945b4f685a06cbcf15d623..c88b70a037d7205b084c648cf9ef28e2495580a2 100644 (file)
@@ -73,12 +73,12 @@ def test_url():
     request = httpcore.Request("GET", "http://example.org")
     assert request.url.scheme == "http"
     assert request.url.port == 80
-    assert request.url.target == "/"
+    assert request.url.full_path == "/"
 
     request = httpcore.Request("GET", "https://example.org/abc?foo=bar")
     assert request.url.scheme == "https"
     assert request.url.port == 443
-    assert request.url.target == "/abc?foo=bar"
+    assert request.url.full_path == "/abc?foo=bar"
 
 
 def test_invalid_urls():