]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Rationalize backend Semaphore interface slightly (#660)
authorTom Christie <tom@tomchristie.com>
Fri, 20 Dec 2019 15:14:55 +0000 (15:14 +0000)
committerGitHub <noreply@github.com>
Fri, 20 Dec 2019 15:14:55 +0000 (15:14 +0000)
httpx/concurrency/asyncio.py
httpx/concurrency/auto.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/connection_pool.py

index 633801f73ac777867fb0c09963cdac6f7259c086..0265f1b2f1ca9f815832d86a483e5de98f371ab8 100644 (file)
@@ -4,8 +4,8 @@ import ssl
 import typing
 
 from ..config import Timeout
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
+from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -171,26 +171,6 @@ class SocketStream(BaseSocketStream):
         self.stream_writer.close()
 
 
-class PoolSemaphore(BasePoolSemaphore):
-    def __init__(self, max_value: int) -> None:
-        self.max_value = max_value
-
-    @property
-    def semaphore(self) -> asyncio.BoundedSemaphore:
-        if not hasattr(self, "_semaphore"):
-            self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
-        return self._semaphore
-
-    async def acquire(self, timeout: float = None) -> None:
-        try:
-            await asyncio.wait_for(self.semaphore.acquire(), timeout)
-        except asyncio.TimeoutError:
-            raise PoolTimeout()
-
-    def release(self) -> None:
-        self.semaphore.release()
-
-
 class AsyncioBackend(ConcurrencyBackend):
     def __init__(self) -> None:
         global SSL_MONKEY_PATCH_APPLIED
@@ -269,8 +249,8 @@ class AsyncioBackend(ConcurrencyBackend):
         finally:
             self._loop = loop
 
-    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
-        return PoolSemaphore(max_value)
+    def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+        return Semaphore(max_value, exc_class)
 
     def create_event(self) -> BaseEvent:
         return Event()
@@ -285,3 +265,24 @@ class Event(BaseEvent):
 
     async def wait(self) -> None:
         await self._event.wait()
+
+
+class Semaphore(BaseSemaphore):
+    def __init__(self, max_value: int, exc_class: type) -> None:
+        self.max_value = max_value
+        self.exc_class = exc_class
+
+    @property
+    def semaphore(self) -> asyncio.BoundedSemaphore:
+        if not hasattr(self, "_semaphore"):
+            self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
+        return self._semaphore
+
+    async def acquire(self, timeout: float = None) -> None:
+        try:
+            await asyncio.wait_for(self.semaphore.acquire(), timeout)
+        except asyncio.TimeoutError:
+            raise self.exc_class()
+
+    def release(self) -> None:
+        self.semaphore.release()
index c11c06469ecc1c7f41a55895bd938bdc128823b8..32fcf798e587b92cae5d97291984f060c27c01e9 100644 (file)
@@ -6,7 +6,7 @@ import sniffio
 from ..config import Timeout
 from .base import (
     BaseEvent,
-    BasePoolSemaphore,
+    BaseSemaphore,
     BaseSocketStream,
     ConcurrencyBackend,
     lookup_backend,
@@ -44,13 +44,13 @@ class AutoBackend(ConcurrencyBackend):
     def time(self) -> float:
         return self.backend.time()
 
-    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
-        return self.backend.get_semaphore(max_value)
-
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
         return await self.backend.run_in_threadpool(func, *args, **kwargs)
 
+    def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+        return self.backend.create_semaphore(max_value, exc_class)
+
     def create_event(self) -> BaseEvent:
         return self.backend.create_event()
index 366e5749a0e72e8d2ab2dadb3169efec916fe9f5..16c55cc28bde7df031ceedfbe4c43afd6d16c874 100644 (file)
@@ -56,7 +56,8 @@ class BaseSocketStream:
 
 class BaseEvent:
     """
-    An event object. Abstracts away any asyncio-specific interfaces.
+    An abstract interface for Event classes.
+    Abstracts away any asyncio-specific interfaces.
     """
 
     def set(self) -> None:
@@ -66,10 +67,9 @@ class BaseEvent:
         raise NotImplementedError()  # pragma: no cover
 
 
-class BasePoolSemaphore:
+class BaseSemaphore:
     """
-    A semaphore for use with connection pooling.
-
+    An abstract interface for Semaphore classes.
     Abstracts away any asyncio-specific interfaces.
     """
 
@@ -102,9 +102,6 @@ class ConcurrencyBackend:
     def time(self) -> float:
         raise NotImplementedError()  # pragma: no cover
 
-    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
-        raise NotImplementedError()  # pragma: no cover
-
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
@@ -115,5 +112,8 @@ class ConcurrencyBackend:
     ) -> typing.Any:
         raise NotImplementedError()  # pragma: no cover
 
+    def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+        raise NotImplementedError()  # pragma: no cover
+
     def create_event(self) -> BaseEvent:
         raise NotImplementedError()  # pragma: no cover
index 0c64988dbee852fe2f76ed862fe19acced54c80d..8858ca426fcd3d00682008792ca0e28c4216be92 100644 (file)
@@ -5,8 +5,8 @@ import typing
 import trio
 
 from ..config import Timeout
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
+from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 
 def none_as_inf(value: typing.Optional[float]) -> float:
@@ -82,29 +82,6 @@ class SocketStream(BaseSocketStream):
         await self.stream.aclose()
 
 
-class PoolSemaphore(BasePoolSemaphore):
-    def __init__(self, max_value: int):
-        self.max_value = max_value
-
-    @property
-    def semaphore(self) -> trio.Semaphore:
-        if not hasattr(self, "_semaphore"):
-            self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
-        return self._semaphore
-
-    async def acquire(self, timeout: float = None) -> None:
-        timeout = none_as_inf(timeout)
-
-        with trio.move_on_after(timeout):
-            await self.semaphore.acquire()
-            return
-
-        raise PoolTimeout()
-
-    def release(self) -> None:
-        self.semaphore.release()
-
-
 class TrioBackend(ConcurrencyBackend):
     async def open_tcp_stream(
         self,
@@ -159,13 +136,37 @@ class TrioBackend(ConcurrencyBackend):
     def time(self) -> float:
         return trio.current_time()
 
-    def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
-        return PoolSemaphore(max_value)
+    def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+        return Semaphore(max_value, exc_class)
 
     def create_event(self) -> BaseEvent:
         return Event()
 
 
+class Semaphore(BaseSemaphore):
+    def __init__(self, max_value: int, exc_class: type):
+        self.max_value = max_value
+        self.exc_class = exc_class
+
+    @property
+    def semaphore(self) -> trio.Semaphore:
+        if not hasattr(self, "_semaphore"):
+            self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
+        return self._semaphore
+
+    async def acquire(self, timeout: float = None) -> None:
+        timeout = none_as_inf(timeout)
+
+        with trio.move_on_after(timeout):
+            await self.semaphore.acquire()
+            return
+
+        raise self.exc_class()
+
+    def release(self) -> None:
+        self.semaphore.release()
+
+
 class Event(BaseEvent):
     def __init__(self) -> None:
         self._event = trio.Event()
index db23640b574839bc558053282611436500766529..db576d2288eb8d373d1a85bbf5166aaae4edba81 100644 (file)
@@ -1,7 +1,8 @@
 import typing
 
-from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
+from ..concurrency.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
 from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
+from ..exceptions import PoolTimeout
 from ..models import Origin, Request, Response
 from ..utils import get_logger
 from .base import Dispatcher
@@ -13,7 +14,7 @@ CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
 logger = get_logger(__name__)
 
 
-class NullSemaphore(BasePoolSemaphore):
+class NullSemaphore(BaseSemaphore):
     async def acquire(self, timeout: float = None) -> None:
         return
 
@@ -106,15 +107,18 @@ class ConnectionPool(Dispatcher):
         self.next_keepalive_check = 0.0
 
     @property
-    def max_connections(self) -> BasePoolSemaphore:
+    def max_connections(self) -> BaseSemaphore:
         # We do this lazily, to make sure backend autodetection always
         # runs within an async context.
         if not hasattr(self, "_max_connections"):
             limit = self.pool_limits.hard_limit
-            if not limit:
-                self._max_connections = NullSemaphore()  # type: BasePoolSemaphore
+            if limit:
+                self._max_connections = self.backend.create_semaphore(
+                    limit, exc_class=PoolTimeout
+                )
             else:
-                self._max_connections = self.backend.get_semaphore(limit)
+                self._max_connections = NullSemaphore()
+
         return self._max_connections
 
     @property