From: Tom Christie Date: Mon, 22 Apr 2019 14:45:56 +0000 (+0100) Subject: Push SSL context loading into config.py X-Git-Tag: 0.2.0~1^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=25c652a4389241844a80540932fcd5fd173e8b16;p=thirdparty%2Fhttpx.git Push SSL context loading into config.py --- diff --git a/httpcore/config.py b/httpcore/config.py index e2a18b4e..f694fc9a 100644 --- a/httpcore/config.py +++ b/httpcore/config.py @@ -1,3 +1,4 @@ +import ssl import typing import certifi @@ -32,6 +33,58 @@ class SSLConfig: class_name = self.__class__.__name__ return f"{class_name}(cert={self.cert}, verify={self.verify})" + async def load_ssl_context(self) -> ssl.SSLContext: + if not hasattr(self, "ssl_context"): + if not self.verify: + self.ssl_context = self.load_ssl_context_no_verify() + else: + # Run the SSL loading in a threadpool, since it makes disk accesses. + loop = asyncio.get_event_loop() + self.ssl_context = await loop.run_in_executor( + None, self.load_ssl_context_verify + ) + + return self.ssl_context + + def load_ssl_context_no_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for unverified connections. + """ + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_COMPRESSION + context.set_default_verify_paths() + return context + + def load_ssl_context_verify(self) -> ssl.SSLContext: + """ + Return an SSL context for verified connections. + """ + if isinstance(self.verify, bool): + ca_bundle_path = DEFAULT_CA_BUNDLE_PATH + elif os.path.exists(self.verify): + ca_bundle_path = self.verify + else: + raise IOError( + "Could not find a suitable TLS CA certificate bundle, " + "invalid path: {}".format(self.verify) + ) + + context = ssl.create_default_context() + if os.path.isfile(ca_bundle_path): + context.load_verify_locations(cafile=ca_bundle_path) + elif os.path.isdir(ca_bundle_path): + context.load_verify_locations(capath=ca_bundle_path) + + if self.cert is not None: + if isinstance(self.cert, str): + context.load_cert_chain(certfile=self.cert) + else: + context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) + + return context + class TimeoutConfig: """ diff --git a/httpcore/pool.py b/httpcore/pool.py index 74c19490..f09e01e0 100644 --- a/httpcore/pool.py +++ b/httpcore/pool.py @@ -99,7 +99,11 @@ class ConnectionPool: self.num_active_connections += 1 except (KeyError, IndexError): - ssl_context = await self.get_ssl_context(url, ssl) + if url.is_secure: + ssl_context = await ssl.load_ssl_context() + else: + ssl_context = None + try: await asyncio.wait_for( self._max_connections.acquire(), timeout.pool_timeout @@ -134,63 +138,6 @@ class ConnectionPool: except KeyError: self._keepalive_connections[key] = [connection] - async def get_ssl_context( - self, url: URL, config: SSLConfig - ) -> typing.Optional[ssl.SSLContext]: - if not url.is_secure: - return None - - if not hasattr(self, "ssl_context"): - if not config.verify: - self.ssl_context = self.get_ssl_context_no_verify() - else: - # Run the SSL loading in a threadpool, since it makes disk accesses. - loop = asyncio.get_event_loop() - self.ssl_context = await loop.run_in_executor( - None, self.get_ssl_context_verify - ) - - return self.ssl_context - - def get_ssl_context_no_verify(self) -> ssl.SSLContext: - """ - Return an SSL context for unverified connections. - """ - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_COMPRESSION - context.set_default_verify_paths() - return context - - def get_ssl_context_verify(self, config: SSLConfig) -> ssl.SSLContext: - """ - Return an SSL context for verified connections. - """ - if isinstance(config.verify, bool): - ca_bundle_path = DEFAULT_CA_BUNDLE_PATH - elif os.path.exists(config.verify): - ca_bundle_path = config.verify - else: - raise IOError( - "Could not find a suitable TLS CA certificate bundle, " - "invalid path: {}".format(config.verify) - ) - - context = ssl.create_default_context() - if os.path.isfile(ca_bundle_path): - context.load_verify_locations(cafile=ca_bundle_path) - elif os.path.isdir(ca_bundle_path): - context.load_verify_locations(capath=ca_bundle_path) - - if config.cert is not None: - if isinstance(config.cert, str): - context.load_cert_chain(certfile=config.cert) - else: - context.load_cert_chain(certfile=config.cert[0], keyfile=config.cert[1]) - - return context - async def close(self) -> None: self.is_closed = True