]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Backend operations like .read(), .write() now have a manadatory timeout argument...
authorTom Christie <tom@tomchristie.com>
Sat, 7 Dec 2019 11:09:58 +0000 (11:09 +0000)
committerGitHub <noreply@github.com>
Sat, 7 Dec 2019 11:09:58 +0000 (11:09 +0000)
httpx/concurrency/asyncio.py
httpx/concurrency/trio.py
tests/test_concurrency.py

index eee4ec32d5e8b68a4d2a2de46db3b5b2873e29f9..89384bbb95aaeb0bc5c94ac54b3eb0b4e5dabcff 100644 (file)
@@ -40,14 +40,10 @@ def ssl_monkey_patch() -> None:
 
 class SocketStream(BaseSocketStream):
     def __init__(
-        self,
-        stream_reader: asyncio.StreamReader,
-        stream_writer: asyncio.StreamWriter,
-        timeout: Timeout,
+        self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter,
     ):
         self.stream_reader = stream_reader
         self.stream_writer = stream_writer
-        self.timeout = timeout
         self.read_lock = asyncio.Lock()
 
         self._inner: typing.Optional[SocketStream] = None
@@ -114,7 +110,7 @@ class SocketStream(BaseSocketStream):
             transport=transport, protocol=protocol, reader=stream_reader, loop=loop
         )
 
-        ssl_stream = SocketStream(stream_reader, stream_writer, self.timeout)
+        ssl_stream = SocketStream(stream_reader, stream_writer)
         # When we return a new SocketStream with new StreamReader/StreamWriter instances
         # we need to keep references to the old StreamReader/StreamWriter so that they
         # are not garbage collected and closed while we're still using them.
@@ -130,12 +126,7 @@ class SocketStream(BaseSocketStream):
         ident = ssl_object.selected_alpn_protocol()
         return "HTTP/2" if ident == "h2" else "HTTP/1.1"
 
-    async def read(
-        self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None
-    ) -> bytes:
-        if timeout is None:
-            timeout = self.timeout
-
+    async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
         while True:
             # Check our flag at the first possible moment, and use a fine
             # grained retry loop if we're not yet in read-timeout mode.
@@ -161,14 +152,11 @@ class SocketStream(BaseSocketStream):
         return data
 
     async def write(
-        self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
+        self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
     ) -> None:
         if not data:
             return
 
-        if timeout is None:
-            timeout = self.timeout
-
         self.stream_writer.write(data)
         while True:
             try:
@@ -269,9 +257,7 @@ class AsyncioBackend(ConcurrencyBackend):
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-        return SocketStream(
-            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
-        )
+        return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
 
     async def open_uds_stream(
         self,
@@ -292,9 +278,7 @@ class AsyncioBackend(ConcurrencyBackend):
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-        return SocketStream(
-            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
-        )
+        return SocketStream(stream_reader=stream_reader, stream_writer=stream_writer)
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index 39ac51c5c5bafe5ba12aaeddf4d44b48441687da..11737f723c03a3b7c0e45684fee2c815545014cb 100644 (file)
@@ -21,10 +21,9 @@ def _or_inf(value: typing.Optional[float]) -> float:
 
 class SocketStream(BaseSocketStream):
     def __init__(
-        self, stream: typing.Union[trio.SocketStream, trio.SSLStream], timeout: Timeout,
+        self, stream: typing.Union[trio.SocketStream, trio.SSLStream],
     ) -> None:
         self.stream = stream
-        self.timeout = timeout
         self.read_lock = trio.Lock()
         self.write_lock = trio.Lock()
 
@@ -42,7 +41,7 @@ class SocketStream(BaseSocketStream):
         if cancel_scope.cancelled_caught:
             raise ConnectTimeout()
 
-        return SocketStream(ssl_stream, self.timeout)
+        return SocketStream(ssl_stream)
 
     def get_http_version(self) -> str:
         if not isinstance(self.stream, trio.SSLStream):
@@ -51,12 +50,7 @@ class SocketStream(BaseSocketStream):
         ident = self.stream.selected_alpn_protocol()
         return "HTTP/2" if ident == "h2" else "HTTP/1.1"
 
-    async def read(
-        self, n: int, timeout: Timeout = None, flag: TimeoutFlag = None
-    ) -> bytes:
-        if timeout is None:
-            timeout = self.timeout
-
+    async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
         while True:
             # Check our flag at the first possible moment, and use a fine
             # grained retry loop if we're not yet in read-timeout mode.
@@ -86,14 +80,11 @@ class SocketStream(BaseSocketStream):
         return stream.socket.is_readable()
 
     async def write(
-        self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
+        self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
     ) -> None:
         if not data:
             return
 
-        if timeout is None:
-            timeout = self.timeout
-
         write_timeout = _or_inf(timeout.write_timeout)
 
         while True:
@@ -166,7 +157,7 @@ class TrioBackend(ConcurrencyBackend):
         if cancel_scope.cancelled_caught:
             raise ConnectTimeout()
 
-        return SocketStream(stream=stream, timeout=timeout)
+        return SocketStream(stream=stream)
 
     async def open_uds_stream(
         self,
@@ -186,7 +177,7 @@ class TrioBackend(ConcurrencyBackend):
         if cancel_scope.cancelled_caught:
             raise ConnectTimeout()
 
-        return SocketStream(stream=stream, timeout=timeout)
+        return SocketStream(stream=stream)
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index a3b1df7106c2abd1d88119a39674bef76209bb4b..a3286c5569551d34537b38d6d5ec54c21624a55a 100644 (file)
@@ -16,7 +16,7 @@ def get_trio_cipher(stream):
     return stream.stream.cipher() if isinstance(stream.stream, trio.SSLStream) else None
 
 
-async def read_response(stream, timeout: float, should_contain: bytes) -> bytes:
+async def read_response(stream, timeout: Timeout, should_contain: bytes) -> bytes:
     # stream.read() only gives us *up to* as much data as we ask for. In order to
     # cleanly close the stream, we must read until the end of the HTTP response.
     response = b""
@@ -57,7 +57,7 @@ async def test_start_tls_on_tcp_socket_stream(https_server, backend, get_cipher)
         assert stream.is_connection_dropped() is False
         assert get_cipher(stream) is not None
 
-        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
         response = await read_response(stream, timeout, should_contain=b"Hello, world")
         assert response.startswith(b"HTTP/1.1 200 OK\r\n")
@@ -89,7 +89,7 @@ async def test_start_tls_on_uds_socket_stream(https_uds_server, backend, get_cip
         assert stream.is_connection_dropped() is False
         assert get_cipher(stream) is not None
 
-        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
 
         response = await read_response(stream, timeout, should_contain=b"Hello, world")
         assert response.startswith(b"HTTP/1.1 200 OK\r\n")
@@ -105,10 +105,11 @@ async def test_concurrent_read(server, backend):
     stream = await backend.open_tcp_stream(
         server.url.host, server.url.port, ssl_context=None, timeout=Timeout(5)
     )
+    timeout = Timeout(5)
     try:
-        await stream.write(b"GET / HTTP/1.1\r\n\r\n")
+        await stream.write(b"GET / HTTP/1.1\r\n\r\n", timeout)
         await run_concurrently(
-            backend, lambda: stream.read(10), lambda: stream.read(10)
+            backend, lambda: stream.read(10, timeout), lambda: stream.read(10, timeout)
         )
     finally:
         await stream.close()