]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Drop OpenConnection (#700)
authorTom Christie <tom@tomchristie.com>
Tue, 31 Dec 2019 13:18:23 +0000 (13:18 +0000)
committerGitHub <noreply@github.com>
Tue, 31 Dec 2019 13:18:23 +0000 (13:18 +0000)
httpx/dispatch/base.py
httpx/dispatch/connection.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py

index 69073b01ae9ce11577c8373f40a367364e9a10a2..f24eb68afefafa163126c7f408bd493f93d56f22 100644 (file)
@@ -50,26 +50,3 @@ class Dispatcher:
         traceback: TracebackType = None,
     ) -> None:
         await self.close()
-
-
-class OpenConnection:
-    """
-    Base class for connection classes that interact with a host via HTTP.
-    """
-
-    @property
-    def is_http2(self) -> bool:
-        raise NotImplementedError()  # pragma: no cover
-
-    async def send(self, request: Request, timeout: Timeout = None,) -> Response:
-        raise NotImplementedError()  # pragma: no cover
-
-    @property
-    def is_closed(self) -> bool:
-        raise NotImplementedError()  # pragma: no cover
-
-    def is_connection_dropped(self) -> bool:
-        raise NotImplementedError()  # pragma: no cover
-
-    async def close(self) -> None:
-        raise NotImplementedError()  # pragma: no cover
index 1d0425a43dba711449a22de60982064f7c378b33..43dbd963cfc5a4045710c7082c88dcc125671537 100644 (file)
@@ -4,11 +4,11 @@ import typing
 
 import h11
 
-from ..backends.base import BaseSocketStream, ConcurrencyBackend, lookup_backend
+from ..backends.base import ConcurrencyBackend, lookup_backend
 from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
 from ..models import URL, Origin, Request, Response
 from ..utils import get_logger
-from .base import Dispatcher, OpenConnection
+from .base import Dispatcher
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 
@@ -37,21 +37,20 @@ class HTTPConnection(Dispatcher):
         self.backend = lookup_backend(backend)
         self.release_func = release_func
         self.uds = uds
-        self.open_connection: typing.Optional[OpenConnection] = None
+        self.connection: typing.Union[None, HTTP11Connection, HTTP2Connection] = None
         self.expires_at: typing.Optional[float] = None
 
     async def send(self, request: Request, timeout: Timeout = None) -> Response:
         timeout = Timeout() if timeout is None else timeout
 
-        if self.open_connection is None:
-            await self.connect(timeout=timeout)
+        if self.connection is None:
+            self.connection = await self.connect(timeout=timeout)
 
-        assert self.open_connection is not None
-        response = await self.open_connection.send(request, timeout=timeout)
+        return await self.connection.send(request, timeout=timeout)
 
-        return response
-
-    async def connect(self, timeout: Timeout) -> None:
+    async def connect(
+        self, timeout: Timeout
+    ) -> typing.Union[HTTP11Connection, HTTP2Connection]:
         host = self.origin.host
         port = self.origin.port
         ssl_context = await self.get_ssl_context(self.ssl)
@@ -65,21 +64,23 @@ class HTTPConnection(Dispatcher):
             logger.trace(
                 f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
             )
-            stream = await self.backend.open_tcp_stream(
+            socket = await self.backend.open_tcp_stream(
                 host, port, ssl_context, timeout
             )
         else:
             logger.trace(
                 f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
             )
-            stream = await self.backend.open_uds_stream(
+            socket = await self.backend.open_uds_stream(
                 self.uds, host, ssl_context, timeout
             )
 
-        http_version = stream.get_http_version()
+        http_version = socket.get_http_version()
         logger.trace(f"connected http_version={http_version!r}")
 
-        self.set_open_connection(http_version, socket=stream, on_release=on_release)
+        if http_version == "HTTP/2":
+            return HTTP2Connection(socket, self.backend, on_release=on_release)
+        return HTTP11Connection(socket, on_release=on_release)
 
     async def tunnel_start_tls(
         self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
@@ -91,8 +92,8 @@ class HTTPConnection(Dispatcher):
 
         # First, check that we are in the correct state to start TLS, i.e. we've
         # just agreed to switch protocols with the server via HTTP/1.1.
-        assert isinstance(self.open_connection, HTTP11Connection)
-        h11_connection = self.open_connection
+        assert isinstance(self.connection, HTTP11Connection)
+        h11_connection = self.connection
         assert h11_connection is not None
         assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
 
@@ -126,21 +127,12 @@ class HTTPConnection(Dispatcher):
             # HTTP request. Don't try to upgrade to TLS in this case.
             pass
 
-        self.set_open_connection(http_version, socket=socket, on_release=on_release)
-
-    def set_open_connection(
-        self,
-        http_version: str,
-        socket: BaseSocketStream,
-        on_release: typing.Optional[typing.Callable],
-    ) -> None:
         if http_version == "HTTP/2":
-            self.open_connection = HTTP2Connection(
+            self.connection = HTTP2Connection(
                 socket, self.backend, on_release=on_release
             )
         else:
-            assert http_version == "HTTP/1.1"
-            self.open_connection = HTTP11Connection(socket, on_release=on_release)
+            self.connection = HTTP11Connection(socket, on_release=on_release)
 
     async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
@@ -151,22 +143,19 @@ class HTTPConnection(Dispatcher):
 
     async def close(self) -> None:
         logger.trace("close_connection")
-        if self.open_connection is not None:
-            await self.open_connection.close()
+        if self.connection is not None:
+            await self.connection.close()
 
     @property
     def is_http2(self) -> bool:
-        assert self.open_connection is not None
-        return self.open_connection.is_http2
+        return self.connection is not None and self.connection.is_http2
 
     @property
     def is_closed(self) -> bool:
-        assert self.open_connection is not None
-        return self.open_connection.is_closed
+        return self.connection is not None and self.connection.is_closed
 
     def is_connection_dropped(self) -> bool:
-        assert self.open_connection is not None
-        return self.open_connection.is_connection_dropped()
+        return self.connection is not None and self.connection.is_connection_dropped()
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
index 87a636985a1f0b8fa2bbfdca1b0027093e00f59e..2ed3384562ee8d14523f890a10d483269271e360 100644 (file)
@@ -8,7 +8,6 @@ from ..content_streams import AsyncIteratorStream
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
-from .base import OpenConnection
 
 H11Event = typing.Union[
     h11.Request,
@@ -29,7 +28,7 @@ OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
 logger = get_logger(__name__)
 
 
-class HTTP11Connection(OpenConnection):
+class HTTP11Connection:
     READ_NUM_BYTES = 4096
 
     def __init__(
index 0c303e462589300bc7fccf9a75c0a064845b86f4..a85d46b6ba4ace4f1738e2e1e42014b606ccd47c 100644 (file)
@@ -16,12 +16,11 @@ from ..content_streams import AsyncIteratorStream
 from ..exceptions import ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
-from .base import OpenConnection
 
 logger = get_logger(__name__)
 
 
-class HTTP2Connection(OpenConnection):
+class HTTP2Connection:
     READ_NUM_BYTES = 4096
     CONFIG = H2Configuration(validate_inbound_headers=False)