]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop TimeoutFlag (#618)
authorTom Christie <tom@tomchristie.com>
Sun, 8 Dec 2019 19:43:33 +0000 (19:43 +0000)
committerGitHub <noreply@github.com>
Sun, 8 Dec 2019 19:43:33 +0000 (19:43 +0000)
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/http11.py

index 89384bbb95aaeb0bc5c94ac54b3eb0b4e5dabcff..51006959f5c9e053dae3532549ef9ccf1e9ce666 100644 (file)
@@ -6,13 +6,7 @@ import typing
 
 from ..config import PoolLimits, Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
-    BaseEvent,
-    BasePoolSemaphore,
-    BaseSocketStream,
-    ConcurrencyBackend,
-    TimeoutFlag,
-)
+from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -126,51 +120,26 @@ class SocketStream(BaseSocketStream):
         ident = ssl_object.selected_alpn_protocol()
         return "HTTP/2" if ident == "h2" else "HTTP/1.1"
 
-    async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
-        while True:
-            # Check our flag at the first possible moment, and use a fine
-            # grained retry loop if we're not yet in read-timeout mode.
-            should_raise = flag is None or flag.raise_on_read_timeout
-            read_timeout = timeout.read_timeout if should_raise else 0.01
-            try:
-                async with self.read_lock:
-                    data = await asyncio.wait_for(
-                        self.stream_reader.read(n), read_timeout
-                    )
-            except asyncio.TimeoutError:
-                if should_raise:
-                    raise ReadTimeout() from None
-                # FIX(py3.6): yield control back to the event loop to give it a chance
-                # to cancel `.read(n)` before we retry.
-                # This prevents concurrent `.read()` calls, which asyncio
-                # doesn't seem to allow on 3.6.
-                # See: https://github.com/encode/httpx/issues/382
-                await asyncio.sleep(0)
-            else:
-                break
-
-        return data
+    async def read(self, n: int, timeout: Timeout) -> bytes:
+        try:
+            async with self.read_lock:
+                return await asyncio.wait_for(
+                    self.stream_reader.read(n), timeout.read_timeout
+                )
+        except asyncio.TimeoutError:
+            raise ReadTimeout() from None
 
-    async def write(
-        self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
-    ) -> None:
+    async def write(self, data: bytes, timeout: Timeout) -> None:
         if not data:
             return
 
         self.stream_writer.write(data)
-        while True:
-            try:
-                await asyncio.wait_for(  # type: ignore
-                    self.stream_writer.drain(), timeout.write_timeout
-                )
-                break
-            except asyncio.TimeoutError:
-                # We check our flag at the first possible moment, in order to
-                # allow us to suppress write timeouts, if we've since
-                # switched over to read-timeout mode.
-                should_raise = flag is None or flag.raise_on_write_timeout
-                if should_raise:
-                    raise WriteTimeout() from None
+        try:
+            return await asyncio.wait_for(
+                self.stream_writer.drain(), timeout.write_timeout
+            )
+        except asyncio.TimeoutError:
+            raise WriteTimeout() from None
 
     def is_connection_dropped(self) -> bool:
         # Counter-intuitively, what we really want to know here is whether the socket is
