from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
- BaseTCPStream,
+ BaseSocketStream,
ConcurrencyBackend,
)
from .config import (
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
- "BaseTCPStream",
+ "BaseSocketStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
BaseEvent,
BasePoolSemaphore,
BaseQueue,
- BaseTCPStream,
+ BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
)
MonkeyPatch.write = _fixed_write
-class TCPStream(BaseTCPStream):
+class SocketStream(BaseSocketStream):
def __init__(
self,
stream_reader: asyncio.StreamReader,
self.stream_writer = stream_writer
self.timeout = timeout
- self._inner: typing.Optional[TCPStream] = None
+ self._inner: typing.Optional[SocketStream] = None
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
- ) -> BaseTCPStream:
+ ) -> "SocketStream":
loop = asyncio.get_event_loop()
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
- ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
- # When we return a new TCPStream with new StreamReader/StreamWriter instances,
+ ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout)
+ # When we return a new SocketStream with new StreamReader/StreamWriter instances
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ssl_stream._inner = self
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> BaseTCPStream:
+ ) -> SocketStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
except asyncio.TimeoutError:
raise ConnectTimeout()
- return TCPStream(
+ return SocketStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
self.raise_on_write_timeout = True
-class BaseTCPStream:
+class BaseSocketStream:
"""
- A TCP stream with read/write operations. Abstracts away any asyncio-specific
+ A socket stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
- ) -> "BaseTCPStream":
+ ) -> "BaseSocketStream":
raise NotImplementedError() # pragma: no cover
async def read(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> BaseTCPStream:
+ ) -> BaseSocketStream:
raise NotImplementedError() # pragma: no cover
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
BaseEvent,
BasePoolSemaphore,
BaseQueue,
- BaseTCPStream,
+ BaseSocketStream,
ConcurrencyBackend,
TimeoutFlag,
)
return value if value is not None else float("inf")
-class TCPStream(BaseTCPStream):
+class SocketStream(BaseSocketStream):
def __init__(
self,
stream: typing.Union[trio.SocketStream, trio.SSLStream],
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
- ) -> BaseTCPStream:
+ ) -> "SocketStream":
# Check that the write buffer is empty. We should never start a TLS stream
# while there is still pending data to write.
assert self.write_buffer == b""
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
- return TCPStream(ssl_stream, self.timeout)
+ return SocketStream(ssl_stream, self.timeout)
def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> TCPStream:
+ ) -> SocketStream:
connect_timeout = _or_inf(timeout.connect_timeout)
with trio.move_on_after(connect_timeout) as cancel_scope:
if cancel_scope.cancelled_caught:
raise ConnectTimeout()
- return TCPStream(stream=stream, timeout=timeout)
+ return SocketStream(stream=stream, timeout=timeout)
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
import h11
-from ..concurrency.base import BaseTCPStream, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger
def __init__(
self,
- stream: BaseTCPStream,
+ stream: BaseSocketStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
import h2.events
from h2.settings import SettingCodes, Settings
-from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import (
+ BaseEvent,
+ BaseSocketStream,
+ ConcurrencyBackend,
+ TimeoutFlag,
+)
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse
def __init__(
self,
- stream: BaseTCPStream,
+ stream: BaseSocketStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
stream = http_connection.stream
# If we need to start TLS again for the target server
- # we need to pull the TCP stream off the internal
+ # we need to pull the socket stream off the internal
# HTTP connection object and run start_tls()
if origin.is_ssl:
ssl_config = SSLConfig(cert=self.cert, verify=self.verify)
import h2.connection
import h2.events
-from httpx import AsyncioBackend, BaseTCPStream, Request, TimeoutConfig
+from httpx import AsyncioBackend, BaseSocketStream, Request, TimeoutConfig
from tests.concurrency import sleep
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> BaseTCPStream:
+ ) -> BaseSocketStream:
self.server = MockHTTP2Server(self.app, backend=self.backend)
return self.server
return getattr(self.backend, name)
-class MockHTTP2Server(BaseTCPStream):
+class MockHTTP2Server(BaseSocketStream):
def __init__(self, app, backend):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.returning = {}
self.settings_changed = []
- # TCP stream interface
+ # Socket stream interface
def get_http_version(self) -> str:
return "HTTP/2"
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> BaseTCPStream:
+ ) -> BaseSocketStream:
self.received_data.append(
b"--- CONNECT(%s, %d) ---" % (hostname.encode(), port)
)
return getattr(self.backend, name)
-class MockRawSocketStream(BaseTCPStream):
+class MockRawSocketStream(BaseSocketStream):
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend
async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
- ) -> BaseTCPStream:
+ ) -> BaseSocketStream:
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return MockRawSocketStream(self.backend)