]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add http2 connection re-use
authorTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 15:17:42 +0000 (16:17 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 15:17:42 +0000 (16:17 +0100)
httpcore/connection.py
httpcore/connection_pool.py
httpcore/http2.py

index a07a062aadd8d415a845ef75e59f85e1149f95db..d5caadeeef23d1a051e61f641911a8584b46ea13 100644 (file)
@@ -35,39 +35,7 @@ class HTTPConnection(Client):
         timeout: typing.Optional[TimeoutConfig] = None,
     ) -> Response:
         if self.h11_connection is None and self.h2_connection is None:
-            if ssl is None:
-                ssl = self.ssl
-            if timeout is None:
-                timeout = self.timeout
-
-            hostname = self.origin.hostname
-            port = self.origin.port
-            ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
-
-            if self.pool_release_func is None:
-                on_release = None
-            else:
-                on_release = functools.partial(self.pool_release_func, self)
-
-            reader, writer, protocol = await connect(
-                hostname, port, ssl_context, timeout
-            )
-            if protocol == Protocol.HTTP_2:
-                self.h2_connection = HTTP2Connection(
-                    reader,
-                    writer,
-                    origin=self.origin,
-                    timeout=self.timeout,
-                    on_release=on_release,
-                )
-            else:
-                self.h11_connection = HTTP11Connection(
-                    reader,
-                    writer,
-                    origin=self.origin,
-                    timeout=self.timeout,
-                    on_release=on_release,
-                )
+            await self.connect(ssl, timeout)
 
         if self.h2_connection is not None:
             response = await self.h2_connection.send(request, ssl=ssl, timeout=timeout)
@@ -77,12 +45,53 @@ class HTTPConnection(Client):
 
         return response
 
+    async def connect(
+        self,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
+    ) -> None:
+        if ssl is None:
+            ssl = self.ssl
+        if timeout is None:
+            timeout = self.timeout
+
+        hostname = self.origin.hostname
+        port = self.origin.port
+        ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
+
+        if self.pool_release_func is None:
+            on_release = None
+        else:
+            on_release = functools.partial(self.pool_release_func, self)
+
+        reader, writer, protocol = await connect(hostname, port, ssl_context, timeout)
+        if protocol == Protocol.HTTP_2:
+            self.h2_connection = HTTP2Connection(
+                reader,
+                writer,
+                origin=self.origin,
+                timeout=self.timeout,
+                on_release=on_release,
+            )
+        else:
+            self.h11_connection = HTTP11Connection(
+                reader,
+                writer,
+                origin=self.origin,
+                timeout=self.timeout,
+                on_release=on_release,
+            )
+
     async def close(self) -> None:
         if self.h2_connection is not None:
             await self.h2_connection.close()
         elif self.h11_connection is not None:
             await self.h11_connection.close()
 
+    @property
+    def is_http2(self) -> bool:
+        return self.h2_connection is not None
+
     @property
     def is_closed(self) -> bool:
         if self.h2_connection is not None:
index b541435787c597293b5cc91873ac6f7dbffc873f..709ad0f71384633b8cac7baca021369073f6e764 100644 (file)
@@ -33,13 +33,18 @@ class ConnectionStore(collections.abc.Sequence):
             {}
         )  # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
 
-    def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]:
+    def pop_by_origin(
+        self, origin: Origin, http2_only: bool = False
+    ) -> typing.Optional[HTTPConnection]:
         try:
             connections = self.by_origin[origin]
         except KeyError:
             return None
 
         connection = next(reversed(list(connections.keys())))
+        if http2_only and not connection.is_http2:
+            return None
+
         del connections[connection]
         if not connections:
             del self.by_origin[origin]
@@ -111,7 +116,9 @@ class ConnectionPool(Client):
     async def acquire_connection(
         self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
     ) -> HTTPConnection:
-        connection = self.keepalive_connections.pop_by_origin(origin)
+        connection = self.active_connections.pop_by_origin(origin, http2_only=True)
+        if connection is None:
+            connection = self.keepalive_connections.pop_by_origin(origin)
 
         if connection is None:
             await self.max_connections.acquire(timeout)
index 2b6c8e6dba956b447c3f0bf95b751bf83f6214a6..e48c4de8e0d1d64993de75f7345027938c21ee81 100644 (file)
@@ -1,3 +1,4 @@
+import functools
 import typing
 
 import h2.connection
@@ -74,12 +75,14 @@ class HTTP2Connection(Client):
                 headers.append((k, v))
 
         body = self.body_iter(stream_id, timeout)
+        on_close = functools.partial(self.response_closed, stream_id=stream_id)
+
         return Response(
             status_code=status_code,
             protocol="HTTP/2",
             headers=headers,
             body=body,
-            on_close=self.release,
+            on_close=on_close,
         )
 
     def initiate_connection(self) -> None:
@@ -121,7 +124,6 @@ class HTTP2Connection(Client):
             if isinstance(event, h2.events.DataReceived):
                 yield event.data
             elif isinstance(event, h2.events.StreamEnded):
-                del self.events[stream_id]
                 break
 
     async def receive_event(
@@ -139,8 +141,10 @@ class HTTP2Connection(Client):
 
         return self.events[stream_id].pop(0)
 
-    async def release(self) -> None:
-        if self.on_release is not None:
+    async def response_closed(self, stream_id: int) -> None:
+        del self.events[stream_id]
+
+        if not self.events and self.on_release is not None:
             await self.on_release()
 
     async def close(self) -> None: