from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
- BaseReader,
- BaseWriter,
+ BaseStream,
ConcurrencyBackend,
)
from .config import (
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
- "BaseReader",
- "BaseWriter",
+ "BaseStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
"""
-The `Reader` and `Writer` classes here provide a lightweight layer over
+The `Stream` class here provides a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
import typing
from types import TracebackType
+from ..config import PoolLimits, TimeoutConfig
+from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BaseQueue,
- BaseReader,
- BaseWriter,
+ BaseStream,
ConcurrencyBackend,
TimeoutFlag,
)
-from ..config import PoolLimits, TimeoutConfig
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
SSL_MONKEY_PATCH_APPLIED = False
MonkeyPatch.write = _fixed_write
-class Reader(BaseReader):
+class Stream(BaseStream):
def __init__(
- self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
- ) -> None:
+ self,
+ stream_reader: asyncio.StreamReader,
+ stream_writer: asyncio.StreamWriter,
+ timeout: TimeoutConfig,
+ ):
self.stream_reader = stream_reader
+ self.stream_writer = stream_writer
self.timeout = timeout
async def read(
return data
- def is_connection_dropped(self) -> bool:
- return self.stream_reader.at_eof()
-
-
-class Writer(BaseWriter):
- def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
- self.stream_writer = stream_writer
- self.timeout = timeout
-
def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover
if should_raise:
raise WriteTimeout() from None
+ def is_connection_dropped(self) -> bool:
+ return self.stream_reader.at_eof()
+
async def close(self) -> None:
self.stream_writer.close()
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+ ) -> typing.Tuple[BaseStream, str]:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
if ident is None:
ident = ssl_object.selected_npn_protocol()
- reader = Reader(stream_reader=stream_reader, timeout=timeout)
- writer = Writer(stream_writer=stream_writer, timeout=timeout)
+ stream = Stream(
+ stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+ )
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"
- return reader, writer, http_version
+ return stream, http_version
async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
self.raise_on_write_timeout = True
-class BaseReader:
+class BaseStream:
"""
- A stream reader. Abstracts away any asyncio-specific interfaces
- into a more generic base class, that we can use with alternate
- backend, or for stand-alone test cases.
+ A stream with read/write operations. Abstracts away any asyncio-specific
+ interfaces into a more generic base class, that we can use with alternate
+ backends, or for stand-alone test cases.
"""
async def read(
) -> bytes:
raise NotImplementedError() # pragma: no cover
- def is_connection_dropped(self) -> bool:
- raise NotImplementedError() # pragma: no cover
-
-
-class BaseWriter:
- """
- A stream writer. Abstracts away any asyncio-specific interfaces
- into a more generic base class, that we can use with alternate
- backend, or for stand-alone test cases.
- """
-
def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover
async def close(self) -> None:
raise NotImplementedError() # pragma: no cover
+ def is_connection_dropped(self) -> bool:
+ raise NotImplementedError() # pragma: no cover
+
class BaseQueue:
"""
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+ ) -> typing.Tuple[BaseStream, str]:
raise NotImplementedError() # pragma: no cover
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
else:
on_release = functools.partial(self.release_func, self)
- reader, writer, http_version = await self.backend.connect(
+ stream, http_version = await self.backend.connect(
host, port, ssl_context, timeout
)
if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
- reader, writer, self.backend, on_release=on_release
+ stream, self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(
- reader, writer, self.backend, on_release=on_release
+ stream, self.backend, on_release=on_release
)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
import h11
-from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
def __init__(
self,
- reader: BaseReader,
- writer: BaseWriter,
+ stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
- self.reader = reader
- self.writer = writer
+ self.stream = stream
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
except h11.LocalProtocolError: # pragma: no cover
# Premature client disconnect
pass
- await self.writer.close()
+ await self.stream.close()
async def _send_request(
self, request: AsyncRequest, timeout: TimeoutConfig = None
drain before returning.
"""
bytes_to_send = self.h11_state.send(event)
- await self.writer.write(bytes_to_send, timeout)
+ await self.stream.write(bytes_to_send, timeout)
async def _receive_response(
self, timeout: TimeoutConfig = None
event = self.h11_state.next_event()
if event is h11.NEED_DATA:
try:
- data = await self.reader.read(
+ data = await self.stream.read(
self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
)
except OSError: # pragma: nocover
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
def is_connection_dropped(self) -> bool:
- return self.reader.is_connection_dropped()
+ return self.stream.is_connection_dropped()
import h2.connection
import h2.events
-from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse
def __init__(
self,
- reader: BaseReader,
- writer: BaseWriter,
+ stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
- self.reader = reader
- self.writer = writer
+ self.stream = stream
self.backend = backend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
)
async def close(self) -> None:
- await self.writer.close()
+ await self.stream.close()
def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
- self.writer.write_no_block(data_to_send)
+ self.stream.write_no_block(data_to_send)
self.initialized = True
async def send_headers(
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
- await self.writer.write(data_to_send, timeout)
+ await self.stream.write(data_to_send, timeout)
return stream_id
async def send_request_data(
chunk = data[idx : idx + chunk_size]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
- await self.writer.write(data_to_send, timeout)
+ await self.stream.write(data_to_send, timeout)
async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
- await self.writer.write(data_to_send, timeout)
+ await self.stream.write(data_to_send, timeout)
async def receive_response(
self, stream_id: int, timeout: TimeoutConfig = None
) -> h2.events.Event:
while not self.events[stream_id]:
flag = self.timeout_flags[stream_id]
- data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
+ data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)
data_to_send = self.h2_state.data_to_send()
- await self.writer.write(data_to_send, timeout)
+ await self.stream.write(data_to_send, timeout)
return self.events[stream_id].pop(0)
return False
def is_connection_dropped(self) -> bool:
- return self.reader.is_connection_dropped()
+ return self.stream.is_connection_dropped()
import h2.connection
import h2.events
-from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig
+from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig
class MockHTTP2Backend(AsyncioBackend):
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
- ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+ ) -> typing.Tuple[BaseStream, str]:
self.server = MockHTTP2Server(self.app)
- return self.server, self.server, "HTTP/2"
+ return self.server, "HTTP/2"
-class MockHTTP2Server(BaseReader, BaseWriter):
- """
- This class exposes Reader and Writer style interfaces.
- """
-
+class MockHTTP2Server(BaseStream):
def __init__(self, app):
config = h2.config.H2Configuration(client_side=False)
self.conn = h2.connection.H2Connection(config=config)
self.requests = {}
self.close_connection = False
- # BaseReader interface
+ # Stream interface
async def read(self, n, timeout, flag=None) -> bytes:
await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send
- # BaseWriter interface
-
def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()