async_dispatch = dispatch
if base_url is None:
- self.base_url = URL('', allow_relative=True)
+ self.base_url = URL("", allow_relative=True)
else:
self.base_url = URL(base_url)
import functools
import ssl
import typing
+from types import TracebackType
from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .interfaces import (
+ BaseBackgroundManager,
BasePoolSemaphore,
BaseReader,
BaseWriter,
_write = MonkeyPatch.write
def _fixed_write(self, data: bytes) -> None: # type: ignore
- if not self._loop.is_closed():
+ if self._loop and not self._loop.is_closed():
_write(self, data)
MonkeyPatch.write = _fixed_write
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
+
+ def background_manager(
+ self, coroutine: typing.Callable, args: typing.Any
+ ) -> "BackgroundManager":
+ return BackgroundManager(coroutine, args)
+
+
+class BackgroundManager(BaseBackgroundManager):
+ def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
+ self.coroutine = coroutine
+ self.args = args
+
+ async def __aenter__(self) -> "BackgroundManager":
+ loop = asyncio.get_event_loop()
+ self.task = loop.create_task(self.coroutine(*self.args))
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Type[BaseException] = None,
+ exc_value: BaseException = None,
+ traceback: TracebackType = None,
+ ) -> None:
+ await self.task
+ if exc_type is None:
+ self.task.result()
host, port, ssl_context, timeout
)
if protocol == Protocol.HTTP_2:
- self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
+ self.h2_connection = HTTP2Connection(
+ reader, writer, self.backend, on_release=on_release
+ )
else:
self.h11_connection = HTTP11Connection(
- reader, writer, on_release=on_release
+ reader, writer, self.backend, on_release=on_release
)
async def close(self) -> None:
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter
+from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
H11Event = typing.Union[
self,
reader: BaseReader,
writer: BaseWriter,
+ backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.reader = reader
self.writer = writer
+ self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
) -> AsyncResponse:
timeout = None if timeout is None else TimeoutConfig(timeout)
- # Start sending the request.
- method = request.method.encode("ascii")
- target = request.url.full_path.encode("ascii")
- headers = request.headers.raw
- if "Host" not in request.headers:
- host = request.url.authority.encode("ascii")
- headers = [(b"host", host)] + headers
- event = h11.Request(method=method, target=target, headers=headers)
- await self._send_event(event, timeout)
-
- # Send the request body.
- async for data in request.stream():
- event = h11.Data(data=data)
- await self._send_event(event, timeout)
-
- # Finalize sending the request.
- event = h11.EndOfMessage()
- await self._send_event(event, timeout)
-
- # Start getting the response.
- event = await self._receive_event(timeout)
- if isinstance(event, h11.InformationalResponse):
- event = await self._receive_event(timeout)
-
- assert isinstance(event, h11.Response)
- status_code = event.status_code
- headers = event.headers
- content = self._body_iter(timeout)
+ await self._send_request(request, timeout)
+ task, args = self._send_request_data, [request.stream(), timeout]
+ async with self.backend.background_manager(task, args=args):
+ status_code, headers = await self._receive_response(timeout)
+ content = self._receive_response_data(timeout)
return AsyncResponse(
status_code=status_code,
async def close(self) -> None:
event = h11.ConnectionClosed()
- self.h11_state.send(event)
+ try:
+ self.h11_state.send(event)
+ except h11.LocalProtocolError as exc: # pragma: no cover
+ # Premature client disconnect
+ pass
await self.writer.close()
- async def _body_iter(
+ async def _send_request(
+ self, request: AsyncRequest, timeout: TimeoutConfig = None
+ ) -> None:
+ """
+ Send the request method, URL, and headers to the network.
+ """
+ method = request.method.encode("ascii")
+ target = request.url.full_path.encode("ascii")
+ headers = request.headers.raw
+ if "Host" not in request.headers:
+ host = request.url.authority.encode("ascii")
+ headers = [(b"host", host)] + headers
+ event = h11.Request(method=method, target=target, headers=headers)
+ await self._send_event(event, timeout)
+
+ async def _send_request_data(
+ self, data: typing.AsyncIterator[bytes], timeout: TimeoutConfig = None
+ ) -> None:
+ """
+ Send the request body to the network.
+ """
+ try:
+ # Send the request body.
+ async for chunk in data:
+ event = h11.Data(data=chunk)
+ await self._send_event(event, timeout)
+
+ # Finalize sending the request.
+ event = h11.EndOfMessage()
+ await self._send_event(event, timeout)
+ except OSError: # pragma: nocover
+ # Once we've sent the initial part of the request we don't actually
+ # care about connection errors that occur when sending the body.
+ # Ignore these, and defer to any exceptions on reading the response.
+ self.h11_state.send_failed()
+
+ async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
+ """
+ Send a single `h11` event to the network, waiting for the data to
+ drain before returning.
+ """
+ bytes_to_send = self.h11_state.send(event)
+ await self.writer.write(bytes_to_send, timeout)
+
+ async def _receive_response(
+ self, timeout: TimeoutConfig = None
+ ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
+ """
+ Read the response status and headers from the network.
+ """
+ while True:
+ event = await self._receive_event(timeout)
+ if isinstance(event, h11.InformationalResponse):
+ continue
+ else:
+ assert isinstance(event, h11.Response)
+ break
+ return (event.status_code, event.headers)
+
+ async def _receive_response_data(
self, timeout: TimeoutConfig = None
) -> typing.AsyncIterator[bytes]:
- event = await self._receive_event(timeout)
- while isinstance(event, h11.Data):
- yield event.data
+ """
+ Read the response data from the network.
+ """
+ while True:
event = await self._receive_event(timeout)
- assert isinstance(event, h11.EndOfMessage)
-
- async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
- data = self.h11_state.send(event)
- await self.writer.write(data, timeout)
+ if isinstance(event, h11.Data):
+ yield event.data
+ else:
+ assert isinstance(event, h11.EndOfMessage)
+ break
async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
- event = self.h11_state.next_event()
-
- while event is h11.NEED_DATA:
- data = await self.reader.read(self.READ_NUM_BYTES, timeout)
- self.h11_state.receive_data(data)
+ """
+ Read a single `h11` event, reading more data from the network if needed.
+ """
+ while True:
event = self.h11_state.next_event()
-
+ if event is h11.NEED_DATA:
+ try:
+ data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+ except OSError: # pragma: nocover
+ data = b""
+ self.h11_state.receive_data(data)
+ else:
+ break
return event
async def response_closed(self) -> None:
from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter
+from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
READ_NUM_BYTES = 4096
def __init__(
- self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
+ self,
+ reader: BaseReader,
+ writer: BaseWriter,
+ backend: ConcurrencyBackend,
+ on_release: typing.Callable = None,
):
self.reader = reader
self.writer = writer
+ self.backend = backend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
stream_id = await self.send_headers(request, timeout)
self.events[stream_id] = []
- # Send the request body.
- async for data in request.stream():
- await self.send_data(stream_id, data, timeout)
-
- # Finalize sending the request.
- await self.end_stream(stream_id, timeout)
-
- # Start getting the response.
- while True:
- event = await self.receive_event(stream_id, timeout)
- if isinstance(event, h2.events.ResponseReceived):
- break
-
- status_code = 200
- headers = []
- for k, v in event.headers:
- if k == b":status":
- status_code = int(v.decode("ascii", errors="ignore"))
- elif not k.startswith(b":"):
- headers.append((k, v))
-
+ task, args = self.send_request_data, [stream_id, request.stream(), timeout]
+ async with self.backend.background_manager(task, args=args):
+ status_code, headers = await self.receive_response(stream_id, timeout)
content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
await self.writer.write(data_to_send, timeout)
return stream_id
+ async def send_request_data(
+ self,
+ stream_id: int,
+ stream: typing.AsyncIterator[bytes],
+ timeout: TimeoutConfig = None,
+ ) -> None:
+ async for data in stream:
+ await self.send_data(stream_id, data, timeout)
+ await self.end_stream(stream_id, timeout)
+
async def send_data(
self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
) -> None:
flow_control = self.h2_state.local_flow_control_window(stream_id)
chunk_size = min(len(data), flow_control)
for idx in range(0, len(data), chunk_size):
- chunk = data[idx:idx+chunk_size]
+ chunk = data[idx : idx + chunk_size]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
+ async def receive_response(
+ self, stream_id: int, timeout: TimeoutConfig = None
+ ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
+ """
+ Read the response status and headers from the network.
+ """
+ while True:
+ event = await self.receive_event(stream_id, timeout)
+ if isinstance(event, h2.events.ResponseReceived):
+ break
+
+ status_code = 200
+ headers = []
+ for k, v in event.headers:
+ if k == b":status":
+ status_code = int(v.decode("ascii", errors="ignore"))
+ elif not k.startswith(b":"):
+ headers.append((k, v))
+ return (status_code, headers)
+
async def body_iter(
self, stream_id: int, timeout: TimeoutConfig = None
) -> typing.AsyncIterator[bytes]:
yield self.run(async_iterator.__anext__)
except StopAsyncIteration:
break
+
+ def background_manager(
+ self, coroutine: typing.Callable, args: typing.Any
+ ) -> "BaseBackgroundManager":
+ raise NotImplementedError() # pragma: no cover
+
+
+class BaseBackgroundManager:
+ async def __aenter__(self) -> "BaseBackgroundManager":
+ raise NotImplementedError() # pragma: no cover
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Type[BaseException] = None,
+ exc_value: BaseException = None,
+ traceback: TracebackType = None,
+ ) -> None:
+ raise NotImplementedError() # pragma: no cover
def test_base_url(server):
base_url = "http://127.0.0.1:8000/"
with http3.Client(base_url=base_url) as http:
- response = http.get('/')
+ response = http.get("/")
assert response.status_code == 200
assert str(response.url) == base_url
+import asyncio
import ssl
import typing
# BaseReader interface
async def read(self, n, timeout) -> bytes:
+ await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send