]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Ensure H2 state is only accessed by the connection, not per-stream. (#628)
authorTom Christie <tom@tomchristie.com>
Thu, 12 Dec 2019 10:40:12 +0000 (10:40 +0000)
committerGitHub <noreply@github.com>
Thu, 12 Dec 2019 10:40:12 +0000 (10:40 +0000)
* Ensure H2 state is only accessed by the connection, not per-stream

* Formatting tweak

httpx/dispatch/http2.py

index 3439b6135bfdc41a9d29daa07cf7d5a04a12beea..e0ddf2355cb012c38caa2da8ae2c76e382725034 100644 (file)
@@ -64,7 +64,7 @@ class HTTP2Connection(OpenConnection):
             await self.init_complete.wait()
             stream_id = self.state.get_next_available_stream_id()
 
-        stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state)
+        stream = HTTP2Stream(stream_id=stream_id, connection=self)
         self.streams[stream_id] = stream
         self.events[stream_id] = []
         return await stream.send(request, timeout)
@@ -158,10 +158,32 @@ class HTTP2Connection(OpenConnection):
         data_to_send = self.state.data_to_send()
         await self.socket.write(data_to_send, timeout)
 
-    async def send_outgoing_data(self, timeout: Timeout) -> None:
+    async def send_headers(
+        self,
+        stream_id: int,
+        headers: typing.List[typing.Tuple[bytes, bytes]],
+        timeout: Timeout,
+    ) -> None:
+        self.state.send_headers(stream_id, headers)
+        data_to_send = self.state.data_to_send()
+        await self.socket.write(data_to_send, timeout)
+
+    async def send_data(self, stream_id: int, chunk: bytes, timeout: Timeout) -> None:
+        self.state.send_data(stream_id, chunk)
+        data_to_send = self.state.data_to_send()
+        await self.socket.write(data_to_send, timeout)
+
+    async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
+        self.state.end_stream(stream_id)
         data_to_send = self.state.data_to_send()
-        if data_to_send:
-            await self.socket.write(data_to_send, timeout)
+        await self.socket.write(data_to_send, timeout)
+
+    async def acknowledge_received_data(
+        self, stream_id: int, amount: int, timeout: Timeout
+    ) -> None:
+        self.state.acknowledge_received_data(amount, stream_id)
+        data_to_send = self.state.data_to_send()
+        await self.socket.write(data_to_send, timeout)
 
     async def close_stream(self, stream_id: int) -> None:
         del self.streams[stream_id]
@@ -172,15 +194,9 @@ class HTTP2Connection(OpenConnection):
 
 
 class HTTP2Stream:
-    def __init__(
-        self,
-        stream_id: int,
-        connection: HTTP2Connection,
-        state: h2.connection.H2Connection,
-    ) -> None:
+    def __init__(self, stream_id: int, connection: HTTP2Connection) -> None:
         self.stream_id = stream_id
         self.connection = connection
-        self.state = state
 
     async def send(self, request: Request, timeout: Timeout) -> Response:
         # Send the request.
@@ -214,8 +230,7 @@ class HTTP2Stream:
             f"target={request.url.full_path!r} "
             f"headers={headers!r}"
         )
-        self.state.send_headers(self.stream_id, headers)
-        await self.connection.send_outgoing_data(timeout)
+        await self.connection.send_headers(self.stream_id, headers, timeout)
 
     async def send_body(self, request: Request, timeout: Timeout) -> None:
         logger.trace(f"send_body stream_id={self.stream_id}")
@@ -226,11 +241,9 @@ class HTTP2Stream:
                 )
                 chunk_size = min(len(data), max_flow)
                 chunk, data = data[:chunk_size], data[chunk_size:]
-                self.state.send_data(self.stream_id, chunk)
-                await self.connection.send_outgoing_data(timeout)
+                await self.connection.send_data(self.stream_id, chunk, timeout)
 
-        self.state.end_stream(self.stream_id)
-        await self.connection.send_outgoing_data(timeout)
+        await self.connection.end_stream(self.stream_id, timeout)
 
     async def receive_response(
         self, timeout: Timeout
@@ -257,10 +270,10 @@ class HTTP2Stream:
         while True:
             event = await self.connection.wait_for_event(self.stream_id, timeout)
             if isinstance(event, h2.events.DataReceived):
-                self.state.acknowledge_received_data(
-                    event.flow_controlled_length, self.stream_id
+                amount = event.flow_controlled_length
+                await self.connection.acknowledge_received_data(
+                    self.stream_id, amount, timeout
                 )
-                await self.connection.send_outgoing_data(timeout)
                 yield event.data
             elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
                 break