From fdb13592c9a5f2ec9fa6b6196f6ed64f5dbb25ac Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Wed, 21 Aug 2019 09:12:01 +0200 Subject: [PATCH] Encapsulate http_version into BaseStream --- httpx/concurrency/asyncio.py | 27 ++++++++++++++------------- httpx/concurrency/base.py | 5 ++++- httpx/dispatch/connection.py | 8 ++++---- tests/dispatch/utils.py | 7 +++++-- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 38449852..34d68c38 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -61,6 +61,18 @@ class Stream(BaseStream): self.stream_writer = stream_writer self.timeout = timeout + def get_http_version(self) -> str: + ssl_object = self.stream_writer.get_extra_info("ssl_object") + + if ssl_object is None: + return "HTTP/1.1" + + ident = ssl_object.selected_alpn_protocol() + if ident is None: + ident = ssl_object.selected_npn_protocol() + + return "HTTP/2" if ident == "h2" else "HTTP/1.1" + async def read( self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None ) -> bytes: @@ -169,7 +181,7 @@ class AsyncioBackend(ConcurrencyBackend): port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseStream, str]: + ) -> BaseStream: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore asyncio.open_connection(hostname, port, ssl=ssl_context), @@ -178,20 +190,9 @@ class AsyncioBackend(ConcurrencyBackend): except asyncio.TimeoutError: raise ConnectTimeout() - ssl_object = stream_writer.get_extra_info("ssl_object") - if ssl_object is None: - ident = "http/1.1" - else: - ident = ssl_object.selected_alpn_protocol() - if ident is None: - ident = ssl_object.selected_npn_protocol() - - stream = Stream( + return Stream( stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout ) - http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1" - - return stream, http_version async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 95a07b90..45785df1 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -44,6 +44,9 @@ class BaseStream: backends, or for stand-alone test cases. """ + def get_http_version(self) -> str: + raise NotImplementedError() # pragma: no cover + async def read( self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None ) -> bytes: @@ -110,7 +113,7 @@ class ConcurrencyBackend: port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseStream, str]: + ) -> BaseStream: raise NotImplementedError() # pragma: no cover def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 7f0d14ee..0e9819cb 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -79,9 +79,9 @@ class HTTPConnection(AsyncDispatcher): else: on_release = functools.partial(self.release_func, self) - stream, http_version = await self.backend.connect( - host, port, ssl_context, timeout - ) + stream = await self.backend.connect(host, port, ssl_context, timeout) + http_version = stream.get_http_version() + if http_version == "HTTP/2": self.h2_connection = HTTP2Connection( stream, self.backend, on_release=on_release @@ -96,7 +96,7 @@ class HTTPConnection(AsyncDispatcher): if not self.origin.is_ssl: return None - # Run the SSL loading in a threadpool, since it may makes disk accesses. + # Run the SSL loading in a threadpool, since it may make disk accesses. return await self.backend.run_in_threadpool( ssl.load_ssl_context, self.http_versions ) diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index 6e76269c..33151357 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -20,9 +20,9 @@ class MockHTTP2Backend(AsyncioBackend): port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseStream, str]: + ) -> BaseStream: self.server = MockHTTP2Server(self.app) - return self.server, "HTTP/2" + return self.server class MockHTTP2Server(BaseStream): @@ -36,6 +36,9 @@ class MockHTTP2Server(BaseStream): # Stream interface + def get_http_version(self) -> str: + return "HTTP/2" + async def read(self, n, timeout, flag=None) -> bytes: await asyncio.sleep(0) send, self.buffer = self.buffer[:n], self.buffer[n:] -- 2.47.3