]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Tighten up connection acquiry/release
authorTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 14:31:47 +0000 (15:31 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 14:31:47 +0000 (15:31 +0100)
httpcore/connection.py
httpcore/connection_pool.py
httpcore/http11.py
httpcore/http2.py
httpcore/streams.py
tests/test_connection_pools.py

index f164232ffe6cb893aadec1c02a66ae5ae1f10281..a07a062aadd8d415a845ef75e59f85e1149f95db 100644 (file)
@@ -1,3 +1,4 @@
+import functools
 import typing
 
 import h2.connection
@@ -17,12 +18,12 @@ class HTTPConnection(Client):
         origin: typing.Union[str, Origin],
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-        on_release: typing.Callable = None,
+        pool_release_func: typing.Callable = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = ssl
         self.timeout = timeout
-        self.on_release = on_release
+        self.pool_release_func = pool_release_func
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
         self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
 
@@ -43,6 +44,11 @@ class HTTPConnection(Client):
             port = self.origin.port
             ssl_context = await ssl.load_ssl_context() if self.origin.is_ssl else None
 
+            if self.pool_release_func is None:
+                on_release = None
+            else:
+                on_release = functools.partial(self.pool_release_func, self)
+
             reader, writer, protocol = await connect(
                 hostname, port, ssl_context, timeout
             )
@@ -52,7 +58,7 @@ class HTTPConnection(Client):
                     writer,
                     origin=self.origin,
                     timeout=self.timeout,
-                    on_release=self.on_release,
+                    on_release=on_release,
                 )
             else:
                 self.h11_connection = HTTP11Connection(
@@ -60,7 +66,7 @@ class HTTPConnection(Client):
                     writer,
                     origin=self.origin,
                     timeout=self.timeout,
-                    on_release=self.on_release,
+                    on_release=on_release,
                 )
 
         if self.h2_connection is not None:
