From: Tom Christie Date: Thu, 5 Dec 2019 11:46:11 +0000 (+0000) Subject: Drop write_no_block from backends. (#594) X-Git-Tag: 0.9.0~11 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e56e120175f4938d45c91c66bdd9b7f9d415e9bb;p=thirdparty%2Fhttpx.git Drop write_no_block from backends. (#594) * Drop write_no_block * Drop redundant code from Trio backend --- diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index b24e3056..a7597162 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -162,9 +162,6 @@ class SocketStream(BaseSocketStream): return data - def write_no_block(self, data: bytes) -> None: - self.stream_writer.write(data) # pragma: nocover - async def write( self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None ) -> None: diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 06599607..f32501a3 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -77,9 +77,6 @@ class BaseSocketStream: async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes: raise NotImplementedError() # pragma: no cover - def write_no_block(self, data: bytes) -> None: - raise NotImplementedError() # pragma: no cover - async def write(self, data: bytes, timeout: Timeout) -> None: raise NotImplementedError() # pragma: no cover diff --git a/httpx/concurrency/trio.py b/httpx/concurrency/trio.py index d593950e..4af4242e 100644 --- a/httpx/concurrency/trio.py +++ b/httpx/concurrency/trio.py @@ -27,17 +27,12 @@ class SocketStream(BaseSocketStream): ) -> None: self.stream = stream self.timeout = timeout - self.write_buffer = b"" self.read_lock = trio.Lock() self.write_lock = trio.Lock() async def start_tls( self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout ) -> "SocketStream": - # Check that the write buffer is empty. We should never start a TLS stream - # while there is still pending data to write. - assert self.write_buffer == b"" - connect_timeout = _or_inf(timeout.connect_timeout) ssl_stream = trio.SSLStream( self.stream, ssl_context=ssl_context, server_hostname=hostname @@ -92,23 +87,9 @@ class SocketStream(BaseSocketStream): # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778 return stream.socket.is_readable() - def write_no_block(self, data: bytes) -> None: - self.write_buffer += data # pragma: no cover - async def write( self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None ) -> None: - if self.write_buffer: - previous_data = self.write_buffer - # Reset before recursive call, otherwise we'll go through - # this branch indefinitely. - self.write_buffer = b"" - try: - await self.write(previous_data, timeout=timeout, flag=flag) - except WriteTimeout: - self.writer_buffer = previous_data - raise - if not data: return diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 8b73c6c6..e6646924 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -35,15 +35,29 @@ class HTTP2Connection: self.h2_state = h2.connection.H2Connection() self.events = {} # type: typing.Dict[int, typing.List[h2.events.Event]] self.timeout_flags = {} # type: typing.Dict[int, TimeoutFlag] - self.initialized = False self.window_update_received = {} # type: typing.Dict[int, BaseEvent] + self.init_started = False + + @property + def init_complete(self) -> BaseEvent: + # We do this lazily, to make sure backend autodetection always + # runs within an async context. + if not hasattr(self, "_initialization_complete"): + self._initialization_complete = self.backend.create_event() + return self._initialization_complete + async def send(self, request: Request, timeout: Timeout = None) -> Response: timeout = Timeout() if timeout is None else timeout - # Start sending the request. - if not self.initialized: - self.initiate_connection() + if not self.init_started: + # The very first stream is responsible for initiating the connection. + self.init_started = True + await self.send_connection_init(timeout) + self.init_complete.set() + else: + # All other streams need to wait until the connection is established. + await self.init_complete.wait() stream_id = await self.send_headers(request, timeout) @@ -69,7 +83,7 @@ class HTTP2Connection: async def close(self) -> None: await self.stream.close() - def initiate_connection(self) -> None: + async def send_connection_init(self, timeout: Timeout) -> None: # Need to set these manually here instead of manipulating via # __setitem__() otherwise the H2Connection will emit SettingsUpdate # frames in addition to sending the undesired defaults. @@ -94,8 +108,7 @@ class HTTP2Connection: self.h2_state.initiate_connection() data_to_send = self.h2_state.data_to_send() - self.stream.write_no_block(data_to_send) - self.initialized = True + await self.stream.write(data_to_send, timeout) async def send_headers(self, request: Request, timeout: Timeout) -> int: stream_id = self.h2_state.get_next_available_stream_id()