]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Fix bug in calculating maximum frame size (#153)
authorIsaacBenghiat <isaac_benghiat@brown.edu>
Fri, 6 Sep 2019 18:06:20 +0000 (14:06 -0400)
committerSeth Michael Larson <sethmichaellarson@gmail.com>
Fri, 6 Sep 2019 18:06:20 +0000 (13:06 -0500)
httpx/concurrency/base.py
httpx/dispatch/http2.py
tests/dispatch/test_http2.py
tests/dispatch/utils.py

index bf2aed4f1f9e2577baf3daa505c8f41eb94cb30b..e62253ddc4dc6193ddcaa98769f9a93120e79080 100644 (file)
@@ -88,6 +88,9 @@ class BaseEvent:
     def is_set(self) -> bool:
         raise NotImplementedError()  # pragma: no cover
 
+    def clear(self) -> None:
+        raise NotImplementedError()  # pragma: no cover
+
     async def wait(self) -> None:
         raise NotImplementedError()  # pragma: no cover
 
index bb83d489b65117ef2ebe23cce0fba79a8c143da1..786110c92044a631d057921930790d59c1c93e0b 100644 (file)
@@ -4,7 +4,7 @@ import typing
 import h2.connection
 import h2.events
 
-from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag, BaseEvent
 from ..config import TimeoutConfig, TimeoutTypes
 from ..models import AsyncRequest, AsyncResponse
 from ..utils import get_logger
@@ -28,6 +28,7 @@ class HTTP2Connection:
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
         self.timeout_flags = {}  # type: typing.Dict[int, TimeoutFlag]
         self.initialized = False
+        self.window_update_received = {}  # type: typing.Dict[int, BaseEvent]
 
     async def send(
         self, request: AsyncRequest, timeout: TimeoutTypes = None
@@ -42,6 +43,7 @@ class HTTP2Connection:
 
         self.events[stream_id] = []
         self.timeout_flags[stream_id] = TimeoutFlag()
+        self.window_update_received[stream_id] = self.backend.create_event()
 
         task, args = self.send_request_data, [stream_id, request.stream(), timeout]
         async with self.backend.background_manager(task, *args):
@@ -108,18 +110,27 @@ class HTTP2Connection:
     async def send_data(
         self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
     ) -> None:
-        flow_control = self.h2_state.local_flow_control_window(stream_id)
-        chunk_size = min(len(data), flow_control)
-        for idx in range(0, len(data), chunk_size):
-            chunk = data[idx : idx + chunk_size]
-
-            logger.debug(
-                f"send_data stream_id={stream_id} data=Data(<{len(chunk)} bytes>)"
+        while data:
+            # The data will be divided into frames to send based on the flow control
+            # window and the maximum frame size. Because the flow control window
+            # can decrease in size, even possibly to zero, this will loop until all the
+            # data is sent. In http2 specification:
+            # https://tools.ietf.org/html/rfc7540#section-6.9
+            flow_control = self.h2_state.local_flow_control_window(stream_id)
+            chunk_size = min(
+                len(data), flow_control, self.h2_state.max_outbound_frame_size
             )
-
-            self.h2_state.send_data(stream_id, chunk)
-            data_to_send = self.h2_state.data_to_send()
-            await self.stream.write(data_to_send, timeout)
+            if chunk_size == 0:
+                # this means that the flow control window is 0 (either for the stream
+                # or the connection one), and no data can be sent until the flow control
+                # window is updated.
+                await self.window_update_received[stream_id].wait()
+                self.window_update_received[stream_id].clear()
+            else:
+                chunk, data = data[:chunk_size], data[chunk_size:]
+                self.h2_state.send_data(stream_id, chunk)
+                data_to_send = self.h2_state.data_to_send()
+                await self.stream.write(data_to_send, timeout)
 
     async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
         logger.debug(f"end_stream stream_id={stream_id}")
@@ -148,7 +159,8 @@ class HTTP2Connection:
                 status_code = int(v.decode("ascii", errors="ignore"))
             elif not k.startswith(b":"):
                 headers.append((k, v))
-        return status_code, headers
+
+        return (status_code, headers)
 
     async def body_iter(
         self, stream_id: int, timeout: TimeoutConfig = None
@@ -156,7 +168,9 @@ class HTTP2Connection:
         while True:
             event = await self.receive_event(stream_id, timeout)
             if isinstance(event, h2.events.DataReceived):
-                self.h2_state.acknowledge_received_data(len(event.data), stream_id)
+                self.h2_state.acknowledge_received_data(
+                    event.flow_controlled_length, stream_id
+                )
                 yield event.data
             elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
                 break
@@ -173,6 +187,19 @@ class HTTP2Connection:
                 logger.debug(
                     f"receive_event stream_id={event_stream_id} event={event!r}"
                 )
+                if isinstance(event, h2.events.WindowUpdated):
+                    if event_stream_id == 0:
+                        for window_update_event in self.window_update_received.values():
+                            window_update_event.set()
+                    else:
+                        try:
+                            self.window_update_received[event_stream_id].set()
+                        except KeyError:
+                            # the window_update_received dictionary is only relevant
+                            # when sending data, which should never raise a KeyError
+                            # here.
+                            pass
+
                 if event_stream_id:
                     self.events[event.stream_id].append(event)
 
@@ -184,6 +211,7 @@ class HTTP2Connection:
     async def response_closed(self, stream_id: int) -> None:
         del self.events[stream_id]
         del self.timeout_flags[stream_id]
+        del self.window_update_received[stream_id]
 
         if not self.events and self.on_release is not None:
             await self.on_release()
index 64bbb6a0aae95e12ba01a60afa58ec726de692fb..c286592e043cc6a66f4922c95bf70b223a9f30ee 100644 (file)
@@ -65,6 +65,34 @@ async def test_async_http2_post_request(backend):
     }
 
 
+def test_http2_large_post_request():
+    backend = MockHTTP2Backend(app=app)
+
+    data = b"a" * 100000
+    with Client(backend=backend) as client:
+        response = client.post("http://example.org", data=data)
+    assert response.status_code == 200
+    assert json.loads(response.content) == {
+        "method": "POST",
+        "path": "/",
+        "body": data.decode(),
+    }
+
+
+async def test_async_http2_large_post_request(backend):
+    backend = MockHTTP2Backend(app=app, backend=backend)
+
+    data = b"a" * 100000
+    async with AsyncClient(backend=backend) as client:
+        response = await client.post("http://example.org", data=data)
+    assert response.status_code == 200
+    assert json.loads(response.content) == {
+        "method": "POST",
+        "path": "/",
+        "body": data.decode(),
+    }
+
+
 def test_http2_multiple_requests():
     backend = MockHTTP2Backend(app=app)
 
index b86552dc6b4da37c577c0558fc858a5ffde10e94..a9ab231712699f2dc87d38dc2672c5013fb72d8e 100644 (file)
@@ -39,6 +39,8 @@ class MockHTTP2Server(BaseStream):
         self.buffer = b""
         self.requests = {}
         self.close_connection = False
+        self.return_data = {}
+        self.returning = {}
 
     # Stream interface
 
@@ -58,8 +60,27 @@ class MockHTTP2Server(BaseStream):
                 self.request_received(event.headers, event.stream_id)
             elif isinstance(event, h2.events.DataReceived):
                 self.receive_data(event.data, event.stream_id)
+                # This should send an UPDATE_WINDOW for both the stream and the
+                # connection increasing it by the amount
+                # consumed keeping the flow control window constant
+                flow_control_consumed = event.flow_controlled_length
+                if flow_control_consumed > 0:
+                    self.conn.increment_flow_control_window(flow_control_consumed)
+                    self.buffer += self.conn.data_to_send()
+                    self.conn.increment_flow_control_window(
+                        flow_control_consumed, event.stream_id
+                    )
+                    self.buffer += self.conn.data_to_send()
             elif isinstance(event, h2.events.StreamEnded):
                 self.stream_complete(event.stream_id)
+            elif isinstance(event, h2.events.WindowUpdated):
+                if event.stream_id == 0:
+                    for key, value in self.returning.items():
+                        if value:
+                            self.send_return_data(key)
+                # This will throw an error if the event is for a not-yet created stream
+                elif self.returning[event.stream_id]:
+                    self.send_return_data(event.stream_id)
 
     async def write(self, data: bytes, timeout) -> None:
         self.write_no_block(data)
@@ -114,5 +135,28 @@ class MockHTTP2Server(BaseStream):
         response_headers = [(b":status", status_code_bytes)] + response.headers.raw
 
         self.conn.send_headers(stream_id, response_headers)
-        self.conn.send_data(stream_id, response.content, end_stream=True)
+        self.buffer += self.conn.data_to_send()
+        self.return_data[stream_id] = response.content
+        self.returning[stream_id] = True
+        self.send_return_data(stream_id)
+
+    def send_return_data(self, stream_id):
+        while self.return_data[stream_id]:
+            flow_control = self.conn.local_flow_control_window(stream_id)
+            chunk_size = min(
+                len(self.return_data[stream_id]),
+                flow_control,
+                self.conn.max_outbound_frame_size,
+            )
+            if chunk_size == 0:
+                return
+            else:
+                chunk, self.return_data[stream_id] = (
+                    self.return_data[stream_id][:chunk_size],
+                    self.return_data[stream_id][chunk_size:],
+                )
+                self.conn.send_data(stream_id, chunk)
+                self.buffer += self.conn.data_to_send()
+        self.returning[stream_id] = False
+        self.conn.end_stream(stream_id)
         self.buffer += self.conn.data_to_send()