From 00a875db934b26868a8bb740d6bbf5ce344b1f65 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Dec 2019 12:30:56 +0000 Subject: [PATCH] Refactor get_semaphore (#625) --- httpx/concurrency/asyncio.py | 24 +++++++----------------- httpx/concurrency/auto.py | 6 +++--- httpx/concurrency/base.py | 4 ++-- httpx/concurrency/trio.py | 26 +++++++------------------- httpx/dispatch/connection_pool.py | 14 +++++++++++++- 5 files changed, 32 insertions(+), 42 deletions(-) diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index e06dc209..1912fc0c 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -4,7 +4,7 @@ import ssl import sys import typing -from ..config import PoolLimits, Timeout +from ..config import Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend @@ -168,32 +168,22 @@ class SocketStream(BaseSocketStream): class PoolSemaphore(BasePoolSemaphore): - def __init__(self, pool_limits: PoolLimits): - self.pool_limits = pool_limits + def __init__(self, max_value: int) -> None: + self.max_value = max_value @property - def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]: + def semaphore(self) -> asyncio.BoundedSemaphore: if not hasattr(self, "_semaphore"): - max_connections = self.pool_limits.hard_limit - if max_connections is None: - self._semaphore = None - else: - self._semaphore = asyncio.BoundedSemaphore(value=max_connections) + self._semaphore = asyncio.BoundedSemaphore(value=self.max_value) return self._semaphore async def acquire(self, timeout: float = None) -> None: - if self.semaphore is None: - return - try: await asyncio.wait_for(self.semaphore.acquire(), timeout) except asyncio.TimeoutError: raise PoolTimeout() def release(self) -> None: - if self.semaphore is None: - return - self.semaphore.release() @@ -271,8 +261,8 @@ class AsyncioBackend(ConcurrencyBackend): finally: self._loop = loop - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: - return PoolSemaphore(limits) + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: + return PoolSemaphore(max_value) def create_event(self) -> BaseEvent: return typing.cast(BaseEvent, asyncio.Event()) diff --git a/httpx/concurrency/auto.py b/httpx/concurrency/auto.py index 3b57e567..c90ce241 100644 --- a/httpx/concurrency/auto.py +++ b/httpx/concurrency/auto.py @@ -3,7 +3,7 @@ import typing import sniffio -from ..config import PoolLimits, Timeout +from ..config import Timeout from .base import ( BaseEvent, BasePoolSemaphore, @@ -41,8 +41,8 @@ class AutoBackend(ConcurrencyBackend): ) -> BaseSocketStream: return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout) - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: - return self.backend.get_semaphore(limits) + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: + return self.backend.get_semaphore(max_value) async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index a735c87e..f0c7dc95 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -1,7 +1,7 @@ import ssl import typing -from ..config import PoolLimits, Timeout +from ..config import Timeout def lookup_backend( @@ -105,7 +105,7 @@ class ConcurrencyBackend: ) -> BaseSocketStream: raise NotImplementedError() # pragma: no cover - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: raise NotImplementedError() # pragma: no cover async def run_in_threadpool( diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index 4fa30019..d607f9c7 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -4,7 +4,7 @@ import typing import trio -from ..config import PoolLimits, Timeout +from ..config import Timeout from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend @@ -83,25 +83,16 @@ class SocketStream(BaseSocketStream): class PoolSemaphore(BasePoolSemaphore): - def __init__(self, pool_limits: PoolLimits): - self.pool_limits = pool_limits + def __init__(self, max_value: int): + self.max_value = max_value @property - def semaphore(self) -> typing.Optional[trio.Semaphore]: + def semaphore(self) -> trio.Semaphore: if not hasattr(self, "_semaphore"): - max_connections = self.pool_limits.hard_limit - if max_connections is None: - self._semaphore = None - else: - self._semaphore = trio.Semaphore( - max_connections, max_value=max_connections - ) + self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value) return self._semaphore async def acquire(self, timeout: float = None) -> None: - if self.semaphore is None: - return - timeout = none_as_inf(timeout) with trio.move_on_after(timeout): @@ -111,9 +102,6 @@ class PoolSemaphore(BasePoolSemaphore): raise PoolTimeout() def release(self) -> None: - if self.semaphore is None: - return - self.semaphore.release() @@ -168,8 +156,8 @@ class TrioBackend(ConcurrencyBackend): functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args ) - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: - return PoolSemaphore(limits) + def get_semaphore(self, max_value: int) -> BasePoolSemaphore: + return PoolSemaphore(max_value) def create_event(self) -> BaseEvent: return Event() diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index f11137fb..4fe39988 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -13,6 +13,14 @@ CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]] logger = get_logger(__name__) +class NullSemaphore(BasePoolSemaphore): + async def acquire(self, timeout: float = None) -> None: + return + + def release(self) -> None: + return + + class ConnectionStore: """ We need to maintain collections of connections in a way that allows us to: @@ -99,7 +107,11 @@ class ConnectionPool(Dispatcher): # We do this lazily, to make sure backend autodetection always # runs within an async context. if not hasattr(self, "_max_connections"): - self._max_connections = self.backend.get_semaphore(self.pool_limits) + limit = self.pool_limits.hard_limit + if not limit: + self._max_connections = NullSemaphore() # type: BasePoolSemaphore + else: + self._max_connections = self.backend.get_semaphore(limit) return self._max_connections @property -- 2.47.3