from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter, Protocol
+from .interfaces import (
+ BasePoolSemaphore,
+ BaseReader,
+ BaseWriter,
+ ConcurrencyBackend,
+ Protocol,
+)
OptionalTimeout = typing.Optional[TimeoutConfig]
-# Monky-patch for https://bugs.python.org/issue36709
-#
-# This prevents console errors when outstanding HTTPS connections
-# still exist at the point of exiting.
-#
-# Clients which have been opened using a `with` block, or which have
-# had `close()` closed, will not exhibit this issue in the first place.
+SSL_MONKEY_PATCH_APPLIED = False
-_write = asyncio.selector_events._SelectorSocketTransport.write # type: ignore
+def ssl_monkey_patch() -> None:
+ """
+ Monky-patch for https://bugs.python.org/issue36709
+ This prevents console errors when outstanding HTTPS connections
+ still exist at the point of exiting.
-def _fixed_write(self, data: bytes) -> None: # type: ignore
- if not self._loop.is_closed():
- _write(self, data)
+ Clients which have been opened using a `with` block, or which have
+ had `close()` closed, will not exhibit this issue in the first place.
+ """
+ MonkeyPatch = asyncio.selector_events._SelectorSocketTransport # type: ignore
+ _write = MonkeyPatch.write
-asyncio.selector_events._SelectorSocketTransport.write = _fixed_write # type: ignore
+ def _fixed_write(self, data: bytes) -> None: # type: ignore
+ if not self._loop.is_closed():
+ _write(self, data)
+
+ MonkeyPatch.write = _fixed_write
class Reader(BaseReader):
self.semaphore.release()
-async def connect(
- hostname: str,
- port: int,
- ssl_context: typing.Optional[ssl.SSLContext] = None,
- timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-) -> typing.Tuple[Reader, Writer, Protocol]:
- try:
- stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
- asyncio.open_connection(hostname, port, ssl=ssl_context),
- timeout.connect_timeout,
- )
- except asyncio.TimeoutError:
- raise ConnectTimeout()
-
- ssl_object = stream_writer.get_extra_info("ssl_object")
- if ssl_object is None:
- ident = "http/1.1"
- else:
- ident = ssl_object.selected_alpn_protocol()
- if ident is None:
- ident = ssl_object.selected_npn_protocol()
-
- reader = Reader(stream_reader=stream_reader, timeout=timeout)
- writer = Writer(stream_writer=stream_writer, timeout=timeout)
- protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
-
- return (reader, writer, protocol)
+class AsyncioBackend(ConcurrencyBackend):
+ def __init__(self) -> None:
+ global SSL_MONKEY_PATCH_APPLIED
+
+ if not SSL_MONKEY_PATCH_APPLIED:
+ ssl_monkey_patch()
+ SSL_MONKEY_PATCH_APPLIED = True
+
+ async def connect(
+ self,
+ hostname: str,
+ port: int,
+ ssl_context: typing.Optional[ssl.SSLContext],
+ timeout: TimeoutConfig,
+ ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
+ try:
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_connection(hostname, port, ssl=ssl_context),
+ timeout.connect_timeout,
+ )
+ except asyncio.TimeoutError:
+ raise ConnectTimeout()
+
+ ssl_object = stream_writer.get_extra_info("ssl_object")
+ if ssl_object is None:
+ ident = "http/1.1"
+ else:
+ ident = ssl_object.selected_alpn_protocol()
+ if ident is None:
+ ident = ssl_object.selected_npn_protocol()
+
+ reader = Reader(stream_reader=stream_reader, timeout=timeout)
+ writer = Writer(stream_writer=stream_writer, timeout=timeout)
+ protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
+
+ return (reader, writer, protocol)
+
+ def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+ return PoolSemaphore(limits)
import h2.connection
import h11
-from ..concurrency import connect
+from ..concurrency import AsyncioBackend
from ..config import (
DEFAULT_SSL_CONFIG,
DEFAULT_TIMEOUT_CONFIG,
TimeoutConfig,
)
from ..exceptions import ConnectTimeout
-from ..interfaces import Dispatcher, Protocol
+from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
from ..models import Origin, Request, Response
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
origin: typing.Union[str, Origin],
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+ backend: ConcurrencyBackend = None,
release_func: typing.Optional[ReleaseCallback] = None,
):
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = ssl
self.timeout = timeout
+ self.backend = AsyncioBackend() if backend is None else backend
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
else:
on_release = functools.partial(self.release_func, self)
- reader, writer, protocol = await connect(host, port, ssl_context, timeout)
+ reader, writer, protocol = await self.backend.connect(
+ host, port, ssl_context, timeout
+ )
if protocol == Protocol.HTTP_2:
self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
else:
import collections.abc
import typing
-from ..concurrency import PoolSemaphore
+from ..concurrency import AsyncioBackend
from ..config import (
DEFAULT_CA_BUNDLE_PATH,
DEFAULT_POOL_LIMITS,
)
from ..decoders import ACCEPT_ENCODING
from ..exceptions import PoolTimeout
-from ..interfaces import Dispatcher
+from ..interfaces import ConcurrencyBackend, Dispatcher
from ..models import Origin, Request, Response
from .connection import HTTPConnection
ssl: SSLConfig = DEFAULT_SSL_CONFIG,
timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+ backend: ConcurrencyBackend = None,
):
self.ssl = ssl
self.timeout = timeout
self.pool_limits = pool_limits
self.is_closed = False
- self.max_connections = PoolSemaphore(pool_limits)
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()
+ self.backend = AsyncioBackend() if backend is None else backend
+ self.max_connections = self.backend.get_semaphore(pool_limits)
+
@property
def num_connections(self) -> int:
return len(self.keepalive_connections) + len(self.active_connections)
origin,
ssl=self.ssl,
timeout=self.timeout,
+ backend=self.backend,
release_func=self.release_connection,
)