From: Tom Christie Date: Sun, 29 Dec 2019 16:38:54 +0000 (+0000) Subject: Switch from an Event primitive to a Lock primitive (#693) X-Git-Tag: 0.10.0~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1f9d0154df4821587d6e008fe54df54d3c02c51b;p=thirdparty%2Fhttpx.git Switch from an Event primitive to a Lock primitive (#693) --- diff --git a/httpx/backends/asyncio.py b/httpx/backends/asyncio.py index 0265f1b2..1859e96a 100644 --- a/httpx/backends/asyncio.py +++ b/httpx/backends/asyncio.py @@ -5,7 +5,7 @@ import typing from ..config import Timeout from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout -from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend +from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend SSL_MONKEY_PATCH_APPLIED = False @@ -252,19 +252,19 @@ class AsyncioBackend(ConcurrencyBackend): def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: return Semaphore(max_value, exc_class) - def create_event(self) -> BaseEvent: - return Event() + def create_lock(self) -> BaseLock: + return Lock() -class Event(BaseEvent): +class Lock(BaseLock): def __init__(self) -> None: - self._event = asyncio.Event() + self._lock = asyncio.Lock() - def set(self) -> None: - self._event.set() + def release(self) -> None: + self._lock.release() - async def wait(self) -> None: - await self._event.wait() + async def acquire(self) -> None: + await self._lock.acquire() class Semaphore(BaseSemaphore): diff --git a/httpx/backends/auto.py b/httpx/backends/auto.py index 32fcf798..cb037981 100644 --- a/httpx/backends/auto.py +++ b/httpx/backends/auto.py @@ -5,7 +5,7 @@ import sniffio from ..config import Timeout from .base import ( - BaseEvent, + BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend, @@ -52,5 +52,5 @@ class AutoBackend(ConcurrencyBackend): def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: return self.backend.create_semaphore(max_value, exc_class) - def create_event(self) -> BaseEvent: - return self.backend.create_event() + def create_lock(self) -> BaseLock: + return self.backend.create_lock() diff --git a/httpx/backends/base.py b/httpx/backends/base.py index 16c55cc2..e55b6363 100644 --- a/httpx/backends/base.py +++ b/httpx/backends/base.py @@ -1,5 +1,6 @@ import ssl import typing +from types import TracebackType from ..config import Timeout @@ -54,16 +55,27 @@ class BaseSocketStream: raise NotImplementedError() # pragma: no cover -class BaseEvent: +class BaseLock: """ - An abstract interface for Event classes. + An abstract interface for Lock classes. Abstracts away any asyncio-specific interfaces. """ - def set(self) -> None: + async def __aenter__(self) -> None: + await self.acquire() + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.release() + + def release(self) -> None: raise NotImplementedError() # pragma: no cover - async def wait(self) -> None: + async def acquire(self) -> None: raise NotImplementedError() # pragma: no cover @@ -115,5 +127,5 @@ class ConcurrencyBackend: def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: raise NotImplementedError() # pragma: no cover - def create_event(self) -> BaseEvent: + def create_lock(self) -> BaseLock: raise NotImplementedError() # pragma: no cover diff --git a/httpx/backends/trio.py b/httpx/backends/trio.py index 8858ca42..7c26dae3 100644 --- a/httpx/backends/trio.py +++ b/httpx/backends/trio.py @@ -6,7 +6,7 @@ import trio from ..config import Timeout from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout -from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend +from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend def none_as_inf(value: typing.Optional[float]) -> float: @@ -139,8 +139,8 @@ class TrioBackend(ConcurrencyBackend): def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore: return Semaphore(max_value, exc_class) - def create_event(self) -> BaseEvent: - return Event() + def create_lock(self) -> BaseLock: + return Lock() class Semaphore(BaseSemaphore): @@ -167,12 +167,12 @@ class Semaphore(BaseSemaphore): self.semaphore.release() -class Event(BaseEvent): +class Lock(BaseLock): def __init__(self) -> None: - self._event = trio.Event() + self._lock = trio.Lock() - def set(self) -> None: - self._event.set() + def release(self) -> None: + self._lock.release() - async def wait(self) -> None: - await self._event.wait() + async def acquire(self) -> None: + await self._lock.acquire() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 285625da..0c303e46 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -6,7 +6,7 @@ from h2.config import H2Configuration from h2.settings import SettingCodes, Settings from ..backends.base import ( - BaseEvent, + BaseLock, BaseSocketStream, ConcurrencyBackend, lookup_backend, @@ -39,32 +39,28 @@ class HTTP2Connection(OpenConnection): self.streams = {} # type: typing.Dict[int, HTTP2Stream] self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] - self.init_started = False + self.sent_connection_init = False @property def is_http2(self) -> bool: return True @property - def init_complete(self) -> BaseEvent: + def init_lock(self) -> BaseLock: # We do this lazily, to make sure backend autodetection always # runs within an async context. - if not hasattr(self, "_initialization_complete"): - self._initialization_complete = self.backend.create_event() - return self._initialization_complete + if not hasattr(self, "_initialization_lock"): + self._initialization_lock = self.backend.create_lock() + return self._initialization_lock async def send(self, request: Request, timeout: Timeout = None) -> Response: timeout = Timeout() if timeout is None else timeout - if not self.init_started: - # The very first stream is responsible for initiating the connection. - self.init_started = True - await self.send_connection_init(timeout) - stream_id = self.state.get_next_available_stream_id() - self.init_complete.set() - else: - # All other streams need to wait until the connection is established. - await self.init_complete.wait() + async with self.init_lock: + if not self.sent_connection_init: + # The very first stream is responsible for initiating the connection. + await self.send_connection_init(timeout) + self.sent_connection_init = True stream_id = self.state.get_next_available_stream_id() stream = HTTP2Stream(stream_id=stream_id, connection=self)