]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Test for recent-enough versions of optional packages. (#1041)
authorBob Halley <halley@dnspython.org>
Fri, 9 Feb 2024 02:01:41 +0000 (18:01 -0800)
committerGitHub <noreply@github.com>
Fri, 9 Feb 2024 02:01:41 +0000 (18:01 -0800)
14 files changed:
dns/_asyncio_backend.py
dns/_features.py [new file with mode: 0644]
dns/_trio_backend.py
dns/asyncquery.py
dns/dnssec.py
dns/dnssecalgs/__init__.py
dns/name.py
dns/query.py
dns/quic/__init__.py
dns/win32util.py
tests/test_async.py
tests/test_doh.py
tests/test_features.py [new file with mode: 0644]
util/generate-features [new file with mode: 0755]

index 7d4d1b54cbc4a90edc9c668ef547696441d8ad67..9d9ed3690c6c84aa88102df63481dec1bf51d3d4 100644 (file)
@@ -7,6 +7,7 @@ import socket
 import sys
 
 import dns._asyncbackend
+import dns._features
 import dns.exception
 import dns.inet
 
@@ -122,7 +123,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
         return self.writer.get_extra_info("peercert")
 
 
-try:
+if dns._features.have("doh"):
     import anyio
     import httpcore
     import httpcore._backends.anyio
@@ -206,7 +207,7 @@ try:
                 resolver, local_port, bootstrap_address, family
             )
 
-except ImportError:
+else:
     _HTTPTransport = dns._asyncbackend.NullTransport  # type: ignore
 
 
diff --git a/dns/_features.py b/dns/_features.py
new file mode 100644 (file)
index 0000000..df61e30
--- /dev/null
@@ -0,0 +1,92 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import importlib.metadata
+import itertools
+import string
+from typing import Dict, List, Tuple
+
+
+def _tuple_from_text(version: str) -> Tuple:
+    text_parts = version.split(".")
+    int_parts = []
+    for text_part in text_parts:
+        digit_prefix = "".join(
+            itertools.takewhile(lambda x: x in string.digits, text_part)
+        )
+        try:
+            int_parts.append(int(digit_prefix))
+        except Exception:
+            break
+    return tuple(int_parts)
+
+
+def _version_check(
+    requirement: str,
+) -> bool:
+    """Is the requirement fulfilled?
+
+    The requirement must be of the form
+
+        package>=version
+    """
+    package, minimum = requirement.split(">=")
+    try:
+        version = importlib.metadata.version(package)
+    except Exception:
+        return False
+    t_version = _tuple_from_text(version)
+    t_minimum = _tuple_from_text(minimum)
+    if t_version < t_minimum:
+        return False
+    return True
+
+
+_cache: Dict[str, bool] = {}
+
+
+def have(feature: str) -> bool:
+    """Is *feature* available?
+
+    This tests if all optional packages needed for the
+    feature are available and recent enough.
+
+    Returns ``True`` if the feature is available,
+    and ``False`` if it is not or if metadata is
+    missing.
+    """
+    value = _cache.get(feature)
+    if value is not None:
+        return value
+    requirements = _requirements.get(feature)
+    if requirements is None:
+        # we make a cache entry here for consistency not performance
+        _cache[feature] = False
+        return False
+    ok = True
+    for requirement in requirements:
+        if not _version_check(requirement):
+            ok = False
+            break
+    _cache[feature] = ok
+    return ok
+
+
+def force(feature: str, enabled: bool) -> None:
+    """Force the status of *feature* to be *enabled*.
+
+    This method is provided as a workaround for any cases
+    where importlib.metadata is ineffective, or for testing.
+    """
+    _cache[feature] = enabled
+
+
+_requirements: Dict[str, List[str]] = {
+    ### BEGIN generated requirements
+    "dnssec": ["cryptography>=42"],
+    "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
+    "doq": ["aioquic>=0.9.25"],
+    "idna": ["idna>=3.6"],
+    "trio": ["trio>=0.23"],
+    "wmi": ["wmi>=1.5.1"],
+    ### END generated requirements
+}
index 4d9fb820445a6b46ba3cdb23e0311d70c6fdc026..398e3276923bb6ae91d14a12de5d089cd134d7bb 100644 (file)
@@ -8,9 +8,13 @@ import trio
 import trio.socket  # type: ignore
 
 import dns._asyncbackend
+import dns._features
 import dns.exception
 import dns.inet
 
+if not dns._features.have("trio"):
+    raise ImportError("trio not found or too old")
+
 
 def _maybe_timeout(timeout):
     if timeout is not None:
