]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Encapsulate http_version into BaseStream 255/head
authorflorimondmanca <florimond.manca@gmail.com>
Wed, 21 Aug 2019 07:12:01 +0000 (09:12 +0200)
committerflorimondmanca <florimond.manca@gmail.com>
Wed, 21 Aug 2019 07:12:01 +0000 (09:12 +0200)
httpx/concurrency/asyncio.py
httpx/concurrency/base.py
httpx/dispatch/connection.py
tests/dispatch/utils.py

index 38449852539ff27f6d945b61b0a9526930956d42..34d68c38304b170001a3f5a1a43501e239da2497 100644 (file)
@@ -61,6 +61,18 @@ class Stream(BaseStream):
         self.stream_writer = stream_writer
         self.timeout = timeout
 
+    def get_http_version(self) -> str:
+        ssl_object = self.stream_writer.get_extra_info("ssl_object")
+
+        if ssl_object is None:
+            return "HTTP/1.1"
+
+        ident = ssl_object.selected_alpn_protocol()
+        if ident is None:
+            ident = ssl_object.selected_npn_protocol()
+
+        return "HTTP/2" if ident == "h2" else "HTTP/1.1"
+
     async def read(
         self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
     ) -> bytes:
@@ -169,7 +181,7 @@ class AsyncioBackend(ConcurrencyBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseStream, str]:
+    ) -> BaseStream:
         try:
             stream_reader, stream_writer = await asyncio.wait_for(  # type: ignore
                 asyncio.open_connection(hostname, port, ssl=ssl_context),
@@ -178,20 +190,9 @@ class AsyncioBackend(ConcurrencyBackend):
         except asyncio.TimeoutError:
             raise ConnectTimeout()
 
-        ssl_object = stream_writer.get_extra_info("ssl_object")
-        if ssl_object is None:
-            ident = "http/1.1"
-        else:
-            ident = ssl_object.selected_alpn_protocol()
-            if ident is None:
-                ident = ssl_object.selected_npn_protocol()
-
-        stream = Stream(
+        return Stream(
             stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
         )
-        http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"
-
-        return stream, http_version
 
     async def run_in_threadpool(
         self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
index 95a07b9039145a851493ec9edce88d10ed041faf..45785df1c39fa31f45ab2cdd7b54e032fc045ed7 100644 (file)
@@ -44,6 +44,9 @@ class BaseStream:
     backends, or for stand-alone test cases.
     """
 
+    def get_http_version(self) -> str:
+        raise NotImplementedError()  # pragma: no cover
+
     async def read(
         self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
     ) -> bytes:
@@ -110,7 +113,7 @@ class ConcurrencyBackend:
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseStream, str]:
+    ) -> BaseStream:
         raise NotImplementedError()  # pragma: no cover
 
     def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
index 7f0d14eeb20db5357396e49eb9188b5ab4c5e8f1..0e9819cb981cee7886f3b2bcc64b8774c2beb192 100644 (file)
@@ -79,9 +79,9 @@ class HTTPConnection(AsyncDispatcher):
         else:
             on_release = functools.partial(self.release_func, self)
 
-        stream, http_version = await self.backend.connect(
-            host, port, ssl_context, timeout
-        )
+        stream = await self.backend.connect(host, port, ssl_context, timeout)
+        http_version = stream.get_http_version()
+
         if http_version == "HTTP/2":
             self.h2_connection = HTTP2Connection(
                 stream, self.backend, on_release=on_release
@@ -96,7 +96,7 @@ class HTTPConnection(AsyncDispatcher):
         if not self.origin.is_ssl:
             return None
 
-        # Run the SSL loading in a threadpool, since it may makes disk accesses.
+        # Run the SSL loading in a threadpool, since it may make disk accesses.
         return await self.backend.run_in_threadpool(
             ssl.load_ssl_context, self.http_versions
         )
index 6e76269c044139ba71bd7078faa080a865cd00b0..3315135797af6b64822c1bb048ab604ea0aef75c 100644 (file)
@@ -20,9 +20,9 @@ class MockHTTP2Backend(AsyncioBackend):
         port: int,
         ssl_context: typing.Optional[ssl.SSLContext],
         timeout: TimeoutConfig,
-    ) -> typing.Tuple[BaseStream, str]:
+    ) -> BaseStream:
         self.server = MockHTTP2Server(self.app)
-        return self.server, "HTTP/2"
+        return self.server
 
 
 class MockHTTP2Server(BaseStream):
@@ -36,6 +36,9 @@ class MockHTTP2Server(BaseStream):
 
     # Stream interface
 
+    def get_http_version(self) -> str:
+        return "HTTP/2"
+
     async def read(self, n, timeout, flag=None) -> bytes:
         await asyncio.sleep(0)
         send, self.buffer = self.buffer[:n], self.buffer[n:]