From: Tom Christie Date: Thu, 12 Dec 2019 10:40:12 +0000 (+0000) Subject: Ensure H2 state is only accessed by the connection, not per-stream. (#628) X-Git-Tag: 0.9.4~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1d25bd58a86d725754c4fe1ae4fc4cd3c22b9b57;p=thirdparty%2Fhttpx.git Ensure H2 state is only accessed by the connection, not per-stream. (#628) * Ensure H2 state is only accessed by the connection, not per-stream * Formatting tweak --- diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 3439b613..e0ddf235 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -64,7 +64,7 @@ class HTTP2Connection(OpenConnection): await self.init_complete.wait() stream_id = self.state.get_next_available_stream_id() - stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state) + stream = HTTP2Stream(stream_id=stream_id, connection=self) self.streams[stream_id] = stream self.events[stream_id] = [] return await stream.send(request, timeout) @@ -158,10 +158,32 @@ class HTTP2Connection(OpenConnection): data_to_send = self.state.data_to_send() await self.socket.write(data_to_send, timeout) - async def send_outgoing_data(self, timeout: Timeout) -> None: + async def send_headers( + self, + stream_id: int, + headers: typing.List[typing.Tuple[bytes, bytes]], + timeout: Timeout, + ) -> None: + self.state.send_headers(stream_id, headers) + data_to_send = self.state.data_to_send() + await self.socket.write(data_to_send, timeout) + + async def send_data(self, stream_id: int, chunk: bytes, timeout: Timeout) -> None: + self.state.send_data(stream_id, chunk) + data_to_send = self.state.data_to_send() + await self.socket.write(data_to_send, timeout) + + async def end_stream(self, stream_id: int, timeout: Timeout) -> None: + self.state.end_stream(stream_id) data_to_send = self.state.data_to_send() - if data_to_send: - await self.socket.write(data_to_send, timeout) + await self.socket.write(data_to_send, timeout) + + async def acknowledge_received_data( + self, stream_id: int, amount: int, timeout: Timeout + ) -> None: + self.state.acknowledge_received_data(amount, stream_id) + data_to_send = self.state.data_to_send() + await self.socket.write(data_to_send, timeout) async def close_stream(self, stream_id: int) -> None: del self.streams[stream_id] @@ -172,15 +194,9 @@ class HTTP2Connection(OpenConnection): class HTTP2Stream: - def __init__( - self, - stream_id: int, - connection: HTTP2Connection, - state: h2.connection.H2Connection, - ) -> None: + def __init__(self, stream_id: int, connection: HTTP2Connection) -> None: self.stream_id = stream_id self.connection = connection - self.state = state async def send(self, request: Request, timeout: Timeout) -> Response: # Send the request. @@ -214,8 +230,7 @@ class HTTP2Stream: f"target={request.url.full_path!r} " f"headers={headers!r}" ) - self.state.send_headers(self.stream_id, headers) - await self.connection.send_outgoing_data(timeout) + await self.connection.send_headers(self.stream_id, headers, timeout) async def send_body(self, request: Request, timeout: Timeout) -> None: logger.trace(f"send_body stream_id={self.stream_id}") @@ -226,11 +241,9 @@ class HTTP2Stream: ) chunk_size = min(len(data), max_flow) chunk, data = data[:chunk_size], data[chunk_size:] - self.state.send_data(self.stream_id, chunk) - await self.connection.send_outgoing_data(timeout) + await self.connection.send_data(self.stream_id, chunk, timeout) - self.state.end_stream(self.stream_id) - await self.connection.send_outgoing_data(timeout) + await self.connection.end_stream(self.stream_id, timeout) async def receive_response( self, timeout: Timeout @@ -257,10 +270,10 @@ class HTTP2Stream: while True: event = await self.connection.wait_for_event(self.stream_id, timeout) if isinstance(event, h2.events.DataReceived): - self.state.acknowledge_received_data( - event.flow_controlled_length, self.stream_id + amount = event.flow_controlled_length + await self.connection.acknowledge_received_data( + self.stream_id, amount, timeout ) - await self.connection.send_outgoing_data(timeout) yield event.data elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): break