]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add ConcurrencyBackend interface
authorTom Christie <tom@tomchristie.com>
Tue, 14 May 2019 09:25:31 +0000 (10:25 +0100)
committerTom Christie <tom@tomchristie.com>
Tue, 14 May 2019 09:25:31 +0000 (10:25 +0100)
httpcore/concurrency.py
httpcore/dispatch/connection.py
httpcore/dispatch/connection_pool.py
httpcore/interfaces.py
tests/conftest.py

index 78676659748d323c1dee0f8fc74d405f501ec6fe..3e63d858baaead634f32cb7b8e82c198ba4bdb86 100644 (file)
@@ -14,29 +14,39 @@ import typing
 
 from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
 from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .interfaces import BasePoolSemaphore, BaseReader, BaseWriter, Protocol
+from .interfaces import (
+    BasePoolSemaphore,
+    BaseReader,
+    BaseWriter,
+    ConcurrencyBackend,
+    Protocol,
+)
 
 OptionalTimeout = typing.Optional[TimeoutConfig]
 
 
-# Monky-patch for https://bugs.python.org/issue36709
-#
-# This prevents console errors when outstanding HTTPS connections
-# still exist at the point of exiting.
-#
-# Clients which have been opened using a `with` block, or which have
-# had `close()` closed, will not exhibit this issue in the first place.
+SSL_MONKEY_PATCH_APPLIED = False
 
 
-_write = asyncio.selector_events._SelectorSocketTransport.write  # type: ignore
+def ssl_monkey_patch() -> None:
+    """
+    Monky-patch for https://bugs.python.org/issue36709
 
+    This prevents console errors when outstanding HTTPS connections
+    still exist at the point of exiting.
 
-def _fixed_write(self, data: bytes) -> None:  # type: ignore
-    if not self._loop.is_closed():
-        _write(self, data)
+    Clients which have been opened using a `with` block, or which have
+    had `close()` closed, will not exhibit this issue in the first place.
+    """
+    MonkeyPatch = asyncio.selector_events._SelectorSocketTransport  # type: ignore
 
+    _write = MonkeyPatch.write
 
-asyncio.selector_events._SelectorSocketTransport.write = _fixed_write  # type: ignore
+    def _fixed_write(self, data: bytes) -> None:  # type: ignore
+        if not self._loop.is_closed():
+            _write(self, data)
+
+    MonkeyPatch.write = _fixed_write
 
 
 class Reader(BaseReader):
@@ -118,30 +128,42 @@ class PoolSemaphore(BasePoolSemaphore):
         self.semaphore.release()
 
 
-async def connect(
-    hostname: str,
-    port: int,
-    ssl_context: typing.Optional[ssl.SSLContext] = None,
-    timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
-) -> typing.Tuple[Reader, Writer, Protocol]:
-    try:
-        stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
-            asyncio.open_connection(hostname, port, ssl=ssl_context),
-            timeout.connect_timeout,
-        )
-    except asyncio.TimeoutError:
-        raise ConnectTimeout()
-
-    ssl_object = stream_writer.get_extra_info("ssl_object")
-    if ssl_object is None:
-        ident = "http/1.1"
-    else:
-        ident = ssl_object.selected_alpn_protocol()
-        if ident is None:
-            ident = ssl_object.selected_npn_protocol()
-
-    reader = Reader(stream_reader=stream_reader, timeout=timeout)
-    writer = Writer(stream_writer=stream_writer, timeout=timeout)
-    protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
-
-    return (reader, writer, protocol)
+class AsyncioBackend(ConcurrencyBackend):
+    def __init__(self) -> None:
+        global SSL_MONKEY_PATCH_APPLIED
+
+        if not SSL_MONKEY_PATCH_APPLIED:
+            ssl_monkey_patch()
+        SSL_MONKEY_PATCH_APPLIED = True
+
+    async def connect(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
+        try:
+            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_connection(hostname, port, ssl=ssl_context),
+                timeout.connect_timeout,
+            )
+        except asyncio.TimeoutError:
+            raise ConnectTimeout()
+
+        ssl_object = stream_writer.get_extra_info("ssl_object")
+        if ssl_object is None:
+            ident = "http/1.1"
+        else:
+            ident = ssl_object.selected_alpn_protocol()
+            if ident is None:
+                ident = ssl_object.selected_npn_protocol()
+
+        reader = Reader(stream_reader=stream_reader, timeout=timeout)
+        writer = Writer(stream_writer=stream_writer, timeout=timeout)
+        protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11
+
+        return (reader, writer, protocol)
+
+    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+        return PoolSemaphore(limits)
index d0beea19f5522f9c35577ccef67a924c473471a7..053a998081b6fe4d88c6b4fcae2af2d475948218 100644 (file)
@@ -4,7 +4,7 @@ import typing
 import h2.connection
 import h11
 
