]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add PoolSemaphore
authorTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 11:57:18 +0000 (12:57 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 25 Apr 2019 11:57:18 +0000 (12:57 +0100)
httpcore/connection_pool.py
httpcore/streams.py
tests/test_http2.py

index 6ec30289912e00028cf6f7735f0794685d4ac2a1..894212ab6dea9008f1741bd232f952708aa4a9ab 100644 (file)
@@ -1,4 +1,3 @@
-import asyncio
 import typing
 
 from .config import (
@@ -13,6 +12,7 @@ from .config import (
 from .connection import HTTPConnection
 from .exceptions import PoolTimeout
 from .models import Client, Origin, Request, Response
+from .streams import PoolSemaphore
 
 
 class ConnectionPool(Client):
@@ -32,9 +32,7 @@ class ConnectionPool(Client):
         self._keepalive_connections = (
             {}
         )  # type: typing.Dict[Origin, typing.List[HTTPConnection]]
-        self._max_connections = ConnectionSemaphore(
-            max_connections=self.limits.hard_limit
-        )
+        self._max_connections = PoolSemaphore(limits, timeout)
 
     async def send(
         self,
@@ -62,15 +60,7 @@ class ConnectionPool(Client):
             self.num_active_connections += 1
 
         except (KeyError, IndexError):
-            if timeout is None:
-                pool_timeout = self.timeout.pool_timeout
-            else:
-                pool_timeout = timeout.pool_timeout
-
-            try:
-                await asyncio.wait_for(self._max_connections.acquire(), pool_timeout)
-            except asyncio.TimeoutError:
-                raise PoolTimeout()
+            await self._max_connections.acquire(timeout)
             connection = HTTPConnection(
                 origin,
                 ssl=self.ssl,
@@ -108,25 +98,3 @@ class ConnectionPool(Client):
         self._keepalive_connections.clear()
         for connection in all_connections:
             await connection.close()
-
-
-class ConnectionSemaphore:
-    def __init__(self, max_connections: int = None):
-        self.max_connections = max_connections
-
-    @property
-    def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
-        if not hasattr(self, "_semaphore"):
-            if self.max_connections is None:
-                self._semaphore = None
-            else:
-                self._semaphore = asyncio.BoundedSemaphore(value=self.max_connections)
-        return self._semaphore
-
-    async def acquire(self) -> None:
-        if self.semaphore is not None:
-            await self.semaphore.acquire()
-
-    def release(self) -> None:
-        if self.semaphore is not None:
-            self.semaphore.release()
index 5a9a0abb314f0002b3e6e56d0c96fde8a4c50cb1..cba51fd7315618ffdc14396ad8ec8d0117128d99 100644 (file)
@@ -2,7 +2,9 @@
 The `Reader` and `Writer` classes here provide a lightweight layer over
 `asyncio.StreamReader` and `asyncio.StreamWriter`.
 
-They help encapsulate the timeout logic, make it easier to unit-test
+Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
+
+These classes help encapsulate the timeout logic, make it easier to unit-test
 protocols, and help keep the rest of the package more `async`/`await`
 based, and less strictly `asyncio`-specific.
 """
@@ -11,8 +13,8 @@ import enum
 import ssl
 import typing
 
-from .config import TimeoutConfig, DEFAULT_TIMEOUT_CONFIG
-from .exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
+from .exceptions import ConnectTimeout, ReadTimeout, PoolTimeout, WriteTimeout
 
 OptionalTimeout = typing.Optional[TimeoutConfig]
 
@@ -38,6 +40,17 @@ class BaseWriter:
         raise NotImplementedError()  # pragma: no cover
 
 
+class BasePoolSemaphore:
+    def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
+        raise NotImplementedError()  # pragma: no cover
+
+    async def acquire(self, timeout: OptionalTimeout = None) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
+    def release(self) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
+
 class Reader(BaseReader):
     def __init__(
         self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
@@ -86,6 +99,40 @@ class Writer(BaseWriter):
         self.stream_writer.close()
 
 
+class PoolSemaphore(BasePoolSemaphore):
+    def __init__(self, limits: PoolLimits, timeout: TimeoutConfig):
+        self.limits = limits
+        self.timeout = timeout
+
+    @property
+    def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
+        if not hasattr(self, "_semaphore"):
+            max_connections = self.limits.hard_limit
+            if max_connections is None:
+                self._semaphore = None
+            else:
+                self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
+        return self._semaphore
+
+    async def acquire(self, timeout: OptionalTimeout = None) -> None:
+        if self.semaphore is None:
+            return
+
+        if timeout is None:
+            timeout = self.timeout
+
+        try:
+            await asyncio.wait_for(self.semaphore.acquire(), timeout.pool_timeout)
+        except asyncio.TimeoutError:
+            raise PoolTimeout()
+
+    def release(self) -> None:
+        if self.semaphore is None:
+            return
+
+        self.semaphore.release()
+
+
 async def connect(
     hostname: str,
     port: int,
index fbdb52d8b3b04b2db83667ad487d5bf5601f45dc..f17d0f98225ab3f63f142678bc16288e1820d23e 100644 (file)
@@ -1,8 +1,9 @@
+import json
+
 import h2.config
 import h2.connection
 import h2.events
 import pytest
-import json
 
 import httpcore
 
@@ -61,15 +62,15 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
         request_headers = dict(request["headers"])
         request_data = request["data"]
 
-        response_body = json.dumps({
-            "method": request_headers[b":method"].decode(),
-            "path": request_headers[b":path"].decode(),
-            "body": request_data.decode()
-        }).encode()
+        response_body = json.dumps(
+            {
+                "method": request_headers[b":method"].decode(),
+                "path": request_headers[b":path"].decode(),
+                "body": request_data.decode(),
+            }
+        ).encode()
 
-        response_headers = (
-            (b":status", b"200"),
-        )
+        response_headers = ((b":status", b"200"),)
         self.conn.send_headers(stream_id, response_headers)
         self.conn.send_data(stream_id, response_body, end_stream=True)
         self.buffer += self.conn.data_to_send()
@@ -79,7 +80,9 @@ class MockServer(httpcore.BaseReader, httpcore.BaseWriter):
 async def test_http2_get_request():
     server = MockServer()
     origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+    async with httpcore.HTTP2Connection(
+        reader=server, writer=server, origin=origin
+    ) as client:
         response = await client.request("GET", "http://example.org")
     assert response.status_code == 200
     assert json.loads(response.body) == {"method": "GET", "path": "/", "body": ""}
@@ -89,17 +92,25 @@ async def test_http2_get_request():
 async def test_http2_post_request():
     server = MockServer()
     origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+    async with httpcore.HTTP2Connection(
+        reader=server, writer=server, origin=origin
+    ) as client:
         response = await client.request("POST", "http://example.org", body=b"<data>")
     assert response.status_code == 200
-    assert json.loads(response.body) == {"method": "POST", "path": "/", "body": "<data>"}
+    assert json.loads(response.body) == {
+        "method": "POST",
+        "path": "/",
+        "body": "<data>",
+    }
 
 
 @pytest.mark.asyncio
 async def test_http2_multiple_requests():
     server = MockServer()
     origin = httpcore.Origin("http://example.org")
-    async with httpcore.HTTP2Connection(reader=server, writer=server, origin=origin) as client:
+    async with httpcore.HTTP2Connection(
+        reader=server, writer=server, origin=origin
+    ) as client:
         response_1 = await client.request("GET", "http://example.org/1")
         response_2 = await client.request("GET", "http://example.org/2")
         response_3 = await client.request("GET", "http://example.org/3")