index 894212ab6dea9008f1741bd232f952708aa4a9ab..b541435787c597293b5cc91873ac6f7dbffc873f 100644 (file)
@@ -1,3 +1,4 @@
+import collections.abc
 import typing
 
 from .config import (
@@ -14,6 +15,66 @@ from .exceptions import PoolTimeout
 from .models import Client, Origin, Request, Response
 from .streams import PoolSemaphore
 
+CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
+
+
+class ConnectionStore(collections.abc.Sequence):
+    """
+    We need to maintain collections of connections in a way that allows us to:
+
+    * Lookup connections by origin.
+    * Iterate over connections by insertion time.
+    * Return the total number of connections.
+    """
+
+    def __init__(self) -> None:
+        self.all = {}  # type: typing.Dict[HTTPConnection, float]
+        self.by_origin = (
+            {}
+        )  # type: typing.Dict[Origin, typing.Dict[HTTPConnection, float]]
+
+    def pop_by_origin(self, origin: Origin) -> typing.Optional[HTTPConnection]:
+        try:
+            connections = self.by_origin[origin]
+        except KeyError:
+            return None
+
+        connection = next(reversed(list(connections.keys())))
+        del connections[connection]
+        if not connections:
+            del self.by_origin[origin]
+        del self.all[connection]
+
+        return connection
+
+    def add(self, connection: HTTPConnection) -> None:
+        self.all[connection] = 0.0
+        try:
+            self.by_origin[connection.origin][connection] = 0.0
+        except KeyError:
+            self.by_origin[connection.origin] = {connection: 0.0}
+
+    def remove(self, connection: HTTPConnection) -> None:
+        del self.all[connection]
+        del self.by_origin[connection.origin][connection]
+        if not self.by_origin[connection.origin]:
+            del self.by_origin[connection.origin]
+
+    def clear(self) -> None:
+        self.all.clear()
+        self.by_origin.clear()
+
+    def __iter__(self) -> typing.Iterator[HTTPConnection]:
+        return iter(self.all.keys())
+
+    def __getitem__(self, key: typing.Any) -> typing.Any:
+        if key in self.all:
+            return key
+        return None
+
+    def __len__(self) -> int:
+        return len(self.all)
+
 
 class ConnectionPool(Client):
     def __init__(
@@ -27,12 +88,14 @@ class ConnectionPool(Client):
         self.timeout = timeout
         self.limits = limits
         self.is_closed = False
-        self.num_active_connections = 0
-        self.num_keepalive_connections = 0
-        self._keepalive_connections = (
-            {}
-        )  # type: typing.Dict[Origin, typing.List[HTTPConnection]]
-        self._max_connections = PoolSemaphore(limits, timeout)
+
+        self.max_connections = PoolSemaphore(limits, timeout)
+        self.keepalive_connections = ConnectionStore()
+        self.active_connections = ConnectionStore()
+
+    @property
+    def num_connections(self) -> int:
+        return len(self.keepalive_connections) + len(self.active_connections)
 
     async def send(
         self,
@@ -45,56 +108,42 @@ class ConnectionPool(Client):
         response = await connection.send(request, ssl=ssl, timeout=timeout)
         return response
 
-    @property
-    def num_connections(self) -> int:
-        return self.num_active_connections + self.num_keepalive_connections
-
     async def acquire_connection(
         self, origin: Origin, timeout: typing.Optional[TimeoutConfig] = None
     ) -> HTTPConnection:
-        try:
-            connection = self._keepalive_connections[origin].pop()
-            if not self._keepalive_connections[origin]:
-                del self._keepalive_connections[origin]
-            self.num_keepalive_connections -= 1
-            self.num_active_connections += 1
-
-        except (KeyError, IndexError):
-            await self._max_connections.acquire(timeout)
+        connection = self.keepalive_connections.pop_by_origin(origin)
+
+        if connection is None:
+            await self.max_connections.acquire(timeout)
             connection = HTTPConnection(
                 origin,
                 ssl=self.ssl,
                 timeout=self.timeout,
-                on_release=self.release_connection,
+                pool_release_func=self.release_connection,
             )
-            self.num_active_connections += 1
+
+        self.active_connections.add(connection)
 
         return connection
 
     async def release_connection(self, connection: HTTPConnection) -> None:
         if connection.is_closed:
-            self._max_connections.release()
-            self.num_active_connections -= 1
+            self.active_connections.remove(connection)
+            self.max_connections.release()
         elif (
             self.limits.soft_limit is not None
             and self.num_connections > self.limits.soft_limit
         ):
-            self._max_connections.release()
-            self.num_active_connections -= 1
+            self.active_connections.remove(connection)
+            self.max_connections.release()
             await connection.close()
         else:
-            self.num_active_connections -= 1
-            self.num_keepalive_connections += 1
-            try:
-                self._keepalive_connections[connection.origin].append(connection)
-            except KeyError:
-                self._keepalive_connections[connection.origin] = [connection]
+            self.active_connections.remove(connection)
+            self.keepalive_connections.add(connection)
 
     async def close(self) -> None:
         self.is_closed = True
-        all_connections = []
-        for connections in self._keepalive_connections.values():
-            all_connections.extend(list(connections))
-        self._keepalive_connections.clear()
-        for connection in all_connections:
+        connections = list(self.keepalive_connections)
+        self.keepalive_connections.clear()
+        for connection in connections:
             await connection.close()
index 253865fe92b8a488c670c6f395add7556eba82ad..3280e1def8ff15a84f379db9b059b92fd66b51c6 100644 (file)
@@ -82,7 +82,7 @@ class HTTP11Connection(Client):
             protocol="HTTP/1.1",
             headers=headers,
             body=body,
-            on_close=self._release,
+            on_close=self.response_closed,
         )
 
     async def _body_iter(self, timeout: OptionalTimeout) -> typing.AsyncIterator[bytes]:
@@ -106,7 +106,7 @@ class HTTP11Connection(Client):
 
         return event
 
-    async def _release(self) -> None:
+    async def response_closed(self) -> None:
         if (
             self.h11_state.our_state is h11.DONE
             and self.h11_state.their_state is h11.DONE
@@ -116,7 +116,7 @@ class HTTP11Connection(Client):
             await self.close()
 
         if self.on_release is not None:
-            await self.on_release(self)
+            await self.on_release()
 
     async def close(self) -> None:
         event = h11.ConnectionClosed()
index f8d2b648bfc89ffee19b2c3b855f7630ef4beddd..2b6c8e6dba956b447c3f0bf95b751bf83f6214a6 100644 (file)
@@ -141,7 +141,7 @@ class HTTP2Connection(Client):
 
     async def release(self) -> None:
         if self.on_release is not None:
-            await self.on_release(self)
+            await self.on_release()
 
     async def close(self) -> None:
         await self.writer.close()
index cba51fd7315618ffdc14396ad8ec8d0117128d99..e46ffee4856174b8011b136f77d25325e6c6a724 100644 (file)
@@ -14,7 +14,7 @@ import ssl
 import typing
 
 from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
-from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
+from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 
 OptionalTimeout = typing.Optional[TimeoutConfig]
 
index 77a221575441fda934bd4799ed8a4a33b00726f0..7d478c5ac6fdf2b67d6cf98b3ee22df0f58af70f 100644 (file)
@@ -10,12 +10,12 @@ async def test_keepalive_connections(server):
     """
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
 
 @pytest.mark.asyncio
@@ -25,12 +25,12 @@ async def test_differing_connection_keys(server):
     """
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://localhost:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 2
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 2
 
 
 @pytest.mark.asyncio
@@ -42,12 +42,12 @@ async def test_soft_limit(server):
 
     async with httpcore.ConnectionPool(limits=limits) as http:
         response = await http.request("GET", "http://127.0.0.1:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
         response = await http.request("GET", "http://localhost:8000/")
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
 
 @pytest.mark.asyncio
@@ -57,13 +57,13 @@ async def test_streaming_response_holds_connection(server):
     """
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 0
+        assert len(http.active_connections) == 1
+        assert len(http.keepalive_connections) == 0
 
         await response.read()
 
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
 
 @pytest.mark.asyncio
@@ -73,20 +73,20 @@ async def test_multiple_concurrent_connections(server):
     """
     async with httpcore.ConnectionPool() as http:
         response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 0
+        assert len(http.active_connections) == 1
+        assert len(http.keepalive_connections) == 0
 
         response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
-        assert http.num_active_connections == 2
-        assert http.num_keepalive_connections == 0
+        assert len(http.active_connections) == 2
+        assert len(http.keepalive_connections) == 0
 
         await response_b.read()
-        assert http.num_active_connections == 1
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 1
+        assert len(http.keepalive_connections) == 1
 
         await response_a.read()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 2
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 2
 
 
 @pytest.mark.asyncio
@@ -97,8 +97,8 @@ async def test_close_connections(server):
     headers = [(b"connection", b"close")]
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 0
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 0
 
 
 @pytest.mark.asyncio
@@ -110,8 +110,8 @@ async def test_standard_response_close(server):
         response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
         await response.read()
         await response.close()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 1
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 1
 
 
 @pytest.mark.asyncio
@@ -122,5 +122,5 @@ async def test_premature_response_close(server):
     async with httpcore.ConnectionPool() as http:
         response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
         await response.close()
-        assert http.num_active_connections == 0
-        assert http.num_keepalive_connections == 0
+        assert len(http.active_connections) == 0
+        assert len(http.keepalive_connections) == 0