]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop BackgroundManager in favor of fork(func1, func2) (#603)
authorFlorimond Manca <florimond.manca@gmail.com>
Fri, 6 Dec 2019 10:49:24 +0000 (11:49 +0100)
committerTom Christie <tom@tomchristie.com>
Fri, 6 Dec 2019 10:49:24 +0000 (10:49 +0000)
* Drop BackgroundManager in favor of fork(func1, func2)

* Please mypy

httpx/__init__.py
httpx/concurrency/asyncio.py
httpx/concurrency/auto.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/http2.py
tests/test_concurrency.py

index 686359a7d36de172316130a00b07c45f417a4a4e..b6cd6df2b73b41c04694fe78330db8c22d20a3d6 100644 (file)
@@ -3,12 +3,7 @@ from .api import delete, get, head, options, patch, post, put, request, stream
 from .auth import BasicAuth, DigestAuth
 from .client import Client
 from .concurrency.asyncio import AsyncioBackend
-from .concurrency.base import (
-    BaseBackgroundManager,
-    BasePoolSemaphore,
-    BaseSocketStream,
-    ConcurrencyBackend,
-)
+from .concurrency.base import BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
 from .config import (
     USER_AGENT,
     CertTypes,
@@ -89,7 +84,6 @@ __all__ = [
     "VerifyTypes",
     "HTTPConnection",
     "BasePoolSemaphore",
-    "BaseBackgroundManager",
     "ConnectionPool",
     "HTTPProxy",
     "HTTPProxyMode",
index a75971620b9fbada6eab7383e8c3868e3fbd7cfc..eee4ec32d5e8b68a4d2a2de46db3b5b2873e29f9 100644 (file)
@@ -3,12 +3,10 @@ import functools
 import ssl
 import sys
 import typing
-from types import TracebackType
 
 from ..config import PoolLimits, Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import (
-    BaseBackgroundManager,
     BaseEvent,
     BasePoolSemaphore,
     BaseSocketStream,
@@ -317,34 +315,30 @@ class AsyncioBackend(ConcurrencyBackend):
         finally:
             self._loop = loop
 
+    async def fork(
+        self,
+        coroutine1: typing.Callable,
+        args1: typing.Sequence,
+        coroutine2: typing.Callable,
+        args2: typing.Sequence,
+    ) -> None:
+        task1 = self.loop.create_task(coroutine1(*args1))
+        task2 = self.loop.create_task(coroutine2(*args2))
+
+        try:
+            await asyncio.gather(task1, task2)
+        finally:
+            pending: typing.Set[asyncio.Future[typing.Any]]  # Please mypy.
+            _, pending = await asyncio.wait({task1, task2}, timeout=0)
+            for task in pending:
+                task.cancel()
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         return PoolSemaphore(limits)
 
     def create_event(self) -> BaseEvent:
         return typing.cast(BaseEvent, asyncio.Event())
-
-    def background_manager(
-        self, coroutine: typing.Callable, *args: typing.Any
-    ) -> "BackgroundManager":
-        return BackgroundManager(coroutine, args)
-
-
-class BackgroundManager(BaseBackgroundManager):
-    def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
-        self.coroutine = coroutine
-        self.args = args
-
-    async def __aenter__(self) -> "BackgroundManager":
-        loop = asyncio.get_event_loop()
-        self.task = loop.create_task(self.coroutine(*self.args))
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        await self.task
-        if exc_type is None:
-            self.task.result()
index 3dd31a8bbf6ee84444b8c0d4898792b2d7e1ea64..3b57e5674dde029d01019ad3da115722f5865ae0 100644 (file)
@@ -5,7 +5,6 @@ import sniffio
 
 from ..config import PoolLimits, Timeout
 from .base import (
-    BaseBackgroundManager,
     BaseEvent,
     BasePoolSemaphore,
     BaseSocketStream,
@@ -52,8 +51,3 @@ class AutoBackend(ConcurrencyBackend):
 
     def create_event(self) -> BaseEvent:
         return self.backend.create_event()
-
-    def background_manager(
-        self, coroutine: typing.Callable, *args: typing.Any
-    ) -> BaseBackgroundManager:
-        return self.backend.background_manager(coroutine, *args)
index f32501a3a094f454ebfe5c2a22b3cc4060286914..ff5f72f30dccf653f90c8002f3ee7682599f9ea2 100644 (file)
@@ -1,6 +1,5 @@
 import ssl
 import typing
-from types import TracebackType
 
 from ..config import PoolLimits, Timeout
 
@@ -154,27 +153,20 @@ class ConcurrencyBackend:
     def create_event(self) -> BaseEvent:
         raise NotImplementedError()  # pragma: no cover
 
-    def background_manager(
-        self, coroutine: typing.Callable, *args: typing.Any
-    ) -> "BaseBackgroundManager":
-        raise NotImplementedError()  # pragma: no cover
-
-
-class BaseBackgroundManager:
-    async def __aenter__(self) -> "BaseBackgroundManager":
-        raise NotImplementedError()  # pragma: no cover
-
-    async def __aexit__(
+    async def fork(
         self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
+        coroutine1: typing.Callable,
+        args1: typing.Sequence,
+        coroutine2: typing.Callable,
+        args2: typing.Sequence,
     ) -> None:
-        raise NotImplementedError()  # pragma: no cover
+        """
+        Run two coroutines concurrently.
+
+        This should start 'coroutine1' with '*args1' and 'coroutine2' with '*args2',
+        and wait for them to finish.
 
-    async def close(self, exception: BaseException = None) -> None:
-        if exception is None:
-            await self.__aexit__(None, None, None)
-        else:
-            traceback = exception.__traceback__  # type: ignore
-            await self.__aexit__(type(exception), exception, traceback)
+        In case one of the coroutines raises an exception, cancel the other one then
+        raise. If the other coroutine had also raised an exception, ignore it (for now).
+        """
+        raise NotImplementedError()  # pragma: no cover
index 4af4242e1936b81ce0c06bb0f6ec0ee018024046..f1fc7c4286f681c1a499fe2a07236007055d86bc 100644 (file)
@@ -1,14 +1,12 @@
 import functools
 import ssl
 import typing
-from types import TracebackType
 
 import trio
 
 from ..config import PoolLimits, Timeout
 from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import (
-    BaseBackgroundManager,
     BaseEvent,
     BasePoolSemaphore,
     BaseSocketStream,
@@ -204,17 +202,31 @@ class TrioBackend(ConcurrencyBackend):
             functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
         )
 
+    async def fork(
+        self,
+        coroutine1: typing.Callable,
+        args1: typing.Sequence,
+        coroutine2: typing.Callable,
+        args2: typing.Sequence,
+    ) -> None:
+        try:
+            async with trio.open_nursery() as nursery:
+                nursery.start_soon(coroutine1, *args1)
+                nursery.start_soon(coroutine2, *args2)
+        except trio.MultiError as exc:
+            # NOTE: asyncio doesn't handle multi-errors yet, so we must align on its
+            # behavior here, and need to arbitrarily decide which exception to raise.
+            # We may want to add an 'httpx.MultiError', manually add support
+            # for this situation in the asyncio backend, and re-raise
+            # an 'httpx.MultiError' from trio's here.
+            raise exc.exceptions[0]
+
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
         return PoolSemaphore(limits)
 
     def create_event(self) -> BaseEvent:
         return Event()
 
-    def background_manager(
-        self, coroutine: typing.Callable, *args: typing.Any
-    ) -> "BackgroundManager":
-        return BackgroundManager(coroutine, *args)
-
 
 class Event(BaseEvent):
     def __init__(self) -> None:
@@ -233,25 +245,3 @@ class Event(BaseEvent):
         # trio.Event.clear() was deprecated in Trio 0.12.
         # https://github.com/python-trio/trio/issues/637
         self._event = trio.Event()
-
-
-class BackgroundManager(BaseBackgroundManager):
-    def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
-        self.coroutine = coroutine
-        self.args = args
-        self.nursery_manager = trio.open_nursery()
-        self.nursery: typing.Optional[trio.Nursery] = None
-
-    async def __aenter__(self) -> "BackgroundManager":
-        self.nursery = await self.nursery_manager.__aenter__()
-        self.nursery.start_soon(self.coroutine, *self.args)
-        return self
-
-    async def __aexit__(
-        self,
-        exc_type: typing.Type[BaseException] = None,
-        exc_value: BaseException = None,
-        traceback: TracebackType = None,
-    ) -> None:
-        assert self.nursery is not None
-        await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
index 09368c884e1bb459c3e8a098ee3bfc81336bbfa5..471ba9b7e012bb03f122da3fdcd51d3d71cf7bec 100644 (file)
@@ -65,9 +65,23 @@ class HTTP2Connection:
         self.timeout_flags[stream_id] = TimeoutFlag()
         self.window_update_received[stream_id] = self.backend.create_event()
 
-        task, args = self.send_request_data, [stream_id, request.stream(), timeout]
-        async with self.backend.background_manager(task, *args):
+        status_code: typing.Optional[int] = None
+        headers: typing.Optional[list] = None
+
+        async def receive_response(stream_id: int, timeout: Timeout) -> None:
+            nonlocal status_code, headers
             status_code, headers = await self.receive_response(stream_id, timeout)
+
+        await self.backend.fork(
+            self.send_request_data,
+            [stream_id, request.stream(), timeout],
+            receive_response,
+            [stream_id, timeout],
+        )
+
+        assert status_code is not None
+        assert headers is not None
+
         content = self.body_iter(stream_id, timeout)
         on_close = functools.partial(self.response_closed, stream_id=stream_id)
 
index 4fecaa6dcb9b89fc524abe42e87ebf34d1e68ddc..0898dd4bb7370c3cc1b1982d5c29f33fe559f9f8 100644 (file)
@@ -3,7 +3,7 @@ import trio
 
 from httpx import AsyncioBackend, SSLConfig, Timeout
 from httpx.concurrency.trio import TrioBackend
-from tests.concurrency import run_concurrently
+from tests.concurrency import run_concurrently, sleep
 
 
 def get_asyncio_cipher(stream):
@@ -110,3 +110,49 @@ async def test_concurrent_read(server, backend):
         )
     finally:
         await stream.close()
+
+
+async def test_fork(backend):
+    ok_counter = 0
+
+    async def ok(delay: int) -> None:
+        nonlocal ok_counter
+        await sleep(backend, delay)
+        ok_counter += 1
+
+    async def fail(message: str, delay: int) -> None:
+        await sleep(backend, delay)
+        raise RuntimeError(message)
+
+    await backend.fork(ok, [0], ok, [0])
+    assert ok_counter == 2
+
+    with pytest.raises(RuntimeError, match="Oops"):
+        await backend.fork(ok, [0], fail, ["Oops", 0.01])
+
+    assert ok_counter == 3
+
+    with pytest.raises(RuntimeError, match="Oops"):
+        await backend.fork(ok, [0.01], fail, ["Oops", 0])
+
+    assert ok_counter == 3
+
+    with pytest.raises(RuntimeError, match="Oops"):
+        await backend.fork(fail, ["Oops", 0.01], ok, [0])
+
+    assert ok_counter == 4
+
+    with pytest.raises(RuntimeError, match="Oops"):
+        await backend.fork(fail, ["Oops", 0], ok, [0.01])
+
+    assert ok_counter == 4
+
+    with pytest.raises(RuntimeError, match="My bad"):
+        await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0.01])
+
+    with pytest.raises(RuntimeError, match="Oops"):
+        await backend.fork(fail, ["My bad", 0.01], fail, ["Oops", 0])
+
+    # No 'match', since we can't know which will be raised first.
+    with pytest.raises(RuntimeError):
+        await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0])