From 42c0e06c8cc303931bff5e31e15a36e90100a8fc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Apr 2019 10:36:30 +0100 Subject: [PATCH] Allow per-request timeout/ssl config --- httpcore/config.py | 62 +++++++++++++++++++++++++++++++-- httpcore/pool.py | 83 ++++++++++++++++++++++++-------------------- tests/test_config.py | 58 +++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+), 41 deletions(-) create mode 100644 tests/test_config.py diff --git a/httpcore/config.py b/httpcore/config.py index 2db89342..e2a18b4e 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -8,10 +8,30 @@ class SSLConfig: SSL Configuration. """ - def __init__(self, *, cert: typing.Optional[str], verify: typing.Union[str, bool]): + def __init__( + self, + *, + cert: typing.Union[None, str, typing.Tuple[str, str]] = None, + verify: typing.Union[str, bool] = True, + ): self.cert = cert self.verify = verify + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.cert == other.cert + and self.verify == other.verify + ) + + def __hash__(self) -> int: + as_tuple = (self.cert, self.verify) + return hash(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(cert={self.cert}, verify={self.verify})" + class TimeoutConfig: """ @@ -24,7 +44,7 @@ class TimeoutConfig: *, connect_timeout: float = None, read_timeout: float = None, - pool_timeout: float = None + pool_timeout: float = None, ): if timeout is not None: # Specified as a single timeout value @@ -35,10 +55,29 @@ class TimeoutConfig: read_timeout = timeout pool_timeout = timeout + self.timeout = timeout self.connect_timeout = connect_timeout self.read_timeout = read_timeout self.pool_timeout = pool_timeout + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.connect_timeout == other.connect_timeout + and self.read_timeout == other.read_timeout + and self.pool_timeout == other.pool_timeout + ) + + def __hash__(self) -> int: + as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout) + return hash(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + if self.timeout is not None: + return f"{class_name}(timeout={self.timeout})" + return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})" + class PoolLimits: """ @@ -49,11 +88,28 @@ class PoolLimits: self, *, soft_limit: typing.Optional[int] = None, - hard_limit: typing.Optional[int] = None + hard_limit: typing.Optional[int] = None, ): self.soft_limit = soft_limit self.hard_limit = hard_limit + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, self.__class__) + and self.soft_limit == other.soft_limit + and self.hard_limit == other.hard_limit + ) + + def __hash__(self) -> int: + as_tuple = (self.soft_limit, self.hard_limit) + return hash(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return ( + f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})" + ) + DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True) DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0) diff --git a/httpcore/pool.py b/httpcore/pool.py index 39216c3d..74c19490 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -18,21 +18,17 @@ from .connections import Connection from .datastructures import URL, Request, Response from .exceptions import PoolTimeout -ConnectionKey = typing.Tuple[str, str, int] # (scheme, host, port) +ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig] class ConnectionSemaphore: - def __init__(self, max_connections: int = None, timeout: float = None): - self.timeout = timeout + def __init__(self, max_connections: int = None): if max_connections is not None: self.semaphore = asyncio.BoundedSemaphore(value=max_connections) async def acquire(self) -> None: if hasattr(self, "semaphore"): - try: - await asyncio.wait_for(self.semaphore.acquire(), self.timeout) - except asyncio.TimeoutError: - raise PoolTimeout() + await self.semaphore.acquire() def release(self) -> None: if hasattr(self, "semaphore"): @@ -53,11 +49,11 @@ class ConnectionPool: self.is_closed = False self.num_active_connections = 0 self.num_keepalive_connections = 0 - self._connections = ( + self._keepalive_connections = ( {} ) # type: typing.Dict[ConnectionKey, typing.List[Connection]] - self._connection_semaphore = ConnectionSemaphore( - max_connections=self.limits.hard_limit, timeout=self.timeout.pool_timeout + self._max_connections = ConnectionSemaphore( + max_connections=self.limits.hard_limit ) async def request( @@ -68,11 +64,17 @@ class ConnectionPool: headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (), body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"", stream: bool = False, + ssl: typing.Optional[SSLConfig] = None, + timeout: typing.Optional[TimeoutConfig] = None, ) -> Response: + if ssl is None: + ssl = self.ssl_config + if timeout is None: + timeout = self.timeout + parsed_url = URL(url) request = Request(method, parsed_url, headers=headers, body=body) - ssl_context = await self.get_ssl_context(parsed_url) - connection = await self.acquire_connection(parsed_url, ssl=ssl_context) + connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout) response = await connection.send(request) if not stream: try: @@ -86,22 +88,28 @@ class ConnectionPool: return self.num_active_connections + self.num_keepalive_connections async def acquire_connection( - self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None + self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig ) -> Connection: - key = (url.scheme, url.hostname, url.port) + key = (url.scheme, url.hostname, url.port, ssl, timeout) try: - connection = self._connections[key].pop() - if not self._connections[key]: - del self._connections[key] + connection = self._keepalive_connections[key].pop() + if not self._keepalive_connections[key]: + del self._keepalive_connections[key] self.num_keepalive_connections -= 1 self.num_active_connections += 1 except (KeyError, IndexError): - await self._connection_semaphore.acquire() + ssl_context = await self.get_ssl_context(url, ssl) + try: + await asyncio.wait_for( + self._max_connections.acquire(), timeout.pool_timeout + ) + except asyncio.TimeoutError: + raise PoolTimeout() release = functools.partial(self.release_connection, key=key) - connection = Connection(timeout=self.timeout, on_release=release) + connection = Connection(timeout=timeout, on_release=release) self.num_active_connections += 1 - await connection.open(url.hostname, url.port, ssl=ssl) + await connection.open(url.hostname, url.port, ssl=ssl_context) return connection @@ -109,29 +117,31 @@ class ConnectionPool: self, connection: Connection, key: ConnectionKey ) -> None: if connection.is_closed: - self._connection_semaphore.release() + self._max_connections.release() self.num_active_connections -= 1 elif ( self.limits.soft_limit is not None and self.num_connections > self.limits.soft_limit ): - self._connection_semaphore.release() + self._max_connections.release() self.num_active_connections -= 1 connection.close() else: self.num_active_connections -= 1 self.num_keepalive_connections += 1 try: - self._connections[key].append(connection) + self._keepalive_connections[key].append(connection) except KeyError: - self._connections[key] = [connection] + self._keepalive_connections[key] = [connection] - async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]: + async def get_ssl_context( + self, url: URL, config: SSLConfig + ) -> typing.Optional[ssl.SSLContext]: if not url.is_secure: return None if not hasattr(self, "ssl_context"): - if not self.ssl_config.verify: + if not config.verify: self.ssl_context = self.get_ssl_context_no_verify() else: # Run the SSL loading in a threadpool, since it makes disk accesses. @@ -153,21 +163,18 @@ class ConnectionPool: context.set_default_verify_paths() return context - def get_ssl_context_verify(self) -> ssl.SSLContext: + def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext: """ Return an SSL context for verified connections. """ - cert = self.ssl_config.cert - verify = self.ssl_config.verify - - if isinstance(verify, bool): + if isinstance(config.verify, bool): ca_bundle_path = DEFAULT_CA_BUNDLE_PATH - elif os.path.exists(verify): - ca_bundle_path = verify + elif os.path.exists(config.verify): + ca_bundle_path = config.verify else: raise IOError( "Could not find a suitable TLS CA certificate bundle, " - "invalid path: {}".format(verify) + "invalid path: {}".format(config.verify) ) context = ssl.create_default_context() @@ -176,11 +183,11 @@ class ConnectionPool: elif os.path.isdir(ca_bundle_path): context.load_verify_locations(capath=ca_bundle_path) - if cert is not None: - if isinstance(cert, str): - context.load_cert_chain(certfile=cert) + if config.cert is not None: + if isinstance(config.cert, str): + context.load_cert_chain(certfile=config.cert) else: - context.load_cert_chain(certfile=cert[0], keyfile=cert[1]) + context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1]) return context diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..daf0e1ec --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,58 @@ +import httpcore + + +def test_ssl_repr(): + ssl = httpcore.SSLConfig(verify=False) + assert repr(ssl) == "SSLConfig(cert=None, verify=False)" + + +def test_timeout_repr(): + timeout = httpcore.TimeoutConfig(timeout=5.0) + assert repr(timeout) == "TimeoutConfig(timeout=5.0)" + + timeout = httpcore.TimeoutConfig(read_timeout=5.0) + assert ( + repr(timeout) + == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)" + ) + + +def test_limits_repr(): + limits = httpcore.PoolLimits(hard_limit=100) + assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)" + + +def test_ssl_eq(): + ssl = httpcore.SSLConfig(verify=False) + assert ssl == httpcore.SSLConfig(verify=False) + + +def test_timeout_eq(): + timeout = httpcore.TimeoutConfig(timeout=5.0) + assert timeout == httpcore.TimeoutConfig(timeout=5.0) + + +def test_limits_eq(): + limits = httpcore.PoolLimits(hard_limit=100) + assert limits == httpcore.PoolLimits(hard_limit=100) + + +def test_ssl_hash(): + cache = {} + ssl = httpcore.SSLConfig(verify=False) + cache[ssl] = "example" + assert cache[httpcore.SSLConfig(verify=False)] == "example" + + +def test_timeout_hash(): + cache = {} + timeout = httpcore.TimeoutConfig(timeout=5.0) + cache[timeout] = "example" + assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example" + + +def test_limits_hash(): + cache = {} + limits = httpcore.PoolLimits(hard_limit=100) + cache[limits] = "example" + assert cache[httpcore.PoolLimits(hard_limit=100)] == "example" -- 2.47.3