index 24b47232022e52fc6bbf15ca504590b4e11b4d0f..27ce9c89bd098ddc587f5a39d83ee4d923c09718 100644 (file)
@@ -26,38 +26,6 @@ def lookup_backend(
     raise RuntimeError(f"Unknown or unsupported concurrency backend {backend!r}")
 
 
-class TimeoutFlag:
-    """
-    A timeout flag holds a state of either read-timeout or write-timeout mode.
-
-    We use this so that we can attempt both reads and writes concurrently, while
-    only enforcing timeouts in one direction.
-
-    During a request/response cycle we start in write-timeout mode.
-
-    Once we've sent a request fully, or once we start seeing a response,
-    then we switch to read-timeout mode instead.
-    """
-
-    def __init__(self) -> None:
-        self.raise_on_read_timeout = False
-        self.raise_on_write_timeout = True
-
-    def set_read_timeouts(self) -> None:
-        """
-        Set the flag to read-timeout mode.
-        """
-        self.raise_on_read_timeout = True
-        self.raise_on_write_timeout = False
-
-    def set_write_timeouts(self) -> None:
-        """
-        Set the flag to write-timeout mode.
-        """
-        self.raise_on_read_timeout = False
-        self.raise_on_write_timeout = True
-
-
 class BaseSocketStream:
     """
     A socket stream with read/write operations. Abstracts away any asyncio-specific
@@ -73,7 +41,7 @@ class BaseSocketStream:
     ) -> "BaseSocketStream":
         raise NotImplementedError()  # pragma: no cover
 
-    async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes:
+    async def read(self, n: int, timeout: Timeout) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
     async def write(self, data: bytes, timeout: Timeout) -> None:
index 11737f723c03a3b7c0e45684fee2c815545014cb..7604e965ab14d7ca98d2e1489faa3a8d641a2139 100644 (file)
@@ -6,16 +6,10 @@ import trio
 
 from ..config import PoolLimits, Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
-from .base import (
-    BaseEvent,
-    BasePoolSemaphore,
-    BaseSocketStream,
-    ConcurrencyBackend,
-    TimeoutFlag,
-)
+from .base import BaseEvent, BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
 
 
-def _or_inf(value: typing.Optional[float]) -> float:
+def none_as_inf(value: typing.Optional[float]) -> float:
     return value if value is not None else float("inf")
 
 
@@ -30,18 +24,16 @@ class SocketStream(BaseSocketStream):
     async def start_tls(
         self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
     ) -> "SocketStream":
-        connect_timeout = _or_inf(timeout.connect_timeout)
+        connect_timeout = none_as_inf(timeout.connect_timeout)
         ssl_stream = trio.SSLStream(
             self.stream, ssl_context=ssl_context, server_hostname=hostname
         )
 
-        with trio.move_on_after(connect_timeout) as cancel_scope:
+        with trio.move_on_after(connect_timeout):
             await ssl_stream.do_handshake()
+            return SocketStream(ssl_stream)
 
-        if cancel_scope.cancelled_caught:
-            raise ConnectTimeout()
-
-        return SocketStream(ssl_stream)
+        raise ConnectTimeout()
 
     def get_http_version(self) -> str:
         if not isinstance(self.stream, trio.SSLStream):
@@ -50,19 +42,26 @@ class SocketStream(BaseSocketStream):
         ident = self.stream.selected_alpn_protocol()
         return "HTTP/2" if ident == "h2" else "HTTP/1.1"
 
-    async def read(self, n: int, timeout: Timeout, flag: TimeoutFlag = None) -> bytes:
-        while True:
-            # Check our flag at the first possible moment, and use a fine
-            # grained retry loop if we're not yet in read-timeout mode.
-            should_raise = flag is None or flag.raise_on_read_timeout
-            read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)
+    async def read(self, n: int, timeout: Timeout) -> bytes:
+        read_timeout = none_as_inf(timeout.read_timeout)
+
+        with trio.move_on_after(read_timeout):
+            async with self.read_lock:
+                return await self.stream.receive_some(max_bytes=n)
+
+        raise ReadTimeout()
+
+    async def write(self, data: bytes, timeout: Timeout) -> None:
+        if not data:
+            return
 
-            with trio.move_on_after(read_timeout):
-                async with self.read_lock:
-                    return await self.stream.receive_some(max_bytes=n)
+        write_timeout = none_as_inf(timeout.write_timeout)
 
-            if should_raise:
-                raise ReadTimeout() from None
+        with trio.move_on_after(write_timeout):
+            async with self.write_lock:
+                return await self.stream.send_all(data)
+
+        raise WriteTimeout()
 
     def is_connection_dropped(self) -> bool:
         # Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
@@ -79,26 +78,6 @@ class SocketStream(BaseSocketStream):
         # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
         return stream.socket.is_readable()
 
-    async def write(
-        self, data: bytes, timeout: Timeout, flag: TimeoutFlag = None
-    ) -> None:
-        if not data:
-            return
-
-        write_timeout = _or_inf(timeout.write_timeout)
-
-        while True:
-            with trio.move_on_after(write_timeout):
-                async with self.write_lock:
-                    await self.stream.send_all(data)
-                break
-            # We check our flag at the first possible moment, in order to
-            # allow us to suppress write timeouts, if we've since
-            # switched over to read-timeout mode.
-            should_raise = flag is None or flag.raise_on_write_timeout
-            if should_raise:
-                raise WriteTimeout() from None
-
     async def close(self) -> None:
         await self.stream.aclose()
 
@@ -123,7 +102,7 @@ class PoolSemaphore(BasePoolSemaphore):
         if self.semaphore is None:
             return
 
-        timeout = _or_inf(timeout)
+        timeout = none_as_inf(timeout)
 
         with trio.move_on_after(timeout):
             await self.semaphore.acquire()
@@ -146,18 +125,16 @@ class TrioBackend(ConcurrencyBackend):
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: Timeout,
     ) -> SocketStream:
-        connect_timeout = _or_inf(timeout.connect_timeout)
+        connect_timeout = none_as_inf(timeout.connect_timeout)
 
-        with trio.move_on_after(connect_timeout) as cancel_scope:
+        with trio.move_on_after(connect_timeout):
             stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
             if ssl_context is not None:
                 stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
                 await stream.do_handshake()
+            return SocketStream(stream=stream)
 
-        if cancel_scope.cancelled_caught:
-            raise ConnectTimeout()
-
-        return SocketStream(stream=stream)
+        raise ConnectTimeout()
 
     async def open_uds_stream(
         self,
@@ -166,18 +143,16 @@ class TrioBackend(ConcurrencyBackend):
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: Timeout,
     ) -> SocketStream:
-        connect_timeout = _or_inf(timeout.connect_timeout)
+        connect_timeout = none_as_inf(timeout.connect_timeout)
 
-        with trio.move_on_after(connect_timeout) as cancel_scope:
+        with trio.move_on_after(connect_timeout):
             stream: trio.SocketStream = await trio.open_unix_socket(path)
             if ssl_context is not None:
                 stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
                 await stream.do_handshake()
+            return SocketStream(stream=stream)
 
-        if cancel_scope.cancelled_caught:
-            raise ConnectTimeout()
-
-        return SocketStream(stream=stream)
+        raise ConnectTimeout()
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index 44961f027f28f68052ee48e8a45d6bcc17446b68..903e7d147119154c4e167a192572997b5aadc1e8 100644 (file)
@@ -2,7 +2,7 @@ import typing
 
 import h11
 
-from ..concurrency.base import BaseSocketStream, TimeoutFlag
+from ..concurrency.base import BaseSocketStream
 from ..config import Timeout
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
@@ -38,7 +38,6 @@ class HTTP11Connection:
         self.socket = socket
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
-        self.timeout_flag = TimeoutFlag()
 
     async def send(self, request: Request, timeout: Timeout = None) -> Response:
         timeout = Timeout() if timeout is None else timeout
@@ -102,9 +101,6 @@ class HTTP11Connection:
             # care about connection errors that occur when sending the body.
             # Ignore these, and defer to any exceptions on reading the response.
             self.h11_state.send_failed()
-        finally:
-            # Once we've sent the request, we enable read timeouts.
-            self.timeout_flag.set_read_timeouts()
 
     async def _send_event(self, event: H11Event, timeout: Timeout) -> None:
         """
@@ -122,9 +118,6 @@ class HTTP11Connection:
         """
         while True:
             event = await self._receive_event(timeout)
-            # As soon as we start seeing response events, we should enable
-            # read timeouts, if we haven't already.
-            self.timeout_flag.set_read_timeouts()
             if isinstance(event, h11.InformationalResponse):
                 continue
             else:
@@ -171,9 +164,7 @@ class HTTP11Connection:
 
             if event is h11.NEED_DATA:
                 try:
-                    data = await self.socket.read(
-                        self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
-                    )
+                    data = await self.socket.read(self.READ_NUM_BYTES, timeout)
                 except OSError:  # pragma: nocover
                     data = b""
                 self.h11_state.receive_data(data)
@@ -194,7 +185,6 @@ class HTTP11Connection:
         ):
             # Get ready for another request/response cycle.
             self.h11_state.start_next_cycle()
-            self.timeout_flag.set_write_timeouts()
         else:
             await self.close()