]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Switch from an Event primitive to a Lock primitive (#693)
authorTom Christie <tom@tomchristie.com>
Sun, 29 Dec 2019 16:38:54 +0000 (16:38 +0000)
committerGitHub <noreply@github.com>
Sun, 29 Dec 2019 16:38:54 +0000 (16:38 +0000)
httpx/backends/asyncio.py
httpx/backends/auto.py
httpx/backends/base.py
httpx/backends/trio.py
httpx/dispatch/http2.py

index 0265f1b2f1ca9f815832d86a483e5de98f371ab8..1859e96a2da75a9ff2b1c53d369b1fabc09d1abe 100644 (file)
@@ -5,7 +5,7 @@ import typing
 
 from ..config import Timeout
 from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
+from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -252,19 +252,19 @@ class AsyncioBackend(ConcurrencyBackend):
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         return Semaphore(max_value, exc_class)
 
-    def create_event(self) -> BaseEvent:
-        return Event()
+    def create_lock(self) -> BaseLock:
+        return Lock()
 
 
-class Event(BaseEvent):
+class Lock(BaseLock):
     def __init__(self) -> None:
-        self._event = asyncio.Event()
+        self._lock = asyncio.Lock()
 
-    def set(self) -> None:
-        self._event.set()
+    def release(self) -> None:
+        self._lock.release()
 
-    async def wait(self) -> None:
-        await self._event.wait()
+    async def acquire(self) -> None:
+        await self._lock.acquire()
 
 
 class Semaphore(BaseSemaphore):
index 32fcf798e587b92cae5d97291984f060c27c01e9..cb03798177ce86b34188368a226b9c46ab7c33cf 100644 (file)
@@ -5,7 +5,7 @@ import sniffio
 
 from ..config import Timeout
 from .base import (
-    BaseEvent,
+    BaseLock,
     BaseSemaphore,
     BaseSocketStream,
     ConcurrencyBackend,
@@ -52,5 +52,5 @@ class AutoBackend(ConcurrencyBackend):
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         return self.backend.create_semaphore(max_value, exc_class)
 
-    def create_event(self) -> BaseEvent:
-        return self.backend.create_event()
+    def create_lock(self) -> BaseLock:
+        return self.backend.create_lock()
index 16c55cc28bde7df031ceedfbe4c43afd6d16c874..e55b6363d353f7901e2457f8e98dee004a3f5a1a 100644 (file)
@@ -1,5 +1,6 @@
 import ssl
 import typing
+from types import TracebackType
 
 from ..config import Timeout
 
@@ -54,16 +55,27 @@ class BaseSocketStream:
         raise NotImplementedError()  # pragma: no cover
 
 
-class BaseEvent:
+class BaseLock:
     """
-    An abstract interface for Event classes.
+    An abstract interface for Lock classes.
     Abstracts away any asyncio-specific interfaces.
     """
 
-    def set(self) -> None:
+    async def __aenter__(self) -> None:
+        await self.acquire()
+
+    async def __aexit__(
+        self,
+        exc_type: typing.Type[BaseException] = None,
+        exc_value: BaseException = None,
+        traceback: TracebackType = None,
+    ) -> None:
+        self.release()
+
+    def release(self) -> None:
         raise NotImplementedError()  # pragma: no cover
 
-    async def wait(self) -> None:
+    async def acquire(self) -> None:
         raise NotImplementedError()  # pragma: no cover
 
 
@@ -115,5 +127,5 @@ class ConcurrencyBackend:
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         raise NotImplementedError()  # pragma: no cover
 
-    def create_event(self) -> BaseEvent:
+    def create_lock(self) -> BaseLock:
         raise NotImplementedError()  # pragma: no cover
index 8858ca426fcd3d00682008792ca0e28c4216be92..7c26dae376bf6240a38afe29528a73592e7c11e2 100644 (file)
@@ -6,7 +6,7 @@ import trio
 
 from ..config import Timeout
 from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
+from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
 
 
 def none_as_inf(value: typing.Optional[float]) -> float:
@@ -139,8 +139,8 @@ class TrioBackend(ConcurrencyBackend):
     def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
         return Semaphore(max_value, exc_class)
 
-    def create_event(self) -> BaseEvent:
-        return Event()
+    def create_lock(self) -> BaseLock:
+        return Lock()
 
 
 class Semaphore(BaseSemaphore):
@@ -167,12 +167,12 @@ class Semaphore(BaseSemaphore):
         self.semaphore.release()
 
 
-class Event(BaseEvent):
+class Lock(BaseLock):
     def __init__(self) -> None:
-        self._event = trio.Event()
+        self._lock = trio.Lock()
 
-    def set(self) -> None:
-        self._event.set()
+    def release(self) -> None:
+        self._lock.release()
 
-    async def wait(self) -> None:
-        await self._event.wait()
+    async def acquire(self) -> None:
+        await self._lock.acquire()
index 285625dad6b3679316536736a10c3832f40f5ebe..0c303e462589300bc7fccf9a75c0a064845b86f4 100644 (file)
@@ -6,7 +6,7 @@ from h2.config import H2Configuration
 from h2.settings import SettingCodes, Settings
 
 from ..backends.base import (
-    BaseEvent,
+    BaseLock,
     BaseSocketStream,
     ConcurrencyBackend,
     lookup_backend,
@@ -39,32 +39,28 @@ class HTTP2Connection(OpenConnection):
         self.streams = {}  # type: typing.Dict[int, HTTP2Stream]
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
 
-        self.init_started = False
+        self.sent_connection_init = False
 
     @property
     def is_http2(self) -> bool:
         return True
 
     @property
-    def init_complete(self) -> BaseEvent:
+    def init_lock(self) -> BaseLock:
         # We do this lazily, to make sure backend autodetection always
         # runs within an async context.
-        if not hasattr(self, "_initialization_complete"):
-            self._initialization_complete = self.backend.create_event()
-        return self._initialization_complete
+        if not hasattr(self, "_initialization_lock"):
+            self._initialization_lock = self.backend.create_lock()
+        return self._initialization_lock
 
     async def send(self, request: Request, timeout: Timeout = None) -> Response:
         timeout = Timeout() if timeout is None else timeout
 
-        if not self.init_started:
-            # The very first stream is responsible for initiating the connection.
-            self.init_started = True
-            await self.send_connection_init(timeout)
-            stream_id = self.state.get_next_available_stream_id()
-            self.init_complete.set()
-        else:
-            # All other streams need to wait until the connection is established.
-            await self.init_complete.wait()
+        async with self.init_lock:
+            if not self.sent_connection_init:
+                # The very first stream is responsible for initiating the connection.
+                await self.send_connection_init(timeout)
+                self.sent_connection_init = True
             stream_id = self.state.get_next_available_stream_id()
 
         stream = HTTP2Stream(stream_id=stream_id, connection=self)