]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Refactor get_semaphore (#625)
authorTom Christie <tom@tomchristie.com>
Wed, 11 Dec 2019 12:30:56 +0000 (12:30 +0000)
committerGitHub <noreply@github.com>
Wed, 11 Dec 2019 12:30:56 +0000 (12:30 +0000)
httpx/concurrency/asyncio.py
httpx/concurrency/auto.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/connection_pool.py

index e06dc2091fd2eaadac9433043a2b56d6ed7e1d29..1912fc0c1c9947e85407516478d88a6a4c76c23d 100644 (file)
@@ -4,7 +4,7 @@ import ssl
 import sys
 import typing
 
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
 
@@ -168,32 +168,22 @@ class SocketStream(BaseSocketStream):
 
 
 class PoolSemaphore(BasePoolSemaphore):
-    def __init__(self, pool_limits: PoolLimits):
-        self.pool_limits = pool_limits
+    def __init__(self, max_value: int) -> None:
+        self.max_value = max_value
 
     @property
-    def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
+    def semaphore(self) -> asyncio.BoundedSemaphore:
         if not hasattr(self, "_semaphore"):
-            max_connections = self.pool_limits.hard_limit
-            if max_connections is None:
-                self._semaphore = None
-            else:
-                self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
+            self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
         return self._semaphore
 
     async def acquire(self, timeout: float = None) -> None:
-        if self.semaphore is None:
-            return
-
         try:
             await asyncio.wait_for(self.semaphore.acquire(), timeout)
         except asyncio.TimeoutError:
             raise PoolTimeout()
 
     def release(self) -> None:
-        if self.semaphore is None:
-            return
-
         self.semaphore.release()
 
 
@@ -271,8 +261,8 @@ class AsyncioBackend(ConcurrencyBackend):
         finally:
             self._loop = loop
 
-    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
-        return PoolSemaphore(limits)
+    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+        return PoolSemaphore(max_value)
 
     def create_event(self) -> BaseEvent:
         return typing.cast(BaseEvent, asyncio.Event())
index 3b57e5674dde029d01019ad3da115722f5865ae0..c90ce2410a86d727c92a8b99d8afb483b0c4d1dd 100644 (file)
@@ -3,7 +3,7 @@ import typing
 
 import sniffio
 
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
 from .base import (
     BaseEvent,
     BasePoolSemaphore,
@@ -41,8 +41,8 @@ class AutoBackend(ConcurrencyBackend):
     ) -> BaseSocketStream:
         return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
 
-    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
-        return self.backend.get_semaphore(limits)
+    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+        return self.backend.get_semaphore(max_value)
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index a735c87e73872af582738fe83e1cd05d30857af7..f0c7dc9531ccc21809b257fe75bdefca3316a5b6 100644 (file)
@@ -1,7 +1,7 @@
 import ssl
 import typing
 
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
 
 
 def lookup_backend(
@@ -105,7 +105,7 @@ class ConcurrencyBackend:
     ) -> BaseSocketStream:
         raise NotImplementedError()  # pragma: no cover
 
-    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
     async def run_in_threadpool(
index 4fa3001983f341c9cb45ddf4646a1b451a022cac..d607f9c7ffe172c77f2c28af482eee4d2cb08dbe 100644 (file)
@@ -4,7 +4,7 @@ import typing
 
 import trio
 
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
 
@@ -83,25 +83,16 @@ class SocketStream(BaseSocketStream):
 
 
 class PoolSemaphore(BasePoolSemaphore):
-    def __init__(self, pool_limits: PoolLimits):
-        self.pool_limits = pool_limits
+    def __init__(self, max_value: int):
+        self.max_value = max_value
 
     @property
-    def semaphore(self) -> typing.Optional[trio.Semaphore]:
+    def semaphore(self) -> trio.Semaphore:
         if not hasattr(self, "_semaphore"):
-            max_connections = self.pool_limits.hard_limit
-            if max_connections is None:
-                self._semaphore = None
-            else:
-                self._semaphore = trio.Semaphore(
-                    max_connections, max_value=max_connections
-                )
+            self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
         return self._semaphore
 
     async def acquire(self, timeout: float = None) -> None:
-        if self.semaphore is None:
-            return
-
         timeout = none_as_inf(timeout)
 
         with trio.move_on_after(timeout):
@@ -111,9 +102,6 @@ class PoolSemaphore(BasePoolSemaphore):
         raise PoolTimeout()
 
     def release(self) -> None:
-        if self.semaphore is None:
-            return
-
         self.semaphore.release()
 
 
@@ -168,8 +156,8 @@ class TrioBackend(ConcurrencyBackend):
             functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
         )
 
-    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
-        return PoolSemaphore(limits)
+    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+        return PoolSemaphore(max_value)
 
     def create_event(self) -> BaseEvent:
         return Event()
index f11137fbd02ca18f5d13c1f84f51c6b804ef77e6..4fe3998814d0380b47e2778620d7142219dd56be 100644 (file)
@@ -13,6 +13,14 @@ CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
 logger = get_logger(__name__)
 
 
+class NullSemaphore(BasePoolSemaphore):
+    async def acquire(self, timeout: float = None) -> None:
+        return
+
+    def release(self) -> None:
+        return
+
+
 class ConnectionStore:
     """
     We need to maintain collections of connections in a way that allows us to:
@@ -99,7 +107,11 @@ class ConnectionPool(Dispatcher):
         # We do this lazily, to make sure backend autodetection always
         # runs within an async context.
         if not hasattr(self, "_max_connections"):
-            self._max_connections = self.backend.get_semaphore(self.pool_limits)
+            limit = self.pool_limits.hard_limit
+            if not limit:
+                self._max_connections = NullSemaphore()  # type: BasePoolSemaphore
+            else:
+                self._max_connections = self.backend.get_semaphore(limit)
         return self._max_connections
 
     @property