]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Allow per-request timeout/ssl config 9/head
authorTom Christie <tom@tomchristie.com>
Thu, 18 Apr 2019 09:36:30 +0000 (10:36 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 18 Apr 2019 09:36:30 +0000 (10:36 +0100)
httpcore/config.py
httpcore/pool.py
tests/test_config.py [new file with mode: 0644]

index 2db89342705c9effbe889da74ae17905d387282c..e2a18b4e754d1420ebbce01d1ea40ac27590e2db 100644 (file)
@@ -8,10 +8,30 @@ class SSLConfig:
     SSL Configuration.
     """
 
-    def __init__(self, *, cert: typing.Optional[str], verify: typing.Union[str, bool]):
+    def __init__(
+        self,
+        *,
+        cert: typing.Union[None, str, typing.Tuple[str, str]] = None,
+        verify: typing.Union[str, bool] = True,
+    ):
         self.cert = cert
         self.verify = verify
 
+    def __eq__(self, other: typing.Any) -> bool:
+        return (
+            isinstance(other, self.__class__)
+            and self.cert == other.cert
+            and self.verify == other.verify
+        )
+
+    def __hash__(self) -> int:
+        as_tuple = (self.cert, self.verify)
+        return hash(as_tuple)
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        return f"{class_name}(cert={self.cert}, verify={self.verify})"
+
 
 class TimeoutConfig:
     """
@@ -24,7 +44,7 @@ class TimeoutConfig:
         *,
         connect_timeout: float = None,
         read_timeout: float = None,
-        pool_timeout: float = None
+        pool_timeout: float = None,
     ):
         if timeout is not None:
             # Specified as a single timeout value
@@ -35,10 +55,29 @@ class TimeoutConfig:
             read_timeout = timeout
             pool_timeout = timeout
 
+        self.timeout = timeout
         self.connect_timeout = connect_timeout
         self.read_timeout = read_timeout
         self.pool_timeout = pool_timeout
 
