]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Unify BaseReader and BaseWriter as BaseStream
authorflorimondmanca <florimond.manca@gmail.com>
Tue, 20 Aug 2019 20:53:47 +0000 (22:53 +0200)
committerflorimondmanca <florimond.manca@gmail.com>
Wed, 21 Aug 2019 07:02:54 +0000 (09:02 +0200)
httpx/__init__.py
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/dispatch/connection.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py
tests/dispatch/utils.py

index ebacaad87ebe4fc2f1c98b36e22d3d0e246bdd82..8b2dda0fcaeeeb504fbf3eca88ceeb0298b2f2b4 100644 (file)
@@ -5,8 +5,7 @@ from .concurrency.asyncio import AsyncioBackend
 from .concurrency.base import (
     BaseBackgroundManager,
     BasePoolSemaphore,
-    BaseReader,
-    BaseWriter,
+    BaseStream,
     ConcurrencyBackend,
 )
 from .config import (
@@ -105,8 +104,7 @@ __all__ = [
     "TooManyRedirects",
     "WriteTimeout",
     "AsyncDispatcher",
-    "BaseReader",
-    "BaseWriter",
+    "BaseStream",
     "ConcurrencyBackend",
     "Dispatcher",
     "URL",
index 8fa19e9fdc4e94b437c6a797e819aa987a519cae..38449852539ff27f6d945b61b0a9526930956d42 100644 (file)
@@ -1,5 +1,5 @@
 """
-The `Reader` and `Writer` classes here provide a lightweight layer over
+The `Stream` class here provides a lightweight layer over
 `asyncio.StreamReader` and `asyncio.StreamWriter`.
 
 Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
@@ -14,18 +14,17 @@ import ssl
 import typing
 from types import TracebackType
 
+from ..config import PoolLimits, TimeoutConfig
+from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 from .base import (
     BaseBackgroundManager,
     BasePoolSemaphore,
     BaseEvent,
     BaseQueue,
-    BaseReader,
-    BaseWriter,
+    BaseStream,
     ConcurrencyBackend,
     TimeoutFlag,
 )
-from ..config import PoolLimits, TimeoutConfig
-from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
 
 SSL_MONKEY_PATCH_APPLIED = False
 
@@ -51,11 +50,15 @@ def ssl_monkey_patch() -> None:
     MonkeyPatch.write = _fixed_write
 
 
-class Reader(BaseReader):
+class Stream(BaseStream):
     def __init__(
-        self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
-    ) -> None:
+        self,
+        stream_reader: asyncio.StreamReader,
+        stream_writer: asyncio.StreamWriter,
+        timeout: TimeoutConfig,
+    ):
         self.stream_reader = stream_reader
+        self.stream_writer = stream_writer
         self.timeout = timeout
 
     async def read(
@@ -78,15 +81,6 @@ class Reader(BaseReader):
 
         return data
 
-    def is_connection_dropped(self) -> bool:
-        return self.stream_reader.at_eof()
-
-
-class Writer(BaseWriter):
-    def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
-        self.stream_writer = stream_writer
-        self.timeout = timeout
-
     def write_no_block(self, data: bytes) -> None:
         self.stream_writer.write(data)  # pragma: nocover
 
@@ -114,6 +108,9 @@ class Writer(BaseWriter):
                 if should_raise:
                     raise WriteTimeout() from None
 
+    def is_connection_dropped(self) -> bool:
+        return self.stream_reader.at_eof()
+
     async def close(self) -> None:
         self.stream_writer.close()
 
@@ -172,7 +169,7 @@ class AsyncioBackend(ConcurrencyBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+    ) -> typing.Tuple[BaseStream, str]:
         try:
             stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
                 asyncio.open_connection(hostname, port, ssl=ssl_context),
@@ -189,11 +186,12 @@ class AsyncioBackend(ConcurrencyBackend):
             if ident is None:
                 ident = ssl_object.selected_npn_protocol()
 
-        reader = Reader(stream_reader=stream_reader, timeout=timeout)
-        writer = Writer(stream_writer=stream_writer, timeout=timeout)
+        stream = Stream(
+            stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
+        )
         http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"
 
-        return reader, writer, http_version
+        return stream, http_version
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index 077961d20757ce218c3067ae5c2fd1f2b6874bcd..95a07b9039145a851493ec9edce88d10ed041faf 100644 (file)
@@ -37,11 +37,11 @@ class TimeoutFlag:
         self.raise_on_write_timeout = True
 
 
-class BaseReader:
+class BaseStream:
     """
-    A stream reader. Abstracts away any asyncio-specific interfaces
-    into a more generic base class, that we can use with alternate
-    backend, or for stand-alone test cases.
+    A stream with read/write operations. Abstracts away any asyncio-specific
+    interfaces into a more generic base class, that we can use with alternate
+    backends, or for stand-alone test cases.
     """
 
     async def read(
@@ -49,17 +49,6 @@ class BaseReader:
     ) -> bytes:
         raise NotImplementedError()  # pragma: no cover
 
-    def is_connection_dropped(self) -> bool:
-        raise NotImplementedError()  # pragma: no cover
-
-
-class BaseWriter:
-    """
-    A stream writer. Abstracts away any asyncio-specific interfaces
-    into a more generic base class, that we can use with alternate
-    backend, or for stand-alone test cases.
-    """
-
     def write_no_block(self, data: bytes) -> None:
         raise NotImplementedError()  # pragma: no cover
 
@@ -69,6 +58,9 @@ class BaseWriter:
     async def close(self) -> None:
         raise NotImplementedError()  # pragma: no cover
 
+    def is_connection_dropped(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
 
 class BaseQueue:
     """
@@ -118,7 +110,7 @@ class ConcurrencyBackend:
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+    ) -> typing.Tuple[BaseStream, str]:
         raise NotImplementedError()  # pragma: no cover
 
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
index 36f96f8065ee979b4043dc85a923f8c858a352b3..7f0d14eeb20db5357396e49eb9188b5ab4c5e8f1 100644 (file)
@@ -79,17 +79,17 @@ class HTTPConnection(AsyncDispatcher):
         else:
             on_release = functools.partial(self.release_func, self)
 
-        reader, writer, http_version = await self.backend.connect(
+        stream, http_version = await self.backend.connect(
             host, port, ssl_context, timeout
         )
         if http_version == "HTTP/2":
             self.h2_connection = HTTP2Connection(
-                reader, writer, self.backend, on_release=on_release
+                stream, self.backend, on_release=on_release
             )
         else:
             assert http_version == "HTTP/1.1"
             self.h11_connection = HTTP11Connection(
-                reader, writer, self.backend, on_release=on_release
+                stream, self.backend, on_release=on_release
             )
 
     async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
index 554591f8040b8f16fa91af1d2ad90000bb38eefc..236c81e2bbe68560c2675526d280f88463ab1efc 100644 (file)
@@ -2,7 +2,7 @@ import typing
 
 import h11
 
-from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
 from ..models import AsyncRequest, AsyncResponse
 
@@ -27,13 +27,11 @@ class HTTP11Connection:
 
     def __init__(
         self,
-        reader: BaseReader,
-        writer: BaseWriter,
+        stream: BaseStream,
         backend: ConcurrencyBackend,
         on_release: typing.Optional[OnReleaseCallback] = None,
     ):
-        self.reader = reader
-        self.writer = writer
+        self.stream = stream
         self.backend = backend
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
@@ -67,7 +65,7 @@ class HTTP11Connection:
         except h11.LocalProtocolError:  # pragma: no cover
             # Premature client disconnect
             pass
-        await self.writer.close()
+        await self.stream.close()
 
     async def _send_request(
         self, request: AsyncRequest, timeout: TimeoutConfig = None
@@ -111,7 +109,7 @@ class HTTP11Connection:
         drain before returning.
         """
         bytes_to_send = self.h11_state.send(event)
-        await self.writer.write(bytes_to_send, timeout)
+        await self.stream.write(bytes_to_send, timeout)
 
     async def _receive_response(
         self, timeout: TimeoutConfig = None
@@ -154,7 +152,7 @@ class HTTP11Connection:
             event = self.h11_state.next_event()
             if event is h11.NEED_DATA:
                 try:
-                    data = await self.reader.read(
+                    data = await self.stream.read(
                         self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
                     )
                 except OSError:  # pragma: nocover
@@ -184,4 +182,4 @@ class HTTP11Connection:
         return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)
 
     def is_connection_dropped(self) -> bool:
-        return self.reader.is_connection_dropped()
+        return self.stream.is_connection_dropped()
index bf258e3a7a4e54c52dd4ec0fbdf81eb2cf50c609..0a698f35f4133a716ed83a40a8621b4fc1169af3 100644 (file)
@@ -4,7 +4,7 @@ import typing
 import h2.connection
 import h2.events
 
-from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
+from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
 from ..config import TimeoutConfig, TimeoutTypes
 from ..models import AsyncRequest, AsyncResponse
 
@@ -14,13 +14,11 @@ class HTTP2Connection:
 
     def __init__(
         self,
-        reader: BaseReader,
-        writer: BaseWriter,
+        stream: BaseStream,
         backend: ConcurrencyBackend,
         on_release: typing.Callable = None,
     ):
-        self.reader = reader
-        self.writer = writer
+        self.stream = stream
         self.backend = backend
         self.on_release = on_release
         self.h2_state = h2.connection.H2Connection()
@@ -58,12 +56,12 @@ class HTTP2Connection:
         )
 
     async def close(self) -> None:
-        await self.writer.close()
+        await self.stream.close()
 
     def initiate_connection(self) -> None:
         self.h2_state.initiate_connection()
         data_to_send = self.h2_state.data_to_send()
-        self.writer.write_no_block(data_to_send)
+        self.stream.write_no_block(data_to_send)
         self.initialized = True
 
     async def send_headers(
@@ -78,7 +76,7 @@ class HTTP2Connection:
         ] + [(k, v) for k, v in request.headers.raw if k != b"host"]
         self.h2_state.send_headers(stream_id, headers)
         data_to_send = self.h2_state.data_to_send()
-        await self.writer.write(data_to_send, timeout)
+        await self.stream.write(data_to_send, timeout)
         return stream_id
 
     async def send_request_data(
@@ -104,12 +102,12 @@ class HTTP2Connection:
             chunk = data[idx : idx + chunk_size]
             self.h2_state.send_data(stream_id, chunk)
             data_to_send = self.h2_state.data_to_send()
-            await self.writer.write(data_to_send, timeout)
+            await self.stream.write(data_to_send, timeout)
 
     async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
         self.h2_state.end_stream(stream_id)
         data_to_send = self.h2_state.data_to_send()
-        await self.writer.write(data_to_send, timeout)
+        await self.stream.write(data_to_send, timeout)
 
     async def receive_response(
         self, stream_id: int, timeout: TimeoutConfig = None
@@ -150,14 +148,14 @@ class HTTP2Connection:
     ) -> h2.events.Event:
         while not self.events[stream_id]:
             flag = self.timeout_flags[stream_id]
-            data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
+            data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
             events = self.h2_state.receive_data(data)
             for event in events:
                 if getattr(event, "stream_id", 0):
                     self.events[event.stream_id].append(event)
 
             data_to_send = self.h2_state.data_to_send()
-            await self.writer.write(data_to_send, timeout)
+            await self.stream.write(data_to_send, timeout)
 
         return self.events[stream_id].pop(0)
 
@@ -173,4 +171,4 @@ class HTTP2Connection:
         return False
 
     def is_connection_dropped(self) -> bool:
-        return self.reader.is_connection_dropped()
+        return self.stream.is_connection_dropped()
index b5aac850378292812d1fba508ec8ba371cbe2cc5..6e76269c044139ba71bd7078faa080a865cd00b0 100644 (file)
@@ -6,7 +6,7 @@ import h2.config
 import h2.connection
 import h2.events
 
-from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig
+from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig
 
 
 class MockHTTP2Backend(AsyncioBackend):
@@ -20,16 +20,12 @@ class MockHTTP2Backend(AsyncioBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseReader, BaseWriter, str]:
+    ) -> typing.Tuple[BaseStream, str]:
         self.server = MockHTTP2Server(self.app)
-        return self.server, self.server, "HTTP/2"
+        return self.server, "HTTP/2"
 
 
-class MockHTTP2Server(BaseReader, BaseWriter):
-    """
-    This class exposes Reader and Writer style interfaces.
-    """
-
+class MockHTTP2Server(BaseStream):
     def __init__(self, app):
         config = h2.config.H2Configuration(client_side=False)
         self.conn = h2.connection.H2Connection(config=config)
@@ -38,15 +34,13 @@ class MockHTTP2Server(BaseReader, BaseWriter):
         self.requests = {}
         self.close_connection = False
 
-    # BaseReader interface
+    # Stream interface
 
     async def read(self, n, timeout, flag=None) -> bytes:
         await asyncio.sleep(0)
         send, self.buffer = self.buffer[:n], self.buffer[n:]
         return send
 
-    # BaseWriter interface
-
     def write_no_block(self, data: bytes) -> None:
         events = self.conn.receive_data(data)
         self.buffer += self.conn.data_to_send()