From: Tom Christie Date: Fri, 20 Dec 2019 15:14:55 +0000 (+0000) Subject: Rationalize backend Semaphore interface slightly (#660) X-Git-Tag: 0.10.0~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=36af9d959764a4b82d2a8bd30fb0904da0223e16;p=thirdparty%2Fhttpx.git Rationalize backend Semaphore interface slightly (#660) --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 633801f7..0265f1b2 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -4,8 +4,8 @@ import ssl import typing from ..config import Timeout -from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend +from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout +from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend SSL_MONKEY_PATCH_APPLIED = False @@ -171,26 +171,6 @@ class SocketStream(BaseSocketStream): self.stream_writer.close() -class PoolSemaphore(BasePoolSemaphore): - def __init__(self, max_value: int) -> None: - self.max_value = max_value - - @property - def semaphore(self) -> asyncio.BoundedSemaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = asyncio.BoundedSemaphore(value=self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - try: - await asyncio.wait_for(self.semaphore.acquire(), timeout) - except asyncio.TimeoutError: - raise PoolTimeout() - - def release(self) -> None: - self.semaphore.release() - - class AsyncioBackend(ConcurrencyBackend): def __init__(self) -> None: global SSL_MONKEY_PATCH_APPLIED @@ -269,8 +249,8 @@ class AsyncioBackend(ConcurrencyBackend): finally: self._loop = loop - def get_semaphore(self, max_value: int) -> BasePoolSemaphore: - return PoolSemaphore(max_value) + def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: + return Semaphore(max_value, exc_class) def create_event(self) -> BaseEvent: return Event() @@ -285,3 +265,24 @@ class Event(BaseEvent): async def wait(self) -> None: await self._event.wait() + + +class Semaphore(BaseSemaphore): + def __init__(self, max_value: int, exc_class: type) -> None: + self.max_value = max_value + self.exc_class = exc_class + + @property + def semaphore(self) -> asyncio.BoundedSemaphore: + if not hasattr(self, "_semaphore"): + self._semaphore = asyncio.BoundedSemaphore(value=self.max_value) + return self._semaphore + + async def acquire(self, timeout: float = None) -> None: + try: + await asyncio.wait_for(self.semaphore.acquire(), timeout) + except asyncio.TimeoutError: + raise self.exc_class() + + def release(self) -> None: + self.semaphore.release() diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py index c11c0646..32fcf798 100644 --- a/httpx/concurrency/auto.py +++ b/httpx/concurrency/auto.py @@ -6,7 +6,7 @@ import sniffio from ..config import Timeout from .base import ( BaseEvent, - BasePoolSemaphore, + BaseSemaphore, BaseSocketStream, ConcurrencyBackend, lookup_backend, @@ -44,13 +44,13 @@ class AutoBackend(ConcurrencyBackend): def time(self) -> float: return self.backend.time() - def get_semaphore(self, max_value: int) -> BasePoolSemaphore: - return self.backend.get_semaphore(max_value) - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: return await self.backend.run_in_threadpool(func, *args, **kwargs) + def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: + return self.backend.create_semaphore(max_value, exc_class) + def create_event(self) -> BaseEvent: return self.backend.create_event() diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 366e5749..16c55cc2 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -56,7 +56,8 @@ class BaseSocketStream: class BaseEvent: """ - An event object. Abstracts away any asyncio-specific interfaces. + An abstract interface for Event classes. + Abstracts away any asyncio-specific interfaces. """ def set(self) -> None: @@ -66,10 +67,9 @@ class BaseEvent: raise NotImplementedError() # pragma: no cover -class BasePoolSemaphore: +class BaseSemaphore: """ - A semaphore for use with connection pooling. - + An abstract interface for Semaphore classes. Abstracts away any asyncio-specific interfaces. """ @@ -102,9 +102,6 @@ class ConcurrencyBackend: def time(self) -> float: raise NotImplementedError() # pragma: no cover - def get_semaphore(self, max_value: int) -> BasePoolSemaphore: - raise NotImplementedError() # pragma: no cover - async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any ) -> typing.Any: @@ -115,5 +112,8 @@ class ConcurrencyBackend: ) -> typing.Any: raise NotImplementedError() # pragma: no cover + def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: + raise NotImplementedError() # pragma: no cover + def create_event(self) -> BaseEvent: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 0c64988d..8858ca42 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -5,8 +5,8 @@ import typing import trio from ..config import Timeout -from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend +from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout +from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend def none_as_inf(value: typing.Optional[float]) -> float: @@ -82,29 +82,6 @@ class SocketStream(BaseSocketStream): await self.stream.aclose() -class PoolSemaphore(BasePoolSemaphore): - def __init__(self, max_value: int): - self.max_value = max_value - - @property - def semaphore(self) -> trio.Semaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - timeout = none_as_inf(timeout) - - with trio.move_on_after(timeout): - await self.semaphore.acquire() - return - - raise PoolTimeout() - - def release(self) -> None: - self.semaphore.release() - - class TrioBackend(ConcurrencyBackend): async def open_tcp_stream( self, @@ -159,13 +136,37 @@ class TrioBackend(ConcurrencyBackend): def time(self) -> float: return trio.current_time() - def get_semaphore(self, max_value: int) -> BasePoolSemaphore: - return PoolSemaphore(max_value) + def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: + return Semaphore(max_value, exc_class) def create_event(self) -> BaseEvent: return Event() +class Semaphore(BaseSemaphore): + def __init__(self, max_value: int, exc_class: type): + self.max_value = max_value + self.exc_class = exc_class + + @property + def semaphore(self) -> trio.Semaphore: + if not hasattr(self, "_semaphore"): + self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value) + return self._semaphore + + async def acquire(self, timeout: float = None) -> None: + timeout = none_as_inf(timeout) + + with trio.move_on_after(timeout): + await self.semaphore.acquire() + return + + raise self.exc_class() + + def release(self) -> None: + self.semaphore.release() + + class Event(BaseEvent): def __init__(self) -> None: self._event = trio.Event() diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index db23640b..db576d22 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -1,7 +1,8 @@ import typing -from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend +from ..concurrency.base import BaseSemaphore, ConcurrencyBackend, lookup_backend from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes +from ..exceptions import PoolTimeout from ..models import Origin, Request, Response from ..utils import get_logger from .base import Dispatcher @@ -13,7 +14,7 @@ CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] logger = get_logger(__name__) -class NullSemaphore(BasePoolSemaphore): +class NullSemaphore(BaseSemaphore): async def acquire(self, timeout: float = None) -> None: return @@ -106,15 +107,18 @@ class ConnectionPool(Dispatcher): self.next_keepalive_check = 0.0 @property - def max_connections(self) -> BasePoolSemaphore: + def max_connections(self) -> BaseSemaphore: # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_max_connections"): limit = self.pool_limits.hard_limit - if not limit: - self._max_connections = NullSemaphore() # type: BasePoolSemaphore + if limit: + self._max_connections = self.backend.create_semaphore( + limit, exc_class=PoolTimeout + ) else: - self._max_connections = self.backend.get_semaphore(limit) + self._max_connections = NullSemaphore() + return self._max_connections @property