--- /dev/null
+from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
+
+__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
import typing
from types import TracebackType
-from ..config import PoolLimits, TimeoutConfig
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
+from httpx.config import PoolLimits, TimeoutConfig
+from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+
+from ..base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
ConcurrencyBackend,
TimeoutFlag,
)
+from .compat import Stream, connect_compat
SSL_MONKEY_PATCH_APPLIED = False
class TCPStream(BaseTCPStream):
- def __init__(
- self,
- stream_reader: asyncio.StreamReader,
- stream_writer: asyncio.StreamWriter,
- timeout: TimeoutConfig,
- ):
- self.stream_reader = stream_reader
- self.stream_writer = stream_writer
+ def __init__(self, stream: Stream, timeout: TimeoutConfig):
+ self.stream = stream
self.timeout = timeout
def get_http_version(self) -> str:
- ssl_object = self.stream_writer.get_extra_info("ssl_object")
+ ssl_object = self.stream.get_extra_info("ssl_object")
if ssl_object is None:
return "HTTP/1.1"
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = timeout.read_timeout if should_raise else 0.01
try:
- data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
+ data = await asyncio.wait_for(self.stream.read(n), read_timeout)
break
except asyncio.TimeoutError:
if should_raise:
return data
def write_no_block(self, data: bytes) -> None:
- self.stream_writer.write(data) # pragma: nocover
+ self.stream.write(data) # pragma: nocover
async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
if timeout is None:
timeout = self.timeout
- self.stream_writer.write(data)
+ self.stream.write(data)
while True:
try:
await asyncio.wait_for( # type: ignore
- self.stream_writer.drain(), timeout.write_timeout
+ self.stream.drain(), timeout.write_timeout
)
break
except asyncio.TimeoutError:
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
- return self.stream_reader.at_eof()
+ return self.stream.at_eof()
async def close(self) -> None:
- self.stream_writer.close()
+ # FIXME: We should await on this call, but need a workaround for this first:
+ # https://github.com/aio-libs/aiohttp/issues/3535
+ self.stream.close()
class PoolSemaphore(BasePoolSemaphore):
timeout: TimeoutConfig,
) -> BaseTCPStream:
try:
- stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
- asyncio.open_connection(hostname, port, ssl=ssl_context),
- timeout.connect_timeout,
+ stream = await asyncio.wait_for( # type: ignore
+ connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
)
except asyncio.TimeoutError:
raise ConnectTimeout()
- return TCPStream(
- stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
- )
+ return TCPStream(stream=stream, timeout=timeout)
async def start_tls(
self,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
-
- loop = self.loop
- if not hasattr(loop, "start_tls"): # pragma: no cover
- raise NotImplementedError(
- "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
- )
-
assert isinstance(stream, TCPStream)
- stream_reader = asyncio.StreamReader()
- protocol = asyncio.StreamReaderProtocol(stream_reader)
- transport = stream.stream_writer.transport
-
- loop_start_tls = loop.start_tls # type: ignore
- transport = await asyncio.wait_for(
- loop_start_tls(
- transport=transport,
- protocol=protocol,
- sslcontext=ssl_context,
- server_hostname=hostname,
- ),
+ await asyncio.wait_for(
+ stream.stream.start_tls(ssl_context, server_hostname=hostname),
timeout=timeout.connect_timeout,
)
- stream_reader.set_transport(transport)
- stream.stream_reader = stream_reader
- stream.stream_writer = asyncio.StreamWriter(
- transport=transport, protocol=protocol, reader=stream_reader, loop=loop
- )
return stream
async def run_in_threadpool(
--- /dev/null
+import asyncio
+import ssl
+import sys
+import typing
+
+if sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+
+class Stream(Protocol): # pragma: no cover
+ """Protocol defining just the methods we use from asyncio.Stream."""
+
+ def at_eof(self) -> bool:
+ ...
+
+ def close(self) -> typing.Awaitable[None]:
+ ...
+
+ async def drain(self) -> None:
+ ...
+
+ def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
+ ...
+
+ async def read(self, n: int = -1) -> bytes:
+ ...
+
+ async def start_tls(
+ self,
+ sslContext: ssl.SSLContext,
+ *,
+ server_hostname: typing.Optional[str] = None,
+ ssl_handshake_timeout: typing.Optional[float] = None,
+ ) -> None:
+ ...
+
+ def write(self, data: bytes) -> typing.Awaitable[None]:
+ ...
+
+
+async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
+ if sys.version_info >= (3, 8):
+ return await asyncio.connect(*args, **kwargs)
+ else:
+ reader, writer = await asyncio.open_connection(*args, **kwargs)
+ return StreamCompat(reader, writer)
+
+
+class StreamCompat:
+ """
+ Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
+ behave similarly to an asyncio.Stream.
+ """
+
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+ self.reader = reader
+ self.writer = writer
+
+ def at_eof(self) -> bool:
+ return self.reader.at_eof()
+
+ def close(self) -> typing.Awaitable[None]:
+ self.writer.close()
+ return _OptionalAwait(self.wait_closed)
+
+ async def drain(self) -> None:
+ await self.writer.drain()
+
+ def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
+ return self.writer.get_extra_info(name, default)
+
+ async def read(self, n: int = -1) -> bytes:
+ return await self.reader.read(n)
+
+ async def start_tls(
+ self,
+ sslContext: ssl.SSLContext,
+ *,
+ server_hostname: typing.Optional[str] = None,
+ ssl_handshake_timeout: typing.Optional[float] = None,
+ ) -> None:
+ if not sys.version_info >= (3, 7): # pragma: no cover
+ raise NotImplementedError(
+ "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+ )
+ else:
+ # This code is in an else branch to appease mypy on Python < 3.7
+
+ reader = asyncio.StreamReader()
+ protocol = asyncio.StreamReaderProtocol(reader)
+ transport = self.writer.transport
+
+ loop = asyncio.get_event_loop()
+ loop_start_tls = loop.start_tls # type: ignore
+ tls_transport = await loop_start_tls(
+ transport=transport,
+ protocol=protocol,
+ sslcontext=sslContext,
+ server_hostname=server_hostname,
+ ssl_handshake_timeout=ssl_handshake_timeout,
+ )
+
+ reader.set_transport(tls_transport)
+ self.reader = reader
+ self.writer = asyncio.StreamWriter(
+ transport=tls_transport, protocol=protocol, reader=reader, loop=loop
+ )
+
+ def write(self, data: bytes) -> typing.Awaitable[None]:
+ self.writer.write(data)
+ return _OptionalAwait(self.drain)
+
+ async def wait_closed(self) -> None:
+ if sys.version_info >= (3, 7):
+ await self.writer.wait_closed()
+ # else not much we can do to wait for the connection to close
+
+
+# This code is copied from cPython 3.8 but with type annotations added:
+# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
+_T = typing.TypeVar("_T")
+
+
+class _OptionalAwait(typing.Generic[_T]):
+ # The class doesn't create a coroutine
+ # if not awaited
+ # It prevents "coroutine is never awaited" message
+
+ __slots___ = ("_method",)
+
+ def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
+ self._method = method
+
+ def __await__(self) -> typing.Generator[typing.Any, None, _T]:
+ return self._method().__await__()
try:
assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+ assert stream.stream.get_extra_info("cipher", default=None) is None
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
- assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+ assert stream.stream.get_extra_info("cipher", default=None) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")