]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop write_no_block from backends. (#594)
authorTom Christie <tom@tomchristie.com>
Thu, 5 Dec 2019 11:46:11 +0000 (11:46 +0000)
committerGitHub <noreply@github.com>
Thu, 5 Dec 2019 11:46:11 +0000 (11:46 +0000)
* Drop write_no_block

* Drop redundant code from Trio backend

httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/concurrency/trio.py
httpx/dispatch/http2.py

index b24e30562e8b4f02ffdc565db4159385538d72ae..a75971620b9fbada6eab7383e8c3868e3fbd7cfc 100644 (file)
@@ -162,9 +162,6 @@ class SocketStream(BaseSocketStream):
 
         return data
 
-    def write_no_block(self, data: bytes) -> None:
-        self.stream_writer.write(data)  # pragma: nocover
-
     async def write(
         self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
     ) -> None:
index 065996077c0e1c118848f4d79c1ede2c5488e60d..f32501a3a094f454ebfe5c2a22b3cc4060286914 100644 (file)
@@ -77,9 +77,6 @@ class BaseSocketStream:
     async def read(self, n: int, timeout: Timeout, flag: typing.Any = None) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
-    def write_no_block(self, data: bytes) -> None:
-        raise NotImplementedError()  # pragma: no cover
-
     async def write(self, data: bytes, timeout: Timeout) -> None:
         raise NotImplementedError()  # pragma: no cover
 
index d593950ecb14cb0af16b2c003c9273b294840413..4af4242e1936b81ce0c06bb0f6ec0ee018024046 100644 (file)
@@ -27,17 +27,12 @@ class SocketStream(BaseSocketStream):
     ) -> None:
         self.stream = stream
         self.timeout = timeout
-        self.write_buffer = b""
         self.read_lock = trio.Lock()
         self.write_lock = trio.Lock()
 
     async def start_tls(
         self, hostname: str, ssl_context: ssl.SSLContext, timeout: Timeout
     ) -> "SocketStream":
-        # Check that the write buffer is empty. We should never start a TLS stream
-        # while there is still pending data to write.
-        assert self.write_buffer == b""
-
         connect_timeout = _or_inf(timeout.connect_timeout)
         ssl_stream = trio.SSLStream(
             self.stream, ssl_context=ssl_context, server_hostname=hostname
@@ -92,23 +87,9 @@ class SocketStream(BaseSocketStream):
         # See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
         return stream.socket.is_readable()
 
-    def write_no_block(self, data: bytes) -> None:
-        self.write_buffer += data  # pragma: no cover
-
     async def write(
         self, data: bytes, timeout: Timeout = None, flag: TimeoutFlag = None
     ) -> None:
-        if self.write_buffer:
-            previous_data = self.write_buffer
-            # Reset before recursive call, otherwise we'll go through
-            # this branch indefinitely.
-            self.write_buffer = b""
-            try:
-                await self.write(previous_data, timeout=timeout, flag=flag)
-            except WriteTimeout:
-                self.writer_buffer = previous_data
-                raise
-
         if not data:
             return
 
index 8b73c6c66387b09ffa539da391954af97c071ef9..e664692476c8a998b82661dbe6c1a53d100a2cf3 100644 (file)
@@ -35,15 +35,29 @@ class HTTP2Connection:
         self.h2_state = h2.connection.H2Connection()
         self.events = {}  # type: typing.Dict[int, typing.List[h2.events.Event]]
         self.timeout_flags = {}  # type: typing.Dict[int, TimeoutFlag]
-        self.initialized = False
         self.window_update_received = {}  # type: typing.Dict[int, BaseEvent]
 
+        self.init_started = False
+
+    @property
+    def init_complete(self) -> BaseEvent:
+        # We do this lazily, to make sure backend autodetection always
+        # runs within an async context.
+        if not hasattr(self, "_initialization_complete"):
+            self._initialization_complete = self.backend.create_event()
+        return self._initialization_complete
+
     async def send(self, request: Request, timeout: Timeout = None) -> Response:
         timeout = Timeout() if timeout is None else timeout
 
-        # Start sending the request.
-        if not self.initialized:
-            self.initiate_connection()
+        if not self.init_started:
+            # The very first stream is responsible for initiating the connection.
+            self.init_started = True
+            await self.send_connection_init(timeout)
+            self.init_complete.set()
+        else:
+            # All other streams need to wait until the connection is established.
+            await self.init_complete.wait()
 
         stream_id = await self.send_headers(request, timeout)
 
@@ -69,7 +83,7 @@ class HTTP2Connection:
     async def close(self) -> None:
         await self.stream.close()
 
-    def initiate_connection(self) -> None:
+    async def send_connection_init(self, timeout: Timeout) -> None:
         # Need to set these manually here instead of manipulating via
         # __setitem__() otherwise the H2Connection will emit SettingsUpdate
         # frames in addition to sending the undesired defaults.
@@ -94,8 +108,7 @@ class HTTP2Connection:
 
         self.h2_state.initiate_connection()
         data_to_send = self.h2_state.data_to_send()
-        self.stream.write_no_block(data_to_send)
-        self.initialized = True
+        await self.stream.write(data_to_send, timeout)
 
     async def send_headers(self, request: Request, timeout: Timeout) -> int:
         stream_id = self.h2_state.get_next_available_stream_id()