]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Add OpenConnection base class (#616)
authorFlorimond Manca <florimond.manca@gmail.com>
Mon, 9 Dec 2019 10:34:02 +0000 (11:34 +0100)
committerGitHub <noreply@github.com>
Mon, 9 Dec 2019 10:34:02 +0000 (11:34 +0100)
* Add OpenConnection base class

* Move is_http2 property around

httpx/dispatch/base.py
httpx/dispatch/connection.py
httpx/dispatch/http11.py
httpx/dispatch/http2.py

index cf09b07924a56d5265d55c730d6e09f251da9106..b5a4da56fac497eeb0d9af95c21bd75d8c6b15a2 100644 (file)
@@ -58,3 +58,26 @@ class Dispatcher:
         traceback: TracebackType = None,
     ) -> None:
         await self.close()
+
+
+class OpenConnection:
+    """
+    Base class for connection classes that interact with a host via HTTP.
+    """
+
+    @property
+    def is_http2(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
+    async def send(self, request: Request, timeout: Timeout = None,) -> Response:
+        raise NotImplementedError()  # pragma: no cover
+
+    @property
+    def is_closed(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
+    def is_connection_dropped(self) -> bool:
+        raise NotImplementedError()  # pragma: no cover
+
+    async def close(self) -> None:
+        raise NotImplementedError()  # pragma: no cover
index 6d7790a4f1813cab6761199f2c510ffa87bd13ef..009895af2fde1d7b6036a878388c35f54937e4d0 100644 (file)
@@ -4,11 +4,11 @@ import typing
 
 import h11
 
-from ..concurrency.base import ConcurrencyBackend, lookup_backend
+from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, lookup_backend
 from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
 from ..models import URL, Origin, Request, Response
 from ..utils import get_logger
-from .base import Dispatcher
+from .base import Dispatcher, OpenConnection
 from .http2 import HTTP2Connection
 from .http11 import HTTP11Connection
 
@@ -37,8 +37,7 @@ class HTTPConnection(Dispatcher):
         self.backend = lookup_backend(backend)
         self.release_func = release_func
         self.uds = uds
-        self.h11_connection = None  # type: typing.Optional[HTTP11Connection]
-        self.h2_connection = None  # type: typing.Optional[HTTP2Connection]
+        self.open_connection: typing.Optional[OpenConnection] = None
 
     async def send(
         self,
@@ -49,14 +48,11 @@ class HTTPConnection(Dispatcher):
     ) -> Response:
         timeout = Timeout() if timeout is None else timeout
 
-        if self.h11_connection is None and self.h2_connection is None:
+        if self.open_connection is None:
             await self.connect(verify=verify, cert=cert, timeout=timeout)
 
-        if self.h2_connection is not None:
-            response = await self.h2_connection.send(request, timeout=timeout)
-        else:
-            assert self.h11_connection is not None
-            response = await self.h11_connection.send(request, timeout=timeout)
+        assert self.open_connection is not None
+        response = await self.open_connection.send(request, timeout=timeout)
 
         return response
 
@@ -92,13 +88,7 @@ class HTTPConnection(Dispatcher):
         http_version = stream.get_http_version()
         logger.trace(f"connected http_version={http_version!r}")
 
-        if http_version == "HTTP/2":
-            self.h2_connection = HTTP2Connection(
-                stream, backend=self.backend, on_release=on_release
-            )
-        else:
-            assert http_version == "HTTP/1.1"
-            self.h11_connection = HTTP11Connection(stream, on_release=on_release)
+        self.set_open_connection(http_version, socket=stream, on_release=on_release)
 
     async def tunnel_start_tls(
         self,
@@ -115,7 +105,8 @@ class HTTPConnection(Dispatcher):
 
         # First, check that we are in the correct state to start TLS, i.e. we've
         # just agreed to switch protocols with the server via HTTP/1.1.
-        h11_connection = self.h11_connection
+        assert isinstance(self.open_connection, HTTP11Connection)
+        h11_connection = self.open_connection
         assert h11_connection is not None
         assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
 
@@ -150,13 +141,21 @@ class HTTPConnection(Dispatcher):
             # HTTP request. Don't try to upgrade to TLS in this case.
             pass
 
+        self.set_open_connection(http_version, socket=socket, on_release=on_release)
+
+    def set_open_connection(
+        self,
+        http_version: str,
+        socket: BaseSocketStream,
+        on_release: typing.Optional[typing.Callable],
+    ) -> None:
         if http_version == "HTTP/2":
-            self.h2_connection = HTTP2Connection(
+            self.open_connection = HTTP2Connection(
                 socket, self.backend, on_release=on_release
             )
         else:
             assert http_version == "HTTP/1.1"
-            self.h11_connection = HTTP11Connection(socket, on_release=on_release)
+            self.open_connection = HTTP11Connection(socket, on_release=on_release)
 
     async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
         if not self.origin.is_ssl:
@@ -167,29 +166,22 @@ class HTTPConnection(Dispatcher):
 
     async def close(self) -> None:
         logger.trace("close_connection")
-        if self.h2_connection is not None:
-            await self.h2_connection.close()
-        elif self.h11_connection is not None:
-            await self.h11_connection.close()
+        if self.open_connection is not None:
+            await self.open_connection.close()
 
     @property
     def is_http2(self) -> bool:
-        return self.h2_connection is not None
+        assert self.open_connection is not None
+        return self.open_connection.is_http2
 
     @property
     def is_closed(self) -> bool:
-        if self.h2_connection is not None:
-            return self.h2_connection.is_closed
-        else:
-            assert self.h11_connection is not None
-            return self.h11_connection.is_closed
+        assert self.open_connection is not None
+        return self.open_connection.is_closed
 
     def is_connection_dropped(self) -> bool:
-        if self.h2_connection is not None:
-            return self.h2_connection.is_connection_dropped()
-        else:
-            assert self.h11_connection is not None
-            return self.h11_connection.is_connection_dropped()
+        assert self.open_connection is not None
+        return self.open_connection.is_connection_dropped()
 
     def __repr__(self) -> str:
         class_name = self.__class__.__name__
index 903e7d147119154c4e167a192572997b5aadc1e8..4367b13f9db5335fab73ad3d7adcb52849400865 100644 (file)
@@ -7,6 +7,7 @@ from ..config import Timeout
 from ..exceptions import ConnectionClosed, ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
+from .base import OpenConnection
 
 H11Event = typing.Union[
     h11.Request,
@@ -27,7 +28,7 @@ OnReleaseCallback = typing.Callable[[], typing.Awaitable[None]]
 logger = get_logger(__name__)
 
 
-class HTTP11Connection:
+class HTTP11Connection(OpenConnection):
     READ_NUM_BYTES = 4096
 
     def __init__(
@@ -39,6 +40,10 @@ class HTTP11Connection:
         self.on_release = on_release
         self.h11_state = h11.Connection(our_role=h11.CLIENT)
 
+    @property
+    def is_http2(self) -> bool:
+        return False
+
     async def send(self, request: Request, timeout: Timeout = None) -> Response:
         timeout = Timeout() if timeout is None else timeout
 
index 6021483feb85a2ea3d2d128d8adbbaf53b070b5f..3439b6135bfdc41a9d29daa07cf7d5a04a12beea 100644 (file)
@@ -14,11 +14,12 @@ from ..config import Timeout
 from ..exceptions import ProtocolError
 from ..models import Request, Response
 from ..utils import get_logger
+from .base import OpenConnection
 
 logger = get_logger(__name__)
 
 
-class HTTP2Connection:
+class HTTP2Connection(OpenConnection):
     READ_NUM_BYTES = 4096
 
     def __init__(
@@ -37,6 +38,10 @@ class HTTP2Connection:
 
         self.init_started = False
 
+    @property
+    def is_http2(self) -> bool:
+        return True
+
     @property
     def init_complete(self) -> BaseEvent:
         # We do this lazily, to make sure backend autodetection always