import sys
import dns._asyncbackend
+import dns._features
import dns.exception
import dns.inet
return self.writer.get_extra_info("peercert")
-try:
+if dns._features.have("doh"):
import anyio
import httpcore
import httpcore._backends.anyio
resolver, local_port, bootstrap_address, family
)
-except ImportError:
+else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import importlib.metadata
+import itertools
+import string
+from typing import Dict, List, Tuple
+
+
+def _tuple_from_text(version: str) -> Tuple:
+ text_parts = version.split(".")
+ int_parts = []
+ for text_part in text_parts:
+ digit_prefix = "".join(
+ itertools.takewhile(lambda x: x in string.digits, text_part)
+ )
+ try:
+ int_parts.append(int(digit_prefix))
+ except Exception:
+ break
+ return tuple(int_parts)
+
+
+def _version_check(
+ requirement: str,
+) -> bool:
+ """Is the requirement fulfilled?
+
+ The requirement must be of the form
+
+ package>=version
+ """
+ package, minimum = requirement.split(">=")
+ try:
+ version = importlib.metadata.version(package)
+ except Exception:
+ return False
+ t_version = _tuple_from_text(version)
+ t_minimum = _tuple_from_text(minimum)
+ if t_version < t_minimum:
+ return False
+ return True
+
+
+_cache: Dict[str, bool] = {}
+
+
+def have(feature: str) -> bool:
+ """Is *feature* available?
+
+ This tests if all optional packages needed for the
+ feature are available and recent enough.
+
+ Returns ``True`` if the feature is available,
+ and ``False`` if it is not or if metadata is
+ missing.
+ """
+ value = _cache.get(feature)
+ if value is not None:
+ return value
+ requirements = _requirements.get(feature)
+ if requirements is None:
+ # we make a cache entry here for consistency not performance
+ _cache[feature] = False
+ return False
+ ok = True
+ for requirement in requirements:
+ if not _version_check(requirement):
+ ok = False
+ break
+ _cache[feature] = ok
+ return ok
+
+
+def force(feature: str, enabled: bool) -> None:
+ """Force the status of *feature* to be *enabled*.
+
+ This method is provided as a workaround for any cases
+ where importlib.metadata is ineffective, or for testing.
+ """
+ _cache[feature] = enabled
+
+
+_requirements: Dict[str, List[str]] = {
+ ### BEGIN generated requirements
+ "dnssec": ["cryptography>=42"],
+ "doh": ["httpcore>=1.0.0", "httpx>=0.26.0", "h2>=4.1.0"],
+ "doq": ["aioquic>=0.9.25"],
+ "idna": ["idna>=3.6"],
+ "trio": ["trio>=0.23"],
+ "wmi": ["wmi>=1.5.1"],
+ ### END generated requirements
+}
import trio.socket # type: ignore
import dns._asyncbackend
+import dns._features
import dns.exception
import dns.inet
+if not dns._features.have("trio"):
+ raise ImportError("trio not found or too old")
+
def _maybe_timeout(timeout):
if timeout is not None:
raise NotImplementedError
-try:
+if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
resolver, local_port, bootstrap_address, family
)
-except ImportError:
+else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
NoDOQ,
UDPMode,
_compute_times,
- _have_http2,
_make_dot_ssl_context,
_matches_destination,
_remaining,
transport = backend.get_transport_class()(
local_address=local_address,
http1=True,
- http2=_have_http2,
+ http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
cm: contextlib.AbstractAsyncContextManager = NullContext(client)
else:
cm = httpx.AsyncClient(
- http1=True, http2=_have_http2, verify=verify, transport=transport
+ http1=True, http2=True, verify=verify, transport=transport
)
async with cm as the_client:
from datetime import datetime
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, cast
+import dns._features
import dns.exception
import dns.name
import dns.node
) # pragma: no cover
-try:
+if dns._features.have("dnssec"):
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric import dsa # pylint: disable=W0611
from cryptography.hazmat.primitives.asymmetric import ec # pylint: disable=W0611
get_algorithm_cls_from_dnskey,
)
from dns.dnssecalgs.base import GenericPrivateKey, GenericPublicKey
-except ImportError: # pragma: no cover
- validate = _need_pyca
- validate_rrsig = _need_pyca
- sign = _need_pyca
- make_dnskey = _need_pyca
- make_cdnskey = _need_pyca
- _have_pyca = False
-else:
+
validate = _validate # type: ignore
validate_rrsig = _validate_rrsig # type: ignore
sign = _sign
make_dnskey = _make_dnskey
make_cdnskey = _make_cdnskey
_have_pyca = True
+else: # pragma: no cover
+ validate = _need_pyca
+ validate_rrsig = _need_pyca
+ sign = _need_pyca
+ make_dnskey = _need_pyca
+ make_cdnskey = _need_pyca
+ _have_pyca = False
### BEGIN generated Algorithm constants
import dns.name
-try:
+if dns._features.have("dnssec"):
from dns.dnssecalgs.base import GenericPrivateKey
from dns.dnssecalgs.dsa import PrivateDSA, PrivateDSANSEC3SHA1
from dns.dnssecalgs.ecdsa import PrivateECDSAP256SHA256, PrivateECDSAP384SHA384
)
_have_cryptography = True
-except ImportError:
+else:
_have_cryptography = False
from dns.dnssectypes import Algorithm
import struct
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
-try:
+import dns._features
+
+if dns._features.have("idna"):
import idna # type: ignore
have_idna_2008 = True
-except ImportError: # pragma: no cover
+else: # pragma: no cover
have_idna_2008 = False
import dns.enum
@dns.immutable.immutable
class Name:
-
"""A DNS name.
The dns.name.Name class represents a DNS name as a tuple of
import time
from typing import Any, Dict, Optional, Tuple, Union
+import dns._features
import dns.exception
import dns.inet
import dns.message
return min(time.time() + timeout, expiration)
-_have_httpx = False
-_have_http2 = False
-try:
- import httpcore
+_have_httpx = dns._features.have("doh")
+if _have_httpx:
import httpcore._backends.sync
import httpx
_CoreNetworkBackend = httpcore.NetworkBackend
_CoreSyncStream = httpcore._backends.sync.SyncStream
- _have_httpx = True
- try:
- # See if http2 support is available.
- with httpx.Client(http2=True):
- _have_http2 = True
- except Exception:
- pass
-
class _NetworkBackend(_CoreNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
resolver, local_port, bootstrap_address, family
)
-except ImportError: # pragma: no cover
+else:
class _HTTPTransport: # type: ignore
def connect_tcp(self, host, port, timeout, local_address):
transport = _HTTPTransport(
local_address=local_address,
http1=True,
- http2=_have_http2,
+ http2=True,
verify=verify,
local_port=local_port,
bootstrap_address=bootstrap_address,
if session:
cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
else:
- cm = httpx.Client(
- http1=True, http2=_have_http2, verify=verify, transport=transport
- )
+ cm = httpx.Client(http1=True, http2=True, verify=verify, transport=transport)
with cm as session:
# see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
# GET and POST examples
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-try:
+import dns._features
+
+if dns._features.have("doq"):
import aioquic.quic.configuration # type: ignore
import dns.asyncbackend
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
- try:
+ if dns._features.have("trio"):
import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
- except ImportError:
- pass
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
-except ImportError:
+else: # pragma: no cover
have_quic = False
from typing import Any
import sys
+import dns._features
+
if sys.platform == "win32":
from typing import Any
except KeyError:
WindowsError = Exception
- try:
+ if dns._features.have("wmi"):
import threading
import pythoncom # pylint: disable=import-error
import wmi # pylint: disable=import-error
_have_wmi = True
- except Exception:
+ else:
_have_wmi = False
def _config_domain(domain):
self.async_run(run)
- @unittest.skipIf(not dns.query._have_httpx, "httpx not available")
- def testDOHGetRequestHttp1(self):
- async def run():
- saved_have_http2 = dns.query._have_http2
- try:
- dns.query._have_http2 = False
- nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
- q = dns.message.make_query("example.com.", dns.rdatatype.A)
- r = await dns.asyncquery.https(
- q, nameserver_url, post=False, timeout=4, family=family
- )
- self.assertTrue(q.is_response(r))
- finally:
- dns.query._have_http2 = saved_have_http2
-
- self.async_run(run)
-
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testDOHPostRequest(self):
async def run():
)
self.assertTrue(q.is_response(r))
- def test_get_request_http1(self):
- saved_have_http2 = dns.query._have_http2
- try:
- dns.query._have_http2 = False
- nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
- q = dns.message.make_query("example.com.", dns.rdatatype.A)
- r = dns.query.https(
- q,
- nameserver_url,
- session=self.session,
- post=False,
- timeout=4,
- family=family,
- )
- self.assertTrue(q.is_response(r))
- finally:
- dns.query._have_http2 = saved_have_http2
-
def test_post_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query("example.com.", dns.rdatatype.A)
--- /dev/null
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import pytest
+
+from dns._features import (
+ _cache,
+ _requirements,
+ _tuple_from_text,
+ _version_check,
+ force,
+ have,
+)
+
+try:
+ import cryptography
+
+ v = _tuple_from_text(cryptography.__version__)
+ have_cryptography = v >= (42, 0, 0)
+except ImportError:
+ have_cryptography = False
+
+
+def test_tuple_from_text():
+ assert _tuple_from_text("") == ()
+ assert _tuple_from_text("1") == (1,)
+ assert _tuple_from_text("1.2") == (1, 2)
+ assert _tuple_from_text("1.2rc1") == (1, 2)
+ assert _tuple_from_text("1.2.junk3") == (1, 2)
+
+
+@pytest.mark.skipif(
+ not have_cryptography, reason="cryptography not available or too old"
+)
+def test_version_check():
+ assert _version_check("cryptography>=42")
+ assert not _version_check("cryptography>=10000")
+ assert not _version_check("totallyboguspackagename>=10000")
+
+
+@pytest.mark.skipif(
+ not have_cryptography, reason="cryptography not available or too old"
+)
+def test_have():
+ # ensure cache is empty; we can't just assign as our local is shadowing the
+ # variable in dns._features
+ while len(_cache) > 0:
+ _cache.popitem()
+ assert have("dnssec")
+ assert _cache["dnssec"] == True
+ assert not have("bogusfeature")
+ assert _cache["bogusfeature"] == False
+ _requirements["unavailable"] = ["bogusmodule>=10000"]
+ try:
+ assert not have("unavailable")
+ finally:
+ del _requirements["unavailable"]
+
+
+def test_force():
+ while len(_cache) > 0:
+ _cache.popitem()
+ assert not have("bogusfeature")
+ assert _cache["bogusfeature"] == False
+ force("bogusfeature", True)
+ assert have("bogusfeature")
+ assert _cache["bogusfeature"] == True
+ force("bogusfeature", False)
+ assert not have("bogusfeature")
+ assert _cache["bogusfeature"] == False
+ _requirements["unavailable"] = ["bogusmodule>=10000"]
+ try:
+ assert not have("unavailable")
+ assert _cache["unavailable"] == False
+ force("unavailable", True)
+ assert _cache["unavailable"] == True
+ finally:
+ del _requirements["unavailable"]
--- /dev/null
+#!/usr/bin/env python3
+
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import os
+import tomllib
+
+with open("pyproject.toml", "rb") as pp:
+ pyproject = tomllib.load(pp)
+
+FEATURES = "dns/_features.py"
+NEW_FEATURES = FEATURES + ".new"
+skip = False
+with open(FEATURES, "r") as input:
+ with open(NEW_FEATURES, "w") as output:
+ for l in input.readlines():
+ l = l.rstrip()
+ if l.startswith(" ### BEGIN generated requirements"):
+ print(l, file=output)
+ for name, deps in pyproject["project"]["optional-dependencies"].items():
+ if name == "dev":
+ continue
+ print(
+ f" {repr(name)}: {repr(deps)},".replace("'", '"'),
+ file=output,
+ )
+ skip = True
+ elif l.startswith(" ### END generated requirements"):
+ skip = False
+ if not skip:
+ print(l, file=output)
+os.rename(NEW_FEATURES, FEATURES)