from .auth import BasicAuth, DigestAuth
from .client import Client
from .concurrency.asyncio import AsyncioBackend
-from .concurrency.base import (
- BaseBackgroundManager,
- BasePoolSemaphore,
- BaseSocketStream,
- ConcurrencyBackend,
-)
+from .concurrency.base import BasePoolSemaphore, BaseSocketStream, ConcurrencyBackend
from .config import (
USER_AGENT,
CertTypes,
"VerifyTypes",
"HTTPConnection",
"BasePoolSemaphore",
- "BaseBackgroundManager",
"ConnectionPool",
"HTTPProxy",
"HTTPProxyMode",
import ssl
import sys
import typing
-from types import TracebackType
from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
- BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
finally:
self._loop = loop
+ async def fork(
+ self,
+ coroutine1: typing.Callable,
+ args1: typing.Sequence,
+ coroutine2: typing.Callable,
+ args2: typing.Sequence,
+ ) -> None:
+ task1 = self.loop.create_task(coroutine1(*args1))
+ task2 = self.loop.create_task(coroutine2(*args2))
+
+ try:
+ await asyncio.gather(task1, task2)
+ finally:
+ pending: typing.Set[asyncio.Future[typing.Any]] # Please mypy.
+ _, pending = await asyncio.wait({task1, task2}, timeout=0)
+ for task in pending:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
def create_event(self) -> BaseEvent:
return typing.cast(BaseEvent, asyncio.Event())
-
- def background_manager(
- self, coroutine: typing.Callable, *args: typing.Any
- ) -> "BackgroundManager":
- return BackgroundManager(coroutine, args)
-
-
-class BackgroundManager(BaseBackgroundManager):
- def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
- self.coroutine = coroutine
- self.args = args
-
- async def __aenter__(self) -> "BackgroundManager":
- loop = asyncio.get_event_loop()
- self.task = loop.create_task(self.coroutine(*self.args))
- return self
-
- async def __aexit__(
- self,
- exc_type: typing.Type[BaseException] = None,
- exc_value: BaseException = None,
- traceback: TracebackType = None,
- ) -> None:
- await self.task
- if exc_type is None:
- self.task.result()
from ..config import PoolLimits, Timeout
from .base import (
- BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
def create_event(self) -> BaseEvent:
return self.backend.create_event()
-
- def background_manager(
- self, coroutine: typing.Callable, *args: typing.Any
- ) -> BaseBackgroundManager:
- return self.backend.background_manager(coroutine, *args)
import ssl
import typing
-from types import TracebackType
from ..config import PoolLimits, Timeout
def create_event(self) -> BaseEvent:
raise NotImplementedError() # pragma: no cover
- def background_manager(
- self, coroutine: typing.Callable, *args: typing.Any
- ) -> "BaseBackgroundManager":
- raise NotImplementedError() # pragma: no cover
-
-
-class BaseBackgroundManager:
- async def __aenter__(self) -> "BaseBackgroundManager":
- raise NotImplementedError() # pragma: no cover
-
- async def __aexit__(
+ async def fork(
self,
- exc_type: typing.Type[BaseException] = None,
- exc_value: BaseException = None,
- traceback: TracebackType = None,
+ coroutine1: typing.Callable,
+ args1: typing.Sequence,
+ coroutine2: typing.Callable,
+ args2: typing.Sequence,
) -> None:
- raise NotImplementedError() # pragma: no cover
+ """
+ Run two coroutines concurrently.
+
+ This should start 'coroutine1' with '*args1' and 'coroutine2' with '*args2',
+ and wait for them to finish.
- async def close(self, exception: BaseException = None) -> None:
- if exception is None:
- await self.__aexit__(None, None, None)
- else:
- traceback = exception.__traceback__ # type: ignore
- await self.__aexit__(type(exception), exception, traceback)
+ In case one of the coroutines raises an exception, cancel the other one then
+ raise. If the other coroutine had also raised an exception, ignore it (for now).
+ """
+ raise NotImplementedError() # pragma: no cover
import functools
import ssl
import typing
-from types import TracebackType
import trio
from ..config import PoolLimits, Timeout
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
- BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseSocketStream,
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)
+ async def fork(
+ self,
+ coroutine1: typing.Callable,
+ args1: typing.Sequence,
+ coroutine2: typing.Callable,
+ args2: typing.Sequence,
+ ) -> None:
+ try:
+ async with trio.open_nursery() as nursery:
+ nursery.start_soon(coroutine1, *args1)
+ nursery.start_soon(coroutine2, *args2)
+ except trio.MultiError as exc:
+ # NOTE: asyncio doesn't handle multi-errors yet, so we must align on its
+ # behavior here, and need to arbitrarily decide which exception to raise.
+ # We may want to add an 'httpx.MultiError', manually add support
+ # for this situation in the asyncio backend, and re-raise
+ # an 'httpx.MultiError' from trio's here.
+ raise exc.exceptions[0]
+
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)
def create_event(self) -> BaseEvent:
return Event()
- def background_manager(
- self, coroutine: typing.Callable, *args: typing.Any
- ) -> "BackgroundManager":
- return BackgroundManager(coroutine, *args)
-
class Event(BaseEvent):
def __init__(self) -> None:
# trio.Event.clear() was deprecated in Trio 0.12.
# https://github.com/python-trio/trio/issues/637
self._event = trio.Event()
-
-
-class BackgroundManager(BaseBackgroundManager):
- def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
- self.coroutine = coroutine
- self.args = args
- self.nursery_manager = trio.open_nursery()
- self.nursery: typing.Optional[trio.Nursery] = None
-
- async def __aenter__(self) -> "BackgroundManager":
- self.nursery = await self.nursery_manager.__aenter__()
- self.nursery.start_soon(self.coroutine, *self.args)
- return self
-
- async def __aexit__(
- self,
- exc_type: typing.Type[BaseException] = None,
- exc_value: BaseException = None,
- traceback: TracebackType = None,
- ) -> None:
- assert self.nursery is not None
- await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
self.timeout_flags[stream_id] = TimeoutFlag()
self.window_update_received[stream_id] = self.backend.create_event()
- task, args = self.send_request_data, [stream_id, request.stream(), timeout]
- async with self.backend.background_manager(task, *args):
+ status_code: typing.Optional[int] = None
+ headers: typing.Optional[list] = None
+
+ async def receive_response(stream_id: int, timeout: Timeout) -> None:
+ nonlocal status_code, headers
status_code, headers = await self.receive_response(stream_id, timeout)
+
+ await self.backend.fork(
+ self.send_request_data,
+ [stream_id, request.stream(), timeout],
+ receive_response,
+ [stream_id, timeout],
+ )
+
+ assert status_code is not None
+ assert headers is not None
+
content = self.body_iter(stream_id, timeout)
on_close = functools.partial(self.response_closed, stream_id=stream_id)
from httpx import AsyncioBackend, SSLConfig, Timeout
from httpx.concurrency.trio import TrioBackend
-from tests.concurrency import run_concurrently
+from tests.concurrency import run_concurrently, sleep
def get_asyncio_cipher(stream):
)
finally:
await stream.close()
+
+
+async def test_fork(backend):
+ ok_counter = 0
+
+ async def ok(delay: int) -> None:
+ nonlocal ok_counter
+ await sleep(backend, delay)
+ ok_counter += 1
+
+ async def fail(message: str, delay: int) -> None:
+ await sleep(backend, delay)
+ raise RuntimeError(message)
+
+ await backend.fork(ok, [0], ok, [0])
+ assert ok_counter == 2
+
+ with pytest.raises(RuntimeError, match="Oops"):
+ await backend.fork(ok, [0], fail, ["Oops", 0.01])
+
+ assert ok_counter == 3
+
+ with pytest.raises(RuntimeError, match="Oops"):
+ await backend.fork(ok, [0.01], fail, ["Oops", 0])
+
+ assert ok_counter == 3
+
+ with pytest.raises(RuntimeError, match="Oops"):
+ await backend.fork(fail, ["Oops", 0.01], ok, [0])
+
+ assert ok_counter == 4
+
+ with pytest.raises(RuntimeError, match="Oops"):
+ await backend.fork(fail, ["Oops", 0], ok, [0.01])
+
+ assert ok_counter == 4
+
+ with pytest.raises(RuntimeError, match="My bad"):
+ await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0.01])
+
+ with pytest.raises(RuntimeError, match="Oops"):
+ await backend.fork(fail, ["My bad", 0.01], fail, ["Oops", 0])
+
+ # No 'match', since we can't know which will be raised first.
+ with pytest.raises(RuntimeError):
+ await backend.fork(fail, ["My bad", 0], fail, ["Oops", 0])