@@ -95,7 +99,7 @@ class StreamSocket(dns._asyncbackend.StreamSocket):
             raise NotImplementedError
 
 
-try:
+if dns._features.have("doh"):
     import httpcore
     import httpcore._backends.trio
     import httpx
@@ -177,7 +181,7 @@ try:
                 resolver, local_port, bootstrap_address, family
             )
 
-except ImportError:
+else:
     _HTTPTransport = dns._asyncbackend.NullTransport  # type: ignore
 
 
index 7e6b389929edbc7334bf6e28aed1d521b055af5d..35a355bb6cbf4081ea2379a3b3da346719139a8f 100644 (file)
@@ -41,7 +41,6 @@ from dns.query import (
     NoDOQ,
     UDPMode,
     _compute_times,
-    _have_http2,
     _make_dot_ssl_context,
     _matches_destination,
     _remaining,
@@ -534,7 +533,7 @@ async def https(
     transport = backend.get_transport_class()(
         local_address=local_address,
         http1=True,
-        http2=_have_http2,
+        http2=True,
         verify=verify,
         local_port=local_port,
         bootstrap_address=bootstrap_address,
@@ -546,7 +545,7 @@ async def https(
         cm: contextlib.AbstractAsyncContextManager = NullContext(client)
     else:
         cm = httpx.AsyncClient(
-            http1=True, http2=_have_http2, verify=verify, transport=transport
+            http1=True, http2=True, verify=verify, transport=transport
         )
 
     async with cm as the_client:
index 2949f61977db17937e0a1895bb7a446f3d6a0af2..e49c3b795b5108486700e79e6c4ea7c534adcebf 100644 (file)
@@ -27,6 +27,7 @@ import time
 from datetime import datetime
 from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
 
+import dns._features
 import dns.exception
 import dns.name
 import dns.node
@@ -1169,7 +1170,7 @@ def _need_pyca(*args, **kwargs):
     )  # pragma: no cover
 
 
-try:
+if dns._features.have("dnssec"):
     from cryptography.exceptions import InvalidSignature
     from cryptography.hazmat.primitives.asymmetric import dsa  # pylint: disable=W0611
     from cryptography.hazmat.primitives.asymmetric import ec  # pylint: disable=W0611
@@ -1184,20 +1185,20 @@ try:
         get_algorithm_cls_from_dnskey,
     )
     from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
-except ImportError:  # pragma: no cover
-    validate = _need_pyca
-    validate_rrsig = _need_pyca
-    sign = _need_pyca
-    make_dnskey = _need_pyca
-    make_cdnskey = _need_pyca
-    _have_pyca = False
-else:
+
     validate = _validate  # type: ignore
     validate_rrsig = _validate_rrsig  # type: ignore
     sign = _sign
     make_dnskey = _make_dnskey
     make_cdnskey = _make_cdnskey
     _have_pyca = True
+else:  # pragma: no cover
+    validate = _need_pyca
+    validate_rrsig = _need_pyca
+    sign = _need_pyca
+    make_dnskey = _need_pyca
+    make_cdnskey = _need_pyca
+    _have_pyca = False
 
 ### BEGIN generated Algorithm constants
 
index d1ffd51907548778953ee656f85473a8774a840a..377193ffb21bea7ade4cae07cc6a810c3baedf69 100644 (file)
@@ -2,7 +2,7 @@ from typing import Dict, Optional, Tuple, Type, Union
 
 import dns.name
 
-try:
+if dns._features.have("dnssec"):
     from dns.dnssecalgs.base import GenericPrivateKey
     from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
     from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
@@ -16,7 +16,7 @@ try:
     )
 
     _have_cryptography = True
-except ImportError:
+else:
     _have_cryptography = False
 
 from dns.dnssectypes import Algorithm
index 2e44763ca8144a9b8caf626b01776cb7ea75e38e..933c1612246dfbc2a2935f22c0a0015053e55e9f 100644 (file)
@@ -24,11 +24,13 @@ import functools
 import struct
 from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
 
-try:
+import dns._features
+
+if dns._features.have("idna"):
     import idna  # type: ignore
 
     have_idna_2008 = True
-except ImportError:  # pragma: no cover
+else:  # pragma: no cover
     have_idna_2008 = False
 
 import dns.enum
@@ -355,7 +357,6 @@ def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes:
 
 @dns.immutable.immutable
 class Name:
-
     """A DNS name.
 
     The dns.name.Name class represents a DNS name as a tuple of
index 9ffdbf4fa6d712fcc8d2347cdb518a61fb49290a..d4bd6b92ceb97927d00b4a1ef9e9ce8150c1dbd9 100644 (file)
@@ -29,6 +29,7 @@ import struct
 import time
 from typing import Any, Dict, Optional, Tuple, Union
 
+import dns._features
 import dns.exception
 import dns.inet
 import dns.message
@@ -58,24 +59,14 @@ def _expiration_for_this_attempt(timeout, expiration):
     return min(time.time() + timeout, expiration)
 
 
-_have_httpx = False
-_have_http2 = False
-try:
-    import httpcore
+_have_httpx = dns._features.have("doh")
+if _have_httpx:
     import httpcore._backends.sync
     import httpx
 
     _CoreNetworkBackend = httpcore.NetworkBackend
     _CoreSyncStream = httpcore._backends.sync.SyncStream
 
-    _have_httpx = True
-    try:
-        # See if http2 support is available.
-        with httpx.Client(http2=True):
-            _have_http2 = True
-    except Exception:
-        pass
-
     class _NetworkBackend(_CoreNetworkBackend):
         def __init__(self, resolver, local_port, bootstrap_address, family):
             super().__init__()
@@ -148,7 +139,7 @@ try:
                 resolver, local_port, bootstrap_address, family
             )
 
-except ImportError:  # pragma: no cover
+else:
 
     class _HTTPTransport:  # type: ignore
         def connect_tcp(self, host, port, timeout, local_address):
@@ -462,7 +453,7 @@ def https(
     transport = _HTTPTransport(
         local_address=local_address,
         http1=True,
-        http2=_have_http2,
+        http2=True,
         verify=verify,
         local_port=local_port,
         bootstrap_address=bootstrap_address,
@@ -473,9 +464,7 @@ def https(
     if session:
         cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
     else:
-        cm = httpx.Client(
-            http1=True, http2=_have_http2, verify=verify, transport=transport
-        )
+        cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
     with cm as session:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
         # GET and POST examples
index d9eee7b50f5d4e1255a44b3665a2cccce893b09b..15803d933d6f1bdbdb750c0a57d9ad01b3450dc6 100644 (file)
@@ -1,6 +1,8 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
-try:
+import dns._features
+
+if dns._features.have("doq"):
     import aioquic.quic.configuration  # type: ignore
 
     import dns.asyncbackend
@@ -31,7 +33,7 @@ try:
 
     _async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
 
-    try:
+    if dns._features.have("trio"):
         import trio
 
         from dns.quic._trio import (  # pylint: disable=ungrouped-imports
@@ -47,15 +49,13 @@ try:
             return TrioQuicManager(context, *args, **kwargs)
 
         _async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
-    except ImportError:
-        pass
 
     def factories_for_backend(backend=None):
         if backend is None:
             backend = dns.asyncbackend.get_default_backend()
         return _async_factories[backend.name()]
 
-except ImportError:
+else:  # pragma: no cover
     have_quic = False
 
     from typing import Any
index 3e67c6bb20378ddb9c4b6c6d283ec131e1e5c202..aaa7e93e328f7b25ee2271d6ebb9accf0616f78f 100644 (file)
@@ -1,5 +1,7 @@
 import sys
 
+import dns._features
+
 if sys.platform == "win32":
     from typing import Any
 
@@ -15,14 +17,14 @@ if sys.platform == "win32":
     except KeyError:
         WindowsError = Exception
 
-    try:
+    if dns._features.have("wmi"):
         import threading
 
         import pythoncom  # pylint: disable=import-error
         import wmi  # pylint: disable=import-error
 
         _have_wmi = True
-    except Exception:
+    else:
         _have_wmi = False
 
     def _config_domain(domain):
index 7c432ff57458ad6b8efc9a01216630e7222104f2..4ea2301586a00039ad810e3ae6b3cf1c81ada777 100644 (file)
@@ -541,23 +541,6 @@ class AsyncTests(unittest.TestCase):
 
         self.async_run(run)
 
-    @unittest.skipIf(not dns.query._have_httpx, "httpx not available")
-    def testDOHGetRequestHttp1(self):
-        async def run():
-            saved_have_http2 = dns.query._have_http2
-            try:
-                dns.query._have_http2 = False
-                nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
-                q = dns.message.make_query("example.com.", dns.rdatatype.A)
-                r = await dns.asyncquery.https(
-                    q, nameserver_url, post=False, timeout=4, family=family
-                )
-                self.assertTrue(q.is_response(r))
-            finally:
-                dns.query._have_http2 = saved_have_http2
-
-        self.async_run(run)
-
     @unittest.skipIf(not dns.query._have_httpx, "httpx not available")
     def testDOHPostRequest(self):
         async def run():
index a2d9bad53bfe7f623ad2c84bcf584a9fdd13653f..0a5908f95b18d1241b36cf4328f53feeec859623 100644 (file)
@@ -95,24 +95,6 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase):
         )
         self.assertTrue(q.is_response(r))
 
-    def test_get_request_http1(self):
-        saved_have_http2 = dns.query._have_http2
-        try:
-            dns.query._have_http2 = False
-            nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
-            q = dns.message.make_query("example.com.", dns.rdatatype.A)
-            r = dns.query.https(
-                q,
-                nameserver_url,
-                session=self.session,
-                post=False,
-                timeout=4,
-                family=family,
-            )
-            self.assertTrue(q.is_response(r))
-        finally:
-            dns.query._have_http2 = saved_have_http2
-
     def test_post_request(self):
         nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
         q = dns.message.make_query("example.com.", dns.rdatatype.A)
diff --git a/tests/test_features.py b/tests/test_features.py
new file mode 100644 (file)
index 0000000..766f65f
--- /dev/null
@@ -0,0 +1,77 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import pytest
+
+from dns._features import (
+    _cache,
+    _requirements,
+    _tuple_from_text,
+    _version_check,
+    force,
+    have,
+)
+
+try:
+    import cryptography
+
+    v = _tuple_from_text(cryptography.__version__)
+    have_cryptography = v >= (42, 0, 0)
+except ImportError:
+    have_cryptography = False
+
+
+def test_tuple_from_text():
+    assert _tuple_from_text("") == ()
+    assert _tuple_from_text("1") == (1,)
+    assert _tuple_from_text("1.2") == (1, 2)
+    assert _tuple_from_text("1.2rc1") == (1, 2)
+    assert _tuple_from_text("1.2.junk3") == (1, 2)
+
+
+@pytest.mark.skipif(
+    not have_cryptography, reason="cryptography not available or too old"
+)
+def test_version_check():
+    assert _version_check("cryptography>=42")
+    assert not _version_check("cryptography>=10000")
+    assert not _version_check("totallyboguspackagename>=10000")
+
+
+@pytest.mark.skipif(
+    not have_cryptography, reason="cryptography not available or too old"
+)
+def test_have():
+    # ensure cache is empty; we can't just assign as our local is shadowing the
+    # variable in dns._features
+    while len(_cache) > 0:
+        _cache.popitem()
+    assert have("dnssec")
+    assert _cache["dnssec"] == True
+    assert not have("bogusfeature")
+    assert _cache["bogusfeature"] == False
+    _requirements["unavailable"] = ["bogusmodule>=10000"]
+    try:
+        assert not have("unavailable")
+    finally:
+        del _requirements["unavailable"]
+
+
+def test_force():
+    while len(_cache) > 0:
+        _cache.popitem()
+    assert not have("bogusfeature")
+    assert _cache["bogusfeature"] == False
+    force("bogusfeature", True)
+    assert have("bogusfeature")
+    assert _cache["bogusfeature"] == True
+    force("bogusfeature", False)
+    assert not have("bogusfeature")
+    assert _cache["bogusfeature"] == False
+    _requirements["unavailable"] = ["bogusmodule>=10000"]
+    try:
+        assert not have("unavailable")
+        assert _cache["unavailable"] == False
+        force("unavailable", True)
+        assert _cache["unavailable"] == True
+    finally:
+        del _requirements["unavailable"]
diff --git a/util/generate-features b/util/generate-features
new file mode 100755 (executable)
index 0000000..22eb875
--- /dev/null
@@ -0,0 +1,32 @@
+#!/usr/bin/env python3
+
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import os
+import tomllib
+
+with open("pyproject.toml", "rb") as pp:
+    pyproject = tomllib.load(pp)
+
+FEATURES = "dns/_features.py"
+NEW_FEATURES = FEATURES + ".new"
+skip = False
+with open(FEATURES, "r") as input:
+    with open(NEW_FEATURES, "w") as output:
+        for l in input.readlines():
+            l = l.rstrip()
+            if l.startswith("    ### BEGIN generated requirements"):
+                print(l, file=output)
+                for name, deps in pyproject["project"]["optional-dependencies"].items():
+                    if name == "dev":
+                        continue
+                    print(
+                        f"    {repr(name)}: {repr(deps)},".replace("'", '"'),
+                        file=output,
+                    )
+                skip = True
+            elif l.startswith("    ### END generated requirements"):
+                skip = False
+            if not skip:
+                print(l, file=output)
+os.rename(NEW_FEATURES, FEATURES)