from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
- BasePoolSemaphore,
BaseEvent,
+ BasePoolSemaphore,
BaseQueue,
BaseStream,
ConcurrencyBackend,
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
+ async def start_tls(
+ self,
+ stream: BaseStream,
+ hostname: str,
+ ssl_context: ssl.SSLContext,
+ timeout: TimeoutConfig,
+ ) -> BaseStream:
+
+ loop = self.loop
+ if not hasattr(loop, "start_tls"): # pragma: no cover
+ raise NotImplementedError(
+ "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
+ )
+
+ assert isinstance(stream, Stream)
+
+ stream_reader = asyncio.StreamReader()
+ protocol = asyncio.StreamReaderProtocol(stream_reader)
+ transport = stream.stream_writer.transport
+
+ loop_start_tls = loop.start_tls # type: ignore
+ transport = await asyncio.wait_for(
+ loop_start_tls(
+ transport=transport,
+ protocol=protocol,
+ sslcontext=ssl_context,
+ server_hostname=hostname,
+ ),
+ timeout=timeout.connect_timeout,
+ )
+
+ stream_reader.set_transport(transport)
+ stream.stream_reader = stream_reader
+ stream.stream_writer = asyncio.StreamWriter(
+ transport=transport, protocol=protocol, reader=stream_reader, loop=loop
+ )
+ return stream
+
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
) -> BaseStream:
raise NotImplementedError() # pragma: no cover
+ async def start_tls(
+ self,
+ stream: BaseStream,
+ hostname: str,
+ ssl_context: ssl.SSLContext,
+ timeout: TimeoutConfig,
+ ) -> BaseStream:
+ raise NotImplementedError() # pragma: no cover
+
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover
import typing
-from .base import AsyncDispatcher
-from ..concurrency.base import ConcurrencyBackend
from ..concurrency.asyncio import AsyncioBackend
+from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import AsyncRequest, AsyncResponse
+from .base import AsyncDispatcher
class ASGIDispatch(AsyncDispatcher):
import ssl
import typing
-from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
+from .base import AsyncDispatcher
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
import typing
-from .base import AsyncDispatcher
from ..concurrency.asyncio import AsyncioBackend
from ..concurrency.base import ConcurrencyBackend
from ..config import (
VerifyTypes,
)
from ..models import AsyncRequest, AsyncResponse, Origin
+from .base import AsyncDispatcher
from .connection import HTTPConnection
CONNECTIONS_DICT = typing.Dict[Origin, typing.List[HTTPConnection]]
-from .base import AsyncDispatcher, Dispatcher
from ..concurrency.base import ConcurrencyBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import (
Response,
ResponseContent,
)
+from .base import AsyncDispatcher, Dispatcher
class ThreadedDispatcher(AsyncDispatcher):
import io
import typing
-from .base import Dispatcher
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..models import Request, Response
+from .base import Dispatcher
class WSGIDispatch(Dispatcher):
--- /dev/null
+import sys
+
+import pytest
+
+from httpx import AsyncioBackend, HTTPVersionConfig, SSLConfig, TimeoutConfig
+
+
+@pytest.mark.xfail(
+ sys.version_info < (3, 7),
+ reason="Requires Python 3.7+ for AbstractEventLoop.start_tls()",
+)
+@pytest.mark.asyncio
+async def test_start_tls_on_socket_stream(https_server):
+ """
+ See that the backend can make a connection without TLS then
+ start TLS on an existing connection.
+ """
+ backend = AsyncioBackend()
+ ctx = SSLConfig().load_ssl_context_no_verify(HTTPVersionConfig())
+ timeout = TimeoutConfig(5)
+
+ stream = await backend.connect("127.0.0.1", 8001, None, timeout)
+ assert stream.is_connection_dropped() is False
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is None
+
+ stream = await backend.start_tls(stream, "127.0.0.1", ctx, timeout)
+ assert stream.is_connection_dropped() is False
+ assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
+
+ await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+ assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")