from ..config import Timeout
from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
+from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
SSL_MONKEY_PATCH_APPLIED = False
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return Semaphore(max_value, exc_class)
- def create_event(self) -> BaseEvent:
- return Event()
+ def create_lock(self) -> BaseLock:
+ return Lock()
-class Event(BaseEvent):
+class Lock(BaseLock):
def __init__(self) -> None:
- self._event = asyncio.Event()
+ self._lock = asyncio.Lock()
- def set(self) -> None:
- self._event.set()
+ def release(self) -> None:
+ self._lock.release()
- async def wait(self) -> None:
- await self._event.wait()
+ async def acquire(self) -> None:
+ await self._lock.acquire()
class Semaphore(BaseSemaphore):
from ..config import Timeout
from .base import (
- BaseEvent,
+ BaseLock,
BaseSemaphore,
BaseSocketStream,
ConcurrencyBackend,
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return self.backend.create_semaphore(max_value, exc_class)
- def create_event(self) -> BaseEvent:
- return self.backend.create_event()
+ def create_lock(self) -> BaseLock:
+ return self.backend.create_lock()
import ssl
import typing
+from types import TracebackType
from ..config import Timeout
raise NotImplementedError() # pragma: no cover
-class BaseEvent:
+class BaseLock:
"""
- An abstract interface for Event classes.
+ An abstract interface for Lock classes.
Abstracts away any asyncio-specific interfaces.
"""
- def set(self) -> None:
+ async def __aenter__(self) -> None:
+ await self.acquire()
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Type[BaseException] = None,
+ exc_value: BaseException = None,
+ traceback: TracebackType = None,
+ ) -> None:
+ self.release()
+
+ def release(self) -> None:
raise NotImplementedError() # pragma: no cover
- async def wait(self) -> None:
+ async def acquire(self) -> None:
raise NotImplementedError() # pragma: no cover
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
raise NotImplementedError() # pragma: no cover
- def create_event(self) -> BaseEvent:
+ def create_lock(self) -> BaseLock:
raise NotImplementedError() # pragma: no cover
from ..config import Timeout
from ..exceptions import ConnectTimeout, ReadTimeout, WriteTimeout
-from .base import BaseEvent, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
+from .base import BaseLock, BaseSemaphore, BaseSocketStream, ConcurrencyBackend
def none_as_inf(value: typing.Optional[float]) -> float:
def create_semaphore(self, max_value: int, exc_class: type) -> BaseSemaphore:
return Semaphore(max_value, exc_class)
- def create_event(self) -> BaseEvent:
- return Event()
+ def create_lock(self) -> BaseLock:
+ return Lock()
class Semaphore(BaseSemaphore):
self.semaphore.release()
-class Event(BaseEvent):
+class Lock(BaseLock):
def __init__(self) -> None:
- self._event = trio.Event()
+ self._lock = trio.Lock()
- def set(self) -> None:
- self._event.set()
+ def release(self) -> None:
+ self._lock.release()
- async def wait(self) -> None:
- await self._event.wait()
+ async def acquire(self) -> None:
+ await self._lock.acquire()
from h2.settings import SettingCodes, Settings
from ..backends.base import (
- BaseEvent,
+ BaseLock,
BaseSocketStream,
ConcurrencyBackend,
lookup_backend,
self.streams = {} # type: typing.Dict[int, HTTP2Stream]
self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]]
- self.init_started = False
+ self.sent_connection_init = False
@property
def is_http2(self) -> bool:
return True
@property
- def init_complete(self) -> BaseEvent:
+ def init_lock(self) -> BaseLock:
# We do this lazily, to make sure backend autodetection always
# runs within an async context.
- if not hasattr(self, "_initialization_complete"):
- self._initialization_complete = self.backend.create_event()
- return self._initialization_complete
+ if not hasattr(self, "_initialization_lock"):
+ self._initialization_lock = self.backend.create_lock()
+ return self._initialization_lock
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
- if not self.init_started:
- # The very first stream is responsible for initiating the connection.
- self.init_started = True
- await self.send_connection_init(timeout)
- stream_id = self.state.get_next_available_stream_id()
- self.init_complete.set()
- else:
- # All other streams need to wait until the connection is established.
- await self.init_complete.wait()
+ async with self.init_lock:
+ if not self.sent_connection_init:
+ # The very first stream is responsible for initiating the connection.
+ await self.send_connection_init(timeout)
+ self.sent_connection_init = True
stream_id = self.state.get_next_available_stream_id()
stream = HTTP2Stream(stream_id=stream_id, connection=self)