import sys
import typing
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
class PoolSemaphore(BasePoolSemaphore):
- def __init__(self, pool_limits: PoolLimits):
- self.pool_limits = pool_limits
+ def __init__(self, max_value: int) -> None:
+ self.max_value = max_value
@property
- def semaphore(self) -> typing.Optional[asyncio.BoundedSemaphore]:
+ def semaphore(self) -> asyncio.BoundedSemaphore:
if not hasattr(self, "_semaphore"):
- max_connections = self.pool_limits.hard_limit
- if max_connections is None:
- self._semaphore = None
- else:
- self._semaphore = asyncio.BoundedSemaphore(value=max_connections)
+ self._semaphore = asyncio.BoundedSemaphore(value=self.max_value)
return self._semaphore
async def acquire(self, timeout: float = None) -> None:
- if self.semaphore is None:
- return
-
try:
await asyncio.wait_for(self.semaphore.acquire(), timeout)
except asyncio.TimeoutError:
raise PoolTimeout()
def release(self) -> None:
- if self.semaphore is None:
- return
-
self.semaphore.release()
finally:
self._loop = loop
- def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
- return PoolSemaphore(limits)
+ def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+ return PoolSemaphore(max_value)
def create_event(self) -> BaseEvent:
return typing.cast(BaseEvent, asyncio.Event())
import sniffio
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
from .base import (
BaseEvent,
BasePoolSemaphore,
) -> BaseSocketStream:
return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout)
- def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
- return self.backend.get_semaphore(limits)
+ def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+ return self.backend.get_semaphore(max_value)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
import ssl
import typing
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
def lookup_backend(
) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover
- def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+ def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
async def run_in_threadpool(
import trio
-from ..config import PoolLimits, Timeout
+from ..config import Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
class PoolSemaphore(BasePoolSemaphore):
- def __init__(self, pool_limits: PoolLimits):
- self.pool_limits = pool_limits
+ def __init__(self, max_value: int):
+ self.max_value = max_value
@property
- def semaphore(self) -> typing.Optional[trio.Semaphore]:
+ def semaphore(self) -> trio.Semaphore:
if not hasattr(self, "_semaphore"):
- max_connections = self.pool_limits.hard_limit
- if max_connections is None:
- self._semaphore = None
- else:
- self._semaphore = trio.Semaphore(
- max_connections, max_value=max_connections
- )
+ self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value)
return self._semaphore
async def acquire(self, timeout: float = None) -> None:
- if self.semaphore is None:
- return
-
timeout = none_as_inf(timeout)
with trio.move_on_after(timeout):
raise PoolTimeout()
def release(self) -> None:
- if self.semaphore is None:
- return
-
self.semaphore.release()
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)
- def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
- return PoolSemaphore(limits)
+ def get_semaphore(self, max_value: int) -> BasePoolSemaphore:
+ return PoolSemaphore(max_value)
def create_event(self) -> BaseEvent:
return Event()
logger = get_logger(__name__)
+class NullSemaphore(BasePoolSemaphore):
+ async def acquire(self, timeout: float = None) -> None:
+ return
+
+ def release(self) -> None:
+ return
+
+
class ConnectionStore:
"""
We need to maintain collections of connections in a way that allows us to:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
if not hasattr(self, "_max_connections"):
- self._max_connections = self.backend.get_semaphore(self.pool_limits)
+ limit = self.pool_limits.hard_limit
+ if not limit:
+ self._max_connections = NullSemaphore() # type: BasePoolSemaphore
+ else:
+ self._max_connections = self.backend.get_semaphore(limit)
return self._max_connections
@property