+    def __eq__(self, other: typing.Any) -> bool:
+        return (
+            isinstance(other, self.__class__)
+            and self.connect_timeout == other.connect_timeout
+            and self.read_timeout == other.read_timeout
+            and self.pool_timeout == other.pool_timeout
+        )
+
+    def __hash__(self) -> int:
+        as_tuple = (self.connect_timeout, self.read_timeout, self.pool_timeout)
+        return hash(as_tuple)
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        if self.timeout is not None:
+            return f"{class_name}(timeout={self.timeout})"
+        return f"{class_name}(connect_timeout={self.connect_timeout}, read_timeout={self.read_timeout}, pool_timeout={self.pool_timeout})"
+
 
 class PoolLimits:
     """
@@ -49,11 +88,28 @@ class PoolLimits:
         self,
         *,
         soft_limit: typing.Optional[int] = None,
-        hard_limit: typing.Optional[int] = None
+        hard_limit: typing.Optional[int] = None,
     ):
         self.soft_limit = soft_limit
         self.hard_limit = hard_limit
 
+    def __eq__(self, other: typing.Any) -> bool:
+        return (
+            isinstance(other, self.__class__)
+            and self.soft_limit == other.soft_limit
+            and self.hard_limit == other.hard_limit
+        )
+
+    def __hash__(self) -> int:
+        as_tuple = (self.soft_limit, self.hard_limit)
+        return hash(as_tuple)
+
+    def __repr__(self) -> str:
+        class_name = self.__class__.__name__
+        return (
+            f"{class_name}(soft_limit={self.soft_limit}, hard_limit={self.hard_limit})"
+        )
+
 
 DEFAULT_SSL_CONFIG = SSLConfig(cert=None, verify=True)
 DEFAULT_TIMEOUT_CONFIG = TimeoutConfig(timeout=5.0)
index 39216c3dff7c82d7b06eebd5092079f09bb06328..74c194906f611e460a392b1a68131a027f426fd7 100644 (file)
@@ -18,21 +18,17 @@ from .connections import Connection
 from .datastructures import URL, Request, Response
 from .exceptions import PoolTimeout
 
-ConnectionKey = typing.Tuple[str, str, int]  # (scheme, host, port)
+ConnectionKey = typing.Tuple[str, str, int, SSLConfig, TimeoutConfig]
 
 
 class ConnectionSemaphore:
-    def __init__(self, max_connections: int = None, timeout: float = None):
-        self.timeout = timeout
+    def __init__(self, max_connections: int = None):
         if max_connections is not None:
             self.semaphore = asyncio.BoundedSemaphore(value=max_connections)
 
     async def acquire(self) -> None:
         if hasattr(self, "semaphore"):
-            try:
-                await asyncio.wait_for(self.semaphore.acquire(), self.timeout)
-            except asyncio.TimeoutError:
-                raise PoolTimeout()
+            await self.semaphore.acquire()
 
     def release(self) -> None:
         if hasattr(self, "semaphore"):
@@ -53,11 +49,11 @@ class ConnectionPool:
         self.is_closed = False
         self.num_active_connections = 0
         self.num_keepalive_connections = 0
-        self._connections = (
+        self._keepalive_connections = (
             {}
         )  # type: typing.Dict[ConnectionKey, typing.List[Connection]]
-        self._connection_semaphore = ConnectionSemaphore(
-            max_connections=self.limits.hard_limit, timeout=self.timeout.pool_timeout
+        self._max_connections = ConnectionSemaphore(
+            max_connections=self.limits.hard_limit
         )
 
     async def request(
@@ -68,11 +64,17 @@ class ConnectionPool:
         headers: typing.Sequence[typing.Tuple[bytes, bytes]] = (),
         body: typing.Union[bytes, typing.AsyncIterator[bytes]] = b"",
         stream: bool = False,
+        ssl: typing.Optional[SSLConfig] = None,
+        timeout: typing.Optional[TimeoutConfig] = None,
     ) -> Response:
+        if ssl is None:
+            ssl = self.ssl_config
+        if timeout is None:
+            timeout = self.timeout
+
         parsed_url = URL(url)
         request = Request(method, parsed_url, headers=headers, body=body)
-        ssl_context = await self.get_ssl_context(parsed_url)
-        connection = await self.acquire_connection(parsed_url, ssl=ssl_context)
+        connection = await self.acquire_connection(parsed_url, ssl=ssl, timeout=timeout)
         response = await connection.send(request)
         if not stream:
             try:
@@ -86,22 +88,28 @@ class ConnectionPool:
         return self.num_active_connections + self.num_keepalive_connections
 
     async def acquire_connection(
-        self, url: URL, *, ssl: typing.Optional[ssl.SSLContext] = None
+        self, url: URL, ssl: SSLConfig, timeout: TimeoutConfig
     ) -> Connection:
-        key = (url.scheme, url.hostname, url.port)
+        key = (url.scheme, url.hostname, url.port, ssl, timeout)
         try:
-            connection = self._connections[key].pop()
-            if not self._connections[key]:
-                del self._connections[key]
+            connection = self._keepalive_connections[key].pop()
+            if not self._keepalive_connections[key]:
+                del self._keepalive_connections[key]
             self.num_keepalive_connections -= 1
             self.num_active_connections += 1
 
         except (KeyError, IndexError):
-            await self._connection_semaphore.acquire()
+            ssl_context = await self.get_ssl_context(url, ssl)
+            try:
+                await asyncio.wait_for(
+                    self._max_connections.acquire(), timeout.pool_timeout
+                )
+            except asyncio.TimeoutError:
+                raise PoolTimeout()
             release = functools.partial(self.release_connection, key=key)
-            connection = Connection(timeout=self.timeout, on_release=release)
+            connection = Connection(timeout=timeout, on_release=release)
             self.num_active_connections += 1
-            await connection.open(url.hostname, url.port, ssl=ssl)
+            await connection.open(url.hostname, url.port, ssl=ssl_context)
 
         return connection
 
@@ -109,29 +117,31 @@ class ConnectionPool:
         self, connection: Connection, key: ConnectionKey
     ) -> None:
         if connection.is_closed:
-            self._connection_semaphore.release()
+            self._max_connections.release()
             self.num_active_connections -= 1
         elif (
             self.limits.soft_limit is not None
             and self.num_connections > self.limits.soft_limit
         ):
-            self._connection_semaphore.release()
+            self._max_connections.release()
             self.num_active_connections -= 1
             connection.close()
         else:
             self.num_active_connections -= 1
             self.num_keepalive_connections += 1
             try:
-                self._connections[key].append(connection)
+                self._keepalive_connections[key].append(connection)
             except KeyError:
-                self._connections[key] = [connection]
+                self._keepalive_connections[key] = [connection]
 
-    async def get_ssl_context(self, url: URL) -> typing.Optional[ssl.SSLContext]:
+    async def get_ssl_context(
+        self, url: URL, config: SSLConfig
+    ) -> typing.Optional[ssl.SSLContext]:
         if not url.is_secure:
             return None
 
         if not hasattr(self, "ssl_context"):
-            if not self.ssl_config.verify:
+            if not config.verify:
                 self.ssl_context = self.get_ssl_context_no_verify()
             else:
                 # Run the SSL loading in a threadpool, since it makes disk accesses.
@@ -153,21 +163,18 @@ class ConnectionPool:
         context.set_default_verify_paths()
         return context
 
-    def get_ssl_context_verify(self) -> ssl.SSLContext:
+    def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext:
         """
         Return an SSL context for verified connections.
         """
-        cert = self.ssl_config.cert
-        verify = self.ssl_config.verify
-
-        if isinstance(verify, bool):
+        if isinstance(config.verify, bool):
             ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
-        elif os.path.exists(verify):
-            ca_bundle_path = verify
+        elif os.path.exists(config.verify):
+            ca_bundle_path = config.verify
         else:
             raise IOError(
                 "Could not find a suitable TLS CA certificate bundle, "
-                "invalid path: {}".format(verify)
+                "invalid path: {}".format(config.verify)
             )
 
         context = ssl.create_default_context()
@@ -176,11 +183,11 @@ class ConnectionPool:
         elif os.path.isdir(ca_bundle_path):
             context.load_verify_locations(capath=ca_bundle_path)
 
-        if cert is not None:
-            if isinstance(cert, str):
-                context.load_cert_chain(certfile=cert)
+        if config.cert is not None:
+            if isinstance(config.cert, str):
+                context.load_cert_chain(certfile=config.cert)
             else:
-                context.load_cert_chain(certfile=cert[0], keyfile=cert[1])
+                context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1])
 
         return context
 
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644 (file)
index 0000000..daf0e1e
--- /dev/null
@@ -0,0 +1,58 @@
+import httpcore
+
+
+def test_ssl_repr():
+    ssl = httpcore.SSLConfig(verify=False)
+    assert repr(ssl) == "SSLConfig(cert=None, verify=False)"
+
+
+def test_timeout_repr():
+    timeout = httpcore.TimeoutConfig(timeout=5.0)
+    assert repr(timeout) == "TimeoutConfig(timeout=5.0)"
+
+    timeout = httpcore.TimeoutConfig(read_timeout=5.0)
+    assert (
+        repr(timeout)
+        == "TimeoutConfig(connect_timeout=None, read_timeout=5.0, pool_timeout=None)"
+    )
+
+
+def test_limits_repr():
+    limits = httpcore.PoolLimits(hard_limit=100)
+    assert repr(limits) == "PoolLimits(soft_limit=None, hard_limit=100)"
+
+
+def test_ssl_eq():
+    ssl = httpcore.SSLConfig(verify=False)
+    assert ssl == httpcore.SSLConfig(verify=False)
+
+
+def test_timeout_eq():
+    timeout = httpcore.TimeoutConfig(timeout=5.0)
+    assert timeout == httpcore.TimeoutConfig(timeout=5.0)
+
+
+def test_limits_eq():
+    limits = httpcore.PoolLimits(hard_limit=100)
+    assert limits == httpcore.PoolLimits(hard_limit=100)
+
+
+def test_ssl_hash():
+    cache = {}
+    ssl = httpcore.SSLConfig(verify=False)
+    cache[ssl] = "example"
+    assert cache[httpcore.SSLConfig(verify=False)] == "example"
+
+
+def test_timeout_hash():
+    cache = {}
+    timeout = httpcore.TimeoutConfig(timeout=5.0)
+    cache[timeout] = "example"
+    assert cache[httpcore.TimeoutConfig(timeout=5.0)] == "example"
+
+
+def test_limits_hash():
+    cache = {}
+    limits = httpcore.PoolLimits(hard_limit=100)
+    cache[limits] = "example"
+    assert cache[httpcore.PoolLimits(hard_limit=100)] == "example"