]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Gracefully end_stream early on no-body requests. (#682)
authorTom Christie <tom@tomchristie.com>
Mon, 23 Dec 2019 10:49:36 +0000 (10:49 +0000)
committerGitHub <noreply@github.com>
Mon, 23 Dec 2019 10:49:36 +0000 (10:49 +0000)
httpx/dispatch/http2.py

index adf74f5afbf901234cd46866be806d9a3ad3622a..285625dad6b3679316536736a10c3832f40f5ebe 100644 (file)
@@ -166,9 +166,10 @@ class HTTP2Connection(OpenConnection):
         self,
         stream_id: int,
         headers: typing.List[typing.Tuple[bytes, bytes]],
+        end_stream: bool,
         timeout: Timeout,
     ) -> None:
-        self.state.send_headers(stream_id, headers)
+        self.state.send_headers(stream_id, headers, end_stream=end_stream)
         self.state.increment_flow_control_window(2 ** 24, stream_id=stream_id)
         data_to_send = self.state.data_to_send()
         await self.socket.write(data_to_send, timeout)
@@ -205,8 +206,14 @@ class HTTP2Stream:
 
     async def send(self, request: Request, timeout: Timeout) -> Response:
         # Send the request.
-        await self.send_headers(request, timeout)
-        await self.send_body(request, timeout)
+        has_body = (
+            "Content-Length" in request.headers
+            or "Transfer-Encoding" in request.headers
+        )
+
+        await self.send_headers(request, has_body, timeout)
+        if has_body:
+            await self.send_body(request, timeout)
 
         # Receive the response.
         status_code, headers = await self.receive_response(timeout)
@@ -222,13 +229,20 @@ class HTTP2Stream:
             request=request,
         )
 
-    async def send_headers(self, request: Request, timeout: Timeout) -> None:
+    async def send_headers(
+        self, request: Request, has_body: bool, timeout: Timeout
+    ) -> None:
         headers = [
             (b":method", request.method.encode("ascii")),
             (b":authority", request.url.authority.encode("ascii")),
             (b":scheme", request.url.scheme.encode("ascii")),
             (b":path", request.url.full_path.encode("ascii")),
-        ] + [(k, v) for k, v in request.headers.raw if k != b"host"]
+        ] + [
+            (k, v)
+            for k, v in request.headers.raw
+            if k not in (b"host", b"transfer-encoding")
+        ]
+        end_stream = not has_body
 
         logger.trace(
             f"send_headers "
@@ -237,7 +251,7 @@ class HTTP2Stream:
             f"target={request.url.full_path!r} "
             f"headers={headers!r}"
         )
-        await self.connection.send_headers(self.stream_id, headers, timeout)
+        await self.connection.send_headers(self.stream_id, headers, end_stream, timeout)
 
     async def send_body(self, request: Request, timeout: Timeout) -> None:
         logger.trace(f"send_body stream_id={self.stream_id}")