]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Handle early connection closes (#103)
authorTom Christie <tom@tomchristie.com>
Mon, 24 Jun 2019 15:53:42 +0000 (16:53 +0100)
committerGitHub <noreply@github.com>
Mon, 24 Jun 2019 15:53:42 +0000 (16:53 +0100)
* Refactoring h11 implementation

* More h11 refactoring

* Support early connection closes on H11 connections

* Tweak comment

* Refactor concurrent read/writes

* Drop WriteTimeout masking

* Linting

* Use concurrent read/writes for HTTP2

* Push background sending into ConcurrencyBackend

http3/client.py
http3/concurrency.py
http3/dispatch/connection.py
http3/dispatch/http11.py
http3/dispatch/http2.py
http3/interfaces.py
tests/client/test_client.py
tests/dispatch/utils.py

index 6c0557b88828e5848ee38dd8cd2aedecd073bb36..8bf20de97c592724015f861df92a8c67d6eca52c 100644 (file)
@@ -81,7 +81,7 @@ class BaseClient:
             async_dispatch = dispatch
 
         if base_url is None:
-            self.base_url = URL('', allow_relative=True)
+            self.base_url = URL("", allow_relative=True)
         else:
             self.base_url = URL(base_url)
 
index 664cb294484f03594cf0c04cbc0a052daedb8b4e..fd6af36833390624af477e5f4b054bbb6e103053 100644 (file)
@@ -12,10 +12,12 @@ import asyncio
 import functools
 import ssl
 import typing
+from types import TracebackType
 
 from .config import DEFAULT_TIMEOUT_CONFIG, PoolLimits, TimeoutConfig
 from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .interfaces import (
+    BaseBackgroundManager,
     BasePoolSemaphore,
     BaseReader,
     BaseWriter,
@@ -41,7 +43,7 @@ def ssl_monkey_patch() -> None:
     _write = MonkeyPatch.write
 
     def _fixed_write(self, data: bytes) -> None:  # type: ignore
-        if not self._loop.is_closed():
+        if self._loop and not self._loop.is_closed():
             _write(self, data)
 
     MonkeyPatch.write = _fixed_write
@@ -193,3 +195,29 @@ class AsyncioBackend(ConcurrencyBackend):
 
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         return PoolSemaphore(limits)
+
+    def background_manager(
+        self, coroutine: typing.Callable, args: typing.Any
+    ) -> "BackgroundManager":
+        return BackgroundManager(coroutine, args)
+
+
+class BackgroundManager(BaseBackgroundManager):
+    def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
+        self.coroutine = coroutine
+        self.args = args
+
+    async def __aenter__(self) -> "BackgroundManager":
+        loop = asyncio.get_event_loop()
+        self.task = loop.create_task(self.coroutine(*self.args))
+        return self
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        await self.task
+        if exc_type is None:
+            self.task.result()
index d644bcba73b32094bebbdda0da8bf1396ad88cd9..68dea5c4d1aba5f538542bed061e6ae2e72bce89 100644 (file)
@@ -82,10 +82,12 @@ class HTTPConnection(AsyncDispatcher):
             host, port, ssl_context, timeout
         )
         if protocol == Protocol.HTTP_2:
-            self.h2_connection = HTTP2Connection(reader, writer, on_release=on_release)
+            self.h2_connection = HTTP2Connection(
+                reader, writer, self.backend, on_release=on_release
+            )
         else:
             self.h11_connection = HTTP11Connection(
-                reader, writer, on_release=on_release
+                reader, writer, self.backend, on_release=on_release
             )
 
     async def close(self) -> None:
index 6a45a04937a0645f9322a2cc2a28a0bf6e278456..1f632d8eb29e1bfb5a00df98bebe409c2976c5b3 100644 (file)
@@ -4,7 +4,7 @@ import h11
 
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter
+from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
 H11Event = typing.Union[
@@ -30,10 +30,12 @@ class HTTP11Connection:
         self,
         reader: BaseReader,
         writer: BaseWriter,
+        backend: ConcurrencyBackend,
         on_release: typing.Optional[OnReleaseCallback] = None,
     ):
         self.reader = reader
         self.writer = writer
+        self.backend = backend
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
@@ -42,34 +44,11 @@ class HTTP11Connection:
     ) -> AsyncResponse:
         timeout = None if timeout is None else TimeoutConfig(timeout)
 