-from ..concurrency import connect
+from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_SSL_CONFIG,
     DEFAULT_TIMEOUT_CONFIG,
@@ -12,7 +12,7 @@ from ..config import (
     TimeoutConfig,
 )
 from ..exceptions import ConnectTimeout
-from ..interfaces import Dispatcher, Protocol
+from ..interfaces import ConcurrencyBackend, Dispatcher, Protocol
 from ..models import Origin, Request, Response
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
@@ -27,11 +27,13 @@ class HTTPConnection(Dispatcher):
         origin: typing.Union[str, Origin],
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
+        backend: ConcurrencyBackend = None,
         release_func: typing.Optional[ReleaseCallback] = None,
     ):
         self.origin = Origin(origin) if isinstance(origin, str) else origin
         self.ssl = ssl
         self.timeout = timeout
+        self.backend = AsyncioBackend() if backend is None else backend
         self.release_func = release_func
         self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
         self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
@@ -75,7 +77,9 @@ class HTTPConnection(Dispatcher):
         else:
             on_release = functools.partial(self.release_func, self)
 
-        reader, writer, protocol = await connect(host, port, ssl_context, timeout)
+        reader, writer, protocol = await self.backend.connect(
+            host, port, ssl_context, timeout
+        )
         if protocol == Protocol.HTTP_2:
             self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
         else:
index 2489da9509e62aedffa55de9bb776d4e64fb9f98..e0fb642e35dd0566b11b0da2dd47004a9183aa30 100644 (file)
@@ -1,7 +1,7 @@
 import collections.abc
 import typing
 
-from ..concurrency import PoolSemaphore
+from ..concurrency import AsyncioBackend
 from ..config import (
     DEFAULT_CA_BUNDLE_PATH,
     DEFAULT_POOL_LIMITS,
@@ -13,7 +13,7 @@ from ..config import (
 )
 from ..decoders import ACCEPT_ENCODING
 from ..exceptions import PoolTimeout
-from ..interfaces import Dispatcher
+from ..interfaces import ConcurrencyBackend, Dispatcher
 from ..models import Origin, Request, Response
 from .connection import HTTPConnection
 
@@ -90,16 +90,19 @@ class ConnectionPool(Dispatcher):
         ssl: SSLConfig = DEFAULT_SSL_CONFIG,
         timeout: TimeoutConfig = DEFAULT_TIMEOUT_CONFIG,
         pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
+        backend: ConcurrencyBackend = None,
     ):
         self.ssl = ssl
         self.timeout = timeout
         self.pool_limits = pool_limits
         self.is_closed = False
 
-        self.max_connections = PoolSemaphore(pool_limits)
         self.keepalive_connections = ConnectionStore()
         self.active_connections = ConnectionStore()
 
+        self.backend = AsyncioBackend() if backend is None else backend
+        self.max_connections = self.backend.get_semaphore(pool_limits)
+
     @property
     def num_connections(self) -> int:
         return len(self.keepalive_connections) + len(self.active_connections)
@@ -133,6 +136,7 @@ class ConnectionPool(Dispatcher):
                 origin,
                 ssl=self.ssl,
                 timeout=self.timeout,
+                backend=self.backend,
                 release_func=self.release_connection,
             )
 
index 8025de3e1bcb0fe1b1a2861b66731ecf96e133c2..6baf941f51a8f9eb2e08e66638a55c1cd5ca39df 100644 (file)
@@ -1,8 +1,9 @@
 import enum
+import ssl
 import typing
 from types import TracebackType
 
-from .config import SSLConfig, TimeoutConfig
+from .config import PoolLimits, SSLConfig, TimeoutConfig
 from .models import (
     URL,
     Headers,
@@ -117,3 +118,17 @@ class BasePoolSemaphore:
 
     def release(self) -> None:
         raise NotImplementedError()  # pragma: no cover
+
+
+class ConcurrencyBackend:
+    async def connect(
+        self,
+        hostname: str,
+        port: int,
+        ssl_context: typing.Optional[ssl.SSLContext],
+        timeout: TimeoutConfig,
+    ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]:
+        raise NotImplementedError()  # pragma: no cover
+
+    def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
+        raise NotImplementedError()  # pragma: no cover
index cc84fbb4f7639a118b4a38d5a4dad6b0bebbb974..762c832637dd4d327bc393ba6d0e96497b3a7ef7 100644 (file)
@@ -1,5 +1,4 @@
 import asyncio
-import os
 
 import pytest
 import trustme