]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Read/Write timeout modes (#104)
authorTom Christie <tom@tomchristie.com>
Tue, 25 Jun 2019 11:54:14 +0000 (12:54 +0100)
committerGitHub <noreply@github.com>
Tue, 25 Jun 2019 11:54:14 +0000 (12:54 +0100)
* Read/Write timeout modes

* Read/Write timeout modes

http3/concurrency.py
http3/dispatch/http11.py
http3/dispatch/http2.py
http3/interfaces.py
tests/dispatch/utils.py

index fd6af36833390624af477e5f4b054bbb6e103053..bd04c2da1cdca3f4121de1c6ba5314aabdbcbd3b 100644 (file)
@@ -49,6 +49,38 @@ def ssl_monkey_patch() -> None:
     MonkeyPatch.write = _fixed_write
 
 
+class TimeoutFlag:
+    """
+    A timeout flag holds a state of either read-timeout or write-timeout mode.
+
+    We use this so that we can attempt both reads and writes concurrently, while
+    only enforcing timeouts in one direction.
+
+    During a request/response cycle we start in write-timeout mode.
+
+    Once we've sent a request fully, or once we start seeing a response,
+    then we switch to read-timeout mode instead.
+    """
+
+    def __init__(self) -> None:
+        self.raise_on_read_timeout = False
+        self.raise_on_write_timeout = True
+
+    def set_read_timeouts(self) -> None:
+        """
+        Set the flag to read-timeout mode.
+        """
+        self.raise_on_read_timeout = True
+        self.raise_on_write_timeout = False
+
+    def set_write_timeouts(self) -> None:
+        """
+        Set the flag to write-timeout mode.
+        """
+        self.raise_on_read_timeout = False
+        self.raise_on_write_timeout = True
+
+
 class Reader(BaseReader):
     def __init__(
         self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
@@ -56,16 +88,22 @@ class Reader(BaseReader):
         self.stream_reader = stream_reader
         self.timeout = timeout
 
-    async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
+    async def read(
+        self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+    ) -> bytes:
         if timeout is None:
             timeout = self.timeout
 
-        try:
-            data = await asyncio.wait_for(
-                self.stream_reader.read(n), timeout.read_timeout
-            )
-        except asyncio.TimeoutError:
-            raise ReadTimeout()
+        while True:
+            should_raise = flag is None or flag.raise_on_read_timeout
+            try:
+                data = await asyncio.wait_for(
+                    self.stream_reader.read(n), timeout.read_timeout
+                )
+                break
+            except asyncio.TimeoutError:
+                if should_raise:
+                    raise ReadTimeout()
 
         return data
 
@@ -78,7 +116,9 @@ class Writer(BaseWriter):
     def write_no_block(self, data: bytes) -> None:
         self.stream_writer.write(data)  # pragma: nocover
 
-    async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
+    async def write(
+        self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
+    ) -> None:
         if not data:
             return
 
@@ -86,12 +126,16 @@ class Writer(BaseWriter):
             timeout = self.timeout
 
         self.stream_writer.write(data)
-        try:
-            await asyncio.wait_for(  # type: ignore
-                self.stream_writer.drain(), timeout.write_timeout
-            )
-        except asyncio.TimeoutError:
-            raise WriteTimeout()
+        while True:
+            try:
+                await asyncio.wait_for(  # type: ignore
+                    self.stream_writer.drain(), timeout.write_timeout
+                )
+                break
+            except asyncio.TimeoutError:
+                should_raise = flag is None or flag.raise_on_write_timeout
+                if should_raise:
+                    raise WriteTimeout()
 
     async def close(self) -> None:
         self.stream_writer.close()
index 1f632d8eb29e1bfb5a00df98bebe409c2976c5b3..ae9e74603b874a16bd3400d66a6d3274a73daa83 100644 (file)
@@ -2,6 +2,7 @@ import typing
 
 import h11
 
+from ..concurrency import TimeoutFlag
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
@@ -38,6 +39,7 @@ class HTTP11Connection:
         self.backend = backend
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
+        self.timeout_flag = TimeoutFlag()
 
     async def send(
         self, request: AsyncRequest, timeout: TimeoutTypes = None
@@ -103,6 +105,9 @@ class HTTP11Connection:
             # care about connection errors that occur when sending the body.
             # Ignore these, and defer to any exceptions on reading the response.
             self.h11_state.send_failed()
+        finally:
+            # Once we've sent the request, we enable read timeouts.
+            self.timeout_flag.set_read_timeouts()
 
     async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
         """
@@ -120,6 +125,9 @@ class HTTP11Connection:
         """
         while True:
             event = await self._receive_event(timeout)
+            # As soon as we start seeing response events, we should enable
+            # read timeouts, if we haven't already.
+            self.timeout_flag.set_read_timeouts()
             if isinstance(event, h11.InformationalResponse):
                 continue
             else:
@@ -149,7 +157,9 @@ class HTTP11Connection:
             event = self.h11_state.next_event()
             if event is h11.NEED_DATA:
                 try:
-                    data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+                    data = await self.reader.read(
+                        self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
+                    )
                 except OSError:  # pragma: nocover
                     data = b""
                 self.h11_state.receive_data(data)
@@ -162,7 +172,9 @@ class HTTP11Connection:
             self.h11_state.our_state is h11.DONE
             and self.h11_state.their_state is h11.DONE
         ):
+            # Get ready for another request/response cycle.
             self.h11_state.start_next_cycle()
+            self.timeout_flag.set_write_timeouts()
         else:
             await self.close()
 
index ae42b27309fd2e5d83b3e4d27893fd6934d7639e..9bd35eaf2dab6a5a3cf8b7af89353aabda7c7f46 100644 (file)
@@ -4,6 +4,7 @@ import typing
 import h2.connection
 import h2.events
 
+from ..concurrency import TimeoutFlag
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
 from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
@@ -26,6 +27,7 @@ class HTTP2Connection:
         self.on_release = on_release
         self.h2_state = h2.connection.H2Connection()
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
+        self.timeout_flags = {}  # type: typing.Dict[int, TimeoutFlag]
         self.initialized = False
 
     async def send(
@@ -39,6 +41,7 @@ class HTTP2Connection:
 
         stream_id = await self.send_headers(request, timeout)
         self.events[stream_id] = []
+        self.timeout_flags[stream_id] = TimeoutFlag()
 
         task, args = self.send_request_data, [stream_id, request.stream(), timeout]
         async with self.backend.background_manager(task, args=args):
@@ -85,9 +88,13 @@ class HTTP2Connection:
         stream: typing.AsyncIterator[bytes],
         timeout: TimeoutConfig = None,
     ) -> None:
-        async for data in stream:
-            await self.send_data(stream_id, data, timeout)
-        await self.end_stream(stream_id, timeout)
+        try:
+            async for data in stream:
+                await self.send_data(stream_id, data, timeout)
+            await self.end_stream(stream_id, timeout)
+        finally:
+            # Once we've sent the request we should enable read timeouts.
+            self.timeout_flags[stream_id].set_read_timeouts()
 
     async def send_data(
         self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
@@ -113,6 +120,9 @@ class HTTP2Connection:
         """
         while True:
             event = await self.receive_event(stream_id, timeout)
+            # As soon as we start seeing response events, we should enable
+            # read timeouts, if we haven't already.
+            self.timeout_flags[stream_id].set_read_timeouts()
             if isinstance(event, h2.events.ResponseReceived):
                 break
 
@@ -140,7 +150,8 @@ class HTTP2Connection:
         self, stream_id: int, timeout: TimeoutConfig = None
     ) -> h2.events.Event:
         while not self.events[stream_id]:
-            data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+            flag = self.timeout_flags[stream_id]
+            data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
             events = self.h2_state.receive_data(data)
             for event in events:
                 if getattr(event, "stream_id", 0):
@@ -153,6 +164,7 @@ class HTTP2Connection:
 
     async def response_closed(self, stream_id: int) -> None:
         del self.events[stream_id]
+        del self.timeout_flags[stream_id]
 
         if not self.events and self.on_release is not None:
             await self.on_release()
index 231263978381d1d468dc3c1e8040fef74557ba3b..5d9b99c781bd5d973349c11d1c23fe25b0da4b81 100644 (file)
@@ -127,7 +127,9 @@ class BaseReader:
     backend, or for stand-alone test cases.
     """
 
-    async def read(self, n: int, timeout: TimeoutConfig = None) -> bytes:
+    async def read(
+        self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
+    ) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
 
index cdb7c03161dfb7df3f2b7266946206f096074c15..4764f3186a2bc73f3d0373c7a02b7d08d364945c 100644 (file)
@@ -45,7 +45,7 @@ class MockHTTP2Server(BaseReader, BaseWriter):
 
     # BaseReader interface
 
-    async def read(self, n, timeout) -> bytes:
+    async def read(self, n, timeout, flag=None) -> bytes:
         await asyncio.sleep(0)
         send, self.buffer = self.buffer[:n], self.buffer[n:]
         return send