]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Connection pooling
authorTom Christie <tom@tomchristie.com>
Tue, 16 Apr 2019 14:58:38 +0000 (15:58 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 16 Apr 2019 14:58:38 +0000 (15:58 +0100)
httpcore/connections.py
httpcore/pool.py
tests/conftest.py
tests/test_pool.py [new file with mode: 0644]

index f39a365f97705f8c00574cebc590706dca660afc..8d5c13bd76fdf2c05c5e0e74a409d9e757136c49 100644 (file)
@@ -19,18 +19,19 @@ H11Event = typing.Union[
 
 
 class Connection:
-    def __init__(self, timeout: TimeoutConfig):
+    def __init__(self, timeout: TimeoutConfig, on_release: typing.Callable = None):
         self.reader = None
         self.writer = None
         self.state = h11.Connection(our_role=h11.CLIENT)
         self.timeout = timeout
+        self.on_release = on_release
+
+    @property
+    def is_closed(self) -> bool:
+        return self.state.our_state in (h11.CLOSED, h11.ERROR)
 
     async def open(
-        self,
-        hostname: str,
-        port: int,
-        *,
-        ssl: typing.Union[bool, ssl.SSLContext] = False
+        self, hostname: str, port: int, *, ssl: typing.Optional[ssl.SSLContext] = None
     ) -> None:
         try:
             self.reader, self.writer = await asyncio.wait_for(  # type: ignore
@@ -69,18 +70,17 @@ class Connection:
         assert isinstance(event, h11.Response)
         status_code = event.status_code
         headers = event.headers
-        body = self.body_iter()
+        body = self._body_iter()
         return Response(
-            status_code=status_code, headers=headers, body=body, on_close=self.close
+            status_code=status_code, headers=headers, body=body, on_close=self._release
         )
 
-    async def body_iter(self) -> typing.AsyncIterator[bytes]:
+    async def _body_iter(self) -> typing.AsyncIterator[bytes]:
         event = await self._receive_event()
         while isinstance(event, h11.Data):
             yield event.data
             event = await self._receive_event()
         assert isinstance(event, h11.EndOfMessage)
-        await self.close()
 
     async def _send_event(self, event: H11Event) -> None:
         assert self.writer is not None
@@ -105,8 +105,25 @@ class Connection:
 
         return event
 
-    async def close(self) -> None:
-        if self.writer is not None:
+    async def _release(self) -> None:
+        assert self.writer is not None
+
+        if self.state.our_state is h11.DONE and self.state.their_state is h11.DONE:
+            self.state.start_next_cycle()
+        else:
+            event = h11.ConnectionClosed()
+            try:
+                # If we're in h11.MUST_CLOSE then we'll end up in h11.CLOSED.
+                self.state.send(event)
+            except h11.ProtocolError:
+                # If we're in some other state then it's a premature close,
+                # and we'll end up in h11.ERROR.
+                pass
+
+        if self.is_closed:
             self.writer.close()
             if hasattr(self.writer, "wait_closed"):
                 await self.writer.wait_closed()
+
+        if self.on_release is not None:
+            await self.on_release(self)
index 6446c39a90d5ae2c81c43161a1f44baa2542febb..6b4d328ada5c6919a50f3f13a37c5c49c5afa5b0 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+import functools
 import os
 import ssl
 import typing
@@ -16,6 +17,8 @@ from .config import (
 from .connections import Connection
 from .datastructures import URL, Request, Response
 
+ConnectionKey = typing.Tuple[str, str, int]  # (scheme, host, port)
+
 
 class ConnectionPool:
     def __init__(
@@ -29,6 +32,11 @@ class ConnectionPool:
         self.timeout = timeout
         self.limits = limits
         self.is_closed = False
+        self.num_active_connections = 0
+        self.num_keepalive_connections = 0
+        self._connections = (
+            {}
+        )  # type: typing.Dict[ConnectionKey, typing.List[Connection]]
 
     async def request(
         self,
@@ -52,15 +60,38 @@ class ConnectionPool:
         return response
 
     async def acquire_connection(
-        self, url: URL, *, ssl: typing.Union[bool, ssl.SSLContext] = False
+        self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
     ) -> Connection:
-        connection = Connection(timeout=self.timeout)
-        await connection.open(url.hostname, url.port, ssl=ssl)
+        key = (url.scheme, url.hostname, url.port)
+        try:
+            connection = self._connections[key].pop()
+            if not self._connections[key]:
+                del self._connections[key]
+            self.num_keepalive_connections -= 1
+            self.num_active_connections += 1
+
+        except (KeyError, IndexError):
+            release = functools.partial(self.release_connection, key=key)
+            connection = Connection(timeout=self.timeout, on_release=release)
+            self.num_active_connections += 1
+            await connection.open(url.hostname, url.port, ssl=ssl)
+
         return connection
 
-    async def get_ssl_context(self, url: URL) -> typing.Union[bool, ssl.SSLContext]:
+    async def release_connection(
+        self, connection: Connection, key: ConnectionKey
+    ) -> None:
+        self.num_active_connections -= 1
+        if not connection.is_closed:
+            self.num_keepalive_connections += 1
+            try:
+                self._connections[key].append(connection)
+            except KeyError:
+                self._connections[key] = [connection]
+
+    async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]:
         if not url.is_secure:
-            return False
+            return None
 
         if not hasattr(self, "ssl_context"):
             if not self.ssl_config.verify:
index 234cf43bb8f22670996a5bb95ee6260cccc84131..efb79df195cd1383eac1dba70a09cf5d4334a13b 100644 (file)
@@ -27,4 +27,6 @@ async def server():
             await asyncio.sleep(0.0001)
         yield server
     finally:
-        task.cancel()
+        server.should_exit = True
+        server.force_exit = True
+        await task
diff --git a/tests/test_pool.py b/tests/test_pool.py
new file mode 100644 (file)
index 0000000..444d51c
--- /dev/null
@@ -0,0 +1,109 @@
+import pytest
+
+import httpcore
+
+
+@pytest.mark.asyncio
+async def test_keepalive_connections(server):
+    """
+    Connections should default to staying in a keep-alive state.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_differing_connection_keys(server):
+    """
+    Connnections to differing connection keys should result in multiple connections.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+        response = await http.request("GET", "http://localhost:8000/")
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 2
+
+
+@pytest.mark.asyncio
+async def test_streaming_response_holds_connection(server):
+    """
+    A streaming request should hold the connection open until the response is read.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 0
+
+        await response.read()
+
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_multiple_concurrent_connections(server):
+    """
+    Multiple conncurrent requests should open multiple conncurrent connections.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response_a = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 0
+
+        response_b = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        assert http.num_active_connections == 2
+        assert http.num_keepalive_connections == 0
+
+        await response_b.read()
+        assert http.num_active_connections == 1
+        assert http.num_keepalive_connections == 1
+
+        await response_a.read()
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 2
+
+
+@pytest.mark.asyncio
+async def test_close_connections(server):
+    """
+    Using a `Connection: close` header should close the connection.
+    """
+    headers = [(b"connection", b"close")]
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers)
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 0
+
+
+@pytest.mark.asyncio
+async def test_standard_response_close(server):
+    """
+    A standard close should keep the connection open.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        await response.read()
+        await response.close()
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 1
+
+
+@pytest.mark.asyncio
+async def test_premature_response_close(server):
+    """
+    A premature close should close the connection.
+    """
+    async with httpcore.ConnectionPool() as http:
+        response = await http.request("GET", "http://127.0.0.1:8000/", stream=True)
+        await response.close()
+        assert http.num_active_connections == 0
+        assert http.num_keepalive_connections == 0