SSL Configuration.
"""
- def __init__(self, *, cert: typing.Optional[str], verify: typing.Union[str, bool]):
+ def __init__(
+ self,
+ *,
+ cert: typing.Union[None, str, typing.Tuple[str, str]] = None,
+ verify: typing.Union[str, bool] = True,
+ ):
self.cert = cert
self.verify = verify
+ def __eq__(self, other: typing.Any) -> bool:
+ return (
+ isinstance(other, self.__class__)
+ and self.cert == other.cert
+ and self.verify == other.verify
+ )
+
+ def __hash__(self) -> int:
+ as_tuple = (self.cert, self.verify)
+ return hash(as_tuple)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ return f"{class_name}(cert={self.cert}, verify={self.verify})"
+
class TimeoutConfig:
"""
*,
connect_timeout: float = None,
read_timeout: float = None,
- pool_timeout: float = None
+ pool_timeout: float = None,
):
if timeout is not None:
# Specified as a single timeout value
read_timeout = timeout
pool_timeout = timeout
+ self.timeout = timeout
self.connect_timeout = connect_timeout
self.read_timeout = read_timeout
self.pool_timeout = pool_timeout
+ def __eq__(self, other: typing.Any) -> bool:
+ return (
+ isinstance(other, self.__class__)
+ and self.connect_timeout == other.connect_timeout
+ and self.read_timeout == other.read_timeout
+ and self.pool_timeout == other.pool_timeout
+ )
+
+ def __hash__(self) -> int:
+ as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
+ return hash(as_tuple)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ if self.timeout is not None:
+ return f"{class_name}(timeout={self.timeout})"
+ return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
+
class PoolLimits:
"""
self,
*,
soft_limit: typing.Optional[int] = None,
- hard_limit: typing.Optional[int] = None
+ hard_limit: typing.Optional[int] = None,
):
self.soft_limit = soft_limit
self.hard_limit = hard_limit
+ def __eq__(self, other: typing.Any) -> bool:
+ return (
+ isinstance(other, self.__class__)
+ and self.soft_limit == other.soft_limit
+ and self.hard_limit == other.hard_limit
+ )
+
+ def __hash__(self) -> int:
+ as_tuple = (self.soft_limit, self.hard_limit)
+ return hash(as_tuple)
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ return (
+ f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
+ )
+
DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
from .datastructures import URL, Request, Response
from .exceptions import PoolTimeout
-ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port)
+ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig]
class ConnectionSemaphore:
- def __init__(self, max_connections: int = None, timeout: float = None):
- self.timeout = timeout
+ def __init__(self, max_connections: int = None):
if max_connections is not None:
self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
async def acquire(self) -> None:
if hasattr(self, "semaphore"):
- try:
- await asyncio.wait_for(self.semaphore.acquire(), self.timeout)
- except asyncio.TimeoutError:
- raise PoolTimeout()
+ await self.semaphore.acquire()
def release(self) -> None:
if hasattr(self, "semaphore"):
self.is_closed = False
self.num_active_connections = 0
self.num_keepalive_connections = 0
- self._connections = (
+ self._keepalive_connections = (
{}
) # type: typing.Dict[ConnectionKey, typing.List[Connection]]
- self._connection_semaphore = ConnectionSemaphore(
- max_connections=self.limits.hard_limit, timeout=self.timeout.pool_timeout
+ self._max_connections = ConnectionSemaphore(
+ max_connections=self.limits.hard_limit
)
async def request(
headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
stream: bool = False,
+ ssl: typing.Optional[SSLConfig] = None,
+ timeout: typing.Optional[TimeoutConfig] = None,
) -> Response:
+ if ssl is None:
+ ssl = self.ssl_config
+ if timeout is None:
+ timeout = self.timeout
+
parsed_url = URL(url)
request = Request(method, parsed_url, headers=headers, body=body)
- ssl_context = await self.get_ssl_context(parsed_url)
- connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
+ connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout)
response = await connection.send(request)
if not stream:
try:
return self.num_active_connections + self.num_keepalive_connections
async def acquire_connection(
- self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
+ self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig
) -> Connection:
- key = (url.scheme, url.hostname, url.port)
+ key = (url.scheme, url.hostname, url.port, ssl, timeout)
try:
- connection = self._connections[key].pop()
- if not self._connections[key]:
- del self._connections[key]
+ connection = self._keepalive_connections[key].pop()
+ if not self._keepalive_connections[key]:
+ del self._keepalive_connections[key]
self.num_keepalive_connections -= 1
self.num_active_connections += 1
except (KeyError, IndexError):
- await self._connection_semaphore.acquire()
+ ssl_context = await self.get_ssl_context(url, ssl)
+ try:
+ await asyncio.wait_for(
+ self._max_connections.acquire(), timeout.pool_timeout
+ )
+ except asyncio.TimeoutError:
+ raise PoolTimeout()
release = functools.partial(self.release_connection, key=key)
- connection = Connection(timeout=self.timeout, on_release=release)
+ connection = Connection(timeout=timeout, on_release=release)
self.num_active_connections += 1
- await connection.open(url.hostname, url.port, ssl=ssl)
+ await connection.open(url.hostname, url.port, ssl=ssl_context)
return connection
self, connection: Connection, key: ConnectionKey
) -> None:
if connection.is_closed:
- self._connection_semaphore.release()
+ self._max_connections.release()
self.num_active_connections -= 1
elif (
self.limits.soft_limit is not None
and self.num_connections > self.limits.soft_limit
):
- self._connection_semaphore.release()
+ self._max_connections.release()
self.num_active_connections -= 1
connection.close()
else:
self.num_active_connections -= 1
self.num_keepalive_connections += 1
try:
- self._connections[key].append(connection)
+ self._keepalive_connections[key].append(connection)
except KeyError:
- self._connections[key] = [connection]
+ self._keepalive_connections[key] = [connection]
- async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]:
+ async def get_ssl_context(
+ self, url: URL, config: SSLConfig
+ ) -> typing.Optional[ssl.SSLContext]:
if not url.is_secure:
return None
if not hasattr(self, "ssl_context"):
- if not self.ssl_config.verify:
+ if not config.verify:
self.ssl_context = self.get_ssl_context_no_verify()
else:
# Run the SSL loading in a threadpool, since it makes disk accesses.
context.set_default_verify_paths()
return context
- def get_ssl_context_verify(self) -> ssl.SSLContext:
+ def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext:
"""
Return an SSL context for verified connections.
"""
- cert = self.ssl_config.cert
- verify = self.ssl_config.verify
-
- if isinstance(verify, bool):
+ if isinstance(config.verify, bool):
ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
- elif os.path.exists(verify):
- ca_bundle_path = verify
+ elif os.path.exists(config.verify):
+ ca_bundle_path = config.verify
else:
raise IOError(
"Could not find a suitable TLS CA certificate bundle, "
- "invalid path: {}".format(verify)
+ "invalid path: {}".format(config.verify)
)
context = ssl.create_default_context()
elif os.path.isdir(ca_bundle_path):
context.load_verify_locations(capath=ca_bundle_path)
- if cert is not None:
- if isinstance(cert, str):
- context.load_cert_chain(certfile=cert)
+ if config.cert is not None:
+ if isinstance(config.cert, str):
+ context.load_cert_chain(certfile=config.cert)
else:
- context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
+ context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1])
return context
--- /dev/null
+import httpcore
+
+
+def test_ssl_repr():
+ ssl = httpcore.SSLConfig(verify=False)
+ assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
+
+
+def test_timeout_repr():
+ timeout = httpcore.TimeoutConfig(timeout=5.0)
+ assert repr(timeout) == "TimeoutConfig(timeout=5.0)"
+
+ timeout = httpcore.TimeoutConfig(read_timeout=5.0)
+ assert (
+ repr(timeout)
+ == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
+ )
+
+
+def test_limits_repr():
+ limits = httpcore.PoolLimits(hard_limit=100)
+ assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
+
+
+def test_ssl_eq():
+ ssl = httpcore.SSLConfig(verify=False)
+ assert ssl == httpcore.SSLConfig(verify=False)
+
+
+def test_timeout_eq():
+ timeout = httpcore.TimeoutConfig(timeout=5.0)
+ assert timeout == httpcore.TimeoutConfig(timeout=5.0)
+
+
+def test_limits_eq():
+ limits = httpcore.PoolLimits(hard_limit=100)
+ assert limits == httpcore.PoolLimits(hard_limit=100)
+
+
+def test_ssl_hash():
+ cache = {}
+ ssl = httpcore.SSLConfig(verify=False)
+ cache[ssl] = "example"
+ assert cache[httpcore.SSLConfig(verify=False)] == "example"
+
+
+def test_timeout_hash():
+ cache = {}
+ timeout = httpcore.TimeoutConfig(timeout=5.0)
+ cache[timeout] = "example"
+ assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example"
+
+
+def test_limits_hash():
+ cache = {}
+ limits = httpcore.PoolLimits(hard_limit=100)
+ cache[limits] = "example"
+ assert cache[httpcore.PoolLimits(hard_limit=100)] == "example"