From: Bob Halley Date: Fri, 9 Feb 2024 02:01:41 +0000 (-0800) Subject: Test for recent-enough versions of optional packages. (#1041) X-Git-Tag: v2.6.0rc1~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3c6a7976a5746852f841544ef92edaf7bd12450b;p=thirdparty%2Fdnspython.git Test for recent-enough versions of optional packages. (#1041) --- diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 7d4d1b54..9d9ed369 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -7,6 +7,7 @@ import socket import sys import dns._asyncbackend +import dns._features import dns.exception import dns.inet @@ -122,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): return self.writer.get_extra_info("peercert") -try: +if dns._features.have("doh"): import anyio import httpcore import httpcore._backends.anyio @@ -206,7 +207,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: +else: _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore diff --git a/dns/_features.py b/dns/_features.py new file mode 100644 index 00000000..df61e300 --- /dev/null +++ b/dns/_features.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import importlib.metadata +import itertools +import string +from typing import Dict, List, Tuple + + +def _tuple_from_text(version: str) -> Tuple: + text_parts = version.split(".") + int_parts = [] + for text_part in text_parts: + digit_prefix = "".join( + itertools.takewhile(lambda x: x in string.digits, text_part) + ) + try: + int_parts.append(int(digit_prefix)) + except Exception: + break + return tuple(int_parts) + + +def _version_check( + requirement: str, +) -> bool: + """Is the requirement fulfilled? + + The requirement must be of the form + + package>=version + """ + package, minimum = requirement.split(">=") + try: + version = importlib.metadata.version(package) + except Exception: + return False + t_version = _tuple_from_text(version) + t_minimum = _tuple_from_text(minimum) + if t_version < t_minimum: + return False + return True + + +_cache: Dict[str, bool] = {} + + +def have(feature: str) -> bool: + """Is *feature* available? + + This tests if all optional packages needed for the + feature are available and recent enough. + + Returns ``True`` if the feature is available, + and ``False`` if it is not or if metadata is + missing. + """ + value = _cache.get(feature) + if value is not None: + return value + requirements = _requirements.get(feature) + if requirements is None: + # we make a cache entry here for consistency not performance + _cache[feature] = False + return False + ok = True + for requirement in requirements: + if not _version_check(requirement): + ok = False + break + _cache[feature] = ok + return ok + + +def force(feature: str, enabled: bool) -> None: + """Force the status of *feature* to be *enabled*. + + This method is provided as a workaround for any cases + where importlib.metadata is ineffective, or for testing. + """ + _cache[feature] = enabled + + +_requirements: Dict[str, List[str]] = { + ### BEGIN generated requirements + "dnssec": ["cryptography>=42"], + "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"], + "doq": ["aioquic>=0.9.25"], + "idna": ["idna>=3.6"], + "trio": ["trio>=0.23"], + "wmi": ["wmi>=1.5.1"], + ### END generated requirements +} diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 4d9fb820..398e3276 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -8,9 +8,13 @@ import trio import trio.socket # type: ignore import dns._asyncbackend +import dns._features import dns.exception import dns.inet +if not dns._features.have("trio"): + raise ImportError("trio not found or too old") + def _maybe_timeout(timeout): if timeout is not None: @@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket): raise NotImplementedError -try: +if dns._features.have("doh"): import httpcore import httpcore._backends.trio import httpx @@ -177,7 +181,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: +else: _HTTPTransport = dns._asyncbackend.NullTransport # type: ignore diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 7e6b3899..35a355bb 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -41,7 +41,6 @@ from dns.query import ( NoDOQ, UDPMode, _compute_times, - _have_http2, _make_dot_ssl_context, _matches_destination, _remaining, @@ -534,7 +533,7 @@ async def https( transport = backend.get_transport_class()( local_address=local_address, http1=True, - http2=_have_http2, + http2=True, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -546,7 +545,7 @@ async def https( cm: contextlib.AbstractAsyncContextManager = NullContext(client) else: cm = httpx.AsyncClient( - http1=True, http2=_have_http2, verify=verify, transport=transport + http1=True, http2=True, verify=verify, transport=transport ) async with cm as the_client: diff --git a/dns/dnssec.py b/dns/dnssec.py index 2949f619..e49c3b79 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -27,6 +27,7 @@ import time from datetime import datetime from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast +import dns._features import dns.exception import dns.name import dns.node @@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs): ) # pragma: no cover -try: +if dns._features.have("dnssec"): from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611 from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611 @@ -1184,20 +1185,20 @@ try: get_algorithm_cls_from_dnskey, ) from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey -except ImportError: # pragma: no cover - validate = _need_pyca - validate_rrsig = _need_pyca - sign = _need_pyca - make_dnskey = _need_pyca - make_cdnskey = _need_pyca - _have_pyca = False -else: + validate = _validate # type: ignore validate_rrsig = _validate_rrsig # type: ignore sign = _sign make_dnskey = _make_dnskey make_cdnskey = _make_cdnskey _have_pyca = True +else: # pragma: no cover + validate = _need_pyca + validate_rrsig = _need_pyca + sign = _need_pyca + make_dnskey = _need_pyca + make_cdnskey = _need_pyca + _have_pyca = False ### BEGIN generated Algorithm constants diff --git a/dns/dnssecalgs/__init__.py b/dns/dnssecalgs/__init__.py index d1ffd519..377193ff 100644 --- a/dns/dnssecalgs/__init__.py +++ b/dns/dnssecalgs/__init__.py @@ -2,7 +2,7 @@ from typing import Dict, Optional, Tuple, Type, Union import dns.name -try: +if dns._features.have("dnssec"): from dns.dnssecalgs.base import GenericPrivateKey from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1 from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384 @@ -16,7 +16,7 @@ try: ) _have_cryptography = True -except ImportError: +else: _have_cryptography = False from dns.dnssectypes import Algorithm diff --git a/dns/name.py b/dns/name.py index 2e44763c..933c1612 100644 --- a/dns/name.py +++ b/dns/name.py @@ -24,11 +24,13 @@ import functools import struct from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union -try: +import dns._features + +if dns._features.have("idna"): import idna # type: ignore have_idna_2008 = True -except ImportError: # pragma: no cover +else: # pragma: no cover have_idna_2008 = False import dns.enum @@ -355,7 +357,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes: @dns.immutable.immutable class Name: - """A DNS name. The dns.name.Name class represents a DNS name as a tuple of diff --git a/dns/query.py b/dns/query.py index 9ffdbf4f..d4bd6b92 100644 --- a/dns/query.py +++ b/dns/query.py @@ -29,6 +29,7 @@ import struct import time from typing import Any, Dict, Optional, Tuple, Union +import dns._features import dns.exception import dns.inet import dns.message @@ -58,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration): return min(time.time() + timeout, expiration) -_have_httpx = False -_have_http2 = False -try: - import httpcore +_have_httpx = dns._features.have("doh") +if _have_httpx: import httpcore._backends.sync import httpx _CoreNetworkBackend = httpcore.NetworkBackend _CoreSyncStream = httpcore._backends.sync.SyncStream - _have_httpx = True - try: - # See if http2 support is available. - with httpx.Client(http2=True): - _have_http2 = True - except Exception: - pass - class _NetworkBackend(_CoreNetworkBackend): def __init__(self, resolver, local_port, bootstrap_address, family): super().__init__() @@ -148,7 +139,7 @@ try: resolver, local_port, bootstrap_address, family ) -except ImportError: # pragma: no cover +else: class _HTTPTransport: # type: ignore def connect_tcp(self, host, port, timeout, local_address): @@ -462,7 +453,7 @@ def https( transport = _HTTPTransport( local_address=local_address, http1=True, - http2=_have_http2, + http2=True, verify=verify, local_port=local_port, bootstrap_address=bootstrap_address, @@ -473,9 +464,7 @@ def https( if session: cm: contextlib.AbstractContextManager = contextlib.nullcontext(session) else: - cm = httpx.Client( - http1=True, http2=_have_http2, verify=verify, transport=transport - ) + cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport) with cm as session: # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH # GET and POST examples diff --git a/dns/quic/__init__.py b/dns/quic/__init__.py index d9eee7b5..15803d93 100644 --- a/dns/quic/__init__.py +++ b/dns/quic/__init__.py @@ -1,6 +1,8 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license -try: +import dns._features + +if dns._features.have("doq"): import aioquic.quic.configuration # type: ignore import dns.asyncbackend @@ -31,7 +33,7 @@ try: _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)} - try: + if dns._features.have("trio"): import trio from dns.quic._trio import ( # pylint: disable=ungrouped-imports @@ -47,15 +49,13 @@ try: return TrioQuicManager(context, *args, **kwargs) _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory) - except ImportError: - pass def factories_for_backend(backend=None): if backend is None: backend = dns.asyncbackend.get_default_backend() return _async_factories[backend.name()] -except ImportError: +else: # pragma: no cover have_quic = False from typing import Any diff --git a/dns/win32util.py b/dns/win32util.py index 3e67c6bb..aaa7e93e 100644 --- a/dns/win32util.py +++ b/dns/win32util.py @@ -1,5 +1,7 @@ import sys +import dns._features + if sys.platform == "win32": from typing import Any @@ -15,14 +17,14 @@ if sys.platform == "win32": except KeyError: WindowsError = Exception - try: + if dns._features.have("wmi"): import threading import pythoncom # pylint: disable=import-error import wmi # pylint: disable=import-error _have_wmi = True - except Exception: + else: _have_wmi = False def _config_domain(domain): diff --git a/tests/test_async.py b/tests/test_async.py index 7c432ff5..4ea23015 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -541,23 +541,6 @@ class AsyncTests(unittest.TestCase): self.async_run(run) - @unittest.skipIf(not dns.query._have_httpx, "httpx not available") - def testDOHGetRequestHttp1(self): - async def run(): - saved_have_http2 = dns.query._have_http2 - try: - dns.query._have_http2 = False - nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = await dns.asyncquery.https( - q, nameserver_url, post=False, timeout=4, family=family - ) - self.assertTrue(q.is_response(r)) - finally: - dns.query._have_http2 = saved_have_http2 - - self.async_run(run) - @unittest.skipIf(not dns.query._have_httpx, "httpx not available") def testDOHPostRequest(self): async def run(): diff --git a/tests/test_doh.py b/tests/test_doh.py index a2d9bad5..0a5908f9 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -95,24 +95,6 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): ) self.assertTrue(q.is_response(r)) - def test_get_request_http1(self): - saved_have_http2 = dns.query._have_http2 - try: - dns.query._have_http2 = False - nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) - q = dns.message.make_query("example.com.", dns.rdatatype.A) - r = dns.query.https( - q, - nameserver_url, - session=self.session, - post=False, - timeout=4, - family=family, - ) - self.assertTrue(q.is_response(r)) - finally: - dns.query._have_http2 = saved_have_http2 - def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query("example.com.", dns.rdatatype.A) diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 00000000..766f65f7 --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import pytest + +from dns._features import ( + _cache, + _requirements, + _tuple_from_text, + _version_check, + force, + have, +) + +try: + import cryptography + + v = _tuple_from_text(cryptography.__version__) + have_cryptography = v >= (42, 0, 0) +except ImportError: + have_cryptography = False + + +def test_tuple_from_text(): + assert _tuple_from_text("") == () + assert _tuple_from_text("1") == (1,) + assert _tuple_from_text("1.2") == (1, 2) + assert _tuple_from_text("1.2rc1") == (1, 2) + assert _tuple_from_text("1.2.junk3") == (1, 2) + + +@pytest.mark.skipif( + not have_cryptography, reason="cryptography not available or too old" +) +def test_version_check(): + assert _version_check("cryptography>=42") + assert not _version_check("cryptography>=10000") + assert not _version_check("totallyboguspackagename>=10000") + + +@pytest.mark.skipif( + not have_cryptography, reason="cryptography not available or too old" +) +def test_have(): + # ensure cache is empty; we can't just assign as our local is shadowing the + # variable in dns._features + while len(_cache) > 0: + _cache.popitem() + assert have("dnssec") + assert _cache["dnssec"] == True + assert not have("bogusfeature") + assert _cache["bogusfeature"] == False + _requirements["unavailable"] = ["bogusmodule>=10000"] + try: + assert not have("unavailable") + finally: + del _requirements["unavailable"] + + +def test_force(): + while len(_cache) > 0: + _cache.popitem() + assert not have("bogusfeature") + assert _cache["bogusfeature"] == False + force("bogusfeature", True) + assert have("bogusfeature") + assert _cache["bogusfeature"] == True + force("bogusfeature", False) + assert not have("bogusfeature") + assert _cache["bogusfeature"] == False + _requirements["unavailable"] = ["bogusmodule>=10000"] + try: + assert not have("unavailable") + assert _cache["unavailable"] == False + force("unavailable", True) + assert _cache["unavailable"] == True + finally: + del _requirements["unavailable"] diff --git a/util/generate-features b/util/generate-features new file mode 100755 index 00000000..22eb8759 --- /dev/null +++ b/util/generate-features @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import os +import tomllib + +with open("pyproject.toml", "rb") as pp: + pyproject = tomllib.load(pp) + +FEATURES = "dns/_features.py" +NEW_FEATURES = FEATURES + ".new" +skip = False +with open(FEATURES, "r") as input: + with open(NEW_FEATURES, "w") as output: + for l in input.readlines(): + l = l.rstrip() + if l.startswith(" ### BEGIN generated requirements"): + print(l, file=output) + for name, deps in pyproject["project"]["optional-dependencies"].items(): + if name == "dev": + continue + print( + f" {repr(name)}: {repr(deps)},".replace("'", '"'), + file=output, + ) + skip = True + elif l.startswith(" ### END generated requirements"): + skip = False + if not skip: + print(l, file=output) +os.rename(NEW_FEATURES, FEATURES)