]> git.ipfire.org Git - thirdparty/httpx.git/commitdiff
Push SSL context loading into config.py
authorTom Christie <tom@tomchristie.com>
Mon, 22 Apr 2019 14:45:56 +0000 (15:45 +0100)
committerTom Christie <tom@tomchristie.com>
Mon, 22 Apr 2019 14:45:56 +0000 (15:45 +0100)
httpcore/config.py
httpcore/pool.py

index e2a18b4e754d1420ebbce01d1ea40ac27590e2db..f694fc9a5e289168260a928a7e4488e0cd95b5db 100644 (file)
@@ -1,3 +1,4 @@
+import ssl
 import typing
 
 import certifi
@@ -32,6 +33,58 @@ class SSLConfig:
         class_name = self.__class__.__name__
         return f"{class_name}(cert={self.cert}, verify={self.verify})"
 
+    async def load_ssl_context(self) -> ssl.SSLContext:
+        if not hasattr(self, "ssl_context"):
+            if not self.verify:
+                self.ssl_context = self.load_ssl_context_no_verify()
+            else:
+                # Run the SSL loading in a threadpool, since it makes disk accesses.
+                loop = asyncio.get_event_loop()
+                self.ssl_context = await loop.run_in_executor(
+                    None, self.load_ssl_context_verify
+                )
+
+        return self.ssl_context
+
+    def load_ssl_context_no_verify(self) -> ssl.SSLContext:
+        """
+        Return an SSL context for unverified connections.
+        """
+        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        context.options |= ssl.OP_NO_SSLv2
+        context.options |= ssl.OP_NO_SSLv3
+        context.options |= ssl.OP_NO_COMPRESSION
+        context.set_default_verify_paths()
+        return context
+
+    def load_ssl_context_verify(self) -> ssl.SSLContext:
+        """
+        Return an SSL context for verified connections.
+        """
+        if isinstance(self.verify, bool):
+            ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
+        elif os.path.exists(self.verify):
+            ca_bundle_path = self.verify
+        else:
+            raise IOError(
+                "Could not find a suitable TLS CA certificate bundle, "
+                "invalid path: {}".format(self.verify)
+            )
+
+        context = ssl.create_default_context()
+        if os.path.isfile(ca_bundle_path):
+            context.load_verify_locations(cafile=ca_bundle_path)
+        elif os.path.isdir(ca_bundle_path):
+            context.load_verify_locations(capath=ca_bundle_path)
+
+        if self.cert is not None:
+            if isinstance(self.cert, str):
+                context.load_cert_chain(certfile=self.cert)
+            else:
+                context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1])
+
+        return context
+
 
 class TimeoutConfig:
     """
index 74c194906f611e460a392b1a68131a027f426fd7..f09e01e04019a6cce3b9fa4be60f8a34f4f46e76 100644 (file)
@@ -99,7 +99,11 @@ class ConnectionPool:
             self.num_active_connections += 1
 
         except (KeyError, IndexError):
-            ssl_context = await self.get_ssl_context(url, ssl)
+            if url.is_secure:
+                ssl_context = await ssl.load_ssl_context()
+            else:
+                ssl_context = None
+
             try:
                 await asyncio.wait_for(
                     self._max_connections.acquire(), timeout.pool_timeout
@@ -134,63 +138,6 @@ class ConnectionPool:
             except KeyError:
                 self._keepalive_connections[key] = [connection]
 
-    async def get_ssl_context(
-        self, url: URL, config: SSLConfig
-    ) -> typing.Optional[ssl.SSLContext]:
-        if not url.is_secure:
-            return None
-
-        if not hasattr(self, "ssl_context"):
-            if not config.verify:
-                self.ssl_context = self.get_ssl_context_no_verify()
-            else:
-                # Run the SSL loading in a threadpool, since it makes disk accesses.
-                loop = asyncio.get_event_loop()
-                self.ssl_context = await loop.run_in_executor(
-                    None, self.get_ssl_context_verify
-                )
-
-        return self.ssl_context
-
-    def get_ssl_context_no_verify(self) -> ssl.SSLContext:
-        """
-        Return an SSL context for unverified connections.
-        """
-        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
-        context.options |= ssl.OP_NO_SSLv2
-        context.options |= ssl.OP_NO_SSLv3
-        context.options |= ssl.OP_NO_COMPRESSION
-        context.set_default_verify_paths()
-        return context
-
-    def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext:
-        """
-        Return an SSL context for verified connections.
-        """
-        if isinstance(config.verify, bool):
-            ca_bundle_path = DEFAULT_CA_BUNDLE_PATH
-        elif os.path.exists(config.verify):
-            ca_bundle_path = config.verify
-        else:
-            raise IOError(
-                "Could not find a suitable TLS CA certificate bundle, "
-                "invalid path: {}".format(config.verify)
-            )
-
-        context = ssl.create_default_context()
-        if os.path.isfile(ca_bundle_path):
-            context.load_verify_locations(cafile=ca_bundle_path)
-        elif os.path.isdir(ca_bundle_path):
-            context.load_verify_locations(capath=ca_bundle_path)
-
-        if config.cert is not None:
-            if isinstance(config.cert, str):
-                context.load_cert_chain(certfile=config.cert)
-            else:
-                context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1])
-
-        return context
-
     async def close(self) -> None:
         self.is_closed = True