]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Use Python 3.8 asyncio.Stream where possible (#369)
authorJamie Hewland <jhewland@gmail.com>
Sat, 28 Sep 2019 17:23:14 +0000 (19:23 +0200)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Sat, 28 Sep 2019 17:23:14 +0000 (12:23 -0500)
httpx/concurrency/asyncio/__init__.py [new file with mode: 0644]
httpx/concurrency/asyncio/backend.py [moved from httpx/concurrency/asyncio.py with 79% similarity]
httpx/concurrency/asyncio/compat.py [new file with mode: 0644]
tests/test_concurrency.py

diff --git a/httpx/concurrency/asyncio/__init__.py b/httpx/concurrency/asyncio/__init__.py
new file mode 100644 (file)
index 0000000..3543542
--- /dev/null
@@ -0,0 +1,3 @@
+from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
+
+__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
similarity index 79%
rename from httpx/concurrency/asyncio.py
rename to httpx/concurrency/asyncio/backend.py
index 5d3501447dad3600baf73db1410739593ee02dee..51082ad93ba3725b3bae55737da4d4f316d2546d 100644 (file)
@@ -4,9 +4,10 @@ import ssl
 import typing
 from types import TracebackType
 
-from ..config import PoolLimits, TimeoutConfig
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
+from httpx.config import PoolLimits, TimeoutConfig
+from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+
+from ..base import (
     BaseBackgroundManager,
     BaseEvent,
     BasePoolSemaphore,
@@ -15,6 +16,7 @@ from .base import (
     ConcurrencyBackend,
     TimeoutFlag,
 )
+from .compat import Stream, connect_compat
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -41,18 +43,12 @@ def ssl_monkey_patch() -> None:
 
 
 class TCPStream(BaseTCPStream):
-    def __init__(
-        self,
-        stream_reader: asyncio.StreamReader,
-        stream_writer: asyncio.StreamWriter,
-        timeout: TimeoutConfig,
-    ):
-        self.stream_reader = stream_reader
-        self.stream_writer = stream_writer
+    def __init__(self, stream: Stream, timeout: TimeoutConfig):
+        self.stream = stream
         self.timeout = timeout
 
     def get_http_version(self) -> str:
-        ssl_object = self.stream_writer.get_extra_info("ssl_object")
+        ssl_object = self.stream.get_extra_info("ssl_object")
 
         if ssl_object is None:
             return "HTTP/1.1"
@@ -76,7 +72,7 @@ class TCPStream(BaseTCPStream):
             should_raise = flag is None or flag.raise_on_read_timeout
             read_timeout = timeout.read_timeout if should_raise else 0.01
             try:
-                data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
+                data = await asyncio.wait_for(self.stream.read(n), read_timeout)
                 break
             except asyncio.TimeoutError:
                 if should_raise:
@@ -91,7 +87,7 @@ class TCPStream(BaseTCPStream):
         return data
 
     def write_no_block(self, data: bytes) -> None:
-        self.stream_writer.write(data)  # pragma: nocover
+        self.stream.write(data)  # pragma: nocover
 
     async def write(
         self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
@@ -102,11 +98,11 @@ class TCPStream(BaseTCPStream):
         if timeout is None:
             timeout = self.timeout
 
-        self.stream_writer.write(data)
+        self.stream.write(data)
         while True:
             try:
                 await asyncio.wait_for(  # type: ignore
-                    self.stream_writer.drain(), timeout.write_timeout
+                    self.stream.drain(), timeout.write_timeout
                 )
                 break
             except asyncio.TimeoutError:
@@ -132,10 +128,12 @@ class TCPStream(BaseTCPStream):
         # (For a solution that uses private asyncio APIs, see:
         # https://github.com/encode/httpx/pull/143#issuecomment-515202982)
 
-        return self.stream_reader.at_eof()
+        return self.stream.at_eof()
 
     async def close(self) -> None:
-        self.stream_writer.close()
+        # FIXME: We should await on this call, but need a workaround for this first:
+        # https://github.com/aio-libs/aiohttp/issues/3535
+        self.stream.close()
 
 
 class PoolSemaphore(BasePoolSemaphore):
@@ -194,16 +192,13 @@ class AsyncioBackend(ConcurrencyBackend):
         timeout: TimeoutConfig,
     ) -> BaseTCPStream:
         try:
-            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
-                asyncio.open_connection(hostname, port, ssl=ssl_context),
-                timeout.connect_timeout,
+            stream = await asyncio.wait_for(  # type: ignore
+                connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
             )
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-        return TCPStream(
-            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
-        )
+        return TCPStream(stream=stream, timeout=timeout)
 
     async def start_tls(
         self,
@@ -212,35 +207,13 @@ class AsyncioBackend(ConcurrencyBackend):
         ssl_context: ssl.SSLContext,
         timeout: TimeoutConfig,
     ) -> BaseTCPStream:
-
-        loop = self.loop
-        if not hasattr(loop, "start_tls"):  # pragma: no cover
-            raise NotImplementedError(
-                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
-            )
-
         assert isinstance(stream, TCPStream)
 
-        stream_reader = asyncio.StreamReader()
-        protocol = asyncio.StreamReaderProtocol(stream_reader)
-        transport = stream.stream_writer.transport
-
-        loop_start_tls = loop.start_tls  # type: ignore
-        transport = await asyncio.wait_for(
-            loop_start_tls(
-                transport=transport,
-                protocol=protocol,
-                sslcontext=ssl_context,
-                server_hostname=hostname,
-            ),
+        await asyncio.wait_for(
+            stream.stream.start_tls(ssl_context, server_hostname=hostname),
             timeout=timeout.connect_timeout,
         )
 
-        stream_reader.set_transport(transport)
-        stream.stream_reader = stream_reader
-        stream.stream_writer = asyncio.StreamWriter(
-            transport=transport, protocol=protocol, reader=stream_reader, loop=loop
-        )
         return stream
 
     async def run_in_threadpool(
diff --git a/httpx/concurrency/asyncio/compat.py b/httpx/concurrency/asyncio/compat.py
new file mode 100644 (file)
index 0000000..d83b209
--- /dev/null
@@ -0,0 +1,137 @@
+import asyncio
+import ssl
+import sys
+import typing
+
+if sys.version_info >= (3, 8):
+    from typing import Protocol
+else:
+    from typing_extensions import Protocol
+
+
+class Stream(Protocol):  # pragma: no cover
+    """Protocol defining just the methods we use from asyncio.Stream."""
+
+    def at_eof(self) -> bool:
+        ...
+
+    def close(self) -> typing.Awaitable[None]:
+        ...
+
+    async def drain(self) -> None:
+        ...
+
+    def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
+        ...
+
+    async def read(self, n: int = -1) -> bytes:
+        ...
+
+    async def start_tls(
+        self,
+        sslContext: ssl.SSLContext,
+        *,
+        server_hostname: typing.Optional[str] = None,
+        ssl_handshake_timeout: typing.Optional[float] = None,
+    ) -> None:
+        ...
+
+    def write(self, data: bytes) -> typing.Awaitable[None]:
+        ...
+
+
+async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
+    if sys.version_info >= (3, 8):
+        return await asyncio.connect(*args, **kwargs)
+    else:
+        reader, writer = await asyncio.open_connection(*args, **kwargs)
+        return StreamCompat(reader, writer)
+
+
+class StreamCompat:
+    """
+    Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
+    behave similarly to an asyncio.Stream.
+    """
+
+    def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        self.reader = reader
+        self.writer = writer
+
+    def at_eof(self) -> bool:
+        return self.reader.at_eof()
+
+    def close(self) -> typing.Awaitable[None]:
+        self.writer.close()
+        return _OptionalAwait(self.wait_closed)
+
+    async def drain(self) -> None:
+        await self.writer.drain()
+
+    def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
+        return self.writer.get_extra_info(name, default)
+
+    async def read(self, n: int = -1) -> bytes:
+        return await self.reader.read(n)
+
+    async def start_tls(
+        self,
+        sslContext: ssl.SSLContext,
+        *,
+        server_hostname: typing.Optional[str] = None,
+        ssl_handshake_timeout: typing.Optional[float] = None,
+    ) -> None:
+        if not sys.version_info >= (3, 7):  # pragma: no cover
+            raise NotImplementedError(
+                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+            )
+        else:
+            # This code is in an else branch to appease mypy on Python < 3.7
+
+            reader = asyncio.StreamReader()
+            protocol = asyncio.StreamReaderProtocol(reader)
+            transport = self.writer.transport
+
+            loop = asyncio.get_event_loop()
+            loop_start_tls = loop.start_tls  # type: ignore
+            tls_transport = await loop_start_tls(
+                transport=transport,
+                protocol=protocol,
+                sslcontext=sslContext,
+                server_hostname=server_hostname,
+                ssl_handshake_timeout=ssl_handshake_timeout,
+            )
+
+            reader.set_transport(tls_transport)
+            self.reader = reader
+            self.writer = asyncio.StreamWriter(
+                transport=tls_transport, protocol=protocol, reader=reader, loop=loop
+            )
+
+    def write(self, data: bytes) -> typing.Awaitable[None]:
+        self.writer.write(data)
+        return _OptionalAwait(self.drain)
+
+    async def wait_closed(self) -> None:
+        if sys.version_info >= (3, 7):
+            await self.writer.wait_closed()
+        # else not much we can do to wait for the connection to close
+
+
+# This code is copied from cPython 3.8 but with type annotations added:
+# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
+_T = typing.TypeVar("_T")
+
+
+class _OptionalAwait(typing.Generic[_T]):
+    # The class doesn't create a coroutine
+    # if not awaited
+    # It prevents "coroutine is never awaited" message
+
+    __slots___ = ("_method",)
+
+    def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
+        self._method = method
+
+    def __await__(self) -> typing.Generator[typing.Any, None, _T]:
+        return self._method().__await__()
index ab93b302829d2293eb35845fd77276d69811de95..03f8c0aaea8dd00bf4d3b103b330087a302b4f51 100644 (file)
@@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server):
 
     try:
         assert stream.is_connection_dropped() is False
-        assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+        assert stream.stream.get_extra_info("cipher", default=None) is None
 
         stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+        assert stream.stream.get_extra_info("cipher", default=None) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n")
         assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")