]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
HTTP/2 refactoring (#612)
authorTom Christie <tom@tomchristie.com>
Sat, 7 Dec 2019 14:14:09 +0000 (14:14 +0000)
committerGitHub <noreply@github.com>
Sat, 7 Dec 2019 14:14:09 +0000 (14:14 +0000)
* HTTP/2 refactoring

* Clean up flow control

* Remove extra blank line

httpx/dispatch/http2.py

index 471ba9b7e012bb03f122da3fdcd51d3d71cf7bec..6021483feb85a2ea3d2d128d8adbbaf53b070b5f 100644 (file)
@@ -1,4 +1,3 @@
-import functools
 import typing
 
 import h2.connection
@@ -9,7 +8,6 @@ from ..concurrency.base import (
     BaseEvent,
     BaseSocketStream,
     ConcurrencyBackend,
-    TimeoutFlag,
     lookup_backend,
 )
 from ..config import Timeout
@@ -32,10 +30,10 @@ class HTTP2Connection:
         self.socket = socket
         self.backend = lookup_backend(backend)
         self.on_release = on_release
-        self.h2_state = h2.connection.H2Connection()
+        self.state = h2.connection.H2Connection()
+
+        self.streams = {}  # type: typing.Dict[int, HTTP2Stream]
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
-        self.timeout_flags = {}  # type: typing.Dict[int, TimeoutFlag]
-        self.window_update_received = {}  # type: typing.Dict[int, BaseEvent]
 
         self.init_started = False
 
@@ -54,54 +52,28 @@ class HTTP2Connection:
             # The very first stream is responsible for initiating the connection.
             self.init_started = True
             await self.send_connection_init(timeout)
+            stream_id = self.state.get_next_available_stream_id()
             self.init_complete.set()
         else:
             # All other streams need to wait until the connection is established.
             await self.init_complete.wait()
+            stream_id = self.state.get_next_available_stream_id()
 
-        stream_id = await self.send_headers(request, timeout)
-
+        stream = HTTP2Stream(stream_id=stream_id, connection=self, state=self.state)
+        self.streams[stream_id] = stream
         self.events[stream_id] = []
-        self.timeout_flags[stream_id] = TimeoutFlag()
-        self.window_update_received[stream_id] = self.backend.create_event()
-
-        status_code: typing.Optional[int] = None
-        headers: typing.Optional[list] = None
-
-        async def receive_response(stream_id: int, timeout: Timeout) -> None:
-            nonlocal status_code, headers
-            status_code, headers = await self.receive_response(stream_id, timeout)
-
-        await self.backend.fork(
-            self.send_request_data,
-            [stream_id, request.stream(), timeout],
-            receive_response,
-            [stream_id, timeout],
-        )
-
-        assert status_code is not None
-        assert headers is not None
-
-        content = self.body_iter(stream_id, timeout)
-        on_close = functools.partial(self.response_closed, stream_id=stream_id)
-
-        return Response(
-            status_code=status_code,
-            http_version="HTTP/2",
-            headers=headers,
-            content=content,
-            on_close=on_close,
-            request=request,
-        )
-
-    async def close(self) -> None:
-        await self.socket.close()
+        return await stream.send(request, timeout)
 
     async def send_connection_init(self, timeout: Timeout) -> None:
+        """
+        The HTTP/2 connection requires some initial setup before we can start
+        using individual request/response streams on it.
+        """
+
         # Need to set these manually here instead of manipulating via
         # __setitem__() otherwise the H2Connection will emit SettingsUpdate
         # frames in addition to sending the undesired defaults.
-        self.h2_state.local_settings = Settings(
+        self.state.local_settings = Settings(
             client=True,
             initial_values={
                 # Disable PUSH_PROMISE frames from the server since we don't do anything
@@ -116,16 +88,113 @@ class HTTP2Connection:
         # Some websites (*cough* Yahoo *cough*) balk at this setting being
         # present in the initial handshake since it's not defined in the original
         # RFC despite the RFC mandating ignoring settings you don't know about.
-        del self.h2_state.local_settings[
-            h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL
-        ]
+        del self.state.local_settings[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL]
 
-        self.h2_state.initiate_connection()
-        data_to_send = self.h2_state.data_to_send()
+        self.state.initiate_connection()
+        data_to_send = self.state.data_to_send()
         await self.socket.write(data_to_send, timeout)
 
-    async def send_headers(self, request: Request, timeout: Timeout) -> int:
-        stream_id = self.h2_state.get_next_available_stream_id()
+    @property
+    def is_closed(self) -> bool:
+        return False
+
+    def is_connection_dropped(self) -> bool:
+        return self.socket.is_connection_dropped()
+
+    async def close(self) -> None:
+        await self.socket.close()
+
+    async def wait_for_outgoing_flow(self, stream_id: int, timeout: Timeout) -> int:
+        """
+        Returns the maximum allowable outgoing flow for a given stream.
+
+        If the allowable flow is zero, then waits on the network until
+        WindowUpdated frames have increased the flow rate.
+
+        https://tools.ietf.org/html/rfc7540#section-6.9
+        """
+        local_flow = self.state.local_flow_control_window(stream_id)
+        connection_flow = self.state.max_outbound_frame_size
+        flow = min(local_flow, connection_flow)
+        while flow == 0:
+            await self.receive_events(timeout)
+            local_flow = self.state.local_flow_control_window(stream_id)
+            connection_flow = self.state.max_outbound_frame_size
+            flow = min(local_flow, connection_flow)
+        return flow
+
+    async def wait_for_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
+        """
+        Returns the next event for a given stream.
+
+        If no events are available yet, then waits on the network until
+        an event is available.
+        """
+        while not self.events[stream_id]:
+            await self.receive_events(timeout)
+        return self.events[stream_id].pop(0)
+
+    async def receive_events(self, timeout: Timeout) -> None:
+        """
+        Read some data from the network, and update the H2 state.
+        """
+        data = await self.socket.read(self.READ_NUM_BYTES, timeout)
+        events = self.state.receive_data(data)
+        for event in events:
+            event_stream_id = getattr(event, "stream_id", 0)
+            logger.trace(f"receive_event stream_id={event_stream_id} event={event!r}")
+
+            if hasattr(event, "error_code"):
+                raise ProtocolError(event)
+
+            if event_stream_id in self.events:
+                self.events[event_stream_id].append(event)
+
+        data_to_send = self.state.data_to_send()
+        await self.socket.write(data_to_send, timeout)
+
+    async def send_outgoing_data(self, timeout: Timeout) -> None:
+        data_to_send = self.state.data_to_send()
+        if data_to_send:
+            await self.socket.write(data_to_send, timeout)
+
+    async def close_stream(self, stream_id: int) -> None:
+        del self.streams[stream_id]
+        del self.events[stream_id]
+
+        if not self.streams and self.on_release is not None:
+            await self.on_release()
+
+
+class HTTP2Stream:
+    def __init__(
+        self,
+        stream_id: int,
+        connection: HTTP2Connection,
+        state: h2.connection.H2Connection,
+    ) -> None:
+        self.stream_id = stream_id
+        self.connection = connection
+        self.state = state
+
+    async def send(self, request: Request, timeout: Timeout) -> Response:
+        # Send the request.
+        await self.send_headers(request, timeout)
+        await self.send_body(request, timeout)
+
+        # Receive the response.
+        status_code, headers = await self.receive_response(timeout)
+        content = self.body_iter(timeout)
+        return Response(
+            status_code=status_code,
+            http_version="HTTP/2",
+            headers=headers,
+            content=content,
+            on_close=self.close,
+            request=request,
+        )
+
+    async def send_headers(self, request: Request, timeout: Timeout) -> None:
         headers = [
             (b":method", request.method.encode("ascii")),
             (b":authority", request.url.authority.encode("ascii")),
@@ -135,67 +204,37 @@ class HTTP2Connection:
 
         logger.trace(
             f"send_headers "
-            f"stream_id={stream_id} "
+            f"stream_id={self.stream_id} "
             f"method={request.method!r} "
             f"target={request.url.full_path!r} "
             f"headers={headers!r}"
         )
-        self.h2_state.send_headers(stream_id, headers)
-        data_to_send = self.h2_state.data_to_send()
-        await self.socket.write(data_to_send, timeout)
-        return stream_id
-
-    async def send_request_data(
-        self, stream_id: int, stream: typing.AsyncIterator[bytes], timeout: Timeout,
-    ) -> None:
-        try:
-            async for data in stream:
-                await self.send_data(stream_id, data, timeout)
-            await self.end_stream(stream_id, timeout)
-        finally:
-            # Once we've sent the request we should enable read timeouts.
-            self.timeout_flags[stream_id].set_read_timeouts()
-
-    async def send_data(self, stream_id: int, data: bytes, timeout: Timeout) -> None:
-        while data:
-            # The data will be divided into frames to send based on the flow control
-            # window and the maximum frame size. Because the flow control window
-            # can decrease in size, even possibly to zero, this will loop until all the
-            # data is sent. In http2 specification:
-            # https://tools.ietf.org/html/rfc7540#section-6.9
-            flow_control = self.h2_state.local_flow_control_window(stream_id)
-            chunk_size = min(
-                len(data), flow_control, self.h2_state.max_outbound_frame_size
-            )
-            if chunk_size == 0:
-                # this means that the flow control window is 0 (either for the stream
-                # or the connection one), and no data can be sent until the flow control
-                # window is updated.
-                await self.window_update_received[stream_id].wait()
-                self.window_update_received[stream_id].clear()
-            else:
+        self.state.send_headers(self.stream_id, headers)
+        await self.connection.send_outgoing_data(timeout)
+
+    async def send_body(self, request: Request, timeout: Timeout) -> None:
+        logger.trace(f"send_body stream_id={self.stream_id}")
+        async for data in request.stream():
+            while data:
+                max_flow = await self.connection.wait_for_outgoing_flow(
+                    self.stream_id, timeout
+                )
+                chunk_size = min(len(data), max_flow)
                 chunk, data = data[:chunk_size], data[chunk_size:]
-                self.h2_state.send_data(stream_id, chunk)
-                data_to_send = self.h2_state.data_to_send()
-                await self.socket.write(data_to_send, timeout)
-
-    async def end_stream(self, stream_id: int, timeout: Timeout) -> None:
-        logger.trace(f"end_stream stream_id={stream_id}")
-        self.h2_state.end_stream(stream_id)
-        data_to_send = self.h2_state.data_to_send()
-        await self.socket.write(data_to_send, timeout)
+                self.state.send_data(self.stream_id, chunk)
+                await self.connection.send_outgoing_data(timeout)
+
+        self.state.end_stream(self.stream_id)
+        await self.connection.send_outgoing_data(timeout)
 
     async def receive_response(
-        self, stream_id: int, timeout: Timeout
+        self, timeout: Timeout
     ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
         """
         Read the response status and headers from the network.
         """
         while True:
-            event = await self.receive_event(stream_id, timeout)
-            # As soon as we start seeing response events, we should enable
-            # read timeouts, if we haven't already.
-            self.timeout_flags[stream_id].set_read_timeouts()
+            event = await self.connection.wait_for_event(self.stream_id, timeout)
             if isinstance(event, h2.events.ResponseReceived):
                 break
 
@@ -209,65 +248,17 @@ class HTTP2Connection:
 
         return (status_code, headers)
 
-    async def body_iter(
-        self, stream_id: int, timeout: Timeout
-    ) -> typing.AsyncIterator[bytes]:
+    async def body_iter(self, timeout: Timeout) -> typing.AsyncIterator[bytes]:
         while True:
-            event = await self.receive_event(stream_id, timeout)
+            event = await self.connection.wait_for_event(self.stream_id, timeout)
             if isinstance(event, h2.events.DataReceived):
-                self.h2_state.acknowledge_received_data(
-                    event.flow_controlled_length, stream_id
+                self.state.acknowledge_received_data(
+                    event.flow_controlled_length, self.stream_id
                 )
+                await self.connection.send_outgoing_data(timeout)
                 yield event.data
             elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)):
                 break
 
-    async def receive_event(self, stream_id: int, timeout: Timeout) -> h2.events.Event:
-        while not self.events[stream_id]:
-            flag = self.timeout_flags[stream_id]
-            data = await self.socket.read(self.READ_NUM_BYTES, timeout, flag=flag)
-            events = self.h2_state.receive_data(data)
-            for event in events:
-                event_stream_id = getattr(event, "stream_id", 0)
-                logger.trace(
-                    f"receive_event stream_id={event_stream_id} event={event!r}"
-                )
-
-                if hasattr(event, "error_code"):
-                    raise ProtocolError(event)
-
-                if isinstance(event, h2.events.WindowUpdated):
-                    if event_stream_id == 0:
-                        for window_update_event in self.window_update_received.values():
-                            window_update_event.set()
-                    else:
-                        try:
-                            self.window_update_received[event_stream_id].set()
-                        except KeyError:  # pragma: no cover
-                            # the window_update_received dictionary is only relevant
-                            # when sending data, which should never raise a KeyError
-                            # here.
-                            pass
-
-                if event_stream_id:
-                    self.events[event.stream_id].append(event)
-
-            data_to_send = self.h2_state.data_to_send()
-            await self.socket.write(data_to_send, timeout)
-
-        return self.events[stream_id].pop(0)
-
-    async def response_closed(self, stream_id: int) -> None:
-        del self.events[stream_id]
-        del self.timeout_flags[stream_id]
-        del self.window_update_received[stream_id]
-
-        if not self.events and self.on_release is not None:
-            await self.on_release()
-
-    @property
-    def is_closed(self) -> bool:
-        return False
-
-    def is_connection_dropped(self) -> bool:
-        return self.socket.is_connection_dropped()
+    async def close(self) -> None:
+        await self.connection.close_stream(self.stream_id)