import h11
-from ..concurrency.base import ConcurrencyBackend, lookup_backend
+from ..concurrency.base import BaseSocketStream, ConcurrencyBackend, lookup_backend
from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
from ..models import URL, Origin, Request, Response
from ..utils import get_logger
-from .base import Dispatcher
+from .base import Dispatcher, OpenConnection
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
self.backend = lookup_backend(backend)
self.release_func = release_func
self.uds = uds
- self.h11_connection = None # type: typing.Optional[HTTP11Connection]
- self.h2_connection = None # type: typing.Optional[HTTP2Connection]
+ self.open_connection: typing.Optional[OpenConnection] = None
async def send(
self,
) -> Response:
timeout = Timeout() if timeout is None else timeout
- if self.h11_connection is None and self.h2_connection is None:
+ if self.open_connection is None:
await self.connect(verify=verify, cert=cert, timeout=timeout)
- if self.h2_connection is not None:
- response = await self.h2_connection.send(request, timeout=timeout)
- else:
- assert self.h11_connection is not None
- response = await self.h11_connection.send(request, timeout=timeout)
+ assert self.open_connection is not None
+ response = await self.open_connection.send(request, timeout=timeout)
return response
http_version = stream.get_http_version()
logger.trace(f"connected http_version={http_version!r}")
- if http_version == "HTTP/2":
- self.h2_connection = HTTP2Connection(
- stream, backend=self.backend, on_release=on_release
- )
- else:
- assert http_version == "HTTP/1.1"
- self.h11_connection = HTTP11Connection(stream, on_release=on_release)
+ self.set_open_connection(http_version, socket=stream, on_release=on_release)
async def tunnel_start_tls(
self,
# First, check that we are in the correct state to start TLS, i.e. we've
# just agreed to switch protocols with the server via HTTP/1.1.
- h11_connection = self.h11_connection
+ assert isinstance(self.open_connection, HTTP11Connection)
+ h11_connection = self.open_connection
assert h11_connection is not None
assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
# HTTP request. Don't try to upgrade to TLS in this case.
pass
+ self.set_open_connection(http_version, socket=socket, on_release=on_release)
+
+ def set_open_connection(
+ self,
+ http_version: str,
+ socket: BaseSocketStream,
+ on_release: typing.Optional[typing.Callable],
+ ) -> None:
if http_version == "HTTP/2":
- self.h2_connection = HTTP2Connection(
+ self.open_connection = HTTP2Connection(
socket, self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
- self.h11_connection = HTTP11Connection(socket, on_release=on_release)
+ self.open_connection = HTTP11Connection(socket, on_release=on_release)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
async def close(self) -> None:
logger.trace("close_connection")
- if self.h2_connection is not None:
- await self.h2_connection.close()
- elif self.h11_connection is not None:
- await self.h11_connection.close()
+ if self.open_connection is not None:
+ await self.open_connection.close()
@property
def is_http2(self) -> bool:
- return self.h2_connection is not None
+ assert self.open_connection is not None
+ return self.open_connection.is_http2
@property
def is_closed(self) -> bool:
- if self.h2_connection is not None:
- return self.h2_connection.is_closed
- else:
- assert self.h11_connection is not None
- return self.h11_connection.is_closed
+ assert self.open_connection is not None
+ return self.open_connection.is_closed
def is_connection_dropped(self) -> bool:
- if self.h2_connection is not None:
- return self.h2_connection.is_connection_dropped()
- else:
- assert self.h11_connection is not None
- return self.h11_connection.is_connection_dropped()
+ assert self.open_connection is not None
+ return self.open_connection.is_connection_dropped()
def __repr__(self) -> str:
class_name = self.__class__.__name__