import typing
from ..config import Timeout
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
+from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
SSL_MONKEY_PATCH_APPLIED = False
self.stream_writer.close()
-class PoolSemaphore(BasePoolSemaphore):
- def __init__(self, max_value: int) -> None:
- self.max_value = max_value
-
- @property
- def semaphore(self) -> asyncio.BoundedSemaphore:
- if not hasattr(self, "_semaphore"):
- self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
- return self._semaphore
-
- async def acquire(self, timeout: float = None) -> None:
- try:
- await asyncio.wait_for(self.semaphore.acquire(), timeout)
- except asyncio.TimeoutError:
- raise PoolTimeout()
-
- def release(self) -> None:
- self.semaphore.release()
-
-
class AsyncioBackend(ConcurrencyBackend):
def __init__(self) -> None:
global SSL_MONKEY_PATCH_APPLIED
finally:
self._loop = loop
- def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
- return PoolSemaphore(max_value)
+ def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+ return Semaphore(max_value, exc_class)
def create_event(self) -> BaseEvent:
return Event()
async def wait(self) -> None:
await self._event.wait()
+
+
+class Semaphore(BaseSemaphore):
+ def __init__(self, max_value: int, exc_class: type) -> None:
+ self.max_value = max_value
+ self.exc_class = exc_class
+
+ @property
+ def semaphore(self) -> asyncio.BoundedSemaphore:
+ if not hasattr(self, "_semaphore"):
+ self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
+ return self._semaphore
+
+ async def acquire(self, timeout: float = None) -> None:
+ try:
+ await asyncio.wait_for(self.semaphore.acquire(), timeout)
+ except asyncio.TimeoutError:
+ raise self.exc_class()
+
+ def release(self) -> None:
+ self.semaphore.release()
from ..config import Timeout
from .base import (
BaseEvent,
- BasePoolSemaphore,
+ BaseSemaphore,
BaseSocketStream,
ConcurrencyBackend,
lookup_backend,
def time(self) -> float:
return self.backend.time()
- def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
- return self.backend.get_semaphore(max_value)
-
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await self.backend.run_in_threadpool(func, *args, **kwargs)
+ def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+ return self.backend.create_semaphore(max_value, exc_class)
+
def create_event(self) -> BaseEvent:
return self.backend.create_event()
class BaseEvent:
"""
- An event object. Abstracts away any asyncio-specific interfaces.
+ An abstract interface for Event classes.
+ Abstracts away any asyncio-specific interfaces.
"""
def set(self) -> None:
raise NotImplementedError() # pragma: no cover
-class BasePoolSemaphore:
+class BaseSemaphore:
"""
- A semaphore for use with connection pooling.
-
+ An abstract interface for Semaphore classes.
Abstracts away any asyncio-specific interfaces.
"""
def time(self) -> float:
raise NotImplementedError() # pragma: no cover
- def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
- raise NotImplementedError() # pragma: no cover
-
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
) -> typing.Any:
raise NotImplementedError() # pragma: no cover
+ def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+ raise NotImplementedError() # pragma: no cover
+
def create_event(self) -> BaseEvent:
raise NotImplementedError() # pragma: no cover
import trio
from ..config import Timeout
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
+from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
+from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
def none_as_inf(value: typing.Optional[float]) -> float:
await self.stream.aclose()
-class PoolSemaphore(BasePoolSemaphore):
- def __init__(self, max_value: int):
- self.max_value = max_value
-
- @property
- def semaphore(self) -> trio.Semaphore:
- if not hasattr(self, "_semaphore"):
- self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
- return self._semaphore
-
- async def acquire(self, timeout: float = None) -> None:
- timeout = none_as_inf(timeout)
-
- with trio.move_on_after(timeout):
- await self.semaphore.acquire()
- return
-
- raise PoolTimeout()
-
- def release(self) -> None:
- self.semaphore.release()
-
-
class TrioBackend(ConcurrencyBackend):
async def open_tcp_stream(
self,
def time(self) -> float:
return trio.current_time()
- def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
- return PoolSemaphore(max_value)
+ def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
+ return Semaphore(max_value, exc_class)
def create_event(self) -> BaseEvent:
return Event()
+class Semaphore(BaseSemaphore):
+ def __init__(self, max_value: int, exc_class: type):
+ self.max_value = max_value
+ self.exc_class = exc_class
+
+ @property
+ def semaphore(self) -> trio.Semaphore:
+ if not hasattr(self, "_semaphore"):
+ self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
+ return self._semaphore
+
+ async def acquire(self, timeout: float = None) -> None:
+ timeout = none_as_inf(timeout)
+
+ with trio.move_on_after(timeout):
+ await self.semaphore.acquire()
+ return
+
+ raise self.exc_class()
+
+ def release(self) -> None:
+ self.semaphore.release()
+
+
class Event(BaseEvent):
def __init__(self) -> None:
self._event = trio.Event()
import typing
-from ..concurrency.base import BasePoolSemaphore, ConcurrencyBackend, lookup_backend
+from ..concurrency.base import BaseSemaphore, ConcurrencyBackend, lookup_backend
from ..config import DEFAULT_POOL_LIMITS, CertTypes, PoolLimits, Timeout, VerifyTypes
+from ..exceptions import PoolTimeout
from ..models import Origin, Request, Response
from ..utils import get_logger
from .base import Dispatcher
logger = get_logger(__name__)
-class NullSemaphore(BasePoolSemaphore):
+class NullSemaphore(BaseSemaphore):
async def acquire(self, timeout: float = None) -> None:
return
self.next_keepalive_check = 0.0
@property
- def max_connections(self) -> BasePoolSemaphore:
+ def max_connections(self) -> BaseSemaphore:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_connections"):
limit = self.pool_limits.hard_limit
- if not limit:
- self._max_connections = NullSemaphore() # type: BasePoolSemaphore
+ if limit:
+ self._max_connections = self.backend.create_semaphore(
+ limit, exc_class=PoolTimeout
+ )
else:
- self._max_connections = self.backend.get_semaphore(limit)
+ self._max_connections = NullSemaphore()
+
return self._max_connections
@property