]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Revert "Use Python 3.8 asyncio.Stream where possible (#369)" (#423)
authorJosep Cugat <jcugat@gmail.com>
Thu, 3 Oct 2019 08:18:10 +0000 (10:18 +0200)
committerTom Christie <tom@tomchristie.com>
Thu, 3 Oct 2019 08:18:10 +0000 (09:18 +0100)
This reverts commit 71cbde8ba4cab32dd96ccb91e9db20216587fae9.

httpx/concurrency/asyncio.py [moved from httpx/concurrency/asyncio/backend.py with 79% similarity]
httpx/concurrency/asyncio/__init__.py [deleted file]
httpx/concurrency/asyncio/compat.py [deleted file]
tests/test_concurrency.py

similarity index 79%
rename from httpx/concurrency/asyncio/backend.py
rename to httpx/concurrency/asyncio.py
index 00887409b584f2773e8510ca29e2d44ca7db9bf2..dfee42545688c7517f4b48a8a29d9af0b10662ce 100644 (file)
@@ -4,10 +4,9 @@ import ssl
 import typing
 from types import TracebackType
 
-from httpx.config import PoolLimits, TimeoutConfig
-from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-
-from ..base import (
+from ..config import PoolLimits, TimeoutConfig
+from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+from .base import (
     BaseBackgroundManager,
     BaseEvent,
     BasePoolSemaphore,
@@ -16,7 +15,6 @@ from ..base import (
     ConcurrencyBackend,
     TimeoutFlag,
 )
-from .compat import Stream, connect_compat
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -43,12 +41,18 @@ def ssl_monkey_patch() -> None:
 
 
 class TCPStream(BaseTCPStream):
-    def __init__(self, stream: Stream, timeout: TimeoutConfig):
-        self.stream = stream
+    def __init__(
+        self,
+        stream_reader: asyncio.StreamReader,
+        stream_writer: asyncio.StreamWriter,
+        timeout: TimeoutConfig,
+    ):
+        self.stream_reader = stream_reader
+        self.stream_writer = stream_writer
         self.timeout = timeout
 
     def get_http_version(self) -> str:
-        ssl_object = self.stream.get_extra_info("ssl_object")
+        ssl_object = self.stream_writer.get_extra_info("ssl_object")
 
         if ssl_object is None:
             return "HTTP/1.1"
@@ -68,7 +72,7 @@ class TCPStream(BaseTCPStream):
             should_raise = flag is None or flag.raise_on_read_timeout
             read_timeout = timeout.read_timeout if should_raise else 0.01
             try:
-                data = await asyncio.wait_for(self.stream.read(n), read_timeout)
+                data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
                 break
             except asyncio.TimeoutError:
                 if should_raise:
@@ -83,7 +87,7 @@ class TCPStream(BaseTCPStream):
         return data
 
     def write_no_block(self, data: bytes) -> None:
-        self.stream.write(data)  # pragma: nocover
+        self.stream_writer.write(data)  # pragma: nocover
 
     async def write(
         self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
@@ -94,11 +98,11 @@ class TCPStream(BaseTCPStream):
         if timeout is None:
             timeout = self.timeout
 
-        self.stream.write(data)
+        self.stream_writer.write(data)
         while True:
             try:
                 await asyncio.wait_for(  # type: ignore
-                    self.stream.drain(), timeout.write_timeout
+                    self.stream_writer.drain(), timeout.write_timeout
                 )
                 break
             except asyncio.TimeoutError:
@@ -124,12 +128,10 @@ class TCPStream(BaseTCPStream):
         # (For a solution that uses private asyncio APIs, see:
         # https://github.com/encode/httpx/pull/143#issuecomment-515202982)
 
-        return self.stream.at_eof()
+        return self.stream_reader.at_eof()
 
     async def close(self) -> None:
-        # FIXME: We should await on this call, but need a workaround for this first:
-        # https://github.com/aio-libs/aiohttp/issues/3535
-        self.stream.close()
+        self.stream_writer.close()
 
 
 class PoolSemaphore(BasePoolSemaphore):
@@ -188,13 +190,16 @@ class AsyncioBackend(ConcurrencyBackend):
         timeout: TimeoutConfig,
     ) -> BaseTCPStream:
         try:
-            stream = await asyncio.wait_for(  # type: ignore
-                connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
+            stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
+                asyncio.open_connection(hostname, port, ssl=ssl_context),
+                timeout.connect_timeout,
             )
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-        return TCPStream(stream=stream, timeout=timeout)
+        return TCPStream(
+            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+        )
 
     async def start_tls(
         self,
@@ -203,13 +208,35 @@ class AsyncioBackend(ConcurrencyBackend):
         ssl_context: ssl.SSLContext,
         timeout: TimeoutConfig,
     ) -> BaseTCPStream:
+
+        loop = self.loop
+        if not hasattr(loop, "start_tls"):  # pragma: no cover
+            raise NotImplementedError(
+                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+            )
+
         assert isinstance(stream, TCPStream)
 
-        await asyncio.wait_for(
-            stream.stream.start_tls(ssl_context, server_hostname=hostname),
+        stream_reader = asyncio.StreamReader()
+        protocol = asyncio.StreamReaderProtocol(stream_reader)
+        transport = stream.stream_writer.transport
+
+        loop_start_tls = loop.start_tls  # type: ignore
+        transport = await asyncio.wait_for(
+            loop_start_tls(
+                transport=transport,
+                protocol=protocol,
+                sslcontext=ssl_context,
+                server_hostname=hostname,
+            ),
             timeout=timeout.connect_timeout,
         )
 
+        stream_reader.set_transport(transport)
+        stream.stream_reader = stream_reader
+        stream.stream_writer = asyncio.StreamWriter(
+            transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+        )
         return stream
 
     async def run_in_threadpool(
diff --git a/httpx/concurrency/asyncio/__init__.py b/httpx/concurrency/asyncio/__init__.py
deleted file mode 100644 (file)
index 3543542..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
-
-__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
diff --git a/httpx/concurrency/asyncio/compat.py b/httpx/concurrency/asyncio/compat.py
deleted file mode 100644 (file)
index d83b209..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-import asyncio
-import ssl
-import sys
-import typing
-
-if sys.version_info >= (3, 8):
-    from typing import Protocol
-else:
-    from typing_extensions import Protocol
-
-
-class Stream(Protocol):  # pragma: no cover
-    """Protocol defining just the methods we use from asyncio.Stream."""
-
-    def at_eof(self) -> bool:
-        ...
-
-    def close(self) -> typing.Awaitable[None]:
-        ...
-
-    async def drain(self) -> None:
-        ...
-
-    def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
-        ...
-
-    async def read(self, n: int = -1) -> bytes:
-        ...
-
-    async def start_tls(
-        self,
-        sslContext: ssl.SSLContext,
-        *,
-        server_hostname: typing.Optional[str] = None,
-        ssl_handshake_timeout: typing.Optional[float] = None,
-    ) -> None:
-        ...
-
-    def write(self, data: bytes) -> typing.Awaitable[None]:
-        ...
-
-
-async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
-    if sys.version_info >= (3, 8):
-        return await asyncio.connect(*args, **kwargs)
-    else:
-        reader, writer = await asyncio.open_connection(*args, **kwargs)
-        return StreamCompat(reader, writer)
-
-
-class StreamCompat:
-    """
-    Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
-    behave similarly to an asyncio.Stream.
-    """
-
-    def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
-        self.reader = reader
-        self.writer = writer
-
-    def at_eof(self) -> bool:
-        return self.reader.at_eof()
-
-    def close(self) -> typing.Awaitable[None]:
-        self.writer.close()
-        return _OptionalAwait(self.wait_closed)
-
-    async def drain(self) -> None:
-        await self.writer.drain()
-
-    def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
-        return self.writer.get_extra_info(name, default)
-
-    async def read(self, n: int = -1) -> bytes:
-        return await self.reader.read(n)
-
-    async def start_tls(
-        self,
-        sslContext: ssl.SSLContext,
-        *,
-        server_hostname: typing.Optional[str] = None,
-        ssl_handshake_timeout: typing.Optional[float] = None,
-    ) -> None:
-        if not sys.version_info >= (3, 7):  # pragma: no cover
-            raise NotImplementedError(
-                "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
-            )
-        else:
-            # This code is in an else branch to appease mypy on Python < 3.7
-
-            reader = asyncio.StreamReader()
-            protocol = asyncio.StreamReaderProtocol(reader)
-            transport = self.writer.transport
-
-            loop = asyncio.get_event_loop()
-            loop_start_tls = loop.start_tls  # type: ignore
-            tls_transport = await loop_start_tls(
-                transport=transport,
-                protocol=protocol,
-                sslcontext=sslContext,
-                server_hostname=server_hostname,
-                ssl_handshake_timeout=ssl_handshake_timeout,
-            )
-
-            reader.set_transport(tls_transport)
-            self.reader = reader
-            self.writer = asyncio.StreamWriter(
-                transport=tls_transport, protocol=protocol, reader=reader, loop=loop
-            )
-
-    def write(self, data: bytes) -> typing.Awaitable[None]:
-        self.writer.write(data)
-        return _OptionalAwait(self.drain)
-
-    async def wait_closed(self) -> None:
-        if sys.version_info >= (3, 7):
-            await self.writer.wait_closed()
-        # else not much we can do to wait for the connection to close
-
-
-# This code is copied from cPython 3.8 but with type annotations added:
-# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
-_T = typing.TypeVar("_T")
-
-
-class _OptionalAwait(typing.Generic[_T]):
-    # The class doesn't create a coroutine
-    # if not awaited
-    # It prevents "coroutine is never awaited" message
-
-    __slots___ = ("_method",)
-
-    def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
-        self._method = method
-
-    def __await__(self) -> typing.Generator[typing.Any, None, _T]:
-        return self._method().__await__()
index 03f8c0aaea8dd00bf4d3b103b330087a302b4f51..ab93b302829d2293eb35845fd77276d69811de95 100644 (file)
@@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server):
 
     try:
         assert stream.is_connection_dropped() is False
-        assert stream.stream.get_extra_info("cipher", default=None) is None
+        assert stream.stream_writer.get_extra_info("cipher", default=None) is None
 
         stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
         assert stream.is_connection_dropped() is False
-        assert stream.stream.get_extra_info("cipher", default=None) is not None
+        assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n")
         assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")