import h11
-from ..backends.base import BaseSocketStream, ConcurrencyBackend, lookup_backend
+from ..backends.base import ConcurrencyBackend, lookup_backend
from ..config import CertTypes, SSLConfig, Timeout, VerifyTypes
from ..models import URL, Origin, Request, Response
from ..utils import get_logger
-from .base import Dispatcher, OpenConnection
+from .base import Dispatcher
from .http2 import HTTP2Connection
from .http11 import HTTP11Connection
self.backend = lookup_backend(backend)
self.release_func = release_func
self.uds = uds
- self.open_connection: typing.Optional[OpenConnection] = None
+ self.connection: typing.Union[None, HTTP11Connection, HTTP2Connection] = None
self.expires_at: typing.Optional[float] = None
async def send(self, request: Request, timeout: Timeout = None) -> Response:
timeout = Timeout() if timeout is None else timeout
- if self.open_connection is None:
- await self.connect(timeout=timeout)
+ if self.connection is None:
+ self.connection = await self.connect(timeout=timeout)
- assert self.open_connection is not None
- response = await self.open_connection.send(request, timeout=timeout)
+ return await self.connection.send(request, timeout=timeout)
- return response
-
- async def connect(self, timeout: Timeout) -> None:
+ async def connect(
+ self, timeout: Timeout
+ ) -> typing.Union[HTTP11Connection, HTTP2Connection]:
host = self.origin.host
port = self.origin.port
ssl_context = await self.get_ssl_context(self.ssl)
logger.trace(
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
)
- stream = await self.backend.open_tcp_stream(
+ socket = await self.backend.open_tcp_stream(
host, port, ssl_context, timeout
)
else:
logger.trace(
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
)
- stream = await self.backend.open_uds_stream(
+ socket = await self.backend.open_uds_stream(
self.uds, host, ssl_context, timeout
)
- http_version = stream.get_http_version()
+ http_version = socket.get_http_version()
logger.trace(f"connected http_version={http_version!r}")
- self.set_open_connection(http_version, socket=stream, on_release=on_release)
+ if http_version == "HTTP/2":
+ return HTTP2Connection(socket, self.backend, on_release=on_release)
+ return HTTP11Connection(socket, on_release=on_release)
async def tunnel_start_tls(
self, origin: Origin, proxy_url: URL, timeout: Timeout = None,
# First, check that we are in the correct state to start TLS, i.e. we've
# just agreed to switch protocols with the server via HTTP/1.1.
- assert isinstance(self.open_connection, HTTP11Connection)
- h11_connection = self.open_connection
+ assert isinstance(self.connection, HTTP11Connection)
+ h11_connection = self.connection
assert h11_connection is not None
assert h11_connection.h11_state.our_state == h11.SWITCHED_PROTOCOL
# HTTP request. Don't try to upgrade to TLS in this case.
pass
- self.set_open_connection(http_version, socket=socket, on_release=on_release)
-
- def set_open_connection(
- self,
- http_version: str,
- socket: BaseSocketStream,
- on_release: typing.Optional[typing.Callable],
- ) -> None:
if http_version == "HTTP/2":
- self.open_connection = HTTP2Connection(
+ self.connection = HTTP2Connection(
socket, self.backend, on_release=on_release
)
else:
- assert http_version == "HTTP/1.1"
- self.open_connection = HTTP11Connection(socket, on_release=on_release)
+ self.connection = HTTP11Connection(socket, on_release=on_release)
async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
async def close(self) -> None:
logger.trace("close_connection")
- if self.open_connection is not None:
- await self.open_connection.close()
+ if self.connection is not None:
+ await self.connection.close()
@property
def is_http2(self) -> bool:
- assert self.open_connection is not None
- return self.open_connection.is_http2
+ return self.connection is not None and self.connection.is_http2
@property
def is_closed(self) -> bool:
- assert self.open_connection is not None
- return self.open_connection.is_closed
+ return self.connection is not None and self.connection.is_closed
def is_connection_dropped(self) -> bool:
- assert self.open_connection is not None
- return self.open_connection.is_connection_dropped()
+ return self.connection is not None and self.connection.is_connection_dropped()
def __repr__(self) -> str:
class_name = self.__class__.__name__