]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
asyncio: Wait for the stream to close when closing (#494)
authorJamie Hewland <jhewland@gmail.com>
Tue, 22 Oct 2019 20:01:40 +0000 (22:01 +0200)
committerFlorimond Manca <florimond.manca@gmail.com>
Tue, 22 Oct 2019 20:01:40 +0000 (22:01 +0200)
httpx/concurrency/asyncio.py
tests/test_concurrency.py

index 7083345426c1805406fe44e6dbeb54ea26c74d81..4aeb7ca53dc3f9864381a9ffde1830d3c8ad3a3d 100644 (file)
@@ -1,6 +1,7 @@
 import asyncio
 import functools
 import ssl
+import sys
 import typing
 from types import TracebackType
 
@@ -170,6 +171,8 @@ class TCPStream(BaseTCPStream):
 
     async def close(self) -> None:
         self.stream_writer.close()
+        if sys.version_info >= (3, 7):
+            await self.stream_writer.wait_closed()
 
 
 class PoolSemaphore(BasePoolSemaphore):
index 27bbeaf28048286f9466f1498f56402807ec495f..8bb933b6975b946849502991ff0fb5c97acf5bef 100644 (file)
@@ -50,7 +50,21 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
         assert get_cipher(stream) is not None
 
         await stream.write(b"GET / HTTP/1.1\r\n\r\n")
-        assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
+
+        # stream.read() only gives us *up to* as much data as we ask for. In order to
+        # cleanly close the stream, we must read until the end of the HTTP response.
+        read = b""
+        ended = False
+        for _ in range(5):  # Try read some (not too large) number of times...
+            read += await stream.read(8192, timeout)
+            # We know we're at the end of the response when we've received the body plus
+            # the terminating CRLFs.
+            if b"Hello, world!" in read and read.endswith(b"\r\n\r\n"):
+                ended = True
+                break
+
+        assert ended
+        assert read.startswith(b"HTTP/1.1 200 OK\r\n")
 
     finally:
         await stream.close()