From f57bb2f14217baec4e88c4f10941c57aeb627d8b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 7 Dec 2019 14:14:09 +0000 Subject: [PATCH] HTTP/2 refactoring (#612) * HTTP/2 refactoring * Clean up flow control * Remove extra blank line --- httpx/dispatch/http2.py | 297 +++++++++++++++++++--------------------- 1 file changed, 144 insertions(+), 153 deletions(-) diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 471ba9b7..6021483f 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -1,4 +1,3 @@ -import functools import typing import h2.connection @@ -9,7 +8,6 @@ from ..concurrency.base import ( BaseEvent, BaseSocketStream, ConcurrencyBackend, - TimeoutFlag, lookup_backend, ) from ..config import Timeout @@ -32,10 +30,10 @@ class HTTP2Connection: self.socket = socket self.backend = lookup_backend(backend) self.on_release = on_release - self.h2_state = h2.connection.H2Connection() + self.state = h2.connection.H2Connection() + + self.streams = {} # type: typing.Dict[int, HTTP2Stream] self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] - self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag] - self.window_update_received = {} # type: typing.Dict[int, BaseEvent] self.init_started = False @@ -54,54 +52,28 @@ class HTTP2Connection: # The very first stream is responsible for initiating the connection. self.init_started = True await self.send_connection_init(timeout) + stream_id = self.state.get_next_available_stream_id() self.init_complete.set() else: # All other streams need to wait until the connection is established. await self.init_complete.wait() + stream_id = self.state.get_next_available_stream_id() - stream_id = await self.send_headers(request, timeout) - + stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state) + self.streams[stream_id] = stream self.events[stream_id] = [] - self.timeout_flags[stream_id] = TimeoutFlag() - self.window_update_received[stream_id] = self.backend.create_event() - - status_code: typing.Optional[int] = None - headers: typing.Optional[list] = None - - async def receive_response(stream_id: int, timeout: Timeout) -> None: - nonlocal status_code, headers - status_code, headers = await self.receive_response(stream_id, timeout) - - await self.backend.fork( - self.send_request_data, - [stream_id, request.stream(), timeout], - receive_response, - [stream_id, timeout], - ) - - assert status_code is not None - assert headers is not None - - content = self.body_iter(stream_id, timeout) - on_close = functools.partial(self.response_closed, stream_id=stream_id) - - return Response( - status_code=status_code, - http_version="HTTP/2", - headers=headers, - content=content, - on_close=on_close, - request=request, - ) - - async def close(self) -> None: - await self.socket.close() + return await stream.send(request, timeout) async def send_connection_init(self, timeout: Timeout) -> None: + """ + The HTTP/2 connection requires some initial setup before we can start + using individual request/response streams on it. + """ + # Need to set these manually here instead of manipulating via # __setitem__() otherwise the H2Connection will emit SettingsUpdate # frames in addition to sending the undesired defaults. - self.h2_state.local_settings = Settings( + self.state.local_settings = Settings( client=True, initial_values={ # Disable PUSH_PROMISE frames from the server since we don't do anything @@ -116,16 +88,113 @@ class HTTP2Connection: # Some websites (*cough* Yahoo *cough*) balk at this setting being # present in the initial handshake since it's not defined in the original # RFC despite the RFC mandating ignoring settings you don't know about. - del self.h2_state.local_settings[ - h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL - ] + del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] - self.h2_state.initiate_connection() - data_to_send = self.h2_state.data_to_send() + self.state.initiate_connection() + data_to_send = self.state.data_to_send() await self.socket.write(data_to_send, timeout) - async def send_headers(self, request: Request, timeout: Timeout) -> int: - stream_id = self.h2_state.get_next_available_stream_id() + @property + def is_closed(self) -> bool: + return False + + def is_connection_dropped(self) -> bool: + return self.socket.is_connection_dropped() + + async def close(self) -> None: + await self.socket.close() + + async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow = self.state.local_flow_control_window(stream_id) + connection_flow = self.state.max_outbound_frame_size + flow = min(local_flow, connection_flow) + while flow == 0: + await self.receive_events(timeout) + local_flow = self.state.local_flow_control_window(stream_id) + connection_flow = self.state.max_outbound_frame_size + flow = min(local_flow, connection_flow) + return flow + + async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event: + """ + Returns the next event for a given stream. + + If no events are available yet, then waits on the network until + an event is available. + """ + while not self.events[stream_id]: + await self.receive_events(timeout) + return self.events[stream_id].pop(0) + + async def receive_events(self, timeout: Timeout) -> None: + """ + Read some data from the network, and update the H2 state. + """ + data = await self.socket.read(self.READ_NUM_BYTES, timeout) + events = self.state.receive_data(data) + for event in events: + event_stream_id = getattr(event, "stream_id", 0) + logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}") + + if hasattr(event, "error_code"): + raise ProtocolError(event) + + if event_stream_id in self.events: + self.events[event_stream_id].append(event) + + data_to_send = self.state.data_to_send() + await self.socket.write(data_to_send, timeout) + + async def send_outgoing_data(self, timeout: Timeout) -> None: + data_to_send = self.state.data_to_send() + if data_to_send: + await self.socket.write(data_to_send, timeout) + + async def close_stream(self, stream_id: int) -> None: + del self.streams[stream_id] + del self.events[stream_id] + + if not self.streams and self.on_release is not None: + await self.on_release() + + +class HTTP2Stream: + def __init__( + self, + stream_id: int, + connection: HTTP2Connection, + state: h2.connection.H2Connection, + ) -> None: + self.stream_id = stream_id + self.connection = connection + self.state = state + + async def send(self, request: Request, timeout: Timeout) -> Response: + # Send the request. + await self.send_headers(request, timeout) + await self.send_body(request, timeout) + + # Receive the response. + status_code, headers = await self.receive_response(timeout) + content = self.body_iter(timeout) + return Response( + status_code=status_code, + http_version="HTTP/2", + headers=headers, + content=content, + on_close=self.close, + request=request, + ) + + async def send_headers(self, request: Request, timeout: Timeout) -> None: headers = [ (b":method", request.method.encode("ascii")), (b":authority", request.url.authority.encode("ascii")), @@ -135,67 +204,37 @@ class HTTP2Connection: logger.trace( f"send_headers " - f"stream_id={stream_id} " + f"stream_id={self.stream_id} " f"method={request.method!r} " f"target={request.url.full_path!r} " f"headers={headers!r}" ) - self.h2_state.send_headers(stream_id, headers) - data_to_send = self.h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - return stream_id - - async def send_request_data( - self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout, - ) -> None: - try: - async for data in stream: - await self.send_data(stream_id, data, timeout) - await self.end_stream(stream_id, timeout) - finally: - # Once we've sent the request we should enable read timeouts. - self.timeout_flags[stream_id].set_read_timeouts() - - async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None: - while data: - # The data will be divided into frames to send based on the flow control - # window and the maximum frame size. Because the flow control window - # can decrease in size, even possibly to zero, this will loop until all the - # data is sent. In http2 specification: - # https://tools.ietf.org/html/rfc7540#section-6.9 - flow_control = self.h2_state.local_flow_control_window(stream_id) - chunk_size = min( - len(data), flow_control, self.h2_state.max_outbound_frame_size - ) - if chunk_size == 0: - # this means that the flow control window is 0 (either for the stream - # or the connection one), and no data can be sent until the flow control - # window is updated. - await self.window_update_received[stream_id].wait() - self.window_update_received[stream_id].clear() - else: + self.state.send_headers(self.stream_id, headers) + await self.connection.send_outgoing_data(timeout) + + async def send_body(self, request: Request, timeout: Timeout) -> None: + logger.trace(f"send_body stream_id={self.stream_id}") + async for data in request.stream(): + while data: + max_flow = await self.connection.wait_for_outgoing_flow( + self.stream_id, timeout + ) + chunk_size = min(len(data), max_flow) chunk, data = data[:chunk_size], data[chunk_size:] - self.h2_state.send_data(stream_id, chunk) - data_to_send = self.h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - async def end_stream(self, stream_id: int, timeout: Timeout) -> None: - logger.trace(f"end_stream stream_id={stream_id}") - self.h2_state.end_stream(stream_id) - data_to_send = self.h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) + self.state.send_data(self.stream_id, chunk) + await self.connection.send_outgoing_data(timeout) + + self.state.end_stream(self.stream_id) + await self.connection.send_outgoing_data(timeout) async def receive_response( - self, stream_id: int, timeout: Timeout + self, timeout: Timeout ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: """ Read the response status and headers from the network. """ while True: - event = await self.receive_event(stream_id, timeout) - # As soon as we start seeing response events, we should enable - # read timeouts, if we haven't already. - self.timeout_flags[stream_id].set_read_timeouts() + event = await self.connection.wait_for_event(self.stream_id, timeout) if isinstance(event, h2.events.ResponseReceived): break @@ -209,65 +248,17 @@ class HTTP2Connection: return (status_code, headers) - async def body_iter( - self, stream_id: int, timeout: Timeout - ) -> typing.AsyncIterator[bytes]: + async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]: while True: - event = await self.receive_event(stream_id, timeout) + event = await self.connection.wait_for_event(self.stream_id, timeout) if isinstance(event, h2.events.DataReceived): - self.h2_state.acknowledge_received_data( - event.flow_controlled_length, stream_id + self.state.acknowledge_received_data( + event.flow_controlled_length, self.stream_id ) + await self.connection.send_outgoing_data(timeout) yield event.data elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): break - async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event: - while not self.events[stream_id]: - flag = self.timeout_flags[stream_id] - data = await self.socket.read(self.READ_NUM_BYTES, timeout, flag=flag) - events = self.h2_state.receive_data(data) - for event in events: - event_stream_id = getattr(event, "stream_id", 0) - logger.trace( - f"receive_event stream_id={event_stream_id} event={event!r}" - ) - - if hasattr(event, "error_code"): - raise ProtocolError(event) - - if isinstance(event, h2.events.WindowUpdated): - if event_stream_id == 0: - for window_update_event in self.window_update_received.values(): - window_update_event.set() - else: - try: - self.window_update_received[event_stream_id].set() - except KeyError: # pragma: no cover - # the window_update_received dictionary is only relevant - # when sending data, which should never raise a KeyError - # here. - pass - - if event_stream_id: - self.events[event.stream_id].append(event) - - data_to_send = self.h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - return self.events[stream_id].pop(0) - - async def response_closed(self, stream_id: int) -> None: - del self.events[stream_id] - del self.timeout_flags[stream_id] - del self.window_update_received[stream_id] - - if not self.events and self.on_release is not None: - await self.on_release() - - @property - def is_closed(self) -> bool: - return False - - def is_connection_dropped(self) -> bool: - return self.socket.is_connection_dropped() + async def close(self) -> None: + await self.connection.close_stream(self.stream_id) -- 2.47.3