-import functools
import typing
import h2.connection
BaseEvent,
BaseSocketStream,
ConcurrencyBackend,
- TimeoutFlag,
lookup_backend,
)
from ..config import Timeout
self.socket = socket
self.backend = lookup_backend(backend)
self.on_release = on_release
- self.h2_state = h2.connection.H2Connection()
+ self.state = h2.connection.H2Connection()
+
+ self.streams = {} # type: typing.Dict[int, HTTP2Stream]
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
- self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag]
- self.window_update_received = {} # type: typing.Dict[int, BaseEvent]
self.init_started = False
# The very first stream is responsible for initiating the connection.
self.init_started = True
await self.send_connection_init(timeout)
+ stream_id = self.state.get_next_available_stream_id()
self.init_complete.set()
else:
# All other streams need to wait until the connection is established.
await self.init_complete.wait()
+ stream_id = self.state.get_next_available_stream_id()
- stream_id = await self.send_headers(request, timeout)
-
+ stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state)
+ self.streams[stream_id] = stream
self.events[stream_id] = []
- self.timeout_flags[stream_id] = TimeoutFlag()
- self.window_update_received[stream_id] = self.backend.create_event()
-
- status_code: typing.Optional[int] = None
- headers: typing.Optional[list] = None
-
- async def receive_response(stream_id: int, timeout: Timeout) -> None:
- nonlocal status_code, headers
- status_code, headers = await self.receive_response(stream_id, timeout)
-
- await self.backend.fork(
- self.send_request_data,
- [stream_id, request.stream(), timeout],
- receive_response,
- [stream_id, timeout],
- )
-
- assert status_code is not None
- assert headers is not None
-
- content = self.body_iter(stream_id, timeout)
- on_close = functools.partial(self.response_closed, stream_id=stream_id)
-
- return Response(
- status_code=status_code,
- http_version="HTTP/2",
- headers=headers,
- content=content,
- on_close=on_close,
- request=request,
- )
-
- async def close(self) -> None:
- await self.socket.close()
+ return await stream.send(request, timeout)
async def send_connection_init(self, timeout: Timeout) -> None:
+ """
+ The HTTP/2 connection requires some initial setup before we can start
+ using individual request/response streams on it.
+ """
+
# Need to set these manually here instead of manipulating via
# __setitem__() otherwise the H2Connection will emit SettingsUpdate
# frames in addition to sending the undesired defaults.
- self.h2_state.local_settings = Settings(
+ self.state.local_settings = Settings(
client=True,
initial_values={
# Disable PUSH_PROMISE frames from the server since we don't do anything
# Some websites (*cough* Yahoo *cough*) balk at this setting being
# present in the initial handshake since it's not defined in the original
# RFC despite the RFC mandating ignoring settings you don't know about.
- del self.h2_state.local_settings[
- h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
- ]
+ del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
- self.h2_state.initiate_connection()
- data_to_send = self.h2_state.data_to_send()
+ self.state.initiate_connection()
+ data_to_send = self.state.data_to_send()
await self.socket.write(data_to_send, timeout)
- async def send_headers(self, request: Request, timeout: Timeout) -> int:
- stream_id = self.h2_state.get_next_available_stream_id()
+ @property
+ def is_closed(self) -> bool:
+ return False
+
+ def is_connection_dropped(self) -> bool:
+ return self.socket.is_connection_dropped()
+
+ async def close(self) -> None:
+ await self.socket.close()
+
+ async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
+ """
+ Returns the maximum allowable outgoing flow for a given stream.
+
+ If the allowable flow is zero, then waits on the network until
+ WindowUpdated frames have increased the flow rate.
+
+ https://tools.ietf.org/html/rfc7540#section-6.9
+ """
+ local_flow = self.state.local_flow_control_window(stream_id)
+ connection_flow = self.state.max_outbound_frame_size
+ flow = min(local_flow, connection_flow)
+ while flow == 0:
+ await self.receive_events(timeout)
+ local_flow = self.state.local_flow_control_window(stream_id)
+ connection_flow = self.state.max_outbound_frame_size
+ flow = min(local_flow, connection_flow)
+ return flow
+
+ async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
+ """
+ Returns the next event for a given stream.
+
+ If no events are available yet, then waits on the network until
+ an event is available.
+ """
+ while not self.events[stream_id]:
+ await self.receive_events(timeout)
+ return self.events[stream_id].pop(0)
+
+ async def receive_events(self, timeout: Timeout) -> None:
+ """
+ Read some data from the network, and update the H2 state.
+ """
+ data = await self.socket.read(self.READ_NUM_BYTES, timeout)
+ events = self.state.receive_data(data)
+ for event in events:
+ event_stream_id = getattr(event, "stream_id", 0)
+ logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
+
+ if hasattr(event, "error_code"):
+ raise ProtocolError(event)
+
+ if event_stream_id in self.events:
+ self.events[event_stream_id].append(event)
+
+ data_to_send = self.state.data_to_send()
+ await self.socket.write(data_to_send, timeout)
+
+ async def send_outgoing_data(self, timeout: Timeout) -> None:
+ data_to_send = self.state.data_to_send()
+ if data_to_send:
+ await self.socket.write(data_to_send, timeout)
+
+ async def close_stream(self, stream_id: int) -> None:
+ del self.streams[stream_id]
+ del self.events[stream_id]
+
+ if not self.streams and self.on_release is not None:
+ await self.on_release()
+
+
+class HTTP2Stream:
+ def __init__(
+ self,
+ stream_id: int,
+ connection: HTTP2Connection,
+ state: h2.connection.H2Connection,
+ ) -> None:
+ self.stream_id = stream_id
+ self.connection = connection
+ self.state = state
+
+ async def send(self, request: Request, timeout: Timeout) -> Response:
+ # Send the request.
+ await self.send_headers(request, timeout)
+ await self.send_body(request, timeout)
+
+ # Receive the response.
+ status_code, headers = await self.receive_response(timeout)
+ content = self.body_iter(timeout)
+ return Response(
+ status_code=status_code,
+ http_version="HTTP/2",
+ headers=headers,
+ content=content,
+ on_close=self.close,
+ request=request,
+ )
+
+ async def send_headers(self, request: Request, timeout: Timeout) -> None:
headers = [
(b":method", request.method.encode("ascii")),
(b":authority", request.url.authority.encode("ascii")),
logger.trace(
f"send_headers "
- f"stream_id={stream_id} "
+ f"stream_id={self.stream_id} "
f"method={request.method!r} "
f"target={request.url.full_path!r} "
f"headers={headers!r}"
)
- self.h2_state.send_headers(stream_id, headers)
- data_to_send = self.h2_state.data_to_send()
- await self.socket.write(data_to_send, timeout)
- return stream_id
-
- async def send_request_data(
- self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout,
- ) -> None:
- try:
- async for data in stream:
- await self.send_data(stream_id, data, timeout)
- await self.end_stream(stream_id, timeout)
- finally:
- # Once we've sent the request we should enable read timeouts.
- self.timeout_flags[stream_id].set_read_timeouts()
-
- async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None:
- while data:
- # The data will be divided into frames to send based on the flow control
- # window and the maximum frame size. Because the flow control window
- # can decrease in size, even possibly to zero, this will loop until all the
- # data is sent. In http2 specification:
- # https://tools.ietf.org/html/rfc7540#section-6.9
- flow_control = self.h2_state.local_flow_control_window(stream_id)
- chunk_size = min(
- len(data), flow_control, self.h2_state.max_outbound_frame_size
- )
- if chunk_size == 0:
- # this means that the flow control window is 0 (either for the stream
- # or the connection one), and no data can be sent until the flow control
- # window is updated.
- await self.window_update_received[stream_id].wait()
- self.window_update_received[stream_id].clear()
- else:
+ self.state.send_headers(self.stream_id, headers)
+ await self.connection.send_outgoing_data(timeout)
+
+ async def send_body(self, request: Request, timeout: Timeout) -> None:
+ logger.trace(f"send_body stream_id={self.stream_id}")
+ async for data in request.stream():
+ while data:
+ max_flow = await self.connection.wait_for_outgoing_flow(
+ self.stream_id, timeout
+ )
+ chunk_size = min(len(data), max_flow)
chunk, data = data[:chunk_size], data[chunk_size:]
- self.h2_state.send_data(stream_id, chunk)
- data_to_send = self.h2_state.data_to_send()
- await self.socket.write(data_to_send, timeout)
-
- async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
- logger.trace(f"end_stream stream_id={stream_id}")
- self.h2_state.end_stream(stream_id)
- data_to_send = self.h2_state.data_to_send()
- await self.socket.write(data_to_send, timeout)
+ self.state.send_data(self.stream_id, chunk)
+ await self.connection.send_outgoing_data(timeout)
+
+ self.state.end_stream(self.stream_id)
+ await self.connection.send_outgoing_data(timeout)
async def receive_response(
- self, stream_id: int, timeout: Timeout
+ self, timeout: Timeout
) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
"""
Read the response status and headers from the network.
"""
while True:
- event = await self.receive_event(stream_id, timeout)
- # As soon as we start seeing response events, we should enable
- # read timeouts, if we haven't already.
- self.timeout_flags[stream_id].set_read_timeouts()
+ event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.ResponseReceived):
break
return (status_code, headers)
- async def body_iter(
- self, stream_id: int, timeout: Timeout
- ) -> typing.AsyncIterator[bytes]:
+ async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
while True:
- event = await self.receive_event(stream_id, timeout)
+ event = await self.connection.wait_for_event(self.stream_id, timeout)
if isinstance(event, h2.events.DataReceived):
- self.h2_state.acknowledge_received_data(
- event.flow_controlled_length, stream_id
+ self.state.acknowledge_received_data(
+ event.flow_controlled_length, self.stream_id
)
+ await self.connection.send_outgoing_data(timeout)
yield event.data
elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
break
- async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
- while not self.events[stream_id]:
- flag = self.timeout_flags[stream_id]
- data = await self.socket.read(self.READ_NUM_BYTES, timeout, flag=flag)
- events = self.h2_state.receive_data(data)
- for event in events:
- event_stream_id = getattr(event, "stream_id", 0)
- logger.trace(
- f"receive_event stream_id={event_stream_id} event={event!r}"
- )
-
- if hasattr(event, "error_code"):
- raise ProtocolError(event)
-
- if isinstance(event, h2.events.WindowUpdated):
- if event_stream_id == 0:
- for window_update_event in self.window_update_received.values():
- window_update_event.set()
- else:
- try:
- self.window_update_received[event_stream_id].set()
- except KeyError: # pragma: no cover
- # the window_update_received dictionary is only relevant
- # when sending data, which should never raise a KeyError
- # here.
- pass
-
- if event_stream_id:
- self.events[event.stream_id].append(event)
-
- data_to_send = self.h2_state.data_to_send()
- await self.socket.write(data_to_send, timeout)
-
- return self.events[stream_id].pop(0)
-
- async def response_closed(self, stream_id: int) -> None:
- del self.events[stream_id]
- del self.timeout_flags[stream_id]
- del self.window_update_received[stream_id]
-
- if not self.events and self.on_release is not None:
- await self.on_release()
-
- @property
- def is_closed(self) -> bool:
- return False
-
- def is_connection_dropped(self) -> bool:
- return self.socket.is_connection_dropped()
+ async def close(self) -> None:
+ await self.connection.close_stream(self.stream_id)