import typing
from types import TracebackType
-from httpx.config import PoolLimits, TimeoutConfig
-from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-
-from ..base import (
+from ..config import PoolLimits, TimeoutConfig
+from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
+from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
ConcurrencyBackend,
TimeoutFlag,
)
-from .compat import Stream, connect_compat
SSL_MONKEY_PATCH_APPLIED = False
class TCPStream(BaseTCPStream):
- def __init__(self, stream: Stream, timeout: TimeoutConfig):
- self.stream = stream
+ def __init__(
+ self,
+ stream_reader: asyncio.StreamReader,
+ stream_writer: asyncio.StreamWriter,
+ timeout: TimeoutConfig,
+ ):
+ self.stream_reader = stream_reader
+ self.stream_writer = stream_writer
self.timeout = timeout
def get_http_version(self) -> str:
- ssl_object = self.stream.get_extra_info("ssl_object")
+ ssl_object = self.stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
return "HTTP/1.1"
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = timeout.read_timeout if should_raise else 0.01
try:
- data = await asyncio.wait_for(self.stream.read(n), read_timeout)
+ data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
break
except asyncio.TimeoutError:
if should_raise:
return data
def write_no_block(self, data: bytes) -> None:
- self.stream.write(data) # pragma: nocover
+ self.stream_writer.write(data) # pragma: nocover
async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
if timeout is None:
timeout = self.timeout
- self.stream.write(data)
+ self.stream_writer.write(data)
while True:
try:
await asyncio.wait_for( # type: ignore
- self.stream.drain(), timeout.write_timeout
+ self.stream_writer.drain(), timeout.write_timeout
)
break
except asyncio.TimeoutError:
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)
- return self.stream.at_eof()
+ return self.stream_reader.at_eof()
async def close(self) -> None:
- # FIXME: We should await on this call, but need a workaround for this first:
- # https://github.com/aio-libs/aiohttp/issues/3535
- self.stream.close()
+ self.stream_writer.close()
class PoolSemaphore(BasePoolSemaphore):
timeout: TimeoutConfig,
) -> BaseTCPStream:
try:
- stream = await asyncio.wait_for( # type: ignore
- connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
+ stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
+ asyncio.open_connection(hostname, port, ssl=ssl_context),
+ timeout.connect_timeout,
)
except asyncio.TimeoutError:
raise ConnectTimeout()
- return TCPStream(stream=stream, timeout=timeout)
+ return TCPStream(
+ stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+ )
async def start_tls(
self,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
+
+ loop = self.loop
+ if not hasattr(loop, "start_tls"): # pragma: no cover
+ raise NotImplementedError(
+ "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+ )
+
assert isinstance(stream, TCPStream)
- await asyncio.wait_for(
- stream.stream.start_tls(ssl_context, server_hostname=hostname),
+ stream_reader = asyncio.StreamReader()
+ protocol = asyncio.StreamReaderProtocol(stream_reader)
+ transport = stream.stream_writer.transport
+
+ loop_start_tls = loop.start_tls # type: ignore
+ transport = await asyncio.wait_for(
+ loop_start_tls(
+ transport=transport,
+ protocol=protocol,
+ sslcontext=ssl_context,
+ server_hostname=hostname,
+ ),
timeout=timeout.connect_timeout,
)
+ stream_reader.set_transport(transport)
+ stream.stream_reader = stream_reader
+ stream.stream_writer = asyncio.StreamWriter(
+ transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+ )
return stream
async def run_in_threadpool(
+++ /dev/null
-from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream
-
-__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
+++ /dev/null
-import asyncio
-import ssl
-import sys
-import typing
-
-if sys.version_info >= (3, 8):
- from typing import Protocol
-else:
- from typing_extensions import Protocol
-
-
-class Stream(Protocol): # pragma: no cover
- """Protocol defining just the methods we use from asyncio.Stream."""
-
- def at_eof(self) -> bool:
- ...
-
- def close(self) -> typing.Awaitable[None]:
- ...
-
- async def drain(self) -> None:
- ...
-
- def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
- ...
-
- async def read(self, n: int = -1) -> bytes:
- ...
-
- async def start_tls(
- self,
- sslContext: ssl.SSLContext,
- *,
- server_hostname: typing.Optional[str] = None,
- ssl_handshake_timeout: typing.Optional[float] = None,
- ) -> None:
- ...
-
- def write(self, data: bytes) -> typing.Awaitable[None]:
- ...
-
-
-async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
- if sys.version_info >= (3, 8):
- return await asyncio.connect(*args, **kwargs)
- else:
- reader, writer = await asyncio.open_connection(*args, **kwargs)
- return StreamCompat(reader, writer)
-
-
-class StreamCompat:
- """
- Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
- behave similarly to an asyncio.Stream.
- """
-
- def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
- self.reader = reader
- self.writer = writer
-
- def at_eof(self) -> bool:
- return self.reader.at_eof()
-
- def close(self) -> typing.Awaitable[None]:
- self.writer.close()
- return _OptionalAwait(self.wait_closed)
-
- async def drain(self) -> None:
- await self.writer.drain()
-
- def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
- return self.writer.get_extra_info(name, default)
-
- async def read(self, n: int = -1) -> bytes:
- return await self.reader.read(n)
-
- async def start_tls(
- self,
- sslContext: ssl.SSLContext,
- *,
- server_hostname: typing.Optional[str] = None,
- ssl_handshake_timeout: typing.Optional[float] = None,
- ) -> None:
- if not sys.version_info >= (3, 7): # pragma: no cover
- raise NotImplementedError(
- "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
- )
- else:
- # This code is in an else branch to appease mypy on Python < 3.7
-
- reader = asyncio.StreamReader()
- protocol = asyncio.StreamReaderProtocol(reader)
- transport = self.writer.transport
-
- loop = asyncio.get_event_loop()
- loop_start_tls = loop.start_tls # type: ignore
- tls_transport = await loop_start_tls(
- transport=transport,
- protocol=protocol,
- sslcontext=sslContext,
- server_hostname=server_hostname,
- ssl_handshake_timeout=ssl_handshake_timeout,
- )
-
- reader.set_transport(tls_transport)
- self.reader = reader
- self.writer = asyncio.StreamWriter(
- transport=tls_transport, protocol=protocol, reader=reader, loop=loop
- )
-
- def write(self, data: bytes) -> typing.Awaitable[None]:
- self.writer.write(data)
- return _OptionalAwait(self.drain)
-
- async def wait_closed(self) -> None:
- if sys.version_info >= (3, 7):
- await self.writer.wait_closed()
- # else not much we can do to wait for the connection to close
-
-
-# This code is copied from cPython 3.8 but with type annotations added:
-# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
-_T = typing.TypeVar("_T")
-
-
-class _OptionalAwait(typing.Generic[_T]):
- # The class doesn't create a coroutine
- # if not awaited
- # It prevents "coroutine is never awaited" message
-
- __slots___ = ("_method",)
-
- def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
- self._method = method
-
- def __await__(self) -> typing.Generator[typing.Any, None, _T]:
- return self._method().__await__()
try:
assert stream.is_connection_dropped() is False
- assert stream.stream.get_extra_info("cipher", default=None) is None
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is None
stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
- assert stream.stream.get_extra_info("cipher", default=None) is not None
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")