]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add ConcurrencyBackend.start_tls() (#263)
authorSeth Michael Larson <sethmichaellarson@gmail.com>
Sat, 24 Aug 2019 15:04:14 +0000 (10:04 -0500)
committerGitHub <noreply@github.com>
Sat, 24 Aug 2019 15:04:14 +0000 (10:04 -0500)
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/dispatch/asgi.py
httpx/dispatch/connection.py
httpx/dispatch/connection_pool.py
httpx/dispatch/threaded.py
httpx/dispatch/wsgi.py
tests/test_concurrency.py [new file with mode: 0644]

index 9b9778326627991d85caeabdd8580aeefb45f556..8fc625f1c06b82074e86582ec03c9a2f188248b7 100644 (file)
@@ -18,8 +18,8 @@ from ..config import PoolLimits, TimeoutConfig
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import (
     BaseBackgroundManager,
-    BasePoolSemaphore,
     BaseEvent,
+    BasePoolSemaphore,
     BaseQueue,
     BaseStream,
     ConcurrencyBackend,
@@ -194,6 +194,44 @@ class AsyncioBackend(ConcurrencyBackend):
             stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
         )
 
+    async def start_tls(
+        self,
+        stream: BaseStream,
+        hostname: str,
+        ssl_context: ssl.SSLContext,
+        timeout: TimeoutConfig,
+    ) -> BaseStream:
+
+        loop = self.loop
+        if not hasattr(loop, "start_tls"):  # pragma: no cover
+            raise NotImplementedError(
+                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+            )
+
+        assert isinstance(stream, Stream)
+
+        stream_reader = asyncio.StreamReader()
+        protocol = asyncio.StreamReaderProtocol(stream_reader)
+        transport = stream.stream_writer.transport
+
+        loop_start_tls = loop.start_tls  # type: ignore
+        transport = await asyncio.wait_for(
+            loop_start_tls(
+                transport=transport,
+                protocol=protocol,
+                sslcontext=ssl_context,
+                server_hostname=hostname,
+            ),
+            timeout=timeout.connect_timeout,
+        )
+
+        stream_reader.set_transport(transport)
+        stream.stream_reader = stream_reader
+        stream.stream_writer = asyncio.StreamWriter(
+            transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+        )
+        return stream
+
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
     ) -> typing.Any:
index 9bfd54d4a48a2685cbd675fd3b43a8647ca7eb9b..bf2aed4f1f9e2577baf3daa505c8f41eb94cb30b 100644 (file)
@@ -116,6 +116,15 @@ class ConcurrencyBackend:
     ) -> BaseStream:
         raise NotImplementedError()  # pragma: no cover
 
+    async def start_tls(
+        self,
+        stream: BaseStream,
+        hostname: str,
+        ssl_context: ssl.SSLContext,
+        timeout: TimeoutConfig,
+    ) -> BaseStream:
+        raise NotImplementedError()  # pragma: no cover
+
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
index 6c1fc267da636da64a77b016b162f944be29eced..c56d757c71b1926bbba57ace43816e99e8d66223 100644 (file)
@@ -1,10 +1,10 @@
 import typing
 
-from .base import AsyncDispatcher
-from ..concurrency.base import ConcurrencyBackend
 from ..concurrency.asyncio import AsyncioBackend
+from ..concurrency.base import ConcurrencyBackend
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
 from ..models import AsyncRequest, AsyncResponse
+from .base import AsyncDispatcher
 
 
 class ASGIDispatch(AsyncDispatcher):
index 0e9819cb981cee7886f3b2bcc64b8774c2beb192..87c2a6489ef0d981822a5cf1357be752d73d84cf 100644 (file)
@@ -2,7 +2,6 @@ import functools
 import ssl
 import typing
 
-from .base import AsyncDispatcher
 from ..concurrency.asyncio import AsyncioBackend
 from ..concurrency.base import ConcurrencyBackend
 from ..config import (
@@ -16,6 +15,7 @@ from ..config import (
     VerifyTypes,
 )
 from ..models import AsyncRequest, AsyncResponse, Origin
+from .base import AsyncDispatcher
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 
index 0e14bf8354a3db119243d24d2d4a07be862731d8..eb990a9618ee4c351b1a32c13c2b2d5e1e5ba111 100644 (file)
@@ -1,6 +1,5 @@
 import typing
 
-from .base import AsyncDispatcher
 from ..concurrency.asyncio import AsyncioBackend
 from ..concurrency.base import ConcurrencyBackend
 from ..config import (
@@ -13,6 +12,7 @@ from ..config import (
     VerifyTypes,
 )
 from ..models import AsyncRequest, AsyncResponse, Origin
+from .base import AsyncDispatcher
 from .connection import HTTPConnection
 
 CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
index 8176608729493cb0d8fe7c1f77696bf8b0cda72e..7454a9e0ad1092641efe20355567a59c87493070 100644 (file)
@@ -1,4 +1,3 @@
-from .base import AsyncDispatcher, Dispatcher
 from ..concurrency.base import ConcurrencyBackend
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
 from ..models import (
@@ -11,6 +10,7 @@ from ..models import (
     Response,
     ResponseContent,
 )
+from .base import AsyncDispatcher, Dispatcher
 
 
 class ThreadedDispatcher(AsyncDispatcher):
index 0cbe1095e2c510bcabb64c456ee0c649710097cd..73a6fc1f5cdab54408284e024258ec3efe8d37a9 100644 (file)
@@ -1,9 +1,9 @@
 import io
 import typing
 
-from .base import Dispatcher
 from ..config import CertTypes, TimeoutTypes, VerifyTypes
 from ..models import Request, Response
+from .base import Dispatcher
 
 
 class WSGIDispatch(Dispatcher):
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644 (file)
index 0000000..870a592
--- /dev/null
@@ -0,0 +1,31 @@
+import sys
+
+import pytest
+
+from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
+
+
+@pytest.mark.xfail(
+    sys.version_info < (3, 7),
+    reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
+)
+@pytest.mark.asyncio
+async def test_start_tls_on_socket_stream(https_server):
+    """
+    See that the backend can make a connection without TLS then
+    start TLS on an existing connection.
+    """
+    backend = AsyncioBackend()
+    ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
+    timeout = TimeoutConfig(5)
+
+    stream = await backend.connect("127.0.0.1", 8001, None, timeout)
+    assert stream.is_connection_dropped() is False
+    assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+
+    stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
+    assert stream.is_connection_dropped() is False
+    assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+
+    await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+    assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")