-        #  Start sending the request.
-        method = request.method.encode("ascii")
-        target = request.url.full_path.encode("ascii")
-        headers = request.headers.raw
-        if "Host" not in request.headers:
-            host = request.url.authority.encode("ascii")
-            headers = [(b"host", host)] + headers
-        event = h11.Request(method=method, target=target, headers=headers)
-        await self._send_event(event, timeout)
-
-        # Send the request body.
-        async for data in request.stream():
-            event = h11.Data(data=data)
-            await self._send_event(event, timeout)
-
-        # Finalize sending the request.
-        event = h11.EndOfMessage()
-        await self._send_event(event, timeout)
-
-        # Start getting the response.
-        event = await self._receive_event(timeout)
-        if isinstance(event, h11.InformationalResponse):
-            event = await self._receive_event(timeout)
-
-        assert isinstance(event, h11.Response)
-        status_code = event.status_code
-        headers = event.headers
-        content = self._body_iter(timeout)
+        await self._send_request(request, timeout)
+        task, args = self._send_request_data, [request.stream(), timeout]
+        async with self.backend.background_manager(task, args=args):
+            status_code, headers = await self._receive_response(timeout)
+        content = self._receive_response_data(timeout)
 
         return AsyncResponse(
             status_code=status_code,
@@ -82,30 +61,100 @@ class HTTP11Connection:
 
     async def close(self) -> None:
         event = h11.ConnectionClosed()
-        self.h11_state.send(event)
+        try:
+            self.h11_state.send(event)
+        except h11.LocalProtocolError as exc:  # pragma: no cover
+            # Premature client disconnect
+            pass
         await self.writer.close()
 
-    async def _body_iter(
+    async def _send_request(
+        self, request: AsyncRequest, timeout: TimeoutConfig = None
+    ) -> None:
+        """
+        Send the request method, URL, and headers to the network.
+        """
+        method = request.method.encode("ascii")
+        target = request.url.full_path.encode("ascii")
+        headers = request.headers.raw
+        if "Host" not in request.headers:
+            host = request.url.authority.encode("ascii")
+            headers = [(b"host", host)] + headers
+        event = h11.Request(method=method, target=target, headers=headers)
+        await self._send_event(event, timeout)
+
+    async def _send_request_data(
+        self, data: typing.AsyncIterator[bytes], timeout: TimeoutConfig = None
+    ) -> None:
+        """
+        Send the request body to the network.
+        """
+        try:
+            # Send the request body.
+            async for chunk in data:
+                event = h11.Data(data=chunk)
+                await self._send_event(event, timeout)
+
+            # Finalize sending the request.
+            event = h11.EndOfMessage()
+            await self._send_event(event, timeout)
+        except OSError:  # pragma: nocover
+            # Once we've sent the initial part of the request we don't actually
+            # care about connection errors that occur when sending the body.
+            # Ignore these, and defer to any exceptions on reading the response.
+            self.h11_state.send_failed()
+
+    async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
+        """
+        Send a single `h11` event to the network, waiting for the data to
+        drain before returning.
+        """
+        bytes_to_send = self.h11_state.send(event)
+        await self.writer.write(bytes_to_send, timeout)
+
+    async def _receive_response(
+        self, timeout: TimeoutConfig = None
+    ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
+        """
+        Read the response status and headers from the network.
+        """
+        while True:
+            event = await self._receive_event(timeout)
+            if isinstance(event, h11.InformationalResponse):
+                continue
+            else:
+                assert isinstance(event, h11.Response)
+                break
+        return (event.status_code, event.headers)
+
+    async def _receive_response_data(
         self, timeout: TimeoutConfig = None
     ) -> typing.AsyncIterator[bytes]:
-        event = await self._receive_event(timeout)
-        while isinstance(event, h11.Data):
-            yield event.data
+        """
+        Read the response data from the network.
+        """
+        while True:
             event = await self._receive_event(timeout)
-        assert isinstance(event, h11.EndOfMessage)
-
-    async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> None:
-        data = self.h11_state.send(event)
-        await self.writer.write(data, timeout)
+            if isinstance(event, h11.Data):
+                yield event.data
+            else:
+                assert isinstance(event, h11.EndOfMessage)
+                break
 
     async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
-        event = self.h11_state.next_event()
-
-        while event is h11.NEED_DATA:
-            data = await self.reader.read(self.READ_NUM_BYTES, timeout)
-            self.h11_state.receive_data(data)
+        """
+        Read a single `h11` event, reading more data from the network if needed.
+        """
+        while True:
             event = self.h11_state.next_event()
-
+            if event is h11.NEED_DATA:
+                try:
+                    data = await self.reader.read(self.READ_NUM_BYTES, timeout)
+                except OSError:  # pragma: nocover
+                    data = b""
+                self.h11_state.receive_data(data)
+            else:
+                break
         return event
 
     async def response_closed(self) -> None:
index c6b28121149b0c26d0fa65334752e83ac0a461d4..ae42b27309fd2e5d83b3e4d27893fd6934d7639e 100644 (file)
@@ -6,7 +6,7 @@ import h2.events
 
 from ..config import DEFAULT_TIMEOUT_CONFIG, TimeoutConfig, TimeoutTypes
 from ..exceptions import ConnectTimeout, ReadTimeout
-from ..interfaces import BaseReader, BaseWriter
+from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
 from ..models import AsyncRequest, AsyncResponse
 
 
@@ -14,10 +14,15 @@ class HTTP2Connection:
     READ_NUM_BYTES = 4096
 
     def __init__(
-        self, reader: BaseReader, writer: BaseWriter, on_release: typing.Callable = None
+        self,
+        reader: BaseReader,
+        writer: BaseWriter,
+        backend: ConcurrencyBackend,
+        on_release: typing.Callable = None,
     ):
         self.reader = reader
         self.writer = writer
+        self.backend = backend
         self.on_release = on_release
         self.h2_state = h2.connection.H2Connection()
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
@@ -35,27 +40,9 @@ class HTTP2Connection:
         stream_id = await self.send_headers(request, timeout)
         self.events[stream_id] = []
 
-        # Send the request body.
-        async for data in request.stream():
-            await self.send_data(stream_id, data, timeout)
-
-        # Finalize sending the request.
-        await self.end_stream(stream_id, timeout)
-
-        # Start getting the response.
-        while True:
-            event = await self.receive_event(stream_id, timeout)
-            if isinstance(event, h2.events.ResponseReceived):
-                break
-
-        status_code = 200
-        headers = []
-        for k, v in event.headers:
-            if k == b":status":
-                status_code = int(v.decode("ascii", errors="ignore"))
-            elif not k.startswith(b":"):
-                headers.append((k, v))
-
+        task, args = self.send_request_data, [stream_id, request.stream(), timeout]
+        async with self.backend.background_manager(task, args=args):
+            status_code, headers = await self.receive_response(stream_id, timeout)
         content = self.body_iter(stream_id, timeout)
         on_close = functools.partial(self.response_closed, stream_id=stream_id)
 
@@ -92,13 +79,23 @@ class HTTP2Connection:
         await self.writer.write(data_to_send, timeout)
         return stream_id
 
+    async def send_request_data(
+        self,
+        stream_id: int,
+        stream: typing.AsyncIterator[bytes],
+        timeout: TimeoutConfig = None,
+    ) -> None:
+        async for data in stream:
+            await self.send_data(stream_id, data, timeout)
+        await self.end_stream(stream_id, timeout)
+
     async def send_data(
         self, stream_id: int, data: bytes, timeout: TimeoutConfig = None
     ) -> None:
         flow_control = self.h2_state.local_flow_control_window(stream_id)
         chunk_size = min(len(data), flow_control)
         for idx in range(0, len(data), chunk_size):
-            chunk = data[idx:idx+chunk_size]
+            chunk = data[idx : idx + chunk_size]
             self.h2_state.send_data(stream_id, chunk)
             data_to_send = self.h2_state.data_to_send()
             await self.writer.write(data_to_send, timeout)
@@ -108,6 +105,26 @@ class HTTP2Connection:
         data_to_send = self.h2_state.data_to_send()
         await self.writer.write(data_to_send, timeout)
 
+    async def receive_response(
+        self, stream_id: int, timeout: TimeoutConfig = None
+    ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]:
+        """
+        Read the response status and headers from the network.
+        """
+        while True:
+            event = await self.receive_event(stream_id, timeout)
+            if isinstance(event, h2.events.ResponseReceived):
+                break
+
+        status_code = 200
+        headers = []
+        for k, v in event.headers:
+            if k == b":status":
+                status_code = int(v.decode("ascii", errors="ignore"))
+            elif not k.startswith(b":"):
+                headers.append((k, v))
+        return (status_code, headers)
+
     async def body_iter(
         self, stream_id: int, timeout: TimeoutConfig = None
     ) -> typing.AsyncIterator[bytes]:
index 13d118cfcdc9a3f11f4ce64e198b9f29fefda3aa..231263978381d1d468dc3c1e8040fef74557ba3b 100644 (file)
@@ -207,3 +207,21 @@ class ConcurrencyBackend:
                 yield self.run(async_iterator.__anext__)
             except StopAsyncIteration:
                 break
+
+    def background_manager(
+        self, coroutine: typing.Callable, args: typing.Any
+    ) -> "BaseBackgroundManager":
+        raise NotImplementedError()  # pragma: no cover
+
+
+class BaseBackgroundManager:
+    async def __aenter__(self) -> "BaseBackgroundManager":
+        raise NotImplementedError()  # pragma: no cover
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        raise NotImplementedError()  # pragma: no cover
index 0ab2c83fb9cddaa884876f3f760021475bd710a4..19642414f006ecb86ffd348e95539bb42e09a02c 100644 (file)
@@ -147,6 +147,6 @@ def test_delete(server):
 def test_base_url(server):
     base_url = "http://127.0.0.1:8000/"
     with http3.Client(base_url=base_url) as http:
-        response = http.get('/')
+        response = http.get("/")
     assert response.status_code == 200
     assert str(response.url) == base_url
index 5a0203e135b58104d1a7ab0e6aa1f65ebdcd2c67..cdb7c03161dfb7df3f2b7266946206f096074c15 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import ssl
 import typing
 
@@ -45,6 +46,7 @@ class MockHTTP2Server(BaseReader, BaseWriter):
     # BaseReader interface
 
     async def read(self, n, timeout) -> bytes:
+        await asyncio.sleep(0)
         send, self.buffer = self.buffer[:n], self.buffer[n:]
         return send