From: Tom Christie Date: Tue, 14 May 2019 09:25:31 +0000 (+0100) Subject: Add ConcurrencyBackend interface X-Git-Tag: 0.3.0~15^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2694c48492077d93b7f4c2958e7dce0f2319b0ea;p=thirdparty%2Fhttpx.git Add ConcurrencyBackend interface --- diff --git a/httpcore/concurrency.py b/httpcore/concurrency.py index 78676659..3e63d858 100644 --- a/httpcore/concurrency.py +++ b/httpcore/concurrency.py @@ -14,29 +14,39 @@ import typing from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter, Protocol +from .interfaces import ( + BasePoolSemaphore, + BaseReader, + BaseWriter, + ConcurrencyBackend, + Protocol, +) OptionalTimeout = typing.Optional[TimeoutConfig] -# Monky-patch for https://bugs.python.org/issue36709 -# -# This prevents console errors when outstanding HTTPS connections -# still exist at the point of exiting. -# -# Clients which have been opened using a `with` block, or which have -# had `close()` closed, will not exhibit this issue in the first place. +SSL_MONKEY_PATCH_APPLIED = False -_write = asyncio.selector_events._SelectorSocketTransport.write # type: ignore +def ssl_monkey_patch() -> None: + """ + Monky-patch for https://bugs.python.org/issue36709 + This prevents console errors when outstanding HTTPS connections + still exist at the point of exiting. -def _fixed_write(self, data: bytes) -> None: # type: ignore - if not self._loop.is_closed(): - _write(self, data) + Clients which have been opened using a `with` block, or which have + had `close()` closed, will not exhibit this issue in the first place. + """ + MonkeyPatch = asyncio.selector_events._SelectorSocketTransport # type: ignore + _write = MonkeyPatch.write -asyncio.selector_events._SelectorSocketTransport.write = _fixed_write # type: ignore + def _fixed_write(self, data: bytes) -> None: # type: ignore + if not self._loop.is_closed(): + _write(self, data) + + MonkeyPatch.write = _fixed_write class Reader(BaseReader): @@ -118,30 +128,42 @@ class PoolSemaphore(BasePoolSemaphore): self.semaphore.release() -async def connect( - hostname: str, - port: int, - ssl_context: typing.Optional[ssl.SSLContext] = None, - timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, -) -> typing.Tuple[Reader, Writer, Protocol]: - try: - stream_reader, stream_writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, - ) - except asyncio.TimeoutError: - raise ConnectTimeout() - - ssl_object = stream_writer.get_extra_info("ssl_object") - if ssl_object is None: - ident = "http/1.1" - else: - ident = ssl_object.selected_alpn_protocol() - if ident is None: - ident = ssl_object.selected_npn_protocol() - - reader = Reader(stream_reader=stream_reader, timeout=timeout) - writer = Writer(stream_writer=stream_writer, timeout=timeout) - protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 - - return (reader, writer, protocol) +class AsyncioBackend(ConcurrencyBackend): + def __init__(self) -> None: + global SSL_MONKEY_PATCH_APPLIED + + if not SSL_MONKEY_PATCH_APPLIED: + ssl_monkey_patch() + SSL_MONKEY_PATCH_APPLIED = True + + async def connect( + self, + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: + try: + stream_reader, stream_writer = await asyncio.wait_for( # type: ignore + asyncio.open_connection(hostname, port, ssl=ssl_context), + timeout.connect_timeout, + ) + except asyncio.TimeoutError: + raise ConnectTimeout() + + ssl_object = stream_writer.get_extra_info("ssl_object") + if ssl_object is None: + ident = "http/1.1" + else: + ident = ssl_object.selected_alpn_protocol() + if ident is None: + ident = ssl_object.selected_npn_protocol() + + reader = Reader(stream_reader=stream_reader, timeout=timeout) + writer = Writer(stream_writer=stream_writer, timeout=timeout) + protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 + + return (reader, writer, protocol) + + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: + return PoolSemaphore(limits) diff --git a/httpcore/dispatch/connection.py b/httpcore/dispatch/connection.py index d0beea19..053a9980 100644 --- a/httpcore/dispatch/connection.py +++ b/httpcore/dispatch/connection.py @@ -4,7 +4,7 @@ import typing import h2.connection import h11 -from ..concurrency import connect +from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_SSL_CONFIG, DEFAULT_TIMEOUT_CONFIG, @@ -12,7 +12,7 @@ from ..config import ( TimeoutConfig, ) from ..exceptions import ConnectTimeout -from ..interfaces import Dispatcher, Protocol +from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol from ..models import Origin, Request, Response from .http2 import HTTP2Connection from .http11 import HTTP11Connection @@ -27,11 +27,13 @@ class HTTPConnection(Dispatcher): origin: typing.Union[str, Origin], ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, + backend: ConcurrencyBackend = None, release_func: typing.Optional[ReleaseCallback] = None, ): self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = ssl self.timeout = timeout + self.backend = AsyncioBackend() if backend is None else backend self.release_func = release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] @@ -75,7 +77,9 @@ class HTTPConnection(Dispatcher): else: on_release = functools.partial(self.release_func, self) - reader, writer, protocol = await connect(host, port, ssl_context, timeout) + reader, writer, protocol = await self.backend.connect( + host, port, ssl_context, timeout + ) if protocol == Protocol.HTTP_2: self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release) else: diff --git a/httpcore/dispatch/connection_pool.py b/httpcore/dispatch/connection_pool.py index 2489da95..e0fb642e 100644 --- a/httpcore/dispatch/connection_pool.py +++ b/httpcore/dispatch/connection_pool.py @@ -1,7 +1,7 @@ import collections.abc import typing -from ..concurrency import PoolSemaphore +from ..concurrency import AsyncioBackend from ..config import ( DEFAULT_CA_BUNDLE_PATH, DEFAULT_POOL_LIMITS, @@ -13,7 +13,7 @@ from ..config import ( ) from ..decoders import ACCEPT_ENCODING from ..exceptions import PoolTimeout -from ..interfaces import Dispatcher +from ..interfaces import ConcurrencyBackend, Dispatcher from ..models import Origin, Request, Response from .connection import HTTPConnection @@ -90,16 +90,19 @@ class ConnectionPool(Dispatcher): ssl: SSLConfig = DEFAULT_SSL_CONFIG, timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG, pool_limits: PoolLimits = DEFAULT_POOL_LIMITS, + backend: ConcurrencyBackend = None, ): self.ssl = ssl self.timeout = timeout self.pool_limits = pool_limits self.is_closed = False - self.max_connections = PoolSemaphore(pool_limits) self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() + self.backend = AsyncioBackend() if backend is None else backend + self.max_connections = self.backend.get_semaphore(pool_limits) + @property def num_connections(self) -> int: return len(self.keepalive_connections) + len(self.active_connections) @@ -133,6 +136,7 @@ class ConnectionPool(Dispatcher): origin, ssl=self.ssl, timeout=self.timeout, + backend=self.backend, release_func=self.release_connection, ) diff --git a/httpcore/interfaces.py b/httpcore/interfaces.py index 8025de3e..6baf941f 100644 --- a/httpcore/interfaces.py +++ b/httpcore/interfaces.py @@ -1,8 +1,9 @@ import enum +import ssl import typing from types import TracebackType -from .config import SSLConfig, TimeoutConfig +from .config import PoolLimits, SSLConfig, TimeoutConfig from .models import ( URL, Headers, @@ -117,3 +118,17 @@ class BasePoolSemaphore: def release(self) -> None: raise NotImplementedError() # pragma: no cover + + +class ConcurrencyBackend: + async def connect( + self, + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: + raise NotImplementedError() # pragma: no cover + + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: + raise NotImplementedError() # pragma: no cover diff --git a/tests/conftest.py b/tests/conftest.py index cc84fbb4..762c8326 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import asyncio -import os import pytest import trustme