From: Bob Halley Date: Fri, 25 Feb 2022 21:29:09 +0000 (-0800) Subject: Add integrated typing to much of dnspython. X-Git-Tag: v2.3.0rc1~119^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F785%2Fhead;p=thirdparty%2Fdnspython.git Add integrated typing to much of dnspython. --- diff --git a/Makefile b/Makefile index 76e70286..fe4e8bd9 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,10 @@ potestlf: poetry run pytest --lf potype: - poetry run python -m mypy examples tests dns/*.py + poetry run python -m mypy dns/*.py + +potypetests: + poetry run python -m mypy --check-untyped-defs examples tests polint: poetry run pylint dns diff --git a/dns/__init__.py b/dns/__init__.py index 0473ca17..a620f975 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -22,6 +22,7 @@ __all__ = [ 'asyncquery', 'asyncresolver', 'dnssec', + 'dnssectypes', 'e164', 'edns', 'entropy', @@ -60,6 +61,7 @@ __all__ = [ 'wire', 'xfr', 'zone', + 'zonetypes', 'zonefile', ] diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index 1f3a8287..674bf6ea 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -41,6 +41,9 @@ class Socket: # pragma: no cover class DatagramSocket(Socket): # pragma: no cover + def __init__(self, family: int): + self.family = family + async def sendto(self, what, destination, timeout): raise NotImplementedError @@ -67,3 +70,6 @@ class Backend: # pragma: no cover def datagram_connection_required(self): return False + + async def sleep(self, interval): + raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index d737d13c..9d458da0 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -55,8 +55,8 @@ async def _maybe_wait_for(awaitable, timeout): class DatagramSocket(dns._asyncbackend.DatagramSocket): - def __init__(self, family, transport, protocol): - self.family = family + def __init__(self, family: int, transport, protocol): + super().__init__(family) self.transport = transport self.protocol = protocol diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index 535eb84d..3f22b5d3 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -26,8 +26,8 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): + super().__init__(socket.family) self.socket = socket - self.family = socket.family async def sendto(self, what, destination, timeout): async with _maybe_timeout(timeout): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 863d413e..8a337e9d 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -26,8 +26,8 @@ _lltuple = dns.inet.low_level_address_tuple class DatagramSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): + super().__init__(socket.family) self.socket = socket - self.family = socket.family async def sendto(self, what, destination, timeout): with _maybe_timeout(timeout): diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index ad79a572..a8f794ac 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Dict + import dns.exception # pylint: disable=unused-import @@ -10,7 +12,7 @@ from dns._asyncbackend import Socket, DatagramSocket, StreamSocket, Backend # n _default_backend = None -_backends = {} +_backends: Dict[str, Backend] = {} # Allow sniffio import to be disabled for testing purposes _no_sniffio = False @@ -19,7 +21,7 @@ class AsyncLibraryNotFoundError(dns.exception.DNSException): pass -def get_backend(name): +def get_backend(name: str) -> Backend: """Get the specified asynchronous backend. *name*, a ``str``, the name of the backend. Currently the "trio", @@ -46,7 +48,7 @@ def get_backend(name): return backend -def sniff(): +def sniff() -> str: """Attempt to determine the in-use asynchronous I/O library by using the ``sniffio`` module if it is available. @@ -71,13 +73,14 @@ def sniff(): except RuntimeError: raise AsyncLibraryNotFoundError('no async library detected') except AttributeError: # pragma: no cover - # we have to check current_task on 3.6 - if not asyncio.Task.current_task(): + # we have to check current_task on 3.6; we ignore for mypy + # purposes at it is otherwise unhappy on >= 3.7 + if not asyncio.Task.current_task(): # type: ignore raise AsyncLibraryNotFoundError('no async library detected') return 'asyncio' -def get_default_backend(): +def get_default_backend() -> Backend: """Get the default backend, initializing it if necessary. """ if _default_backend: @@ -86,7 +89,7 @@ def get_default_backend(): return set_default_backend(sniff()) -def set_default_backend(name): +def set_default_backend(name: str): """Set the default backend. It's not normally necessary to call this method, as diff --git a/dns/asyncbackend.pyi b/dns/asyncbackend.pyi deleted file mode 100644 index 1ec9d32b..00000000 --- a/dns/asyncbackend.pyi +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -class Backend: - ... - -def get_backend(name: str) -> Backend: - ... -def sniff() -> str: - ... -def get_default_backend() -> Backend: - ... -def set_default_backend(name: str) -> Backend: - ... diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 13f687fb..8c35d1aa 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -17,6 +17,8 @@ """Talk to a DNS server.""" +from typing import Any, Dict, Optional, Tuple, Union + import base64 import socket import struct @@ -31,6 +33,7 @@ import dns.message import dns.rcode import dns.rdataclass import dns.rdatatype +import dns.transaction from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ UDPMode, _have_httpx, _have_http2, NoDOH @@ -67,7 +70,9 @@ def _timeout(expiration, now=None): return None -async def send_udp(sock, what, destination, expiration=None): +async def send_udp(sock: dns.asyncbackend.DatagramSocket, + what: Union[dns.message.Message, bytes], destination: Any, + expiration: Optional[float]=None) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -91,10 +96,11 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination=None, expiration=None, +async def receive_udp(sock: dns.asyncbackend.DatagramSocket, + destination: Optional[Any]=None, expiration: Optional[float]=None, ignore_unexpected=False, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False, - raise_on_truncation=False): + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', + ignore_trailing=False, raise_on_truncation=False) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -116,10 +122,11 @@ async def receive_udp(sock, destination=None, expiration=None, raise_on_truncation=raise_on_truncation) return (r, received_time, from_address) -async def udp(q, where, timeout=None, port=53, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False, sock=None, - backend=None): +async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, + raise_on_truncation=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, @@ -152,6 +159,7 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, dtuple = None s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple) + assert s is not None await send_udp(s, wire, destination, expiration) (r, received_time, _) = await receive_udp(s, destination, expiration, ignore_unexpected, @@ -167,10 +175,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, if not sock and s: await s.close() -async def udp_with_fallback(q, where, timeout=None, port=53, source=None, - source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False, - udp_sock=None, tcp_sock=None, backend=None): +async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, + udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None, + tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -203,7 +213,9 @@ async def udp_with_fallback(q, where, timeout=None, port=53, source=None, return (response, True) -async def send_tcp(sock, what, expiration=None): +async def send_tcp(sock: dns.asyncbackend.StreamSocket, + what: Union[dns.message.Message, bytes], + expiration: Optional[float]=None) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -213,12 +225,14 @@ async def send_tcp(sock, what, expiration=None): """ if isinstance(what, dns.message.Message): - what = what.to_wire() - l = len(what) + wire = what.to_wire() + else: + wire = what + l = len(wire) # copying the wire into tcpmsg is inefficient, but lets us # avoid writev() or doing a short write that would get pushed # onto the net - tcpmsg = struct.pack("!H", l) + what + tcpmsg = struct.pack("!H", l) + wire sent_time = time.time() await sock.sendall(tcpmsg, _timeout(expiration, sent_time)) return (len(tcpmsg), sent_time) @@ -238,8 +252,10 @@ async def _read_exactly(sock, count, expiration): return s -async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False): +async def receive_tcp(sock: dns.asyncbackend.StreamSocket, + expiration: Optional[float]=None, one_rr_per_rrset=False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, + request_mac=b'', ignore_trailing=False) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -258,9 +274,11 @@ async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, return (r, received_time) -async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - backend=None): +async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + sock: Optional[dns.asyncbackend.StreamSocket]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the @@ -297,6 +315,7 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, backend = dns.asyncbackend.get_default_backend() s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout) + assert s is not None await send_tcp(s, wire, expiration) (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, q.keyring, q.mac, @@ -309,9 +328,13 @@ async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, if not sock and s: await s.close() -async def tls(q, where, timeout=None, port=853, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - backend=None, ssl_context=None, server_hostname=None): +async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, + port=853, source: Optional[str]=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + sock: Optional[dns.asyncbackend.StreamSocket]=None, + backend: Optional[dns.asyncbackend.Backend]=None, + ssl_context: Optional[ssl.SSLContext]=None, + server_hostname: Optional[str]=None) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket @@ -363,8 +386,10 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, if not sock and s: await s.close() -async def https(q, where, timeout=None, port=443, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, client=None, +async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, + port=443, source: Optional[str]=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + client: Optional[httpx.AsyncClient]=None, path='/dns-query', post=True, verify=True): """Return the response obtained after sending a query via DNS-over-HTTPS. @@ -419,18 +444,18 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0, timeout=timeout) else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") - wire = wire.decode() # httpx does a repr() if we give it bytes + twire = wire.decode() # httpx does a repr() if we give it bytes response = await client.get(url, headers=headers, timeout=timeout, - params={"dns": wire}) + params={"dns": twire}) finally: if client_to_close: - await client.aclose() + await client_to_close.aclose() # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH # status codes if response.status_code < 200 or response.status_code > 299: raise ValueError('{} responded with status code {}' - '\nResponse body: {}'.format(where, + '\nResponse body: {!r}'.format(where, response.status_code, response.content)) r = dns.message.from_wire(response.content, @@ -438,14 +463,16 @@ async def https(q, where, timeout=None, port=443, source=None, source_port=0, request_mac=q.request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing) - r.time = response.elapsed + r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r -async def inbound_xfr(where, txn_manager, query=None, - port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER, backend=None): +async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message]=None, + port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, + source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend]=None): """Conduct an inbound transfer and apply it via a transaction from the txn_manager. diff --git a/dns/asyncquery.pyi b/dns/asyncquery.pyi deleted file mode 100644 index a03434c2..00000000 --- a/dns/asyncquery.pyi +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional, Union, Dict, Generator, Any -from . import tsig, rdatatype, rdataclass, name, message, asyncbackend - -# If the ssl import works, then -# -# error: Name 'ssl' already defined (by an import) -# -# is expected and can be ignored. -try: - import ssl -except ImportError: - class ssl: # type: ignore - SSLContext : Dict = {} - -async def udp(q : message.Message, where : str, - timeout : Optional[float] = None, port=53, - source : Optional[str] = None, source_port : Optional[int] = 0, - ignore_unexpected : Optional[bool] = False, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.DatagramSocket] = None, - backend : Optional[asyncbackend.Backend] = None) -> message.Message: - pass - -async def tcp(q : message.Message, where : str, timeout : float = None, port=53, - af : Optional[int] = None, source : Optional[str] = None, - source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.StreamSocket] = None, - backend : Optional[asyncbackend.Backend] = None) -> message.Message: - pass - -async def tls(q : message.Message, where : str, - timeout : Optional[float] = None, port=53, - source : Optional[str] = None, source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[asyncbackend.StreamSocket] = None, - backend : Optional[asyncbackend.Backend] = None, - ssl_context: Optional[ssl.SSLContext] = None, - server_hostname: Optional[str] = None) -> message.Message: - pass diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index b4837567..72ef0412 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -17,11 +17,14 @@ """Asynchronous DNS stub resolver.""" +from typing import Optional, Union + import time import dns.asyncbackend import dns.asyncquery import dns.exception +import dns.name import dns.query import dns.resolver # lgtm[py/import-and-import-from] @@ -37,11 +40,13 @@ _tcp = dns.asyncquery.tcp class Resolver(dns.resolver.BaseResolver): """Asynchronous DNS stub resolver.""" - async def resolve(self, qname, rdtype=dns.rdatatype.A, + async def resolve(self, qname: Union[dns.name.Name, str], + rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None, - backend=None): + tcp=False, source: Optional[str]=None, + raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None, search: Optional[bool]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, @@ -66,6 +71,7 @@ class Resolver(dns.resolver.BaseResolver): if answer is not None: # cache hit! return answer + assert request is not None # needed for type checking done = False while not done: (nameserver, port, tcp, backoff) = resolution.next_nameserver() @@ -101,7 +107,7 @@ class Resolver(dns.resolver.BaseResolver): if answer is not None: return answer - async def resolve_address(self, ipaddr, *args, **kwargs): + async def resolve_address(self, ipaddr: str, *args, **kwargs) -> dns.resolver.Answer: """Use an asynchronous resolver to run a reverse query for PTR records. @@ -116,15 +122,19 @@ class Resolver(dns.resolver.BaseResolver): function. """ - + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs = {} + modified_kwargs.update(kwargs) + modified_kwargs['rdtype'] = dns.rdatatype.PTR + modified_kwargs['rdclass'] = dns.rdataclass.IN return await self.resolve(dns.reversename.from_address(ipaddr), - rdtype=dns.rdatatype.PTR, - rdclass=dns.rdataclass.IN, - *args, **kwargs) + *args, **modified_kwargs) # pylint: disable=redefined-outer-name - async def canonical_name(self, name): + async def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. The canonical name is the name the resolver uses for queries @@ -149,10 +159,11 @@ class Resolver(dns.resolver.BaseResolver): default_resolver = None -def get_default_resolver(): +def get_default_resolver() -> Resolver: """Get the default asynchronous resolver, initializing it if necessary.""" if default_resolver is None: reset_default_resolver() + assert default_resolver is not None return default_resolver @@ -167,9 +178,13 @@ def reset_default_resolver(): default_resolver = Resolver() -async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None, backend=None): +async def resolve(qname: Union[dns.name.Name, str], + rdtype=dns.rdatatype.A, + rdclass=dns.rdataclass.IN, + tcp=False, source: Optional[str]=None, + raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None, search: Optional[bool]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. This is a convenience function that uses the default resolver @@ -185,7 +200,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, backend) -async def resolve_address(ipaddr, *args, **kwargs): +async def resolve_address(ipaddr: str, *args, **kwargs) -> dns.resolver.Answer: """Use a resolver to run a reverse query for PTR records. See :py:func:`dns.asyncresolver.Resolver.resolve_address` for more @@ -194,7 +209,7 @@ async def resolve_address(ipaddr, *args, **kwargs): return await get_default_resolver().resolve_address(ipaddr, *args, **kwargs) -async def canonical_name(name): +async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. See :py:func:`dns.resolver.Resolver.canonical_name` for more @@ -203,8 +218,9 @@ async def canonical_name(name): return await get_default_resolver().canonical_name(name) -async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, - resolver=None, backend=None): +async def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN, + tcp=False, resolver: Optional[Resolver]=None, + backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name: """Find the name of the zone which contains the specified name. See :py:func:`dns.resolver.Resolver.zone_for_name` for more @@ -221,6 +237,7 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, try: answer = await resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp, backend=backend) + assert answer.rrset is not None if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher diff --git a/dns/asyncresolver.pyi b/dns/asyncresolver.pyi deleted file mode 100644 index 92759d29..00000000 --- a/dns/asyncresolver.pyi +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union, Optional, List, Any, Dict -from . import exception, rdataclass, name, rdatatype, asyncbackend - -async def resolve(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None, - search : Optional[bool]=None, - backend : Optional[asyncbackend.Backend]=None): - ... -async def resolve_address(self, ipaddr: str, - *args: Any, **kwargs: Optional[Dict]): - ... - -class Resolver: - def __init__(self, filename : Optional[str] = '/etc/resolv.conf', - configure : Optional[bool] = True): - self.nameservers : List[str] - async def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None, - search : Optional[bool]=None, - backend : Optional[asyncbackend.Backend]=None): - ... diff --git a/dns/dnssec.py b/dns/dnssec.py index dee4e618..bb20005e 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -17,12 +17,15 @@ """Common DNSSEC-related functions and constants.""" +from typing import Any, cast, Dict, List, Optional, Tuple, Union + import hashlib import struct import time import base64 -import dns.enum +from dns.dnssectypes import * + import dns.exception import dns.name import dns.node @@ -30,6 +33,10 @@ import dns.rdataset import dns.rdata import dns.rdatatype import dns.rdataclass +import dns.rrset +from dns.rdtypes.ANY.DNSKEY import DNSKEY +from dns.rdtypes.ANY.DS import DS +from dns.rdtypes.ANY.RRSIG import RRSIG class UnsupportedAlgorithm(dns.exception.DNSException): @@ -40,31 +47,7 @@ class ValidationFailure(dns.exception.DNSException): """The DNSSEC signature is invalid.""" -class Algorithm(dns.enum.IntEnum): - RSAMD5 = 1 - DH = 2 - DSA = 3 - ECC = 4 - RSASHA1 = 5 - DSANSEC3SHA1 = 6 - RSASHA1NSEC3SHA1 = 7 - RSASHA256 = 8 - RSASHA512 = 10 - ECCGOST = 12 - ECDSAP256SHA256 = 13 - ECDSAP384SHA384 = 14 - ED25519 = 15 - ED448 = 16 - INDIRECT = 252 - PRIVATEDNS = 253 - PRIVATEOID = 254 - - @classmethod - def _maximum(cls): - return 255 - - -def algorithm_from_text(text): +def algorithm_from_text(text: str) -> Algorithm: """Convert text into a DNSSEC algorithm value. *text*, a ``str``, the text to convert to into an algorithm value. @@ -75,10 +58,10 @@ def algorithm_from_text(text): return Algorithm.from_text(text) -def algorithm_to_text(value): +def algorithm_to_text(value: Union[Algorithm, int]) -> str: """Convert a DNSSEC algorithm value to text - *value*, an ``int`` a DNSSEC algorithm. + *value*, a ``dns.dnssec.Algorithm``. Returns a ``str``, the name of a DNSSEC algorithm. """ @@ -86,7 +69,7 @@ def algorithm_to_text(value): return Algorithm.to_text(value) -def key_id(key): +def key_id(key: DNSKEY) -> int: """Return the key id (a 16-bit number) for the specified key. *key*, a ``dns.rdtypes.ANY.DNSKEY.DNSKEY`` @@ -107,19 +90,10 @@ def key_id(key): total += ((total >> 16) & 0xffff) return total & 0xffff -class DSDigest(dns.enum.IntEnum): - """DNSSEC Delegation Signer Digest Algorithm""" - - SHA1 = 1 - SHA256 = 2 - SHA384 = 4 - - @classmethod - def _maximum(cls): - return 255 - -def make_ds(name, key, algorithm, origin=None): +def make_ds(name: Union[dns.name.Name, str], key: dns.rdata.Rdata, + algorithm: Union[DSDigest, str], + origin: Optional[dns.name.Name]=None) -> DS: """Create a DS record for a DNSSEC key. *name*, a ``dns.name.Name`` or ``str``, the owner name of the DS record. @@ -143,7 +117,8 @@ def make_ds(name, key, algorithm, origin=None): algorithm = DSDigest[algorithm.upper()] except Exception: raise UnsupportedAlgorithm('unsupported algorithm "%s"' % algorithm) - + if not isinstance(key, DNSKEY): + raise ValueError('key is not a DNSKEY') if algorithm == DSDigest.SHA1: dshash = hashlib.sha1() elif algorithm == DSDigest.SHA256: @@ -155,17 +130,20 @@ def make_ds(name, key, algorithm, origin=None): if isinstance(name, str): name = dns.name.from_text(name, origin) - dshash.update(name.canonicalize().to_wire()) + wire = name.canonicalize().to_wire() + assert wire is not None + dshash.update(wire) dshash.update(key.to_wire(origin=origin)) digest = dshash.digest() dsrdata = struct.pack("!HBB", key_id(key), key.algorithm, algorithm) + \ digest - return dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, - len(dsrdata)) + ds = dns.rdata.from_wire(dns.rdataclass.IN, dns.rdatatype.DS, dsrdata, 0, + len(dsrdata)) + return cast(DS, ds) -def _find_candidate_keys(keys, rrsig): +def _find_candidate_keys(keys, rrsig: RRSIG) -> Optional[List[DNSKEY]]: value = keys.get(rrsig.signer) if isinstance(value, dns.node.Node): rdataset = value.get_rdataset(dns.rdataclass.IN, dns.rdatatype.DNSKEY) @@ -173,54 +151,54 @@ def _find_candidate_keys(keys, rrsig): rdataset = value if rdataset is None: return None - return [rd for rd in rdataset if + return [cast(DNSKEY, rd) for rd in rdataset if rd.algorithm == rrsig.algorithm and key_id(rd) == rrsig.key_tag] -def _is_rsa(algorithm): +def _is_rsa(algorithm: int) -> bool: return algorithm in (Algorithm.RSAMD5, Algorithm.RSASHA1, Algorithm.RSASHA1NSEC3SHA1, Algorithm.RSASHA256, Algorithm.RSASHA512) -def _is_dsa(algorithm): +def _is_dsa(algorithm: int) -> bool: return algorithm in (Algorithm.DSA, Algorithm.DSANSEC3SHA1) -def _is_ecdsa(algorithm): +def _is_ecdsa(algorithm: int) -> bool: return algorithm in (Algorithm.ECDSAP256SHA256, Algorithm.ECDSAP384SHA384) -def _is_eddsa(algorithm): +def _is_eddsa(algorithm: int) -> bool: return algorithm in (Algorithm.ED25519, Algorithm.ED448) -def _is_gost(algorithm): +def _is_gost(algorithm: int) -> bool: return algorithm == Algorithm.ECCGOST -def _is_md5(algorithm): +def _is_md5(algorithm: int) -> bool: return algorithm == Algorithm.RSAMD5 -def _is_sha1(algorithm): +def _is_sha1(algorithm: int) -> bool: return algorithm in (Algorithm.DSA, Algorithm.RSASHA1, Algorithm.DSANSEC3SHA1, Algorithm.RSASHA1NSEC3SHA1) -def _is_sha256(algorithm): +def _is_sha256(algorithm: int) -> bool: return algorithm in (Algorithm.RSASHA256, Algorithm.ECDSAP256SHA256) -def _is_sha384(algorithm): +def _is_sha384(algorithm: int) -> bool: return algorithm == Algorithm.ECDSAP384SHA384 -def _is_sha512(algorithm): +def _is_sha512(algorithm: int) -> bool: return algorithm == Algorithm.RSASHA512 -def _make_hash(algorithm): +def _make_hash(algorithm: int) -> Any: if _is_md5(algorithm): return hashes.MD5() if _is_sha1(algorithm): @@ -239,12 +217,14 @@ def _make_hash(algorithm): raise ValidationFailure('unknown hash for algorithm %u' % algorithm) -def _bytes_to_long(b): +def _bytes_to_long(b: bytes) -> int: return int.from_bytes(b, 'big') -def _validate_signature(sig, data, key, chosen_hash): +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any): + keyptr: bytes if _is_rsa(key.algorithm): + # we ignore because mypy is confused and thinks key.key is a str for unknown reasons. keyptr = key.key (bytes_,) = struct.unpack('!B', keyptr[0:1]) keyptr = keyptr[1:] @@ -254,12 +234,12 @@ def _validate_signature(sig, data, key, chosen_hash): rsa_e = keyptr[0:bytes_] rsa_n = keyptr[bytes_:] try: - public_key = rsa.RSAPublicNumbers( + rsa_public_key = rsa.RSAPublicNumbers( _bytes_to_long(rsa_e), _bytes_to_long(rsa_n)).public_key(default_backend()) except ValueError: raise ValidationFailure('invalid public key') - public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) + rsa_public_key.verify(sig, data, padding.PKCS1v15(), chosen_hash) elif _is_dsa(key.algorithm): keyptr = key.key (t,) = struct.unpack('!B', keyptr[0:1]) @@ -273,7 +253,7 @@ def _validate_signature(sig, data, key, chosen_hash): keyptr = keyptr[octets:] dsa_y = keyptr[0:octets] try: - public_key = dsa.DSAPublicNumbers( + dsa_public_key = dsa.DSAPublicNumbers( _bytes_to_long(dsa_y), dsa.DSAParameterNumbers( _bytes_to_long(dsa_p), @@ -281,9 +261,10 @@ def _validate_signature(sig, data, key, chosen_hash): _bytes_to_long(dsa_g))).public_key(default_backend()) except ValueError: raise ValidationFailure('invalid public key') - public_key.verify(sig, data, chosen_hash) + dsa_public_key.verify(sig, data, chosen_hash) elif _is_ecdsa(key.algorithm): keyptr = key.key + curve: Any if key.algorithm == Algorithm.ECDSAP256SHA256: curve = ec.SECP256R1() octets = 32 @@ -293,24 +274,25 @@ def _validate_signature(sig, data, key, chosen_hash): ecdsa_x = keyptr[0:octets] ecdsa_y = keyptr[octets:octets * 2] try: - public_key = ec.EllipticCurvePublicNumbers( + ecdsa_public_key = ec.EllipticCurvePublicNumbers( curve=curve, x=_bytes_to_long(ecdsa_x), y=_bytes_to_long(ecdsa_y)).public_key(default_backend()) except ValueError: raise ValidationFailure('invalid public key') - public_key.verify(sig, data, ec.ECDSA(chosen_hash)) + ecdsa_public_key.verify(sig, data, ec.ECDSA(chosen_hash)) elif _is_eddsa(key.algorithm): keyptr = key.key + loader: Any if key.algorithm == Algorithm.ED25519: loader = ed25519.Ed25519PublicKey else: loader = ed448.Ed448PublicKey try: - public_key = loader.from_public_bytes(keyptr) + eddsa_public_key = loader.from_public_bytes(keyptr) except ValueError: raise ValidationFailure('invalid public key') - public_key.verify(sig, data) + eddsa_public_key.verify(sig, data) elif _is_gost(key.algorithm): raise UnsupportedAlgorithm( 'algorithm "%s" not supported by dnspython' % @@ -319,7 +301,10 @@ def _validate_signature(sig, data, key, chosen_hash): raise ValidationFailure('unknown algorithm %u' % key.algorithm) -def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): +def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsig: RRSIG, + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name]=None, now: Optional[float]=None): """Validate an RRset against a single signature rdata, throwing an exception if validation is not successful. @@ -337,7 +322,7 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): *origin*, a ``dns.name.Name`` or ``None``, the origin to use for relative names. - *now*, an ``int`` or ``None``, the time, in seconds since the epoch, to + *now*, a ``float`` or ``None``, the time, in seconds since the epoch, to use as the current time when validating. If ``None``, the actual current time is used. @@ -394,7 +379,10 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): data += rrsig.signer.to_digestable(origin) # Derelativize the name before considering labels. - rrname = rrname.derelativize(origin) + if not rrname.is_absolute(): + if origin is None: + raise ValidationFailure('relative RR name without an origin specified') + rrname = rrname.derelativize(origin) if len(rrname) - 1 < rrsig.labels: raise ValidationFailure('owner name longer than RRSIG labels') @@ -425,7 +413,10 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): raise ValidationFailure('verify failure') -def _validate(rrset, rrsigset, keys, origin=None, now=None): +def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], + keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], + origin: Optional[dns.name.Name]=None, now: Optional[float]=None): """Validate an RRset against a signature RRset, throwing an exception if none of the signatures validate. @@ -475,6 +466,8 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): raise ValidationFailure("owner names do not match") for rrsig in rrsigrdataset: + if not isinstance(rrsig, RRSIG): + raise ValidationFailure('expected an RRSIG') try: _validate_rrsig(rrset, rrsig, keys, origin, now) return @@ -483,15 +476,6 @@ def _validate(rrset, rrsigset, keys, origin=None, now=None): raise ValidationFailure("no RRSIGs validated") -class NSEC3Hash(dns.enum.IntEnum): - """NSEC3 hash algorithm""" - - SHA1 = 1 - - @classmethod - def _maximum(cls): - return 255 - def nsec3_hash(domain, salt, iterations, algorithm): """ Calculate the NSEC3 hash, according to diff --git a/dns/dnssec.pyi b/dns/dnssec.pyi deleted file mode 100644 index e126f9b8..00000000 --- a/dns/dnssec.pyi +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Union, Dict, Tuple, Optional -from . import rdataset, rrset, exception, name, rdtypes, rdata, node -import dns.rdtypes.ANY.DS as DS -import dns.rdtypes.ANY.DNSKEY as DNSKEY - -_have_pyca : bool - -def validate_rrsig(rrset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsig : rdata.Rdata, keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin : Optional[name.Name] = None, now : Optional[int] = None) -> None: - ... - -def validate(rrset: Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], rrsigset : Union[Tuple[name.Name, rdataset.Rdataset], rrset.RRset], keys : Dict[name.Name, Union[node.Node, rdataset.Rdataset]], origin=None, now=None) -> None: - ... - -class ValidationFailure(exception.DNSException): - ... - -def make_ds(name : name.Name, key : DNSKEY.DNSKEY, algorithm : str, origin : Optional[name.Name] = None) -> DS.DS: - ... - -def nsec3_hash(domain: str, salt: Optional[Union[str, bytes]], iterations: int, algo: int) -> str: - ... diff --git a/dns/dnssectypes.py b/dns/dnssectypes.py new file mode 100644 index 00000000..2a747168 --- /dev/null +++ b/dns/dnssectypes.py @@ -0,0 +1,69 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Common DNSSEC-related types.""" + +# This is a separate file to avoid import circularity between dns.dnssec and +# the implementations of the DS and DNSKEY types. + +import dns.enum + + +class Algorithm(dns.enum.IntEnum): + RSAMD5 = 1 + DH = 2 + DSA = 3 + ECC = 4 + RSASHA1 = 5 + DSANSEC3SHA1 = 6 + RSASHA1NSEC3SHA1 = 7 + RSASHA256 = 8 + RSASHA512 = 10 + ECCGOST = 12 + ECDSAP256SHA256 = 13 + ECDSAP384SHA384 = 14 + ED25519 = 15 + ED448 = 16 + INDIRECT = 252 + PRIVATEDNS = 253 + PRIVATEOID = 254 + + @classmethod + def _maximum(cls): + return 255 + + +class DSDigest(dns.enum.IntEnum): + """DNSSEC Delegation Signer Digest Algorithm""" + + SHA1 = 1 + SHA256 = 2 + SHA384 = 4 + + @classmethod + def _maximum(cls): + return 255 + + +class NSEC3Hash(dns.enum.IntEnum): + """NSEC3 hash algorithm""" + + SHA1 = 1 + + @classmethod + def _maximum(cls): + return 255 diff --git a/dns/e164.py b/dns/e164.py index 83731b2c..8c9a3ac5 100644 --- a/dns/e164.py +++ b/dns/e164.py @@ -17,6 +17,8 @@ """DNS E.164 helpers.""" +from typing import Iterable, Optional, Union + import dns.exception import dns.name import dns.resolver @@ -25,7 +27,7 @@ import dns.resolver public_enum_domain = dns.name.from_text('e164.arpa.') -def from_e164(text, origin=public_enum_domain): +def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> dns.name.Name: """Convert an E.164 number in textual form into a Name object whose value is the ENUM domain name for that number. @@ -45,7 +47,8 @@ def from_e164(text, origin=public_enum_domain): return dns.name.from_text('.'.join(parts), origin=origin) -def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): +def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain, + want_plus_prefix=True) -> str: """Convert an ENUM domain name into an E.164 number. Note that dnspython does not have any information about preferred @@ -77,7 +80,8 @@ def to_e164(name, origin=public_enum_domain, want_plus_prefix=True): return text.decode() -def query(number, domains, resolver=None): +def query(number: str, domains: Iterable[Union[dns.name.Name, str]], + resolver: Optional[dns.resolver.Resolver]=None) -> dns.resolver.Answer: """Look for NAPTR RRs for the specified number in the specified domains. e.g. lookup('16505551212', ['e164.dnspython.org.', 'e164.arpa.']) diff --git a/dns/e164.pyi b/dns/e164.pyi deleted file mode 100644 index 37a99fed..00000000 --- a/dns/e164.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional, Iterable -from . import name, resolver -def from_e164(text : str, origin=name.Name(".")) -> name.Name: - ... - -def to_e164(name : name.Name, origin : Optional[name.Name] = None, want_plus_prefix=True) -> str: - ... - -def query(number : str, domains : Iterable[str], resolver : Optional[resolver.Resolver] = None) -> resolver.Answer: - ... diff --git a/dns/edns.py b/dns/edns.py index fa4e98b1..15c646de 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -17,6 +17,8 @@ """EDNS Options""" +from typing import Any, Dict, Optional, Union + import math import socket import struct @@ -24,6 +26,7 @@ import struct import dns.enum import dns.inet import dns.rdata +import dns.wire class OptionType(dns.enum.IntEnum): @@ -59,14 +62,14 @@ class Option: """Base class for all EDNS option types.""" - def __init__(self, otype): + def __init__(self, otype: Union[OptionType, str]): """Initialize an option. - *otype*, an ``int``, is the option type. + *otype*, a ``dns.edns.OptionType``, is the option type. """ self.otype = OptionType.make(otype) - def to_wire(self, file=None): + def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: """Convert an option to wire format. Returns a ``bytes`` or ``None``. @@ -75,10 +78,10 @@ class Option: raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser(cls, otype: OptionType, parser: 'dns.wire.Parser') -> 'Option': """Build an EDNS option object from wire format. - *otype*, an ``int``, is the option type. + *otype*, a ``dns.edns.OptionType``, is the option type. *parser*, a ``dns.wire.Parser``, the parser, which should be restructed to the option length. @@ -150,28 +153,29 @@ class GenericOption(Option): # lgtm[py/missing-equals] implementation. """ - def __init__(self, otype, data): + def __init__(self, otype: Union[OptionType, str], data: Union[bytes, str]): super().__init__(otype) self.data = dns.rdata.Rdata._as_bytes(data, True) - def to_wire(self, file=None): + def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: if file: file.write(self.data) + return None else: return self.data - def to_text(self): + def to_text(self) -> str: return "Generic %d" % self.otype @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: return cls(otype, parser.get_remaining()) class ECSOption(Option): # lgtm[py/missing-equals] """EDNS Client Subnet (ECS, RFC7871)""" - def __init__(self, address, srclen=None, scopelen=0): + def __init__(self, address: str, srclen: Optional[int]=None, scopelen=0): """*address*, a ``str``, is the client address information. *srclen*, an ``int``, the source prefix length, which is the @@ -202,6 +206,7 @@ class ECSOption(Option): # lgtm[py/missing-equals] else: # pragma: no cover (this will never happen) raise ValueError('Bad address family') + assert srclen is not None self.address = address self.srclen = srclen self.scopelen = scopelen @@ -218,12 +223,12 @@ class ECSOption(Option): # lgtm[py/missing-equals] ord(self.addrdata[-1:]) & (0xff << (8 - nbits))) self.addrdata = self.addrdata[:-1] + last - def to_text(self): + def to_text(self) -> str: return "ECS {}/{} scope/{}".format(self.address, self.srclen, self.scopelen) @staticmethod - def from_text(text): + def from_text(text) -> Option: """Convert a string into a `dns.edns.ECSOption` *text*, a `str`, the text form of the option. @@ -277,16 +282,17 @@ class ECSOption(Option): # lgtm[py/missing-equals] '"{}": srclen must be an integer'.format(srclen)) return ECSOption(address, srclen, scope) - def to_wire(self, file=None): + def to_wire(self, file=None) -> Optional[bytes]: value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + self.addrdata) if file: file.write(value) + return None else: return value @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser(cls, otype: Union[OptionType, str], parser: 'dns.wire.Parser'): family, src, scope = parser.get_struct('!HBB') addrlen = int(math.ceil(src / 8.0)) prefix = parser.get_bytes(addrlen) @@ -337,7 +343,7 @@ class EDECode(dns.enum.IntEnum): class EDEOption(Option): # lgtm[py/missing-equals] """Extended DNS Error (EDE, RFC8914)""" - def __init__(self, code, text=None): + def __init__(self, code: Union[EDECode, str], text: Optional[str]=None): """*code*, a ``dns.edns.EDECode`` or ``str``, the info code of the extended error. @@ -350,28 +356,27 @@ class EDEOption(Option): # lgtm[py/missing-equals] self.code = EDECode.make(code) if text is not None and not isinstance(text, str): raise ValueError('text must be string or None') - - self.code = code self.text = text - def to_text(self): + def to_text(self) -> str: output = f'EDE {self.code}' if self.text is not None: output += f': {self.text}' return output - def to_wire(self, file=None): + def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: value = struct.pack('!H', self.code) if self.text is not None: value += self.text.encode('utf8') if file: file.write(value) + return None else: return value @classmethod - def from_wire_parser(cls, otype, parser): + def from_wire_parser(cls, otype: Union[OptionType, str], parser) -> Option: code = parser.get_uint16() text = parser.get_remaining() @@ -385,13 +390,13 @@ class EDEOption(Option): # lgtm[py/missing-equals] return cls(code, text) -_type_to_class = { +_type_to_class: Dict[OptionType, Any] = { OptionType.ECS: ECSOption, OptionType.EDE: EDEOption, } -def get_option_class(otype): +def get_option_class(otype: OptionType) -> Any: """Return the class for the specified option type. The GenericOption class is used if a more specific class is not @@ -404,7 +409,7 @@ def get_option_class(otype): return cls -def option_from_wire_parser(otype, parser): +def option_from_wire_parser(otype: Union[OptionType, str], parser: 'dns.wire.Parser') -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -414,12 +419,12 @@ def option_from_wire_parser(otype, parser): Returns an instance of a subclass of ``dns.edns.Option``. """ - cls = get_option_class(otype) - otype = OptionType.make(otype) + the_otype = OptionType.make(otype) + cls = get_option_class(the_otype) return cls.from_wire_parser(otype, parser) -def option_from_wire(otype, wire, current, olen): +def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, olen: int) -> Option: """Build an EDNS option object from wire format. *otype*, an ``int``, is the option type. @@ -437,7 +442,7 @@ def option_from_wire(otype, wire, current, olen): with parser.restrict_to(olen): return option_from_wire_parser(otype, parser) -def register_type(implementation, otype): +def register_type(implementation: Any, otype: OptionType): """Register the implementation of an option type. *implementation*, a ``class``, is a subclass of ``dns.edns.Option``. diff --git a/dns/entropy.py b/dns/entropy.py index 086bba78..528a628b 100644 --- a/dns/entropy.py +++ b/dns/entropy.py @@ -15,6 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Any, Optional + import os import hashlib import random @@ -34,7 +36,7 @@ class EntropyPool: def __init__(self, seed=None): self.pool_index = 0 - self.digest = None + self.digest: Optional[bytearray] = None self.next_byte = 0 self.lock = _threading.Lock() self.hash = hashlib.sha1() @@ -76,7 +78,7 @@ class EntropyPool: seed = bytearray(seed) self._stir(seed) - def random_8(self): + def random_8(self) -> int: with self.lock: self._maybe_seed() if self.digest is None or self.next_byte == self.hash_len: @@ -88,13 +90,13 @@ class EntropyPool: self.next_byte += 1 return value - def random_16(self): + def random_16(self) -> int: return self.random_8() * 256 + self.random_8() - def random_32(self): + def random_32(self) -> int: return self.random_16() * 65536 + self.random_16() - def random_between(self, first, last): + def random_between(self, first: int, last: int) -> int: size = last - first + 1 if size > 4294967296: raise ValueError('too big') @@ -111,18 +113,19 @@ class EntropyPool: pool = EntropyPool() +system_random: Optional[Any] try: system_random = random.SystemRandom() except Exception: # pragma: no cover system_random = None -def random_16(): +def random_16() -> int: if system_random is not None: return system_random.randrange(0, 65536) else: return pool.random_16() -def between(first, last): +def between(first: int, last: int) -> int: if system_random is not None: return system_random.randrange(first, last + 1) else: diff --git a/dns/entropy.pyi b/dns/entropy.pyi deleted file mode 100644 index 818f805a..00000000 --- a/dns/entropy.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional -from random import SystemRandom - -system_random : Optional[SystemRandom] - -def random_16() -> int: - pass - -def between(first: int, last: int) -> int: - pass diff --git a/dns/exception.py b/dns/exception.py index 53764588..550a1bcf 100644 --- a/dns/exception.py +++ b/dns/exception.py @@ -21,6 +21,10 @@ Dnspython modules may also define their own exceptions, which will always be subclasses of ``DNSException``. """ + +from typing import Dict, Optional, Set + + class DNSException(Exception): """Abstract base class shared by all dnspython exceptions. @@ -44,9 +48,9 @@ class DNSException(Exception): and ``fmt`` class variables to get nice parametrized messages. """ - msg = None # non-parametrized message - supp_kwargs = set() # accepted parameters for _fmt_kwargs (sanity check) - fmt = None # message parametrized with results from _fmt_kwargs + msg: Optional[str] = None # non-parametrized message + supp_kwargs: Set[str] = set() # accepted parameters for _fmt_kwargs (sanity check) + fmt: Optional[str] = None # message parametrized with results from _fmt_kwargs def __init__(self, *args, **kwargs): self._check_params(*args, **kwargs) @@ -128,6 +132,10 @@ class Timeout(DNSException): supp_kwargs = {'timeout'} fmt = "The DNS operation timed out after {timeout:.3f} seconds" + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + class ExceptionWrapper: def __init__(self, exception_class): diff --git a/dns/exception.pyi b/dns/exception.pyi deleted file mode 100644 index dc571264..00000000 --- a/dns/exception.pyi +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Set, Optional, Dict - -class DNSException(Exception): - supp_kwargs : Set[str] - kwargs : Optional[Dict] - fmt : Optional[str] - -class SyntaxError(DNSException): ... -class FormError(DNSException): ... -class Timeout(DNSException): ... -class TooBig(DNSException): ... -class UnexpectedEnd(SyntaxError): ... diff --git a/dns/flags.py b/dns/flags.py index 96522879..6fe1afd3 100644 --- a/dns/flags.py +++ b/dns/flags.py @@ -17,6 +17,8 @@ """DNS Message Flags.""" +from typing import Any + import enum # Standard DNS flags @@ -45,7 +47,7 @@ class EDNSFlag(enum.IntFlag): DO = 0x8000 -def _from_text(text, enum_class): +def _from_text(text: str, enum_class: Any) -> int: flags = 0 tokens = text.split() for t in tokens: @@ -53,7 +55,7 @@ def _from_text(text, enum_class): return flags -def _to_text(flags, enum_class): +def _to_text(flags: int, enum_class: Any) -> str: text_flags = [] for k, v in enum_class.__members__.items(): if flags & v != 0: @@ -61,7 +63,7 @@ def _to_text(flags, enum_class): return ' '.join(text_flags) -def from_text(text): +def from_text(text: str) -> int: """Convert a space-separated list of flag text values into a flags value. @@ -71,7 +73,7 @@ def from_text(text): return _from_text(text, Flag) -def to_text(flags): +def to_text(flags: int) -> str: """Convert a flags value into a space-separated list of flag text values. @@ -81,7 +83,7 @@ def to_text(flags): return _to_text(flags, Flag) -def edns_from_text(text): +def edns_from_text(text: str) -> int: """Convert a space-separated list of EDNS flag text values into a EDNS flags value. @@ -91,7 +93,7 @@ def edns_from_text(text): return _from_text(text, EDNSFlag) -def edns_to_text(flags): +def edns_to_text(flags: int) -> str: """Convert an EDNS flags value into a space-separated list of EDNS flag text values. diff --git a/dns/grange.py b/dns/grange.py index 112ede47..ebb64d2d 100644 --- a/dns/grange.py +++ b/dns/grange.py @@ -17,9 +17,11 @@ """DNS GENERATE range conversion.""" +from typing import Tuple + import dns -def from_text(text): +def from_text(text: str) -> Tuple[int, int, int]: """Convert the text form of a range in a ``$GENERATE`` statement to an integer. diff --git a/dns/inet.py b/dns/inet.py index d3bdc64c..b3ed9995 100644 --- a/dns/inet.py +++ b/dns/inet.py @@ -17,6 +17,8 @@ """Generic Internet address helper functions.""" +from typing import Any, Optional, Tuple + import socket import dns.ipv4 @@ -30,7 +32,7 @@ AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 -def inet_pton(family, text): +def inet_pton(family: int, text: str) -> bytes: """Convert the textual form of a network address into its binary form. *family* is an ``int``, the address family. @@ -51,7 +53,7 @@ def inet_pton(family, text): raise NotImplementedError -def inet_ntop(family, address): +def inet_ntop(family: int, address: bytes) -> str: """Convert the binary form of a network address into its textual form. *family* is an ``int``, the address family. @@ -72,7 +74,7 @@ def inet_ntop(family, address): raise NotImplementedError -def af_for_address(text): +def af_for_address(text: str) -> int: """Determine the address family of a textual-form network address. *text*, a ``str``, the textual address. @@ -94,7 +96,7 @@ def af_for_address(text): raise ValueError -def is_multicast(text): +def is_multicast(text: str) -> bool: """Is the textual-form network address a multicast address? *text*, a ``str``, the textual address. @@ -116,7 +118,7 @@ def is_multicast(text): raise ValueError -def is_address(text): +def is_address(text: str) -> bool: """Is the specified string an IPv4 or IPv6 address? *text*, a ``str``, the textual address. @@ -135,7 +137,7 @@ def is_address(text): return False -def low_level_address_tuple(high_tuple, af=None): +def low_level_address_tuple(high_tuple: Tuple[str, int], af: Optional[int]=None) -> Any: """Given a "high-level" address tuple, i.e. an (address, port) return the appropriate "low-level" address tuple suitable for use in socket calls. @@ -143,7 +145,6 @@ def low_level_address_tuple(high_tuple, af=None): If an *af* other than ``None`` is provided, it is assumed the address in the high-level tuple is valid and has that af. If af is ``None``, then af_for_address will be called. - """ address, port = high_tuple if af is None: diff --git a/dns/inet.pyi b/dns/inet.pyi deleted file mode 100644 index 6d9dcc70..00000000 --- a/dns/inet.pyi +++ /dev/null @@ -1,4 +0,0 @@ -from typing import Union -from socket import AddressFamily - -AF_INET6 : Union[int, AddressFamily] diff --git a/dns/ipv4.py b/dns/ipv4.py index e1f38d3d..fddad1b1 100644 --- a/dns/ipv4.py +++ b/dns/ipv4.py @@ -17,11 +17,13 @@ """IPv4 helper functions.""" +from typing import Union + import struct import dns.exception -def inet_ntoa(address): +def inet_ntoa(address: bytes) -> str: """Convert an IPv4 address in binary form to text form. *address*, a ``bytes``, the IPv4 address in binary form. @@ -34,17 +36,19 @@ def inet_ntoa(address): return ('%u.%u.%u.%u' % (address[0], address[1], address[2], address[3])) -def inet_aton(text): +def inet_aton(text: Union[str, bytes]) -> bytes: """Convert an IPv4 address in text form to binary form. - *text*, a ``str``, the IPv4 address in textual form. + *text*, a ``str`` or ``bytes``, the IPv4 address in textual form. Returns a ``bytes``. """ if not isinstance(text, bytes): - text = text.encode() - parts = text.split(b'.') + btext = text.encode() + else: + btext = text + parts = btext.split(b'.') if len(parts) != 4: raise dns.exception.SyntaxError for part in parts: diff --git a/dns/ipv6.py b/dns/ipv6.py index 0db6fcfa..1d5bffde 100644 --- a/dns/ipv6.py +++ b/dns/ipv6.py @@ -17,6 +17,8 @@ """IPv6 helper functions.""" +from typing import List, Union + import re import binascii @@ -25,7 +27,7 @@ import dns.ipv4 _leading_zero = re.compile(r'0+([0-9a-f]+)') -def inet_ntoa(address): +def inet_ntoa(address: bytes) -> str: """Convert an IPv6 address in binary form to text form. *address*, a ``bytes``, the IPv6 address in binary form. @@ -84,19 +86,19 @@ def inet_ntoa(address): prefix = '::' else: prefix = '::ffff:' - hex = prefix + dns.ipv4.inet_ntoa(address[12:]) + thex = prefix + dns.ipv4.inet_ntoa(address[12:]) else: - hex = ':'.join(chunks[:best_start]) + '::' + \ + thex = ':'.join(chunks[:best_start]) + '::' + \ ':'.join(chunks[best_start + best_len:]) else: - hex = ':'.join(chunks) - return hex + thex = ':'.join(chunks) + return thex _v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') _colon_colon_start = re.compile(br'::.*') _colon_colon_end = re.compile(br'.*::$') -def inet_aton(text, ignore_scope=False): +def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes: """Convert an IPv6 address in text form to binary form. *text*, a ``str``, the IPv6 address in textual form. @@ -111,53 +113,55 @@ def inet_aton(text, ignore_scope=False): # Our aim here is not something fast; we just want something that works. # if not isinstance(text, bytes): - text = text.encode() + btext = text.encode() + else: + btext = text if ignore_scope: - parts = text.split(b'%') + parts = btext.split(b'%') l = len(parts) if l == 2: - text = parts[0] + btext = parts[0] elif l > 2: raise dns.exception.SyntaxError - if text == b'': + if btext == b'': raise dns.exception.SyntaxError - elif text.endswith(b':') and not text.endswith(b'::'): + elif btext.endswith(b':') and not btext.endswith(b'::'): raise dns.exception.SyntaxError - elif text.startswith(b':') and not text.startswith(b'::'): + elif btext.startswith(b':') and not btext.startswith(b'::'): raise dns.exception.SyntaxError - elif text == b'::': - text = b'0::' + elif btext == b'::': + btext = b'0::' # # Get rid of the icky dot-quad syntax if we have it. # - m = _v4_ending.match(text) + m = _v4_ending.match(btext) if m is not None: b = dns.ipv4.inet_aton(m.group(2)) - text = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), + btext = ("{}:{:02x}{:02x}:{:02x}{:02x}".format(m.group(1).decode(), b[0], b[1], b[2], b[3])).encode() # # Try to turn '::' into ':'; if no match try to # turn '::' into ':' # - m = _colon_colon_start.match(text) + m = _colon_colon_start.match(btext) if m is not None: - text = text[1:] + btext = btext[1:] else: - m = _colon_colon_end.match(text) + m = _colon_colon_end.match(btext) if m is not None: - text = text[:-1] + btext = btext[:-1] # # Now canonicalize into 8 chunks of 4 hex digits each # - chunks = text.split(b':') + chunks = btext.split(b':') l = len(chunks) if l > 8: raise dns.exception.SyntaxError seen_empty = False - canonical = [] + canonical: List[bytes] = [] for c in chunks: if c == b'': if seen_empty: @@ -174,13 +178,13 @@ def inet_aton(text, ignore_scope=False): canonical.append(c) if l < 8 and not seen_empty: raise dns.exception.SyntaxError - text = b''.join(canonical) + btext = b''.join(canonical) # # Finally we can go to binary. # try: - return binascii.unhexlify(text) + return binascii.unhexlify(btext) except (binascii.Error, TypeError): raise dns.exception.SyntaxError diff --git a/dns/message.py b/dns/message.py index c2751a90..46c0a684 100644 --- a/dns/message.py +++ b/dns/message.py @@ -17,6 +17,8 @@ """DNS Messages""" +from typing import Any, Dict, List, Optional, Tuple, Union + import contextlib import io import time @@ -73,6 +75,10 @@ class Truncated(dns.exception.DNSException): supp_kwargs = {'message'} + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def message(self): """As much of the message as could be processed. @@ -109,7 +115,7 @@ class MessageSection(dns.enum.IntEnum): class MessageError: - def __init__(self, exception, offset): + def __init__(self, exception: Exception, offset: int): self.exception = exception self.offset = offset @@ -117,31 +123,38 @@ class MessageError: DEFAULT_EDNS_PAYLOAD = 1232 MAX_CHAIN = 16 +IndexKeyType = Tuple[int, dns.name.Name, dns.rdataclass.RdataClass, + dns.rdatatype.RdataType, Optional[dns.rdatatype.RdataType], + Optional[dns.rdataclass.RdataClass]] +IndexType = Dict[IndexKeyType, dns.rrset.RRset] +SectionType = Union[int, List[dns.rrset.RRset]] + class Message: """A DNS message.""" _section_enum = MessageSection - def __init__(self, id=None): + def __init__(self, id: Optional[int]=None): if id is None: self.id = dns.entropy.random_16() else: self.id = id self.flags = 0 - self.sections = [[], [], [], []] - self.opt = None + self.sections: List[List[dns.rrset.RRset]] = [[], [], [], []] + self.opt: Optional[dns.rrset.RRset] = None self.request_payload = 0 - self.keyring = None - self.tsig = None + self.keyring: Any = None + self.tsig: Optional[dns.rrset.RRset] = None self.request_mac = b'' self.xfr = False - self.origin = None - self.tsig_ctx = None - self.index = {} - self.errors = [] + self.origin: Optional[dns.name.Name] = None + self.tsig_ctx: Optional[Any] = None + self.index: IndexType = {} + self.errors: List[MessageError] = [] + self.time = 0.0 @property - def question(self): + def question(self) -> List[dns.rrset.RRset]: """ The question section.""" return self.sections[0] @@ -150,7 +163,7 @@ class Message: self.sections[0] = v @property - def answer(self): + def answer(self) -> List[dns.rrset.RRset]: """ The answer section.""" return self.sections[1] @@ -159,7 +172,7 @@ class Message: self.sections[1] = v @property - def authority(self): + def authority(self) -> List[dns.rrset.RRset]: """ The authority section.""" return self.sections[2] @@ -168,7 +181,7 @@ class Message: self.sections[2] = v @property - def additional(self): + def additional(self) -> List[dns.rrset.RRset]: """ The additional data section.""" return self.sections[3] @@ -182,7 +195,8 @@ class Message: def __str__(self): return self.to_text() - def to_text(self, origin=None, relativize=True, **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, + **kw): """Convert the message to text. The *origin*, *relativize*, and any other keyword @@ -242,7 +256,7 @@ class Message: def __ne__(self, other): return not self.__eq__(other) - def is_response(self, other): + def is_response(self, other: 'Message') -> bool: """Is *other*, also a ``dns.message.Message``, a response to this message? @@ -275,7 +289,7 @@ class Message: return False return True - def section_number(self, section): + def section_number(self, section: List[dns.rrset.RRset]) -> int: """Return the "section number" of the specified section for use in indexing. @@ -291,7 +305,7 @@ class Message: return self._section_enum(i) raise ValueError('unknown section') - def section_from_number(self, number): + def section_from_number(self, number: int) -> List[dns.rrset.RRset]: """Return the section list associated with the specified section number. @@ -306,9 +320,15 @@ class Message: section = self._section_enum.make(number) return self.sections[section] - def find_rrset(self, section, name, rdclass, rdtype, - covers=dns.rdatatype.NONE, deleting=None, create=False, - force_unique=False): + def find_rrset(self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass]=None, + create=False, + force_unique=False) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. *section*, an ``int`` section number, or one of the section @@ -346,9 +366,10 @@ class Message: if isinstance(section, int): section_number = section - section = self.section_from_number(section_number) + the_section = self.section_from_number(section_number) else: section_number = self.section_number(section) + the_section = section key = (section_number, name, rdclass, rdtype, covers, deleting) if not force_unique: if self.index is not None: @@ -356,21 +377,27 @@ class Message: if rrset is not None: return rrset else: - for rrset in section: + for rrset in the_section: if rrset.full_match(name, rdclass, rdtype, covers, deleting): return rrset if not create: raise KeyError rrset = dns.rrset.RRset(name, rdclass, rdtype, covers, deleting) - section.append(rrset) + the_section.append(rrset) if self.index is not None: self.index[key] = rrset return rrset - def get_rrset(self, section, name, rdclass, rdtype, - covers=dns.rdatatype.NONE, deleting=None, create=False, - force_unique=False): + def get_rrset(self, + section: SectionType, + name: dns.name.Name, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers = dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass]=None, + create=False, + force_unique=False) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. @@ -412,8 +439,8 @@ class Message: rrset = None return rrset - def to_wire(self, origin=None, max_size=0, multi=False, tsig_ctx=None, - **kw): + def to_wire(self, origin: Optional[dns.name.Name]=None, max_size=0, + multi=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes: """Return a string containing the message in DNS compressed wire format. @@ -486,9 +513,9 @@ class Message: original_id, error, other) return dns.rrset.from_rdata(keyname, 0, tsig) - def use_tsig(self, keyring, keyname=None, fudge=300, - original_id=None, tsig_error=0, other_data=b'', - algorithm=dns.tsig.default_algorithm): + def use_tsig(self, keyring: Any, keyname: Optional[dns.name.Name]=None, + fudge=300, original_id: Optional[int]=None, tsig_error=0, + other_data=b'', algorithm=dns.tsig.default_algorithm): """When sending, a TSIG signature using the specified key should be added. @@ -546,35 +573,35 @@ class Message: b'', original_id, tsig_error, other_data) @property - def keyname(self): + def keyname(self) -> Optional[dns.name.Name]: if self.tsig: return self.tsig.name else: return None @property - def keyalgorithm(self): + def keyalgorithm(self) -> Optional[dns.name.Name]: if self.tsig: return self.tsig[0].algorithm else: return None @property - def mac(self): + def mac(self) -> Optional[bytes]: if self.tsig: return self.tsig[0].mac else: return None @property - def tsig_error(self): + def tsig_error(self) -> Optional[int]: if self.tsig: return self.tsig[0].error else: return None @property - def had_tsig(self): + def had_tsig(self) -> bool: return bool(self.tsig) @staticmethod @@ -584,7 +611,8 @@ class Message: return dns.rrset.from_rdata(dns.name.root, int(flags), opt) def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD, - request_payload=None, options=None): + request_payload: Optional[int]=None, + options: Optional[List[dns.edns.Option]]=None): """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -625,14 +653,14 @@ class Message: self.request_payload = request_payload @property - def edns(self): + def edns(self) -> int: if self.opt: return (self.ednsflags & 0xff0000) >> 16 else: return -1 @property - def ednsflags(self): + def ednsflags(self) -> int: if self.opt: return self.opt.ttl else: @@ -646,14 +674,14 @@ class Message: self.opt = self._make_opt(v) @property - def payload(self): + def payload(self) -> int: if self.opt: return self.opt[0].payload else: return 0 @property - def options(self): + def options(self) -> Tuple: if self.opt: return self.opt[0].options else: @@ -673,17 +701,17 @@ class Message: elif self.opt: self.ednsflags &= ~dns.flags.DO - def rcode(self): + def rcode(self) -> dns.rcode.Rcode: """Return the rcode. - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ return dns.rcode.from_flags(int(self.flags), int(self.ednsflags)) - def set_rcode(self, rcode): + def set_rcode(self, rcode: dns.rcode.Rcode): """Set the rcode. - *rcode*, an ``int``, is the rcode to set. + *rcode*, a ``dns.rcode.Rcode``, is the rcode to set. """ (value, evalue) = dns.rcode.to_flags(rcode) self.flags &= 0xFFF0 @@ -691,17 +719,17 @@ class Message: self.ednsflags &= 0x00FFFFFF self.ednsflags |= evalue - def opcode(self): + def opcode(self) -> dns.opcode.Opcode: """Return the opcode. - Returns an ``int``. + Returns a ``dns.opcode.Opcode``. """ return dns.opcode.from_flags(int(self.flags)) - def set_opcode(self, opcode): + def set_opcode(self, opcode: dns.opcode.Opcode): """Set the opcode. - *opcode*, an ``int``, is the opcode to set. + *opcode*, a ``dns.opcode.Opcode``, is the opcode to set. """ self.flags &= 0x87FF self.flags |= dns.opcode.to_flags(opcode) @@ -738,7 +766,7 @@ class ChainingResult: exist. The ``canonical_name`` attribute is the canonical name after all - chaining has been applied (this is the name as ``rrset.name`` in cases + chaining has been applied (this is the same name as ``rrset.name`` in cases where rrset is not ``None``). The ``minimum_ttl`` attribute is the minimum TTL, i.e. the TTL to @@ -749,7 +777,8 @@ class ChainingResult: The ``cnames`` attribute is a list of all the CNAME RRSets followed to get to the canonical name. """ - def __init__(self, canonical_name, answer, minimum_ttl, cnames): + def __init__(self, canonical_name: dns.name.Name, answer: Optional[dns.rrset.RRset], + minimum_ttl: int, cnames: List[dns.rrset.RRset]): self.canonical_name = canonical_name self.answer = answer self.minimum_ttl = minimum_ttl @@ -757,7 +786,7 @@ class ChainingResult: class QueryMessage(Message): - def resolve_chaining(self): + def resolve_chaining(self) -> ChainingResult: """Follow the CNAME chain in the response to determine the answer RRset. @@ -831,7 +860,7 @@ class QueryMessage(Message): break return ChainingResult(qname, answer, min_ttl, cnames) - def canonical_name(self): + def canonical_name(self) -> dns.name.Name: """Return the canonical name of the first name in the question section. @@ -1042,7 +1071,7 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, raise_on_truncation=False, - continue_on_error=False): + continue_on_error=False) -> Message: """Convert a DNS wire format message into a message object. *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the @@ -1354,7 +1383,7 @@ class _TextReader: def from_text(text, idna_codec=None, one_rr_per_rrset=False, - origin=None, relativize=True, relativize_to=None): + origin=None, relativize=True, relativize_to=None) -> Message: """Convert the text format message into a message object. The reader stops after reading the first blank line in the input to @@ -1394,7 +1423,7 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False, return reader.read() -def from_file(f, idna_codec=None, one_rr_per_rrset=False): +def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message: """Read the next text format message from the specified file. Message blocks are separated by a single blank line. @@ -1420,12 +1449,14 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False): if isinstance(f, str): f = stack.enter_context(open(f)) return from_text(f, idna_codec, one_rr_per_rrset) + assert False # for mypy def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, - want_dnssec=False, ednsflags=None, payload=None, - request_payload=None, options=None, idna_codec=None, - id=None, flags=dns.flags.RD): + want_dnssec=False, ednsflags: Optional[int]=None, payload: Optional[int]=None, + request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, + idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None, + flags: int=dns.flags.RD) -> QueryMessage: """Make a query message. The query name, type, and class may all be specified either @@ -1487,7 +1518,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't # been configured. - kwargs = {} + kwargs: Dict[str, Any] = {} if ednsflags is not None: kwargs['ednsflags'] = ednsflags if payload is not None: @@ -1505,7 +1536,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, def make_response(query, recursion_available=False, our_payload=8192, - fudge=300, tsig_error=0): + fudge=300, tsig_error=0) -> Message: """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the diff --git a/dns/message.pyi b/dns/message.pyi deleted file mode 100644 index 252a4118..00000000 --- a/dns/message.pyi +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Optional, Dict, List, Tuple, Union -from . import name, rrset, tsig, rdatatype, entropy, edns, rdataclass, rcode -import hmac - -class Message: - def to_wire(self, origin : Optional[name.Name]=None, max_size=0, **kw) -> bytes: - ... - def find_rrset(self, section : List[rrset.RRset], name : name.Name, rdclass : int, rdtype : int, - covers=rdatatype.NONE, deleting : Optional[int]=None, create=False, - force_unique=False) -> rrset.RRset: - ... - def __init__(self, id : Optional[int] =None) -> None: - self.id : int - self.flags = 0 - self.sections : List[List[rrset.RRset]] = [[], [], [], []] - self.opt : rrset.RRset = None - self.request_payload = 0 - self.keyring = None - self.tsig : rrset.RRset = None - self.request_mac = b'' - self.xfr = False - self.origin = None - self.tsig_ctx = None - self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {} - - def is_response(self, other : Message) -> bool: - ... - - def set_rcode(self, rcode : rcode.Rcode): - ... - -def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message: - ... - -def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, - tsig_ctx : Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]] = None, multi=False, - question_only=False, one_rr_per_rrset=False, - ignore_trailing=False) -> Message: - ... -def make_response(query : Message, recursion_available=False, our_payload=8192, - fudge=300) -> Message: - ... - -def make_query(qname : Union[name.Name,str], rdtype : Union[str,int], rdclass : Union[int,str] =rdataclass.IN, use_edns : Optional[bool] = None, - want_dnssec=False, ednsflags : Optional[int] = None, payload : Optional[int] = None, - request_payload : Optional[int] = None, options : Optional[List[edns.Option]] = None) -> Message: - ... diff --git a/dns/name.py b/dns/name.py index 8905d70f..29078eed 100644 --- a/dns/name.py +++ b/dns/name.py @@ -18,6 +18,8 @@ """DNS Names. """ +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + import copy import struct @@ -28,22 +30,47 @@ try: except ImportError: # pragma: no cover have_idna_2008 = False +import dns.enum import dns.wire import dns.exception import dns.immutable -# fullcompare() result values -#: The compared names have no relationship to each other. -NAMERELN_NONE = 0 -#: the first name is a superdomain of the second. -NAMERELN_SUPERDOMAIN = 1 -#: The first name is a subdomain of the second. -NAMERELN_SUBDOMAIN = 2 -#: The compared names are equal. -NAMERELN_EQUAL = 3 -#: The compared names have a common ancestor. -NAMERELN_COMMONANCESTOR = 4 +CompressType = Dict['Name', int] + + +class NameRelation(dns.enum.IntEnum): + """Name relation result from fullcompare().""" + + # This is an IntEnum for backwards compatibility in case anyone + # has hardwired the constants. + + #: The compared names have no relationship to each other. + NONE = 0 + #: the first name is a superdomain of the second. + SUPERDOMAIN = 1 + #: The first name is a subdomain of the second. + SUBDOMAIN = 2 + #: The compared names are equal. + EQUAL = 3 + #: The compared names have a common ancestor. + COMMONANCESTOR = 4 + + @classmethod + def _maximum(cls): + return cls.COMMONANCESTOR + + @classmethod + def _short_name(cls): + return cls.__name__ + + +# Backwards compatibility +NAMERELN_NONE = NameRelation.NONE +NAMERELN_SUPERDOMAIN = NameRelation.SUPERDOMAIN +NAMERELN_SUBDOMAIN = NameRelation.SUBDOMAIN +NAMERELN_EQUAL = NameRelation.EQUAL +NAMERELN_COMMONANCESTOR = NameRelation.COMMONANCESTOR class EmptyLabel(dns.exception.SyntaxError): @@ -95,6 +122,42 @@ class IDNAException(dns.exception.DNSException): supp_kwargs = {'idna_exception'} fmt = "IDNA processing exception: {idna_exception}" + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +_escaped = b'"().;\\@$' +_escaped_text = '"().;\\@$' + +def _escapify(label: Union[bytes, str]) -> str: + """Escape the characters in label which need it. + @returns: the escaped string + @rtype: string""" + if isinstance(label, bytes): + # Ordinary DNS label mode. Escape special characters and values + # < 0x20 or > 0x7f. + text = '' + for c in label: + if c in _escaped: + text += '\\' + chr(c) + elif c > 0x20 and c < 0x7F: + text += chr(c) + else: + text += '\\%03d' % c + return text + + # Unicode label mode. Escape only special characters and values < 0x20 + text = '' + for uc in label: + if uc in _escaped_text: + text += '\\' + uc + elif uc <= '\x20': + text += '\\%03d' % ord(uc) + else: + text += uc + return text + class IDNACodec: """Abstract base class for IDNA encoder/decoders.""" @@ -102,20 +165,22 @@ class IDNACodec: def __init__(self): pass - def is_idna(self, label): + def is_idna(self, label: bytes) -> bool: return label.lower().startswith(b'xn--') - def encode(self, label): + def encode(self, label: str) -> bytes: raise NotImplementedError # pragma: no cover - def decode(self, label): + def decode(self, label: bytes) -> str: # We do not apply any IDNA policy on decode. if self.is_idna(label): try: - label = label[4:].decode('punycode') + slabel = label[4:].decode('punycode') + return _escapify(slabel) except Exception as e: raise IDNAException(idna_exception=e) - return _escapify(label) + else: + return _escapify(label) class IDNA2003Codec(IDNACodec): @@ -132,7 +197,7 @@ class IDNA2003Codec(IDNACodec): super().__init__() self.strict_decode = strict_decode - def encode(self, label): + def encode(self, label: str) -> bytes: """Encode *label*.""" if label == '': @@ -142,7 +207,7 @@ class IDNA2003Codec(IDNACodec): except UnicodeError: raise LabelTooLong - def decode(self, label): + def decode(self, label: bytes) -> str: """Decode *label*.""" if not self.strict_decode: return super().decode(label) @@ -188,7 +253,7 @@ class IDNA2008Codec(IDNACodec): self.allow_pure_ascii = allow_pure_ascii self.strict_decode = strict_decode - def encode(self, label): + def encode(self, label: str) -> bytes: if label == '': return b'' if self.allow_pure_ascii and is_all_ascii(label): @@ -208,7 +273,7 @@ class IDNA2008Codec(IDNACodec): else: raise IDNAException(idna_exception=e) - def decode(self, label): + def decode(self, label: bytes) -> str: if not self.strict_decode: return super().decode(label) if label == b'': @@ -223,9 +288,6 @@ class IDNA2008Codec(IDNACodec): except (idna.IDNAError, UnicodeError) as e: raise IDNAException(idna_exception=e) -_escaped = b'"().;\\@$' -_escaped_text = '"().;\\@$' - IDNA_2003_Practical = IDNA2003Codec(False) IDNA_2003_Strict = IDNA2003Codec(True) IDNA_2003 = IDNA_2003_Practical @@ -235,35 +297,7 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) IDNA_2008 = IDNA_2008_Practical -def _escapify(label): - """Escape the characters in label which need it. - @returns: the escaped string - @rtype: string""" - if isinstance(label, bytes): - # Ordinary DNS label mode. Escape special characters and values - # < 0x20 or > 0x7f. - text = '' - for c in label: - if c in _escaped: - text += '\\' + chr(c) - elif c > 0x20 and c < 0x7F: - text += chr(c) - else: - text += '\\%03d' % c - return text - - # Unicode label mode. Escape only special characters and values < 0x20 - text = '' - for c in label: - if c in _escaped_text: - text += '\\' + c - elif c <= '\x20': - text += '\\%03d' % ord(c) - else: - text += c - return text - -def _validate_labels(labels): +def _validate_labels(labels: Tuple[bytes, ...]): """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. @@ -293,7 +327,7 @@ def _validate_labels(labels): raise EmptyLabel -def _maybe_convert_to_binary(label): +def _maybe_convert_to_binary(label: Union[bytes, str]) -> bytes: """If label is ``str``, convert it to ``bytes``. If it is already ``bytes`` just return it. @@ -318,12 +352,12 @@ class Name: __slots__ = ['labels'] - def __init__(self, labels): + def __init__(self, labels: Iterable[Union[bytes, str]]): """*labels* is any iterable whose values are ``str`` or ``bytes``. """ - labels = [_maybe_convert_to_binary(x) for x in labels] - self.labels = tuple(labels) + blabels = [_maybe_convert_to_binary(x) for x in labels] + self.labels = tuple(blabels) _validate_labels(self.labels) def __copy__(self): @@ -340,7 +374,7 @@ class Name: super().__setattr__('labels', state['labels']) _validate_labels(self.labels) - def is_absolute(self): + def is_absolute(self) -> bool: """Is the most significant label of this name the root label? Returns a ``bool``. @@ -348,7 +382,7 @@ class Name: return len(self.labels) > 0 and self.labels[-1] == b'' - def is_wild(self): + def is_wild(self) -> bool: """Is this name wild? (I.e. Is the least significant label '*'?) Returns a ``bool``. @@ -356,7 +390,7 @@ class Name: return len(self.labels) > 0 and self.labels[0] == b'*' - def __hash__(self): + def __hash__(self) -> int: """Return a case-insensitive hash of the name. Returns an ``int``. @@ -368,14 +402,14 @@ class Name: h += (h << 3) + c return h - def fullcompare(self, other): + def fullcompare(self, other: 'Name') -> Tuple[NameRelation, int, int]: """Compare two names, returning a 3-tuple ``(relation, order, nlabels)``. *relation* describes the relation ship between the names, - and is one of: ``dns.name.NAMERELN_NONE``, - ``dns.name.NAMERELN_SUPERDOMAIN``, ``dns.name.NAMERELN_SUBDOMAIN``, - ``dns.name.NAMERELN_EQUAL``, or ``dns.name.NAMERELN_COMMONANCESTOR``. + and is one of: ``dns.name.NameRelation.NONE``, + ``dns.name.NameRelation.SUPERDOMAIN``, ``dns.name.NameRelation.SUBDOMAIN``, + ``dns.name.NameRelation.EQUAL``, or ``dns.name.NameRelation.COMMONANCESTOR``. *order* is < 0 if *self* < *other*, > 0 if *self* > *other*, and == 0 if *self* == *other*. A relative name is always less than an @@ -404,9 +438,9 @@ class Name: oabs = other.is_absolute() if sabs != oabs: if sabs: - return (NAMERELN_NONE, 1, 0) + return (NameRelation.NONE, 1, 0) else: - return (NAMERELN_NONE, -1, 0) + return (NameRelation.NONE, -1, 0) l1 = len(self.labels) l2 = len(other.labels) ldiff = l1 - l2 @@ -417,7 +451,7 @@ class Name: order = 0 nlabels = 0 - namereln = NAMERELN_NONE + namereln = NameRelation.NONE while l > 0: l -= 1 l1 -= 1 @@ -427,24 +461,24 @@ class Name: if label1 < label2: order = -1 if nlabels > 0: - namereln = NAMERELN_COMMONANCESTOR + namereln = NameRelation.COMMONANCESTOR return (namereln, order, nlabels) elif label1 > label2: order = 1 if nlabels > 0: - namereln = NAMERELN_COMMONANCESTOR + namereln = NameRelation.COMMONANCESTOR return (namereln, order, nlabels) nlabels += 1 order = ldiff if ldiff < 0: - namereln = NAMERELN_SUPERDOMAIN + namereln = NameRelation.SUPERDOMAIN elif ldiff > 0: - namereln = NAMERELN_SUBDOMAIN + namereln = NameRelation.SUBDOMAIN else: - namereln = NAMERELN_EQUAL + namereln = NameRelation.EQUAL return (namereln, order, nlabels) - def is_subdomain(self, other): + def is_subdomain(self, other: 'Name') -> bool: """Is self a subdomain of other? Note that the notion of subdomain includes equality, e.g. @@ -454,11 +488,11 @@ class Name: """ (nr, _, _) = self.fullcompare(other) - if nr == NAMERELN_SUBDOMAIN or nr == NAMERELN_EQUAL: + if nr == NameRelation.SUBDOMAIN or nr == NameRelation.EQUAL: return True return False - def is_superdomain(self, other): + def is_superdomain(self, other: 'Name') -> bool: """Is self a superdomain of other? Note that the notion of superdomain includes equality, e.g. @@ -468,11 +502,11 @@ class Name: """ (nr, _, _) = self.fullcompare(other) - if nr == NAMERELN_SUPERDOMAIN or nr == NAMERELN_EQUAL: + if nr == NameRelation.SUPERDOMAIN or nr == NameRelation.EQUAL: return True return False - def canonicalize(self): + def canonicalize(self) -> 'Name': """Return a name which is equal to the current name, but is in DNSSEC canonical form. """ @@ -521,7 +555,7 @@ class Name: def __str__(self): return self.to_text(False) - def to_text(self, omit_final_dot=False): + def to_text(self, omit_final_dot=False) -> str: """Convert name to DNS text format. *omit_final_dot* is a ``bool``. If True, don't emit the final @@ -542,7 +576,7 @@ class Name: s = '.'.join(map(_escapify, l)) return s - def to_unicode(self, omit_final_dot=False, idna_codec=None): + def to_unicode(self, omit_final_dot=False, idna_codec: Optional[IDNACodec]=None) -> str: """Convert name to Unicode text format. IDN ACE labels are converted to Unicode. @@ -572,7 +606,7 @@ class Name: idna_codec = IDNA_2003_Practical return '.'.join([idna_codec.decode(x) for x in l]) - def to_digestable(self, origin=None): + def to_digestable(self, origin: Optional['Name']=None) -> bytes: """Convert name to a format suitable for digesting in hashes. The name is canonicalized and converted to uncompressed wire @@ -589,10 +623,12 @@ class Name: Returns a ``bytes``. """ - return self.to_wire(origin=origin, canonicalize=True) + digest = self.to_wire(origin=origin, canonicalize=True) + assert digest is not None + return digest - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False): + def to_wire(self, file=None, compress: Optional[CompressType]=None, + origin: Optional['Name']=None, canonicalize=False) -> Optional[bytes]: """Convert name to wire format, possibly compressing it. *file* is the file where the name is emitted (typically an @@ -638,6 +674,7 @@ class Name: out += label return bytes(out) + labels: Iterable[bytes] if not self.is_absolute(): if origin is None or not origin.is_absolute(): raise NeedAbsoluteNameOrOrigin @@ -670,8 +707,9 @@ class Name: file.write(label.lower()) else: file.write(label) + return None - def __len__(self): + def __len__(self) -> int: """The length of the name (in labels). Returns an ``int``. @@ -688,7 +726,7 @@ class Name: def __sub__(self, other): return self.relativize(other) - def split(self, depth): + def split(self, depth: int) -> Tuple['Name', 'Name']: """Split a name into a prefix and suffix names at the specified depth. *depth* is an ``int`` specifying the number of labels in the suffix @@ -709,7 +747,7 @@ class Name: 'depth must be >= 0 and <= the length of the name') return (Name(self[: -depth]), Name(self[-depth:])) - def concatenate(self, other): + def concatenate(self, other: 'Name') -> 'Name': """Return a new name which is the concatenation of self and other. Raises ``dns.name.AbsoluteConcatenation`` if the name is @@ -724,7 +762,7 @@ class Name: labels.extend(list(other.labels)) return Name(labels) - def relativize(self, origin): + def relativize(self, origin: 'Name') -> 'Name': """If the name is a subdomain of *origin*, return a new name which is the name relative to origin. Otherwise return the name. @@ -740,7 +778,7 @@ class Name: else: return self - def derelativize(self, origin): + def derelativize(self, origin: 'Name') -> 'Name': """If the name is a relative name, return a new name which is the concatenation of the name and origin. Otherwise return the name. @@ -756,7 +794,7 @@ class Name: else: return self - def choose_relativity(self, origin=None, relativize=True): + def choose_relativity(self, origin: Optional['Name']=None, relativize=True) -> 'Name': """Return a name with the relativity desired by the caller. If *origin* is ``None``, then the name is returned. @@ -775,7 +813,7 @@ class Name: else: return self - def parent(self): + def parent(self) -> 'Name': """Return the parent of the name. For example, the parent of ``www.dnspython.org.`` is ``dnspython.org``. @@ -796,7 +834,7 @@ root = Name([b'']) #: The empty name. empty = Name([]) -def from_unicode(text, origin=root, idna_codec=None): +def from_unicode(text: str, origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name: """Convert unicode text into a Name object. Labels are encoded in IDN ACE form according to rules specified by @@ -870,16 +908,16 @@ def from_unicode(text, origin=root, idna_codec=None): labels.extend(list(origin.labels)) return Name(labels) -def is_all_ascii(text): +def is_all_ascii(text: str) -> bool: for c in text: if ord(c) > 0x7f: return False return True -def from_text(text, origin=root, idna_codec=None): +def from_text(text: Union[bytes, str], origin: Optional[Name]=root, idna_codec: Optional[IDNACodec]=None) -> Name: """Convert text into a Name object. - *text*, a ``str``, is the text to convert into a name. + *text*, a ``bytes`` or ``str``, is the text to convert into a name. *origin*, a ``dns.name.Name``, specifies the origin to append to non-absolute names. The default is the root name. @@ -958,8 +996,9 @@ def from_text(text, origin=root, idna_codec=None): labels.extend(list(origin.labels)) return Name(labels) +# we need 'dns.wire.Parser' quoted as dns.name and dns.wire depend on each other. -def from_wire_parser(parser): +def from_wire_parser(parser: 'dns.wire.Parser') -> Name: """Convert possibly compressed wire format into a Name. *parser* is a dns.wire.Parser. @@ -992,7 +1031,7 @@ def from_wire_parser(parser): return Name(labels) -def from_wire(message, current): +def from_wire(message: bytes, current: int) -> Tuple[Name, int]: """Convert possibly compressed wire format into a Name. *message* is a ``bytes`` containing an entire DNS message in DNS diff --git a/dns/name.pyi b/dns/name.pyi deleted file mode 100644 index c48d4bd1..00000000 --- a/dns/name.pyi +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional, Union, Tuple, Iterable, List - -have_idna_2008: bool - -class Name: - def is_subdomain(self, o : Name) -> bool: ... - def is_superdomain(self, o : Name) -> bool: ... - def __init__(self, labels : Iterable[Union[bytes,str]]) -> None: - self.labels : List[bytes] - def is_absolute(self) -> bool: ... - def is_wild(self) -> bool: ... - def fullcompare(self, other) -> Tuple[int,int,int]: ... - def canonicalize(self) -> Name: ... - def __eq__(self, other) -> bool: ... - def __ne__(self, other) -> bool: ... - def __lt__(self, other : Name) -> bool: ... - def __le__(self, other : Name) -> bool: ... - def __ge__(self, other : Name) -> bool: ... - def __gt__(self, other : Name) -> bool: ... - def to_text(self, omit_final_dot=False) -> str: ... - def to_unicode(self, omit_final_dot=False, idna_codec=None) -> str: ... - def to_digestable(self, origin=None) -> bytes: ... - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False) -> Optional[bytes]: ... - def __add__(self, other : Name) -> Name: ... - def __sub__(self, other : Name) -> Name: ... - def split(self, depth) -> List[Tuple[str,str]]: ... - def concatenate(self, other : Name) -> Name: ... - def relativize(self, origin) -> Name: ... - def derelativize(self, origin) -> Name: ... - def choose_relativity(self, origin : Optional[Name] = None, relativize=True) -> Name: ... - def parent(self) -> Name: ... - -class IDNACodec: - pass - -def from_text(text, origin : Optional[Name] = Name('.'), idna_codec : Optional[IDNACodec] = None) -> Name: - ... - -empty : Name diff --git a/dns/node.py b/dns/node.py index 63ce008b..a4c17f96 100644 --- a/dns/node.py +++ b/dns/node.py @@ -17,12 +17,17 @@ """DNS nodes. A node is a set of rdatasets.""" +from typing import List, Optional, Union + import enum import io import dns.immutable +import dns.name +import dns.rdataclass import dns.rdataset import dns.rdatatype +import dns.rrset import dns.renderer @@ -51,7 +56,7 @@ class NodeKind(enum.Enum): CNAME = 2 @classmethod - def classify(cls, rdtype, covers): + def classify(cls, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType) -> 'NodeKind': if _matches_type_or_its_signature(_cname_types, rdtype, covers): return NodeKind.CNAME elif _matches_type_or_its_signature(_neutral_types, rdtype, covers): @@ -60,7 +65,7 @@ class NodeKind(enum.Enum): return NodeKind.REGULAR @classmethod - def classify_rdataset(cls, rdataset): + def classify_rdataset(cls, rdataset: dns.rdataset.Rdataset) -> 'NodeKind': return cls.classify(rdataset.rdtype, rdataset.covers) @@ -85,15 +90,15 @@ class Node: def __init__(self): # the set of rdatasets, represented as a list. - self.rdatasets = [] + self.rdatasets: List[dns.rdataset.Rdataset] = [] - def to_text(self, name, **kw): + def to_text(self, name: dns.name.Name, **kw) -> str: """Convert a node to text format. Each rdataset at the node is printed. Any keyword arguments to this method are passed on to the rdataset's to_text() method. - *name*, a ``dns.name.Name`` or ``str``, the owner name of the + *name*, a ``dns.name.Name``, the owner name of the rdatasets. Returns a ``str``. @@ -155,16 +160,19 @@ class Node: # edit self.rdatasets. self.rdatasets.append(rdataset) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create=False) -> dns.rdataset.Rdataset: """Find an rdataset matching the specified properties in the current node. - *rdclass*, an ``int``, the class of the rdataset. + *rdclass*, a ``dns.rdataclass.RdataClass``, the class of the rdataset. - *rdtype*, an ``int``, the type of the rdataset. + *rdtype*, a ``dns.rdatatype.RdataType``, the type of the rdataset. - *covers*, an ``int`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -191,8 +199,11 @@ class Node: self._append_rdataset(rds) return rds - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create=False) -> Optional[dns.rdataset.Rdataset]: """Get an rdataset matching the specified properties in the current node. @@ -223,7 +234,10 @@ class Node: rds = None return rds - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE): """Delete the rdataset matching the specified properties in the current node. @@ -240,7 +254,7 @@ class Node: if rds is not None: self.rdatasets.remove(rds) - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset): """Replace an rdataset. It is not an error if there is no rdataset matching *replacement*. @@ -265,7 +279,7 @@ class Node: replacement.covers) self._append_rdataset(replacement) - def classify(self): + def classify(self) -> NodeKind: """Classify a node. A node which contains a CNAME or RRSIG(CNAME) is a @@ -286,7 +300,7 @@ class Node: return kind return NodeKind.NEUTRAL - def is_immutable(self): + def is_immutable(self) -> bool: return False @@ -316,5 +330,5 @@ class ImmutableNode(Node): def replace_rdataset(self, replacement): raise TypeError("immutable") - def is_immutable(self): + def is_immutable(self) -> bool: return True diff --git a/dns/node.pyi b/dns/node.pyi deleted file mode 100644 index 0997edf9..00000000 --- a/dns/node.pyi +++ /dev/null @@ -1,17 +0,0 @@ -from typing import List, Optional, Union -from . import rdataset, rdatatype, name -class Node: - def __init__(self): - self.rdatasets : List[rdataset.Rdataset] - def to_text(self, name : Union[str,name.Name], **kw) -> str: - ... - def find_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, - create=False) -> rdataset.Rdataset: - ... - def get_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE, - create=False) -> Optional[rdataset.Rdataset]: - ... - def delete_rdataset(self, rdclass : int, rdtype : int, covers=rdatatype.NONE): - ... - def replace_rdataset(self, replacement : rdataset.Rdataset) -> None: - ... diff --git a/dns/opcode.py b/dns/opcode.py index 5cf6143c..971b62c8 100644 --- a/dns/opcode.py +++ b/dns/opcode.py @@ -45,7 +45,7 @@ class UnknownOpcode(dns.exception.DNSException): """An DNS opcode is unknown.""" -def from_text(text): +def from_text(text: str) -> Opcode: """Convert text into an opcode. *text*, a ``str``, the textual opcode @@ -58,7 +58,7 @@ def from_text(text): return Opcode.from_text(text) -def from_flags(flags): +def from_flags(flags: int) -> Opcode: """Extract an opcode from DNS message flags. *flags*, an ``int``, the DNS flags. @@ -66,10 +66,10 @@ def from_flags(flags): Returns an ``int``. """ - return (flags & 0x7800) >> 11 + return Opcode((flags & 0x7800) >> 11) -def to_flags(value): +def to_flags(value: Opcode) -> int: """Convert an opcode to a value suitable for ORing into DNS message flags. @@ -81,7 +81,7 @@ def to_flags(value): return (value << 11) & 0x7800 -def to_text(value): +def to_text(value: Opcode) -> str: """Convert an opcode to text. *value*, an ``int`` the opcode value, @@ -94,7 +94,7 @@ def to_text(value): return Opcode.to_text(value) -def is_update(flags): +def is_update(flags: int) -> bool: """Is the opcode in flags UPDATE? *flags*, an ``int``, the DNS message flags. diff --git a/dns/query.py b/dns/query.py index 19894df6..e2dca20d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -17,6 +17,8 @@ """Talk to a DNS server.""" +from typing import Any, Dict, Optional, Tuple, Union + import base64 import contextlib import enum @@ -37,6 +39,8 @@ import dns.rcode import dns.rdataclass import dns.rdatatype import dns.serial +import dns.transaction +import dns.tsig import dns.xfr try: @@ -74,6 +78,9 @@ except ImportError: # pragma: no cover class WantWriteException(Exception): pass + class SSLContext: + pass + class SSLSocket: pass @@ -149,9 +156,12 @@ if hasattr(selectors, 'PollSelector'): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). - _selector_class = selectors.PollSelector + # + # We ignore typing here as we can't say _selector_class is Any + # on python < 3.8 due to a bug. + _selector_class = selectors.PollSelector # type: ignore else: - _selector_class = selectors.SelectSelector # pragma: no cover + _selector_class = selectors.SelectSelector # type: ignore def _wait_for_readable(s, expiration): @@ -248,10 +258,11 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): s.close() raise -def https(q, where, timeout=None, port=443, source=None, source_port=0, +def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, + port=443, source: Optional[str]=None, source_port=0, one_rr_per_rrset=False, ignore_trailing=False, - session=None, path='/dns-query', post=True, - bootstrap_address=None, verify=True): + session: Optional[Any]=None, path='/dns-query', post=True, + bootstrap_address: Optional[str]=None, verify=True) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. @@ -314,6 +325,8 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, elif bootstrap_address is not None: _httpx_ok = False split_url = urllib.parse.urlsplit(where) + if split_url.hostname is None: + raise ValueError('DoH URL has no hostname') headers['Host'] = split_url.hostname url = where.replace(split_url.hostname, bootstrap_address) if _have_requests: @@ -374,10 +387,10 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, else: wire = base64.urlsafe_b64encode(wire).rstrip(b"=") if _is_httpx: - wire = wire.decode() # httpx does a repr() if we give it bytes + twire = wire.decode() # httpx does a repr() if we give it bytes response = session.get(url, headers=headers, timeout=timeout, - params={"dns": wire}) + params={"dns": twire}) else: response = session.get(url, headers=headers, timeout=timeout, verify=verify, @@ -395,7 +408,7 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, request_mac=q.request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing) - r.time = response.elapsed + r.time = response.elapsed.total_seconds() if not q.is_response(r): raise BadResponse return r @@ -427,7 +440,8 @@ def _udp_send(sock, data, destination, expiration): _wait_for_writable(sock, expiration) -def send_udp(sock, what, destination, expiration=None): +def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: Any, + expiration: Optional[float]=None) -> Tuple[int, float]: """Send a DNS message to the specified UDP socket. *sock*, a ``socket``. @@ -451,10 +465,10 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination=None, expiration=None, +def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None, ignore_unexpected=False, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False, - raise_on_truncation=False): + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', + ignore_trailing=False, raise_on_truncation=False) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``socket``. @@ -512,9 +526,10 @@ def receive_udp(sock, destination=None, expiration=None, else: return (r, received_time, from_address) -def udp(q, where, timeout=None, port=53, source=None, source_port=0, +def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, - raise_on_truncation=False, sock=None): + raise_on_truncation=False, sock: Optional[Any]=None) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -571,11 +586,13 @@ def udp(q, where, timeout=None, port=53, source=None, source_port=0, if not q.is_response(r): raise BadResponse return r + assert False # help mypy figure out we can't get here -def udp_with_fallback(q, where, timeout=None, port=53, source=None, - source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False, - udp_sock=None, tcp_sock=None): +def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, + udp_sock: Optional[Any]=None, + tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back to TCP if UDP results in a truncated response. @@ -665,7 +682,8 @@ def _net_write(sock, data, expiration): _wait_for_readable(sock, expiration) -def send_tcp(sock, what, expiration=None): +def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], + expiration: Optional[float]=None) -> Tuple[int, float]: """Send a DNS message to the specified TCP socket. *sock*, a ``socket``. @@ -680,18 +698,21 @@ def send_tcp(sock, what, expiration=None): """ if isinstance(what, dns.message.Message): - what = what.to_wire() - l = len(what) + wire = what.to_wire() + else: + wire = what + l = len(wire) # copying the wire into tcpmsg is inefficient, but lets us # avoid writev() or doing a short write that would get pushed # onto the net - tcpmsg = struct.pack("!H", l) + what + tcpmsg = struct.pack("!H", l) + wire sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) -def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, - keyring=None, request_mac=b'', ignore_trailing=False): +def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', + ignore_trailing=False) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``socket``. @@ -737,8 +758,9 @@ def _connect(s, address, expiration): raise OSError(err, os.strerror(err)) -def tcp(q, where, timeout=None, port=53, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None): +def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, + source: Optional[str]=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[Any]=None) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *q*, a ``dns.message.Message``, the query to send @@ -790,6 +812,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, if not q.is_response(r): raise BadResponse return r + assert False # help mypy figure out we can't get here def _tls_handshake(s, expiration): @@ -803,9 +826,11 @@ def _tls_handshake(s, expiration): _wait_for_writable(s, expiration) -def tls(q, where, timeout=None, port=853, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock=None, - ssl_context=None, server_hostname=None): +def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, + port=853, source: Optional[str]=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[ssl.SSLSocket]=None, + ssl_context: Optional[ssl.SSLContext]=None, + server_hostname: Optional[str]=None) -> dns.message.Message: """Return the response obtained after sending a query via TLS. *q*, a ``dns.message.Message``, the query to send @@ -885,7 +910,7 @@ def tls(q, where, timeout=None, port=853, source=None, source_port=0, if not q.is_response(r): raise BadResponse return r - + assert False # help mypy figure out we can't get here def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, timeout=None, port=53, keyring=None, keyname=None, relativize=True, @@ -1066,9 +1091,10 @@ class UDPMode(enum.IntEnum): ONLY = 2 -def inbound_xfr(where, txn_manager, query=None, - port=53, timeout=None, lifetime=None, source=None, - source_port=0, udp_mode=UDPMode.NEVER): +def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, + query: Optional[dns.message.Message]=None, + port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, + source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER): """Conduct an inbound transfer and apply it via a transaction from the txn_manager. diff --git a/dns/query.pyi b/dns/query.pyi deleted file mode 100644 index a22e229f..00000000 --- a/dns/query.pyi +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional, Union, Dict, Generator, Any -from . import tsig, rdatatype, rdataclass, name, message -from requests.sessions import Session - -import socket - -# If the ssl import works, then -# -# error: Name 'ssl' already defined (by an import) -# -# is expected and can be ignored. -try: - import ssl -except ImportError: - class ssl: # type: ignore - SSLContext : Dict = {} - -have_doh: bool - -def https(q : message.Message, where: str, timeout : Optional[float] = None, - port : Optional[int] = 443, source : Optional[str] = None, - source_port : Optional[int] = 0, - session: Optional[Session] = None, - path : Optional[str] = '/dns-query', post : Optional[bool] = True, - bootstrap_address : Optional[str] = None, - verify : Optional[bool] = True) -> message.Message: - pass - -def tcp(q : message.Message, where : str, timeout : float = None, port=53, - af : Optional[int] = None, source : Optional[str] = None, - source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None) -> message.Message: - pass - -def xfr(where : None, zone : Union[name.Name,str], rdtype=rdatatype.AXFR, - rdclass=rdataclass.IN, - timeout : Optional[float] = None, port=53, - keyring : Optional[Dict[name.Name, bytes]] = None, - keyname : Union[str,name.Name]= None, relativize=True, - lifetime : Optional[float] = None, - source : Optional[str] = None, source_port=0, serial=0, - use_udp : Optional[bool] = False, - keyalgorithm=tsig.default_algorithm) \ - -> Generator[Any,Any,message.Message]: - pass - -def udp(q : message.Message, where : str, timeout : Optional[float] = None, - port=53, source : Optional[str] = None, source_port : Optional[int] = 0, - ignore_unexpected : Optional[bool] = False, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None) -> message.Message: - pass - -def tls(q : message.Message, where : str, timeout : Optional[float] = None, - port=53, source : Optional[str] = None, source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[socket.socket] = None, - ssl_context: Optional[ssl.SSLContext] = None, - server_hostname: Optional[str] = None) -> message.Message: - pass diff --git a/dns/rdata.py b/dns/rdata.py index 6b5b5c5a..1e1992be 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -17,6 +17,8 @@ """DNS rdata.""" +from typing import Any, Dict, Optional, Tuple, Union + from importlib import import_module import base64 import binascii @@ -137,7 +139,7 @@ class Rdata: self.rdclass = self._as_rdataclass(rdclass) self.rdtype = self._as_rdatatype(rdtype) - self.rdcomment = None + self.rdcomment: Optional[str] = None def _get_all_slots(self): return itertools.chain.from_iterable(getattr(cls, '__slots__', []) @@ -165,7 +167,7 @@ class Rdata: # it if needed. object.__setattr__(self, 'rdcomment', None) - def covers(self): + def covers(self) -> dns.rdatatype.RdataType: """Return the type a Rdata covers. DNS SIG/RRSIG rdatas apply to a specific type; this type is @@ -174,12 +176,12 @@ class Rdata: creating rdatasets, allowing the rdataset to contain only RRSIGs of a particular type, e.g. RRSIG(NS). - Returns an ``int``. + Returns a ``dns.rdatatype.RdataType``. """ return dns.rdatatype.NONE - def extended_rdatatype(self): + def extended_rdatatype(self) -> int: """Return a 32-bit type value, the least significant 16 bits of which are the ordinary DNS type, and the upper 16 bits of which are the "covered" type, if any. @@ -189,7 +191,7 @@ class Rdata: return self.covers() << 16 | self.rdtype - def to_text(self, origin=None, relativize=True, **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw): """Convert an rdata to text format. Returns a ``str``. @@ -197,11 +199,12 @@ class Rdata: raise NotImplementedError # pragma: no cover - def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + def _to_wire(self, file, compress: Optional[dns.name.CompressType]=None, + origin: Optional[dns.name.Name]=None, canonicalize=False): raise NotImplementedError # pragma: no cover def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False): + canonicalize=False) -> bytes: """Convert an rdata to wire format. Returns a ``bytes`` or ``None``. @@ -214,7 +217,7 @@ class Rdata: self._to_wire(f, compress, origin, canonicalize) return f.getvalue() - def to_generic(self, origin=None): + def to_generic(self, origin: Optional[dns.name.Name]=None) -> 'dns.rdata.GenericRdata': """Creates a dns.rdata.GenericRdata equivalent of this rdata. Returns a ``dns.rdata.GenericRdata``. @@ -222,7 +225,7 @@ class Rdata: return dns.rdata.GenericRdata(self.rdclass, self.rdtype, self.to_wire(origin=origin)) - def to_digestable(self, origin=None): + def to_digestable(self, origin: Optional[dns.name.Name]=None) -> bytes: """Convert rdata to a format suitable for digesting in hashes. This is also the DNSSEC canonical form. @@ -348,12 +351,16 @@ class Rdata: return hash(self.to_digestable(dns.name.root)) @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text(cls, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize=True, + relativize_to: Optional[dns.name.Name]=None): raise NotImplementedError # pragma: no cover @classmethod - def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): + def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None): raise NotImplementedError # pragma: no cover def replace(self, **kwargs): @@ -408,18 +415,20 @@ class Rdata: return dns.rdatatype.RdataType.make(value) @classmethod - def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True): + def _as_bytes(cls, value, encode=False, max_length=None, empty_ok=True) -> bytes: if encode and isinstance(value, str): - value = value.encode() + bvalue = value.encode() elif isinstance(value, bytearray): - value = bytes(value) - elif not isinstance(value, bytes): + bvalue = bytes(value) + elif isinstance(value, bytes): + bvalue = value + else: raise ValueError('not bytes') - if max_length is not None and len(value) > max_length: + if max_length is not None and len(bvalue) > max_length: raise ValueError('too long') - if not empty_ok and len(value) == 0: + if not empty_ok and len(bvalue) == 0: raise ValueError('empty bytes not allowed') - return value + return bvalue @classmethod def _as_name(cls, value): @@ -571,7 +580,7 @@ class GenericRdata(Rdata): def from_wire_parser(cls, rdclass, rdtype, parser, origin=None): return cls(rdclass, rdtype, parser.get_remaining()) -_rdata_classes = {} +_rdata_classes: Dict[Tuple[dns.rdataclass.RdataClass, dns.rdatatype.RdataType], Any] = {} _module_prefix = 'dns.rdtypes' def get_rdata_class(rdclass, rdtype): @@ -602,8 +611,12 @@ def get_rdata_class(rdclass, rdtype): return cls -def from_text(rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None, idna_codec=None): +def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + tok: Union[dns.tokenizer.Tokenizer, str], + origin: Optional[dns.name.Name]=None, + relativize=True, relativize_to: Optional[dns.name.Name]=None, + idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata: """Build an rdata object from text format. This function attempts to dynamically load a class which @@ -617,9 +630,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, If *tok* is a ``str``, then a tokenizer is created and the string is used as its input. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. *tok*, a ``dns.tokenizer.Tokenizer`` or a ``str``. @@ -681,7 +694,9 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, return rdata -def from_wire_parser(rdclass, rdtype, parser, origin=None): +def from_wire_parser(rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which @@ -692,9 +707,9 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): Once a class is chosen, its from_wire() class method is called with the parameters to this function. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass`` or ``str``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdatatype. *parser*, a ``dns.wire.Parser``, the parser, which should be restricted to the rdata length. @@ -712,7 +727,10 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): return cls.from_wire_parser(rdclass, rdtype, parser, origin) -def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): +def from_wire(rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + wire: bytes, current: int, rdlen: int, + origin: Optional[dns.name.Name]=None) -> Rdata: """Build an rdata object from wire format This function attempts to dynamically load a class which diff --git a/dns/rdata.pyi b/dns/rdata.pyi deleted file mode 100644 index f394791f..00000000 --- a/dns/rdata.pyi +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Dict, Tuple, Any, Optional, BinaryIO -from .name import Name, IDNACodec -class Rdata: - def __init__(self): - self.address : str - def to_wire(self, file : Optional[BinaryIO], compress : Optional[Dict[Name,int]], origin : Optional[Name], canonicalize : Optional[bool]) -> Optional[bytes]: - ... - @classmethod - def from_text(cls, rdclass : int, rdtype : int, tok, origin=None, relativize=True): - ... -_rdata_modules : Dict[Tuple[Any,Rdata],Any] - -def from_text(rdclass : int, rdtype : int, tok : Optional[str], origin : Optional[Name] = None, - relativize : bool = True, relativize_to : Optional[Name] = None, - idna_codec : Optional[IDNACodec] = None): - ... - -def from_wire(rdclass : int, rdtype : int, wire : bytes, current : int, rdlen : int, origin : Optional[Name] = None): - ... diff --git a/dns/rdataset.py b/dns/rdataset.py index e6e95480..218adba3 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -17,16 +17,20 @@ """DNS rdatasets (an rdataset is a set of rdatas of a given type and class)""" +from typing import Any, cast, Collection, Dict, List, Optional, Union + import io import random import struct import dns.exception import dns.immutable +import dns.name import dns.rdatatype import dns.rdataclass import dns.rdata import dns.set +import dns.ttl # define SimpleSet here for backwards compatibility SimpleSet = dns.set.Set @@ -47,22 +51,24 @@ class Rdataset(dns.set.Set): __slots__ = ['rdclass', 'rdtype', 'covers', 'ttl'] - def __init__(self, rdclass, rdtype, covers=dns.rdatatype.NONE, ttl=0): + def __init__(self, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers=dns.rdatatype.NONE, ttl=0): """Create a new rdataset of the specified class and type. - *rdclass*, an ``int``, the rdataclass. + *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass. - *rdtype*, an ``int``, the rdatatype. + *rdtype*, an ``dns.rdatatype.RdataType``, the rdatatype. - *covers*, an ``int``, the covered rdatatype. + *covers*, an ``dns.rdatatype.RdataType``, the covered rdatatype. *ttl*, an ``int``, the TTL. """ super().__init__() self.rdclass = rdclass - self.rdtype = rdtype - self.covers = covers + self.rdtype: dns.rdatatype.RdataType = rdtype + self.covers: dns.rdatatype.RdataType = covers self.ttl = ttl def _clone(self): @@ -73,7 +79,7 @@ class Rdataset(dns.set.Set): obj.ttl = self.ttl return obj - def update_ttl(self, ttl): + def update_ttl(self, ttl: int): """Perform TTL minimization. Set the TTL of the rdataset to be the lesser of the set's current @@ -88,7 +94,7 @@ class Rdataset(dns.set.Set): elif ttl < self.ttl: self.ttl = ttl - def add(self, rd, ttl=None): # pylint: disable=arguments-differ + def add(self, rd, ttl: Optional[int]=None): # pylint: disable=arguments-differ """Add the specified rdata to the rdataset. If the optional *ttl* parameter is supplied, then @@ -176,8 +182,11 @@ class Rdataset(dns.set.Set): def __ne__(self, other): return not self.__eq__(other) - def to_text(self, name=None, origin=None, relativize=True, - override_rdclass=None, want_comments=False, **kw): + def to_text(self, name: Optional[dns.name.Name]=None, + origin: Optional[dns.name.Name]=None, + relativize=True, + override_rdclass: Optional[dns.rdataclass.RdataClass]=None, + want_comments=False, **kw) -> str: """Convert the rdataset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -241,8 +250,11 @@ class Rdataset(dns.set.Set): # return s.getvalue()[:-1] - def to_wire(self, name, file, compress=None, origin=None, - override_rdclass=None, want_shuffle=True): + def to_wire(self, name: dns.name.Name, file: Any, + compress: Optional[dns.name.CompressType]=None, + origin: Optional[dns.name.Name]=None, + override_rdclass: Optional[dns.rdataclass.RdataClass]=None, + want_shuffle=True) -> int: """Convert the rdataset to wire format. *name*, a ``dns.name.Name`` is the owner name to use. @@ -279,6 +291,7 @@ class Rdataset(dns.set.Set): file.write(stuff) return 1 else: + l: Union[Rdataset, List[dns.rdata.Rdata]] if want_shuffle: l = list(self) random.shuffle(l) @@ -299,7 +312,9 @@ class Rdataset(dns.set.Set): file.seek(0, io.SEEK_END) return len(self) - def match(self, rdclass, rdtype, covers): + def match(self, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType) -> bool: """Returns ``True`` if this rdataset matches the specified class, type, and covers. """ @@ -309,7 +324,7 @@ class Rdataset(dns.set.Set): return True return False - def processing_order(self): + def processing_order(self) -> List[dns.rdata.Rdata]: """Return rdatas in a valid processing order according to the type's specification. For example, MX records are in preference order from lowest to highest preferences, with items of the same preference @@ -331,7 +346,7 @@ class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] _clone_class = Rdataset - def __init__(self, rdataset): + def __init__(self, rdataset: Rdataset): """Create an immutable rdataset from the specified rdataset.""" super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, @@ -394,8 +409,12 @@ class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] return ImmutableRdataset(super().symmetric_difference(other)) -def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, - origin=None, relativize=True, relativize_to=None): +def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec]=None, + origin: Optional[dns.name.Name]=None, + relativize=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. @@ -414,9 +433,9 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, Returns a ``dns.rdataset.Rdataset`` object. """ - rdclass = dns.rdataclass.RdataClass.make(rdclass) - rdtype = dns.rdatatype.RdataType.make(rdtype) - r = Rdataset(rdclass, rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + r = Rdataset(the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, @@ -425,17 +444,19 @@ def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, return r -def from_text(rdclass, rdtype, ttl, *text_rdatas): +def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + ttl: int, *text_rdatas) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified rdatas in text format. Returns a ``dns.rdataset.Rdataset`` object. """ - return from_text_list(rdclass, rdtype, ttl, text_rdatas) + return from_text_list(rdclass, rdtype, ttl, cast(Collection[str], text_rdatas)) -def from_rdata_list(ttl, rdatas): +def from_rdata_list(ttl: int, rdatas: Collection[dns.rdata.Rdata]) -> Rdataset: """Create an rdataset with the specified TTL, and with the specified list of rdata objects. @@ -450,14 +471,15 @@ def from_rdata_list(ttl, rdatas): r = Rdataset(rd.rdclass, rd.rdtype) r.update_ttl(ttl) r.add(rd) + assert r is not None return r -def from_rdata(ttl, *rdatas): +def from_rdata(ttl: int, *rdatas) -> Rdataset: """Create an rdataset with the specified TTL, and with the specified rdata objects. Returns a ``dns.rdataset.Rdataset`` object. """ - return from_rdata_list(ttl, rdatas) + return from_rdata_list(ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/dns/rdataset.pyi b/dns/rdataset.pyi deleted file mode 100644 index a7bbf2d4..00000000 --- a/dns/rdataset.pyi +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional, Dict, List, Union -from io import BytesIO -from . import exception, name, set, rdatatype, rdata, rdataset - -class DifferingCovers(exception.DNSException): - """An attempt was made to add a DNS SIG/RRSIG whose covered type - is not the same as that of the other rdatas in the rdataset.""" - - -class IncompatibleTypes(exception.DNSException): - """An attempt was made to add DNS RR data of an incompatible type.""" - - -class Rdataset(set.Set): - def __init__(self, rdclass, rdtype, covers=rdatatype.NONE, ttl=0): - self.rdclass : int = rdclass - self.rdtype : int = rdtype - self.covers : int = covers - self.ttl : int = ttl - - def update_ttl(self, ttl : int) -> None: - ... - - def add(self, rd : rdata.Rdata, ttl : Optional[int] =None): - ... - - def union_update(self, other : Rdataset): - ... - - def intersection_update(self, other : Rdataset): - ... - - def update(self, other : Rdataset): - ... - - def to_text(self, name : Optional[name.Name] =None, origin : Optional[name.Name] =None, relativize=True, - override_rdclass : Optional[int] =None, **kw) -> bytes: - ... - - def to_wire(self, name : Optional[name.Name], file : BytesIO, compress : Optional[Dict[name.Name, int]] = None, origin : Optional[name.Name] = None, - override_rdclass : Optional[int] = None, want_shuffle=True) -> int: - ... - - def match(self, rdclass : int, rdtype : int, covers : int) -> bool: - ... - - -def from_text_list(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, text_rdatas : str, idna_codec : Optional[name.IDNACodec] = None) -> rdataset.Rdataset: - ... - -def from_text(rdclass : Union[int,str], rdtype : Union[int,str], ttl : int, *text_rdatas : str) -> rdataset.Rdataset: - ... - -def from_rdata_list(ttl : int, rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: - ... - -def from_rdata(ttl : int, *rdatas : List[rdata.Rdata]) -> rdataset.Rdataset: - ... diff --git a/dns/rdatatype.py b/dns/rdatatype.py index 9499c7b9..80f8acaf 100644 --- a/dns/rdatatype.py +++ b/dns/rdatatype.py @@ -17,6 +17,8 @@ """DNS Rdata Types.""" +from typing import Dict + import dns.enum import dns.exception @@ -120,8 +122,8 @@ class RdataType(dns.enum.IntEnum): def _unknown_exception_class(cls): return UnknownRdatatype -_registered_by_text = {} -_registered_by_value = {} +_registered_by_text: Dict[str, RdataType] = {} +_registered_by_value: Dict[RdataType, str] = {} _metatypes = {RdataType.OPT} diff --git a/dns/rdtypes/ANY/CERT.py b/dns/rdtypes/ANY/CERT.py index f35ce3ad..f8990ebe 100644 --- a/dns/rdtypes/ANY/CERT.py +++ b/dns/rdtypes/ANY/CERT.py @@ -20,7 +20,7 @@ import base64 import dns.exception import dns.immutable -import dns.dnssec +import dns.dnssectypes import dns.rdata import dns.tokenizer @@ -85,7 +85,7 @@ class CERT(dns.rdata.Rdata): def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) return "%s %d %s %s" % (certificate_type, self.key_tag, - dns.dnssec.algorithm_to_text(self.algorithm), + dns.dnssectypes.Algorithm.to_text(self.algorithm), dns.rdata._base64ify(self.certificate, **kw)) @classmethod @@ -93,7 +93,7 @@ class CERT(dns.rdata.Rdata): relativize_to=None): certificate_type = _ctype_from_text(tok.get_string()) key_tag = tok.get_uint16() - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) b64 = tok.concatenate_remaining_identifiers().encode() certificate = base64.b64decode(b64) return cls(rdclass, rdtype, certificate_type, key_tag, diff --git a/dns/rdtypes/ANY/RRSIG.py b/dns/rdtypes/ANY/RRSIG.py index d050ccc6..82650c0f 100644 --- a/dns/rdtypes/ANY/RRSIG.py +++ b/dns/rdtypes/ANY/RRSIG.py @@ -20,7 +20,7 @@ import calendar import struct import time -import dns.dnssec +import dns.dnssectypes import dns.immutable import dns.exception import dns.rdata @@ -65,7 +65,7 @@ class RRSIG(dns.rdata.Rdata): signature): super().__init__(rdclass, rdtype) self.type_covered = self._as_rdatatype(type_covered) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.labels = self._as_uint8(labels) self.original_ttl = self._as_ttl(original_ttl) self.expiration = self._as_uint32(expiration) @@ -94,7 +94,7 @@ class RRSIG(dns.rdata.Rdata): def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, relativize_to=None): type_covered = dns.rdatatype.from_text(tok.get_string()) - algorithm = dns.dnssec.algorithm_from_text(tok.get_string()) + algorithm = dns.dnssectypes.Algorithm.from_text(tok.get_string()) labels = tok.get_int() original_ttl = tok.get_ttl() expiration = sigtime_to_posixtime(tok.get_string()) diff --git a/dns/rdtypes/ANY/TKEY.py b/dns/rdtypes/ANY/TKEY.py index 861fc4e3..59ffe039 100644 --- a/dns/rdtypes/ANY/TKEY.py +++ b/dns/rdtypes/ANY/TKEY.py @@ -18,7 +18,6 @@ import base64 import struct -import dns.dnssec import dns.immutable import dns.exception import dns.rdata diff --git a/dns/rdtypes/ANY/ZONEMD.py b/dns/rdtypes/ANY/ZONEMD.py index 035f7b32..75f99e5e 100644 --- a/dns/rdtypes/ANY/ZONEMD.py +++ b/dns/rdtypes/ANY/ZONEMD.py @@ -6,7 +6,7 @@ import binascii import dns.immutable import dns.rdata import dns.rdatatype -import dns.zone +import dns.zonetypes @dns.immutable.immutable @@ -21,8 +21,8 @@ class ZONEMD(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, serial, scheme, hash_algorithm, digest): super().__init__(rdclass, rdtype) self.serial = self._as_uint32(serial) - self.scheme = dns.zone.DigestScheme.make(scheme) - self.hash_algorithm = dns.zone.DigestHashAlgorithm.make(hash_algorithm) + self.scheme = dns.zonetypes.DigestScheme.make(scheme) + self.hash_algorithm = dns.zonetypes.DigestHashAlgorithm.make(hash_algorithm) self.digest = self._as_bytes(digest) if self.scheme == 0: # reserved, RFC 8976 Sec. 5.2 @@ -30,7 +30,7 @@ class ZONEMD(dns.rdata.Rdata): if self.hash_algorithm == 0: # reserved, RFC 8976 Sec. 5.3 raise ValueError('hash_algorithm 0 is reserved') - hasher = dns.zone._digest_hashers.get(self.hash_algorithm) + hasher = dns.zonetypes._digest_hashers.get(self.hash_algorithm) if hasher and hasher().digest_size != len(self.digest): raise ValueError('digest length inconsistent with hash algorithm') diff --git a/dns/rdtypes/dnskeybase.py b/dns/rdtypes/dnskeybase.py index 788bb2bf..832df2d7 100644 --- a/dns/rdtypes/dnskeybase.py +++ b/dns/rdtypes/dnskeybase.py @@ -21,7 +21,7 @@ import struct import dns.exception import dns.immutable -import dns.dnssec +import dns.dnssectypes import dns.rdata # wildcard import @@ -44,7 +44,7 @@ class DNSKEYBase(dns.rdata.Rdata): super().__init__(rdclass, rdtype) self.flags = self._as_uint16(flags) self.protocol = self._as_uint8(protocol) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.key = self._as_bytes(key) def to_text(self, origin=None, relativize=True, **kw): diff --git a/dns/rdtypes/dnskeybase.pyi b/dns/rdtypes/dnskeybase.pyi deleted file mode 100644 index 1b999cfd..00000000 --- a/dns/rdtypes/dnskeybase.pyi +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Set, Any - -SEP : int -REVOKE : int -ZONE : int - -def flags_to_text_set(flags : int) -> Set[str]: - ... - -def flags_from_text_set(texts_set) -> int: - ... - -from .. import rdata - -class DNSKEYBase(rdata.Rdata): - def __init__(self, rdclass, rdtype, flags, protocol, algorithm, key): - self.flags : int - self.protocol : int - self.key : str - self.algorithm : int - - def to_text(self, origin : Any = None, relativize=True, **kw : Any): - ... - - @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): - ... - - def _to_wire(self, file, compress=None, origin=None, canonicalize=False): - ... - - @classmethod - def from_parser(cls, rdclass, rdtype, parser, origin=None): - ... - - def flags_to_text_set(self) -> Set[str]: - ... diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py index 0c2e7471..3bf93acc 100644 --- a/dns/rdtypes/dsbase.py +++ b/dns/rdtypes/dsbase.py @@ -18,7 +18,7 @@ import struct import binascii -import dns.dnssec +import dns.dnssectypes import dns.immutable import dns.rdata import dns.rdatatype @@ -43,7 +43,7 @@ class DSBase(dns.rdata.Rdata): digest): super().__init__(rdclass, rdtype) self.key_tag = self._as_uint16(key_tag) - self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.algorithm = dns.dnssectypes.Algorithm.make(algorithm) self.digest_type = self._as_uint8(digest_type) self.digest = self._as_bytes(digest) try: diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index 68071ee0..7ad7914f 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -17,6 +17,8 @@ """TXT-like base class.""" +from typing import Iterable, Optional, Tuple, Union + import struct import dns.exception @@ -32,7 +34,7 @@ class TXTBase(dns.rdata.Rdata): __slots__ = ['strings'] - def __init__(self, rdclass, rdtype, strings): + def __init__(self, rdclass, rdtype, strings: Iterable[Union[bytes, str]]): """Initialize a TXT-like rdata. *rdclass*, an ``int`` is the rdataclass of the Rdata. @@ -42,10 +44,9 @@ class TXTBase(dns.rdata.Rdata): *strings*, a tuple of ``bytes`` """ super().__init__(rdclass, rdtype) - self.strings = self._as_tuple(strings, - lambda x: self._as_bytes(x, True, 255)) + self.strings: Tuple[bytes] = self._as_tuple(strings, lambda x: self._as_bytes(x, True, 255)) - def to_text(self, origin=None, relativize=True, **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw): txt = '' prefix = '' for s in self.strings: @@ -54,8 +55,8 @@ class TXTBase(dns.rdata.Rdata): return txt @classmethod - def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True, - relativize_to=None): + def from_text(cls, rdclass, rdtype, tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, + relativize=True, relativize_to: Optional[dns.name.Name]=None): strings = [] for token in tok.get_remaining(): token = token.unescape_to_bytes() diff --git a/dns/rdtypes/txtbase.pyi b/dns/rdtypes/txtbase.pyi deleted file mode 100644 index f8d5df98..00000000 --- a/dns/rdtypes/txtbase.pyi +++ /dev/null @@ -1,12 +0,0 @@ -import typing -from .. import rdata - -class TXTBase(rdata.Rdata): - strings: typing.Tuple[bytes, ...] - - def __init__(self, rdclass: int, rdtype: int, strings: typing.Iterable[bytes]) -> None: - ... - def to_text(self, origin: typing.Any, relativize: bool, **kw: typing.Any) -> str: - ... -class TXT(TXTBase): - ... diff --git a/dns/resolver.py b/dns/resolver.py index 332c82c0..42d228d9 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -16,6 +16,9 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. """DNS stub resolver.""" + +from typing import Any, Dict, List, Optional, Tuple, Union + from urllib.parse import urlparse import contextlib import socket @@ -52,6 +55,10 @@ class NXDOMAIN(dns.exception.DNSException): # pylint: disable=arguments-differ + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def _check_kwargs(self, qnames, responses=None): if not isinstance(qnames, (list, tuple, set)): @@ -132,7 +139,10 @@ class YXDOMAIN(dns.exception.DNSException): """The DNS query name is too long after DNAME substitution.""" -def _errors_to_text(errors): +ErrorTuple = Tuple[str, bool, int, Exception, dns.message.Message] + + +def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: """Turn a resolution errors trace into a list of text.""" texts = [] for err in errors: @@ -148,6 +158,10 @@ class LifetimeTimeout(dns.exception.Timeout): fmt = "%s after {timeout:.3f} seconds: {errors}" % msg[:-1] supp_kwargs = {'timeout', 'errors'} + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def _fmt_kwargs(self, **kwargs): srv_msgs = _errors_to_text(kwargs['errors']) return super()._fmt_kwargs(timeout=kwargs['timeout'], @@ -166,6 +180,10 @@ class NoAnswer(dns.exception.DNSException): 'to the question: {query}' supp_kwargs = {'response'} + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def _fmt_kwargs(self, **kwargs): return super()._fmt_kwargs(query=kwargs['response'].question) @@ -186,6 +204,10 @@ class NoNameservers(dns.exception.DNSException): fmt = "%s {query}: {errors}" % msg[:-1] supp_kwargs = {'request', 'errors'} + # We do this as otherwise mypy complains about unexpected keyword argument idna_exception + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def _fmt_kwargs(self, **kwargs): srv_msgs = _errors_to_text(kwargs['errors']) return super()._fmt_kwargs(query=kwargs['request'].question, @@ -222,8 +244,9 @@ class Answer: RRset's name might not be the query name. """ - def __init__(self, qname, rdtype, rdclass, response, nameserver=None, - port=None): + def __init__(self, qname: dns.name.Name, rdtype: dns.rdatatype.RdataType, + rdclass: dns.rdataclass.RdataClass, response: dns.message.QueryMessage, + nameserver: Optional[str]=None, port: Optional[int]=None): self.qname = qname self.rdtype = rdtype self.rdclass = rdclass @@ -280,7 +303,7 @@ class CacheStatistics: self.hits = 0 self.misses = 0 - def clone(self): + def clone(self) -> 'CacheStatistics': return CacheStatistics(self.hits, self.misses) @@ -304,7 +327,7 @@ class CacheBase: with self.lock: return self.statistics.misses - def get_statistics_snapshot(self): + def get_statistics_snapshot(self) -> CacheStatistics: """Return a consistent snapshot of all the statistics. If running with multiple threads, it's better to take a @@ -315,6 +338,9 @@ class CacheBase: return self.statistics.clone() +CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass] + + class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" @@ -342,12 +368,12 @@ class Cache(CacheBase): now = time.time() self.next_cleaning = now + self.cleaning_interval - def get(self, key): + def get(self, key: CacheKey) -> Optional[Answer]: """Get the answer associated with *key*. Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. @@ -362,10 +388,10 @@ class Cache(CacheBase): self.statistics.hits += 1 return v - def put(self, key, value): + def put(self, key: CacheKey, value: Answer): """Associate key and value in the cache. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. @@ -375,13 +401,13 @@ class Cache(CacheBase): self._maybe_clean() self.data[key] = value - def flush(self, key=None): + def flush(self, key: Optional[CacheKey]=None): """Flush the cache. If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. """ @@ -442,12 +468,12 @@ class LRUCache(CacheBase): max_size = 1 self.max_size = max_size - def get(self, key): + def get(self, key: CacheKey) -> Optional[Answer]: """Get the answer associated with *key*. Returns None if no answer is cached for the key. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. Returns a ``dns.resolver.Answer`` or ``None``. @@ -470,7 +496,7 @@ class LRUCache(CacheBase): node.hits += 1 return node.value - def get_hits_for_key(self, key): + def get_hits_for_key(self, key: CacheKey) -> int: """Return the number of cache hits associated with the specified key.""" with self.lock: node = self.data.get(key) @@ -479,10 +505,10 @@ class LRUCache(CacheBase): else: return node.hits - def put(self, key, value): + def put(self, key: CacheKey, value: Answer): """Associate key and value in the cache. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. *value*, a ``dns.resolver.Answer``, the answer. @@ -501,13 +527,13 @@ class LRUCache(CacheBase): node.link_after(self.sentinel) self.data[key] = node - def flush(self, key=None): + def flush(self, key: Optional[CacheKey]=None): """Flush the cache. If *key* is not ``None``, only that item is flushed. Otherwise the entire cache is flushed. - *key*, a ``(dns.name.Name, int, int)`` tuple whose values are the + *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the query name, rdtype, and rdclass respectively. """ @@ -537,8 +563,10 @@ class _Resolution: resolver data structures directly. """ - def __init__(self, resolver, qname, rdtype, rdclass, tcp, - raise_on_no_answer, search): + def __init__(self, resolver: 'BaseResolver', qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str], + tcp: bool, raise_on_no_answer: bool, search: Optional[bool]): if isinstance(qname, str): qname = dns.name.from_text(qname, None) rdtype = dns.rdatatype.RdataType.make(rdtype) @@ -554,21 +582,20 @@ class _Resolution: self.rdclass = rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer - self.nxdomain_responses = {} - # + self.nxdomain_responses: Dict[dns.name.Name, Answer] = {} # Initialize other things to help analysis tools self.qname = dns.name.empty - self.nameservers = [] - self.current_nameservers = [] - self.errors = [] - self.nameserver = None + self.nameservers: List[str] = [] + self.current_nameservers: List[str] = [] + self.errors: List[ErrorTuple] = [] + self.nameserver: Optional[str] = None self.port = 0 self.tcp_attempt = False self.retry_with_tcp = False - self.request = None - self.backoff = 0 + self.request: Optional[dns.message.QueryMessage] = None + self.backoff = 0.0 - def next_request(self): + def next_request(self) -> Tuple[Optional[dns.message.QueryMessage], Optional[Answer]]: """Get the next request to send, and check the cache. Returns a (request, answer) tuple. At most one of request or @@ -732,6 +759,7 @@ class _Resolution: dns.rcode.to_text(rcode), response)) return (None, False) + class BaseResolver: """DNS stub resolver.""" @@ -765,10 +793,10 @@ class BaseResolver: dns.name.Name(dns.name.from_text(socket.gethostname())[1:]) if len(self.domain) == 0: self.domain = dns.name.root - self.nameservers = [] - self.nameserver_ports = {} + self.nameservers: List[str] = [] + self.nameserver_ports: Dict[str, int] = {} self.port = 53 - self.search = [] + self.search: List[dns.name.Name] = [] self.use_search_by_default = False self.timeout = 2.0 self.lifetime = 5.0 @@ -777,13 +805,13 @@ class BaseResolver: self.keyalgorithm = dns.tsig.default_algorithm self.edns = -1 self.ednsflags = 0 - self.ednsoptions = None + self.ednsoptions: Optional[List[dns.edns.Option]] = None self.payload = 0 self.cache = None self.flags = None self.retry_servfail = False self.rotate = False - self.ndots = None + self.ndots: Optional[int] = None def read_resolv_conf(self, f): """Process *f* as a file in the /etc/resolv.conf format. If f is @@ -862,7 +890,8 @@ class BaseResolver: except AttributeError: raise NotImplementedError - def _compute_timeout(self, start, lifetime=None, errors=None): + def _compute_timeout(self, start: float, lifetime: Optional[float]=None, + errors: Optional[List[ErrorTuple]]=None) -> float: lifetime = self.lifetime if lifetime is None else lifetime now = time.time() duration = now - start @@ -881,7 +910,7 @@ class BaseResolver: raise LifetimeTimeout(timeout=duration, errors=errors) return min(lifetime - duration, self.timeout) - def _get_qnames_to_try(self, qname, search): + def _get_qnames_to_try(self, qname: dns.name.Name, search: Optional[bool]) -> List[dns.name.Name]: # This is a separate method so we can unit test the search # rules without requiring the Internet. if search is None: @@ -960,7 +989,7 @@ class BaseResolver: self.payload = payload self.ednsoptions = options - def set_flags(self, flags): + def set_flags(self, flags: int): """Overrides the default flags with your own. *flags*, an ``int``, the message flags to use. @@ -969,11 +998,11 @@ class BaseResolver: self.flags = flags @property - def nameservers(self): + def nameservers(self) -> List[str]: return self._nameservers @nameservers.setter - def nameservers(self, nameservers): + def nameservers(self, nameservers: List[str]): """ *nameservers*, a ``list`` of nameservers. @@ -998,9 +1027,11 @@ class BaseResolver: class Resolver(BaseResolver): """DNS stub resolver.""" - def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0, - lifetime=None, search=None): # pylint: disable=arguments-differ + def resolve(self, qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pylint: disable=arguments-differ """Query nameservers to find the answer to the question. The *qname*, *rdtype*, and *rdclass* parameters may be objects @@ -1064,6 +1095,7 @@ class Resolver(BaseResolver): if answer is not None: # cache hit! return answer + assert request is not None # needed for type checking done = False while not done: (nameserver, port, tcp, backoff) = resolution.next_nameserver() @@ -1101,9 +1133,11 @@ class Resolver(BaseResolver): if answer is not None: return answer - def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, source_port=0, - lifetime=None): # pragma: no cover + def query(self, qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1117,7 +1151,7 @@ class Resolver(BaseResolver): raise_on_no_answer, source_port, lifetime, True) - def resolve_address(self, ipaddr, *args, **kwargs): + def resolve_address(self, ipaddr: str, *args, **kwargs) -> Answer: """Use a resolver to run a reverse query for PTR records. This utilizes the resolve() method to perform a PTR lookup on the @@ -1130,15 +1164,19 @@ class Resolver(BaseResolver): except for rdtype and rdclass are also supported by this function. """ - + # We make a modified kwargs for type checking happiness, as otherwise + # we get a legit warning about possibly having rdtype and rdclass + # in the kwargs more than once. + modified_kwargs = {} + modified_kwargs.update(kwargs) + modified_kwargs['rdtype'] = dns.rdatatype.PTR + modified_kwargs['rdclass'] = dns.rdataclass.IN return self.resolve(dns.reversename.from_address(ipaddr), - rdtype=dns.rdatatype.PTR, - rdclass=dns.rdataclass.IN, - *args, **kwargs) + *args, **modified_kwargs) # pylint: disable=redefined-outer-name - def canonical_name(self, name): + def canonical_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. The canonical name is the name the resolver uses for queries @@ -1163,13 +1201,14 @@ class Resolver(BaseResolver): #: The default resolver. -default_resolver = None +default_resolver: Optional[Resolver] = None -def get_default_resolver(): +def get_default_resolver() -> Resolver: """Get the default resolver, initializing it if necessary.""" if default_resolver is None: reset_default_resolver() + assert default_resolver is not None return default_resolver @@ -1184,9 +1223,12 @@ def reset_default_resolver(): default_resolver = Resolver() -def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None, search=None): +def resolve(qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pragma: no cover + """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -1200,9 +1242,11 @@ def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, raise_on_no_answer, source_port, lifetime, search) -def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime=None): # pragma: no cover +def query(qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + lifetime: Optional[float]=None) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. This method calls resolve() with ``search=True``, and is @@ -1217,7 +1261,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, True) -def resolve_address(ipaddr, *args, **kwargs): +def resolve_address(ipaddr: str, *args, **kwargs) -> Answer: """Use a resolver to run a reverse query for PTR records. See ``dns.resolver.Resolver.resolve_address`` for more information on the @@ -1227,7 +1271,7 @@ def resolve_address(ipaddr, *args, **kwargs): return get_default_resolver().resolve_address(ipaddr, *args, **kwargs) -def canonical_name(name): +def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: """Determine the canonical name of *name*. See ``dns.resolver.Resolver.canonical_name`` for more information on the @@ -1237,8 +1281,9 @@ def canonical_name(name): return get_default_resolver().canonical_name(name) -def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, - lifetime=None): +def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN, + tcp=False, resolver: Optional[Resolver]=None, + lifetime: Optional[float]=None) -> dns.name.Name: """Find the name of the zone which contains the specified name. *name*, an absolute ``dns.name.Name`` or ``str``, the query name. @@ -1285,6 +1330,7 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None, rlifetime = None answer = resolver.resolve(name, dns.rdatatype.SOA, rdclass, tcp, lifetime=rlifetime) + assert answer.rrset is not None if answer.rrset.name == name: return name # otherwise we were CNAMEd or DNAMEd and need to look higher @@ -1544,7 +1590,7 @@ def _gethostbyaddr(ip): return (canonical, aliases, addresses) -def override_system_resolver(resolver=None): +def override_system_resolver(resolver: Optional[Resolver]=None): """Override the system resolver routines in the socket module with versions which use dnspython's resolver. diff --git a/dns/resolver.pyi b/dns/resolver.pyi deleted file mode 100644 index 348df4da..00000000 --- a/dns/resolver.pyi +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Union, Optional, List, Any, Dict -from . import exception, rdataclass, name, rdatatype - -import socket -_gethostbyname = socket.gethostbyname - -class NXDOMAIN(exception.DNSException): ... -class YXDOMAIN(exception.DNSException): ... -class NoAnswer(exception.DNSException): ... -class NoNameservers(exception.DNSException): ... -class NotAbsolute(exception.DNSException): ... -class NoRootSOA(exception.DNSException): ... -class NoMetaqueries(exception.DNSException): ... -class NoResolverConfiguration(exception.DNSException): ... -Timeout = exception.Timeout - -def resolve(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None, - search : Optional[bool]=None): - ... -def query(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, lifetime : Optional[float]=None): - ... -def resolve_address(ipaddr: str, *args: Any, **kwargs: Optional[Dict]): - ... -class LRUCache: - def __init__(self, max_size=1000): - ... - def get(self, key): - ... - def put(self, key, val): - ... -class Answer: - def __init__(self, qname, rdtype, rdclass, response, - raise_on_no_answer=True): - ... -def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, - resolver : Optional[Resolver] = None): - ... - -class Resolver: - def __init__(self, filename : Optional[str] = '/etc/resolv.conf', - configure : Optional[bool] = True): - self.nameservers : List[str] - def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None, - search : Optional[bool]=None): - ... - def query(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - lifetime : Optional[float]=None): - ... -default_resolver: typing.Optional[Resolver] -def reset_default_resolver() -> None: - ... -def get_default_resolver() -> Resolver: - ... diff --git a/dns/reversename.py b/dns/reversename.py index e0beb03d..4b70cf64 100644 --- a/dns/reversename.py +++ b/dns/reversename.py @@ -27,8 +27,8 @@ ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') -def from_address(text, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain): +def from_address(text: str, v4_origin=ipv4_reverse_domain, + v6_origin=ipv6_reverse_domain) -> dns.name.Name: """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. @@ -63,8 +63,8 @@ def from_address(text, v4_origin=ipv4_reverse_domain, return dns.name.from_text('.'.join(reversed(parts)), origin=origin) -def to_address(name, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain): +def to_address(name: dns.name.Name, v4_origin=ipv4_reverse_domain, + v6_origin=ipv6_reverse_domain) -> str: """Convert a reverse map domain name into textual address form. *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name diff --git a/dns/reversename.pyi b/dns/reversename.pyi deleted file mode 100644 index 97f072ea..00000000 --- a/dns/reversename.pyi +++ /dev/null @@ -1,6 +0,0 @@ -from . import name -def from_address(text : str) -> name.Name: - ... - -def to_address(name : name.Name) -> str: - ... diff --git a/dns/rrset.py b/dns/rrset.py index a71d4573..37458571 100644 --- a/dns/rrset.py +++ b/dns/rrset.py @@ -17,6 +17,7 @@ """DNS RRsets (an RRset is a named rdataset)""" +from typing import cast, Collection, Optional, Union import dns.name import dns.rdataset @@ -37,8 +38,10 @@ class RRset(dns.rdataset.Rdataset): __slots__ = ['name', 'deleting'] - def __init__(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE, - deleting=None): + def __init__(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + deleting: Optional[dns.rdataclass.RdataClass]=None): """Create a new RRset.""" super().__init__(rdclass, rdtype, covers) @@ -76,7 +79,7 @@ class RRset(dns.rdataset.Rdataset): return False return super().__eq__(other) - def match(self, *args, **kwargs): + def match(self, *args, **kwargs) -> bool: """Does this rrset match the specified attributes? Behaves as :py:func:`full_match()` if the first argument is a @@ -93,8 +96,9 @@ class RRset(dns.rdataset.Rdataset): else: return super().match(*args, **kwargs) - def full_match(self, name, rdclass, rdtype, covers, - deleting=None): + def full_match(self, name: dns.name.Name, rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType, + deleting: Optional[dns.rdataclass.RdataClass]=None) -> bool: """Returns ``True`` if this rrset matches the specified name, class, type, covers, and deletion state. """ @@ -106,7 +110,7 @@ class RRset(dns.rdataset.Rdataset): # pylint: disable=arguments-differ - def to_text(self, origin=None, relativize=True, **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw) -> str: # type: ignore """Convert the RRset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -126,8 +130,8 @@ class RRset(dns.rdataset.Rdataset): return super().to_text(self.name, origin, relativize, self.deleting, **kw) - def to_wire(self, file, compress=None, origin=None, - **kw): + def to_wire(self, file, compress: Optional[dns.name.CompressType]=None, # type: ignore + origin: Optional[dns.name.Name]=None, **kw) -> int: """Convert the RRset to wire format. All keyword arguments are passed to ``dns.rdataset.to_wire()``; see @@ -141,7 +145,7 @@ class RRset(dns.rdataset.Rdataset): # pylint: enable=arguments-differ - def to_rdataset(self): + def to_rdataset(self) -> dns.rdataset.Rdataset: """Convert an RRset into an Rdataset. Returns a ``dns.rdataset.Rdataset``. @@ -149,9 +153,13 @@ class RRset(dns.rdataset.Rdataset): return dns.rdataset.from_rdata_list(self.ttl, list(self)) -def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, - idna_codec=None, origin=None, relativize=True, - relativize_to=None): +def from_text_list(name: Union[dns.name.Name, str], ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + text_rdatas: Collection[str], + idna_codec: Optional[dns.name.IDNACodec]=None, + origin: Optional[dns.name.Name]=None, relativize=True, + relativize_to: Optional[dns.name.Name]=None) -> RRset: """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. @@ -172,9 +180,9 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, if isinstance(name, str): name = dns.name.from_text(name, None, idna_codec=idna_codec) - rdclass = dns.rdataclass.RdataClass.make(rdclass) - rdtype = dns.rdatatype.RdataType.make(rdtype) - r = RRset(name, rdclass, rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + r = RRset(name, the_rdclass, the_rdtype) r.update_ttl(ttl) for t in text_rdatas: rd = dns.rdata.from_text(r.rdclass, r.rdtype, t, origin, relativize, @@ -183,17 +191,23 @@ def from_text_list(name, ttl, rdclass, rdtype, text_rdatas, return r -def from_text(name, ttl, rdclass, rdtype, *text_rdatas): +def from_text(name: Union[dns.name.Name, str], ttl: int, + rdclass: Union[dns.rdataclass.RdataClass, str], + rdtype: Union[dns.rdatatype.RdataType, str], + *text_rdatas) -> RRset: """Create an RRset with the specified name, TTL, class, and type and with the specified rdatas in text format. Returns a ``dns.rrset.RRset`` object. """ - return from_text_list(name, ttl, rdclass, rdtype, text_rdatas) + return from_text_list(name, ttl, rdclass, rdtype, + cast(Collection[str], text_rdatas)) -def from_rdata_list(name, ttl, rdatas, idna_codec=None): +def from_rdata_list(name: Union[dns.name.Name, str], ttl:int, + rdatas: Collection[dns.rdata.Rdata], + idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset: """Create an RRset with the specified name and TTL, and with the specified list of rdata objects. @@ -216,14 +230,15 @@ def from_rdata_list(name, ttl, rdatas, idna_codec=None): r = RRset(name, rd.rdclass, rd.rdtype) r.update_ttl(ttl) r.add(rd) + assert r is not None return r -def from_rdata(name, ttl, *rdatas): +def from_rdata(name: Union[dns.name.Name, str], ttl:int, *rdatas) -> RRset: """Create an RRset with the specified name and TTL, and with the specified rdata objects. Returns a ``dns.rrset.RRset`` object. """ - return from_rdata_list(name, ttl, rdatas) + return from_rdata_list(name, ttl, cast(Collection[dns.rdata.Rdata], rdatas)) diff --git a/dns/rrset.pyi b/dns/rrset.pyi deleted file mode 100644 index 0a81a2a0..00000000 --- a/dns/rrset.pyi +++ /dev/null @@ -1,10 +0,0 @@ -from typing import List, Optional -from . import rdataset, rdatatype - -class RRset(rdataset.Rdataset): - def __init__(self, name, rdclass : int , rdtype : int, covers=rdatatype.NONE, - deleting : Optional[int] =None) -> None: - self.name = name - self.deleting = deleting -def from_text(name : str, ttl : int, rdclass : str, rdtype : str, *text_rdatas : str): - ... diff --git a/dns/serial.py b/dns/serial.py index b0474151..138ffbf9 100644 --- a/dns/serial.py +++ b/dns/serial.py @@ -3,7 +3,7 @@ """Serial Number Arthimetic from RFC 1982""" class Serial: - def __init__(self, value, bits=32): + def __init__(self, value:int , bits=32): self.value = value % 2 ** bits self.bits = bits diff --git a/dns/tokenizer.py b/dns/tokenizer.py index cb6a6302..bb94ce94 100644 --- a/dns/tokenizer.py +++ b/dns/tokenizer.py @@ -17,6 +17,8 @@ """Tokenize DNS zone file format""" +from typing import Optional, List, Tuple + import io import sys @@ -48,7 +50,7 @@ class Token: has_escape: Does the token value contain escapes? """ - def __init__(self, ttype, value='', has_escape=False, comment=None): + def __init__(self, ttype: int, value='', has_escape=False, comment: Optional[str]=None): """Initialize a token instance.""" self.ttype = ttype @@ -56,28 +58,28 @@ class Token: self.has_escape = has_escape self.comment = comment - def is_eof(self): + def is_eof(self) -> bool: return self.ttype == EOF - def is_eol(self): + def is_eol(self) -> bool: return self.ttype == EOL - def is_whitespace(self): + def is_whitespace(self) -> bool: return self.ttype == WHITESPACE - def is_identifier(self): + def is_identifier(self) -> bool: return self.ttype == IDENTIFIER - def is_quoted_string(self): + def is_quoted_string(self) -> bool: return self.ttype == QUOTED_STRING - def is_comment(self): + def is_comment(self) -> bool: return self.ttype == COMMENT - def is_delimiter(self): # pragma: no cover (we don't return delimiters yet) + def is_delimiter(self) -> bool: # pragma: no cover (we don't return delimiters yet) return self.ttype == DELIMITER - def is_eol_or_eof(self): + def is_eol_or_eof(self) -> bool: return self.ttype == EOL or self.ttype == EOF def __eq__(self, other): @@ -95,7 +97,7 @@ class Token: def __str__(self): return '%d "%s"' % (self.ttype, self.value) - def unescape(self): + def unescape(self) -> 'Token': if not self.has_escape: return self unescaped = '' @@ -127,7 +129,7 @@ class Token: unescaped += c return Token(self.ttype, unescaped) - def unescape_to_bytes(self): + def unescape_to_bytes(self) -> 'Token': # We used to use unescape() for TXT-like records, but this # caused problems as we'd process DNS escapes into Unicode code # points instead of byte values, and then a to_text() of the @@ -223,7 +225,8 @@ class Tokenizer: encoder/decoder is used. """ - def __init__(self, f=sys.stdin, filename=None, idna_codec=None): + def __init__(self, f=sys.stdin, filename: Optional[str]=None, + idna_codec: Optional[dns.name.IDNACodec]=None): """Initialize a tokenizer instance. f: The file to tokenize. The default is sys.stdin. @@ -253,19 +256,21 @@ class Tokenizer: else: filename = '' self.file = f - self.ungotten_char = None - self.ungotten_token = None + self.ungotten_char: Optional[str] = None + self.ungotten_token: Optional[Token] = None self.multiline = 0 self.quoting = False self.eof = False self.delimiters = _DELIMITERS self.line_number = 1 + assert filename is not None self.filename = filename if idna_codec is None: - idna_codec = dns.name.IDNA_2003 - self.idna_codec = idna_codec + self.idna_codec: dns.name.IDNACodec = dns.name.IDNA_2003 + else: + self.idna_codec = idna_codec - def _get_char(self): + def _get_char(self) -> str: """Read a character from input. """ @@ -283,7 +288,7 @@ class Tokenizer: self.ungotten_char = None return c - def where(self): + def where(self) -> Tuple[str, int]: """Return the current location in the input. Returns a (string, int) tuple. The first item is the filename of @@ -328,7 +333,7 @@ class Tokenizer: return skipped skipped += 1 - def get(self, want_leading=False, want_comment=False): + def get(self, want_leading=False, want_comment=False) -> Token: """Get the next token. want_leading: If True, return a WHITESPACE token if the @@ -345,16 +350,16 @@ class Tokenizer: """ if self.ungotten_token is not None: - token = self.ungotten_token + utoken = self.ungotten_token self.ungotten_token = None - if token.is_whitespace(): + if utoken.is_whitespace(): if want_leading: - return token - elif token.is_comment(): + return utoken + elif utoken.is_comment(): if want_comment: - return token + return utoken else: - return token + return utoken skipped = self.skip_whitespace() if want_leading and skipped > 0: return Token(WHITESPACE, ' ') @@ -438,7 +443,7 @@ class Tokenizer: ttype = EOF return Token(ttype, token, has_escape) - def unget(self, token): + def unget(self, token: Token): """Unget a token. The unget buffer for tokens is only one token large; it is @@ -487,7 +492,7 @@ class Tokenizer: raise dns.exception.SyntaxError('expecting an integer') return int(token.value, base) - def get_uint8(self): + def get_uint8(self) -> int: """Read the next token and interpret it as an 8-bit unsigned integer. @@ -502,7 +507,7 @@ class Tokenizer: '%d is not an unsigned 8-bit integer' % value) return value - def get_uint16(self, base=10): + def get_uint16(self, base=10) -> int: """Read the next token and interpret it as a 16-bit unsigned integer. @@ -521,7 +526,7 @@ class Tokenizer: '%d is not an unsigned 16-bit integer' % value) return value - def get_uint32(self, base=10): + def get_uint32(self, base=10) -> int: """Read the next token and interpret it as a 32-bit unsigned integer. @@ -536,7 +541,7 @@ class Tokenizer: '%d is not an unsigned 32-bit integer' % value) return value - def get_uint48(self, base=10): + def get_uint48(self, base=10) -> int: """Read the next token and interpret it as a 48-bit unsigned integer. @@ -551,7 +556,7 @@ class Tokenizer: '%d is not an unsigned 48-bit integer' % value) return value - def get_string(self, max_length=None): + def get_string(self, max_length=None) -> str: """Read the next token and interpret it as a string. Raises dns.exception.SyntaxError if not a string. @@ -568,7 +573,7 @@ class Tokenizer: raise dns.exception.SyntaxError("string too long") return token.value - def get_identifier(self): + def get_identifier(self) -> str: """Read the next token, which should be an identifier. Raises dns.exception.SyntaxError if not an identifier. @@ -581,7 +586,7 @@ class Tokenizer: raise dns.exception.SyntaxError('expecting an identifier') return token.value - def get_remaining(self, max_tokens=None): + def get_remaining(self, max_tokens=None) -> List[Token]: """Return the remaining tokens on the line, until an EOL or EOF is seen. max_tokens: If not None, stop after this number of tokens. @@ -600,7 +605,7 @@ class Tokenizer: break return tokens - def concatenate_remaining_identifiers(self, allow_empty=False): + def concatenate_remaining_identifiers(self, allow_empty=False) -> str: """Read the remaining tokens on the line, which should be identifiers. Raises dns.exception.SyntaxError if there are no remaining tokens, @@ -625,7 +630,8 @@ class Tokenizer: raise dns.exception.SyntaxError('expecting another identifier') return s - def as_name(self, token, origin=None, relativize=False, relativize_to=None): + def as_name(self, token: Token, origin: Optional[dns.name.Name]=None, + relativize=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: """Try to interpret the token as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -637,7 +643,8 @@ class Tokenizer: name = dns.name.from_text(token.value, origin, self.idna_codec) return name.choose_relativity(relativize_to or origin, relativize) - def get_name(self, origin=None, relativize=False, relativize_to=None): + def get_name(self, origin: Optional[dns.name.Name]=None, relativize=False, + relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: """Read the next token and interpret it as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -648,7 +655,7 @@ class Tokenizer: token = self.get() return self.as_name(token, origin, relativize, relativize_to) - def get_eol_as_token(self): + def get_eol_as_token(self) -> Token: """Read the next token and raise an exception if it isn't EOL or EOF. @@ -662,10 +669,10 @@ class Tokenizer: token.value)) return token - def get_eol(self): + def get_eol(self) -> str: return self.get_eol_as_token().value - def get_ttl(self): + def get_ttl(self) -> int: """Read the next token and interpret it as a DNS TTL. Raises dns.exception.SyntaxError or dns.ttl.BadTTL if not an diff --git a/dns/transaction.py b/dns/transaction.py index d7254924..ccb557ce 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -1,9 +1,12 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Callable, List, Optional, Tuple, Union + import collections import dns.exception import dns.name +import dns.node import dns.rdataclass import dns.rdataset import dns.rdatatype @@ -13,11 +16,11 @@ import dns.ttl class TransactionManager: - def reader(self): + def reader(self) -> 'Transaction': """Begin a read-only transaction.""" raise NotImplementedError # pragma: no cover - def writer(self, replacement=False): + def writer(self, replacement=False) -> 'Transaction': """Begin a writable transaction. *replacement*, a ``bool``. If `True`, the content of the @@ -27,7 +30,7 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover - def origin_information(self): + def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: """Returns a tuple (absolute_origin, relativize, effective_origin) @@ -52,12 +55,12 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover - def get_class(self): + def get_class(self) -> dns.rdataclass.RdataClass: """The class of the transaction manager. """ raise NotImplementedError # pragma: no cover - def from_wire_origin(self): + def from_wire_origin(self) -> Optional[dns.name.Name]: """Origin to use in from_wire() calls. """ (absolute_origin, relativize, _) = self.origin_information() @@ -90,22 +93,33 @@ def _ensure_immutable_node(node): return dns.node.ImmutableNode(node) +CheckPutRdatasetType = Callable[['Transaction', dns.name.Name, dns.rdataset.Rdataset], None] +CheckDeleteRdatasetType = Callable[['Transaction', dns.name.Name, + dns.rdatatype.RdataType, dns.rdatatype.RdataType], None] +CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None] + + class Transaction: - def __init__(self, manager, replacement=False, read_only=False): + def __init__(self, manager: TransactionManager, replacement=False, read_only=False): self.manager = manager self.replacement = replacement self.read_only = read_only self._ended = False - self._check_put_rdataset = [] - self._check_delete_rdataset = [] - self._check_delete_name = [] + self._check_put_rdataset: List[CheckPutRdatasetType]= [] + self._check_delete_rdataset: List[CheckDeleteRdatasetType] = [] + self._check_delete_name: List[CheckDeleteNameType] = [] # # This is the high level API # + # Note that we currently use non-immutable types in the return type signature to avoid + # covariance problems, e.g. if the caller has a List[Rdataset], mypy will be unhappy if we + # return an ImmutableRdataset. - def get(self, name, rdtype, covers=dns.rdatatype.NONE): + def get(self, name: Optional[Union[dns.name.Name,str]], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rdataset.Rdataset: """Return the rdataset associated with *name*, *rdtype*, and *covers*, or `None` if not found. @@ -115,10 +129,11 @@ class Transaction: if isinstance(name, str): name = dns.name.from_text(name, None) rdtype = dns.rdatatype.RdataType.make(rdtype) + covers = dns.rdatatype.RdataType.make(covers) rdataset = self._get_rdataset(name, rdtype, covers) return _ensure_immutable_rdataset(rdataset) - def get_node(self, name): + def get_node(self, name) -> dns.node.Node: """Return the node at *name*, if any. Returns an immutable node or ``None``. @@ -210,7 +225,7 @@ class Transaction: self._check_read_only() return self._delete(True, args) - def name_exists(self, name): + def name_exists(self, name: Union[dns.name.Name, str]) -> bool: """Does the specified name exist?""" self._check_ended() if isinstance(name, str): @@ -253,7 +268,7 @@ class Transaction: self._check_ended() return self._iterate_rdatasets() - def changed(self): + def changed(self) -> bool: """Has this transaction changed anything? For read-only transactions, the result is always `False`. @@ -289,7 +304,7 @@ class Transaction: """ self._end(False) - def check_put_rdataset(self, check): + def check_put_rdataset(self, check: CheckPutRdatasetType): """Call *check* before putting (storing) an rdataset. The function is called with the transaction, the name, and the rdataset. @@ -301,7 +316,7 @@ class Transaction: """ self._check_put_rdataset.append(check) - def check_delete_rdataset(self, check): + def check_delete_rdataset(self, check: CheckDeleteRdatasetType): """Call *check* before deleting an rdataset. The function is called with the transaction, the name, the rdatatype, @@ -314,7 +329,7 @@ class Transaction: """ self._check_delete_rdataset.append(check) - def check_delete_name(self, check): + def check_delete_name(self, check: CheckDeleteNameType): """Call *check* before putting (storing) an rdataset. The function is called with the transaction and the name. diff --git a/dns/tsigkeyring.py b/dns/tsigkeyring.py index 788581c9..06a1bd09 100644 --- a/dns/tsigkeyring.py +++ b/dns/tsigkeyring.py @@ -17,13 +17,15 @@ """A place to store TSIG keys.""" +from typing import Any, Dict, Union + import base64 import dns.name import dns.tsig -def from_text(textring): +def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]: """Convert a dictionary containing (textual DNS name, base64 secret) pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or a dictionary containing (textual DNS name, (algorithm, base64 secret)) @@ -32,16 +34,16 @@ def from_text(textring): keyring = {} for (name, value) in textring.items(): - name = dns.name.from_text(name) + kname = dns.name.from_text(name) if isinstance(value, str): - keyring[name] = dns.tsig.Key(name, value).secret + keyring[kname] = dns.tsig.Key(kname, value).secret else: (algorithm, secret) = value - keyring[name] = dns.tsig.Key(name, secret, algorithm) + keyring[kname] = dns.tsig.Key(kname, secret, algorithm) return keyring -def to_text(keyring): +def to_text(keyring: Dict[dns.name.Name, Any]) -> Dict[str, Any]: """Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs into a text keyring which has (textual DNS name, (textual algorithm, base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes) @@ -52,14 +54,14 @@ def to_text(keyring): def b64encode(secret): return base64.encodebytes(secret).decode().rstrip() for (name, key) in keyring.items(): - name = name.to_text() + tname = name.to_text() if isinstance(key, bytes): - textring[name] = b64encode(key) + textring[tname] = b64encode(key) else: if isinstance(key.secret, bytes): text_secret = b64encode(key.secret) else: text_secret = str(key.secret) - textring[name] = (key.algorithm.to_text(), text_secret) + textring[tname] = (key.algorithm.to_text(), text_secret) return textring diff --git a/dns/tsigkeyring.pyi b/dns/tsigkeyring.pyi deleted file mode 100644 index b5d51e15..00000000 --- a/dns/tsigkeyring.pyi +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Dict -from . import name - -def from_text(textring : Dict[str,str]) -> Dict[name.Name,bytes]: - ... -def to_text(keyring : Dict[name.Name,bytes]) -> Dict[str, str]: - ... diff --git a/dns/ttl.py b/dns/ttl.py index df92b2b6..9f5730e7 100644 --- a/dns/ttl.py +++ b/dns/ttl.py @@ -17,6 +17,8 @@ """DNS TTL conversion.""" +from typing import Union + import dns.exception # Technically TTLs are supposed to be between 0 and 2**31 - 1, with values @@ -31,7 +33,7 @@ class BadTTL(dns.exception.SyntaxError): """DNS TTL value is not well-formed.""" -def from_text(text): +def from_text(text: str) -> int: """Convert the text form of a TTL to an integer. The BIND 8 units syntax for TTLs (e.g. '1w6d4h3m10s') is supported. @@ -81,7 +83,7 @@ def from_text(text): return total -def make(value): +def make(value: Union[int, str]) -> int: if isinstance(value, int): return value elif isinstance(value, str): diff --git a/dns/update.py b/dns/update.py index 9a047553..5df0cc78 100644 --- a/dns/update.py +++ b/dns/update.py @@ -17,6 +17,7 @@ """DNS Dynamic Update Support""" +from typing import Any, Optional, Union import dns.message import dns.name @@ -41,11 +42,14 @@ class UpdateSection(dns.enum.IntEnum): class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] - _section_enum = UpdateSection + # ignore the mypy error here as we mean to use a different enum + _section_enum = UpdateSection # type: ignore - def __init__(self, zone=None, rdclass=dns.rdataclass.IN, keyring=None, - keyname=None, keyalgorithm=dns.tsig.default_algorithm, - id=None): + def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None, + rdclass=dns.rdataclass.IN, + keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None, + keyalgorithm=dns.tsig.default_algorithm, + id: Optional[int]=None): """Initialize a new DNS Update object. See the documentation of the Message class for a complete @@ -152,7 +156,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.origin) self._add_rr(name, ttl, rd, section=section) - def add(self, name, *args): + def add(self, name: Union[dns.name.Name, str], *args): """Add records. The first argument is always a name. The other @@ -167,7 +171,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self._add(False, self.update, name, *args) - def delete(self, name, *args): + def delete(self, name: Union[dns.name.Name, str], *args): """Delete records. The first argument is always a name. The other @@ -187,31 +191,31 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] if len(args) == 0: self.find_rrset(self.update, name, dns.rdataclass.ANY, dns.rdatatype.ANY, dns.rdatatype.NONE, - dns.rdatatype.ANY, True, True) + dns.rdataclass.ANY, True, True) elif isinstance(args[0], dns.rdataset.Rdataset): for rds in args: for rd in rds: self._add_rr(name, 0, rd, dns.rdataclass.NONE) else: - args = list(args) - if isinstance(args[0], dns.rdata.Rdata): - for rd in args: + largs = list(args) + if isinstance(largs[0], dns.rdata.Rdata): + for rd in largs: self._add_rr(name, 0, rd, dns.rdataclass.NONE) else: - rdtype = dns.rdatatype.RdataType.make(args.pop(0)) - if len(args) == 0: + rdtype = dns.rdatatype.RdataType.make(largs.pop(0)) + if len(largs) == 0: self.find_rrset(self.update, name, self.zone_rdclass, rdtype, dns.rdatatype.NONE, dns.rdataclass.ANY, True, True) else: - for s in args: + for s in largs: rd = dns.rdata.from_text(self.zone_rdclass, rdtype, s, self.origin) self._add_rr(name, 0, rd, dns.rdataclass.NONE) - def replace(self, name, *args): + def replace(self, name: Union[dns.name.Name, str], *args): """Replace records. The first argument is always a name. The other @@ -229,7 +233,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self._add(True, self.update, name, *args) - def present(self, name, *args): + def present(self, name: Union[dns.name.Name, str], *args): """Require that an owner name (and optionally an rdata type, or specific rdataset) exists as a prerequisite to the execution of the update. @@ -256,9 +260,11 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] len(args) > 1: if not isinstance(args[0], dns.rdataset.Rdataset): # Add a 0 TTL - args = list(args) - args.insert(0, 0) - self._add(False, self.prerequisite, name, *args) + largs = list(args) + largs.insert(0, 0) + self._add(False, self.prerequisite, name, *largs) + else: + self._add(False, self.prerequisite, name, *args) else: rdtype = dns.rdatatype.RdataType.make(args[0]) self.find_rrset(self.prerequisite, name, @@ -266,7 +272,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] dns.rdatatype.NONE, None, True, True) - def absent(self, name, rdtype=None): + def absent(self, name: Union[dns.name.Name, str], rdtype=None): """Require that an owner name (and optionally an rdata type) does not exist as a prerequisite to the execution of the update.""" diff --git a/dns/update.pyi b/dns/update.pyi deleted file mode 100644 index eeac0591..00000000 --- a/dns/update.pyi +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional,Dict,Union,Any - -from . import message, tsig, rdataclass, name - -class Update(message.Message): - def __init__(self, zone : Union[name.Name, str], rdclass : Union[int,str] = rdataclass.IN, keyring : Optional[Dict[name.Name,bytes]] = None, - keyname : Optional[name.Name] = None, keyalgorithm : Optional[name.Name] = tsig.default_algorithm) -> None: - self.id : int - def add(self, name : Union[str,name.Name], *args : Any): - ... - def delete(self, name, *args : Any): - ... - def replace(self, name : Union[str,name.Name], *args : Any): - ... - def present(self, name : Union[str,name.Name], *args : Any): - ... - def absent(self, name : Union[str,name.Name], rdtype=None): - """Require that an owner name (and optionally an rdata type) does - not exist as a prerequisite to the execution of the update.""" - def to_wire(self, origin : Optional[name.Name] = None, max_size=65535, **kw) -> bytes: - ... diff --git a/dns/versioned.py b/dns/versioned.py index a7e1204b..02316c82 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -2,6 +2,8 @@ """DNS Versioned Zones.""" +from typing import Callable, Deque, Optional, Set, Union + import collections try: import threading as _threading @@ -38,8 +40,8 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] node_factory = Node - def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, - pruning_policy=None): + def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass=dns.rdataclass.IN, relativize=True, + pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None): """Initialize a versioned zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -51,26 +53,26 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. - *pruning policy*, a function taking a `Version` and returning - a `bool`, or `None`. Should the version be pruned? If `None`, + *pruning policy*, a function taking a ``Zone`` and a ``Version`` and returning + a ``bool``, or ``None``. Should the version be pruned? If ``None``, the default policy, which retains one version is used. """ super().__init__(origin, rdclass, relativize) - self._versions = collections.deque() + self._versions: Deque[Version] = collections.deque() self._version_lock = _threading.Lock() if pruning_policy is None: self._pruning_policy = self._default_pruning_policy else: self._pruning_policy = pruning_policy - self._write_txn = None - self._write_event = None - self._write_waiters = collections.deque() - self._readers = set() + self._write_txn: Optional[Transaction] = None + self._write_event: Optional[_threading.Event] = None + self._write_waiters: Deque[_threading.Event] = collections.deque() + self._readers: Set[Transaction] = set() self._commit_version_unlocked(None, WritableVersion(self, replacement=True), origin) - def reader(self, id=None, serial=None): # pylint: disable=arguments-differ + def reader(self, id: Optional[int]=None, serial: Optional[int]=None) -> Transaction: # pylint: disable=arguments-differ if id is not None and serial is not None: raise ValueError('cannot specify both id and serial') with self._version_lock: @@ -86,6 +88,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] if self.relativize: oname = dns.name.empty else: + assert self.origin is not None oname = self.origin version = None for v in reversed(self._versions): @@ -103,7 +106,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] self._readers.add(txn) return txn - def writer(self, replacement=False): + def writer(self, replacement=False) -> Transaction: event = None while True: with self._version_lock: @@ -178,21 +181,21 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] self._pruning_policy(self, self._versions[0]): self._versions.popleft() - def set_max_versions(self, max_versions): + def set_max_versions(self, max_versions: Optional[int]): """Set a pruning policy that retains up to the specified number of versions """ if max_versions is not None and max_versions < 1: raise ValueError('max versions must be at least 1') if max_versions is None: - def policy(*_): + def policy(zone, _): # pylint: disable=unused-argument return False else: def policy(zone, _): return len(zone._versions) > max_versions self.set_pruning_policy(policy) - def set_pruning_policy(self, policy): + def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]): """Set the pruning policy for the zone. The *policy* function takes a `Version` and returns `True` if diff --git a/dns/wire.py b/dns/wire.py index 572e27e7..d3317a59 100644 --- a/dns/wire.py +++ b/dns/wire.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Optional, Tuple + import contextlib import struct @@ -7,7 +9,7 @@ import dns.exception import dns.name class Parser: - def __init__(self, wire, current=0): + def __init__(self, wire: bytes, current=0): self.wire = wire self.current = 0 self.end = len(self.wire) @@ -18,7 +20,8 @@ class Parser: def remaining(self): return self.end - self.current - def get_bytes(self, size): + def get_bytes(self, size=int) -> bytes: + assert size >= 0 if size > self.remaining(): raise dns.exception.FormError output = self.wire[self.current:self.current + size] @@ -26,35 +29,35 @@ class Parser: self.furthest = max(self.furthest, self.current) return output - def get_counted_bytes(self, length_size=1): + def get_counted_bytes(self, length_size=1) -> bytes: length = int.from_bytes(self.get_bytes(length_size), 'big') return self.get_bytes(length) - def get_remaining(self): + def get_remaining(self) -> bytes: return self.get_bytes(self.remaining()) - def get_uint8(self): + def get_uint8(self) -> int: return struct.unpack('!B', self.get_bytes(1))[0] - def get_uint16(self): + def get_uint16(self) -> int: return struct.unpack('!H', self.get_bytes(2))[0] - def get_uint32(self): + def get_uint32(self) -> int: return struct.unpack('!I', self.get_bytes(4))[0] - def get_uint48(self): + def get_uint48(self) -> int: return int.from_bytes(self.get_bytes(6), 'big') - def get_struct(self, format): + def get_struct(self, format: str) -> Tuple: return struct.unpack(format, self.get_bytes(struct.calcsize(format))) - def get_name(self, origin=None): + def get_name(self, origin: Optional['dns.name.Name']=None) -> 'dns.name.Name': name = dns.name.from_wire_parser(self) if origin: name = name.relativize(origin) return name - def seek(self, where): + def seek(self, where: int): # Note that seeking to the end is OK! (If you try to read # after such a seek, you'll get an exception as expected.) if where < 0 or where > self.end: @@ -62,7 +65,8 @@ class Parser: self.current = where @contextlib.contextmanager - def restrict_to(self, size): + def restrict_to(self, size: int): + assert size >= 0 if size > self.remaining(): raise dns.exception.FormError saved_end = self.end diff --git a/dns/xfr.py b/dns/xfr.py index 2ef1b0a7..618eac2f 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -15,12 +15,17 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Any, List, Optional, Tuple + import dns.exception import dns.message import dns.name import dns.rcode import dns.serial +import dns.rdataset import dns.rdatatype +import dns.transaction +import dns.tsig import dns.zone @@ -46,8 +51,8 @@ class Inbound: State machine for zone transfers. """ - def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR, - serial=None, is_udp=False): + def __init__(self, txn_manager: dns.transaction.TransactionManager, rdtype=dns.rdatatype.AXFR, + serial: Optional[int]=None, is_udp=False): """Initialize an inbound zone transfer. *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. @@ -61,7 +66,7 @@ class Inbound: XFR. """ self.txn_manager = txn_manager - self.txn = None + self.txn: Optional[dns.transaction.Transaction] = None self.rdtype = rdtype if rdtype == dns.rdatatype.IXFR: if serial is None: @@ -71,12 +76,12 @@ class Inbound: self.serial = serial self.is_udp = is_udp (_, _, self.origin) = txn_manager.origin_information() - self.soa_rdataset = None + self.soa_rdataset: Optional[dns.rdataset.Rdataset] = None self.done = False self.expecting_SOA = False self.delete_mode = False - def process_message(self, message): + def process_message(self, message: dns.message.Message) -> bool: """Process one message in the transfer. The message should have the same relativization as was specified when @@ -146,6 +151,7 @@ class Inbound: rdataset = rrset if self.done: raise dns.exception.FormError("answers after final SOA") + assert self.txn is not None # for mypy if rdataset.rdtype == dns.rdatatype.SOA and \ name == self.origin: # @@ -238,11 +244,11 @@ class Inbound: return False -def make_query(txn_manager, serial=0, - use_edns=None, ednsflags=None, payload=None, - request_payload=None, options=None, - keyring=None, keyname=None, - keyalgorithm=dns.tsig.default_algorithm): +def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0, + use_edns=None, ednsflags: Optional[int]=None, payload: Optional[int]=None, + request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, + keyring: Any=None, keyname: Optional[dns.name.Name]=None, + keyalgorithm=dns.tsig.default_algorithm) -> Tuple[dns.message.QueryMessage, Optional[int]]: """Make an AXFR or IXFR query. *txn_manager* is a ``dns.transaction.TransactionManager``, typically a @@ -263,6 +269,8 @@ def make_query(txn_manager, serial=0, Returns a `(query, serial)` tuple. """ (zone_origin, _, origin) = txn_manager.origin_information() + if zone_origin is None: + raise ValueError('no zone origin') if serial is None: rdtype = dns.rdatatype.AXFR elif not isinstance(serial, int): @@ -293,15 +301,17 @@ def make_query(txn_manager, serial=0, q.use_tsig(keyring, keyname, algorithm=keyalgorithm) return (q, serial) -def extract_serial_from_query(query): +def extract_serial_from_query(query: dns.message.Message) -> Optional[int]: """Extract the SOA serial number from query if it is an IXFR and return it, otherwise return None. *query* is a dns.message.QueryMessage that is an IXFR or AXFR request. Raises if the query is not an IXFR or AXFR, or if an IXFR doesn't have - an appropriate SOA RRset in the authority section.""" - + an appropriate SOA RRset in the authority section. + """ + if not isinstance(query, dns.message.QueryMessage): + raise ValueError('query not a QueryMessage') question = query.question[0] if question.rdtype == dns.rdatatype.AXFR: return None diff --git a/dns/zone.py b/dns/zone.py index 6a154ced..a9a40077 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -17,6 +17,8 @@ """DNS Zones.""" +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union + import contextlib import hashlib import io @@ -30,6 +32,7 @@ import dns.node import dns.rdataclass import dns.rdatatype import dns.rdata +import dns.rdataset import dns.rdtypes.ANY.SOA import dns.rdtypes.ANY.ZONEMD import dns.rrset @@ -38,6 +41,7 @@ import dns.transaction import dns.ttl import dns.grange import dns.zonefile +from dns.zonetypes import DigestScheme, DigestHashAlgorithm, _digest_hashers class BadZone(dns.exception.DNSException): @@ -80,33 +84,6 @@ class DigestVerificationFailure(dns.exception.DNSException): """The ZONEMD digest failed to verify.""" -class DigestScheme(dns.enum.IntEnum): - """ZONEMD Scheme""" - - SIMPLE = 1 - - @classmethod - def _maximum(cls): - return 255 - - -class DigestHashAlgorithm(dns.enum.IntEnum): - """ZONEMD Hash Algorithm""" - - SHA384 = 1 - SHA512 = 2 - - @classmethod - def _maximum(cls): - return 255 - - -_digest_hashers = { - DigestHashAlgorithm.SHA384: hashlib.sha384, - DigestHashAlgorithm.SHA512: hashlib.sha512, -} - - class Zone(dns.transaction.TransactionManager): """A DNS zone. @@ -123,7 +100,8 @@ class Zone(dns.transaction.TransactionManager): __slots__ = ['rdclass', 'origin', 'nodes', 'relativize'] - def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True): + def __init__(self, origin: Optional[Union[dns.name.Name, str]], + rdclass=dns.rdataclass.IN, relativize=True): """Initialize a zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -146,7 +124,7 @@ class Zone(dns.transaction.TransactionManager): raise ValueError("origin parameter must be an absolute name") self.origin = origin self.rdclass = rdclass - self.nodes = {} + self.nodes: Dict[dns.name.Name, dns.node.Node] = {} self.relativize = relativize def __eq__(self, other): @@ -172,17 +150,27 @@ class Zone(dns.transaction.TransactionManager): return not self.__eq__(other) - def _validate_name(self, name): + def _validate_name(self, name: Union[dns.name.Name, str]) -> dns.name.Name: if isinstance(name, str): name = dns.name.from_text(name, None) elif not isinstance(name, dns.name.Name): raise KeyError("name parameter must be convertible to a DNS name") if name.is_absolute(): + if self.origin is None: + # This should probably never happen as other code (e.g. + # _rr_line) will notice the lack of an origin before us, but + # we check just in case! + raise KeyError('no zone origin is defined') if not name.is_subdomain(self.origin): raise KeyError( "name parameter must be a subdomain of the zone origin") if self.relativize: name = name.relativize(self.origin) + elif not self.relativize: + # We have a relative name in a non-relative zone, so derelativize. + if self.origin is None: + raise KeyError('no zone origin is defined') + name = name.derelativize(self.origin) return name def __getitem__(self, key): @@ -217,7 +205,7 @@ class Zone(dns.transaction.TransactionManager): key = self._validate_name(key) return key in self.nodes - def find_node(self, name, create=False): + def find_node(self, name: Union[dns.name.Name, str], create=False): """Find a node in the zone, possibly creating it. *name*: the name of the node to find. @@ -243,7 +231,7 @@ class Zone(dns.transaction.TransactionManager): self.nodes[name] = node return node - def get_node(self, name, create=False): + def get_node(self, name: Union[dns.name.Name, str], create=False): """Get a node in the zone, possibly creating it. This method is like ``find_node()``, except it returns None instead @@ -270,7 +258,7 @@ class Zone(dns.transaction.TransactionManager): node = None return node - def delete_node(self, name): + def delete_node(self, name: Union[dns.name.Name, str]): """Delete the specified node if it exists. *name*: the name of the node to find. @@ -285,8 +273,10 @@ class Zone(dns.transaction.TransactionManager): if name in self.nodes: del self.nodes[name] - def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, + create=False) -> dns.rdataset.Rdataset: """Look for an rdataset with the specified name and type in the zone, and return an rdataset encapsulating it. @@ -300,9 +290,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -323,13 +313,12 @@ class Zone(dns.transaction.TransactionManager): name = self._validate_name(name) rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) node = self.find_node(name, create) return node.find_rdataset(self.rdclass, rdtype, covers, create) def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + create=False) -> Optional[dns.rdataset.Rdataset]: """Look for an rdataset with the specified name and type in the zone. This method is like ``find_rdataset()``, except it returns None instead @@ -344,9 +333,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -371,7 +360,9 @@ class Zone(dns.transaction.TransactionManager): rdataset = None return rdataset - def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE): """Delete the rdataset matching *rdtype* and *covers*, if it exists at the node specified by *name*. @@ -386,9 +377,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str`` or ``None``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -401,15 +392,15 @@ class Zone(dns.transaction.TransactionManager): name = self._validate_name(name) rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) node = self.get_node(name) if node is not None: node.delete_rdataset(self.rdclass, rdtype, covers) if len(node) == 0: self.delete_node(name) - def replace_rdataset(self, name, replacement): + def replace_rdataset(self, name: Union[dns.name.Name, str], + replacement: dns.rdataset.Rdataset): """Replace an rdataset at name. It is not an error if there is no rdataset matching I{replacement}. @@ -433,7 +424,9 @@ class Zone(dns.transaction.TransactionManager): node = self.find_node(name, True) node.replace_rdataset(replacement) - def find_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): + def find_rrset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> dns.rrset.RRset: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -451,9 +444,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdatatype.RdataType`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdatatype.RdataType`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -472,16 +465,17 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rrset.RRset`` or ``None``. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) - rdataset = self.nodes[name].find_rdataset(self.rdclass, rdtype, covers) - rrset = dns.rrset.RRset(name, self.rdclass, rdtype, covers) + vname = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + rdataset = self.nodes[vname].find_rdataset(self.rdclass, the_rdtype, the_covers) + rrset = dns.rrset.RRset(vname, self.rdclass, the_rdtype, the_covers) rrset.update(rdataset) return rrset - def get_rrset(self, name, rdtype, covers=dns.rdatatype.NONE): + def get_rrset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> Optional[dns.rrset.RRset]: """Look for an rdataset with the specified name and type in the zone, and return an RRset encapsulating it. @@ -498,9 +492,9 @@ class Zone(dns.transaction.TransactionManager): name must be a subdomain of the zone's origin. If ``zone.relativize`` is ``True``, then the name will be relativized. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -526,15 +520,15 @@ class Zone(dns.transaction.TransactionManager): return rrset def iterate_rdatasets(self, rdtype=dns.rdatatype.ANY, - covers=dns.rdatatype.NONE): + covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: """Return a generator which yields (name, rdataset) tuples for all rdatasets in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatasets will be matched. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -546,8 +540,7 @@ class Zone(dns.transaction.TransactionManager): """ rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or \ @@ -555,15 +548,15 @@ class Zone(dns.transaction.TransactionManager): yield (name, rds) def iterate_rdatas(self, rdtype=dns.rdatatype.ANY, - covers=dns.rdatatype.NONE): + covers=dns.rdatatype.NONE) -> Iterator[Tuple[dns.name.Name, int, dns.rdata.Rdata]]: """Return a generator which yields (name, ttl, rdata) tuples for all rdatas in the zone which have the specified *rdtype* and *covers*. If *rdtype* is ``dns.rdatatype.ANY``, the default, then all rdatas will be matched. - *rdtype*, an ``int`` or ``str``, the rdata type desired. + *rdtype*, a ``dns.rdataset.Rdataset`` or ``str``, the rdata type desired. - *covers*, an ``int`` or ``str`` or ``None``, the covered type. + *covers*, a ``dns.rdataset.Rdataset`` or ``str``, the covered type. Usually this value is ``dns.rdatatype.NONE``, but if the rdtype is ``dns.rdatatype.SIG`` or ``dns.rdatatype.RRSIG``, then the covers value will be the rdata type the SIG/RRSIG @@ -575,8 +568,7 @@ class Zone(dns.transaction.TransactionManager): """ rdtype = dns.rdatatype.RdataType.make(rdtype) - if covers is not None: - covers = dns.rdatatype.RdataType.make(covers) + covers = dns.rdatatype.RdataType.make(covers) for (name, node) in self.items(): for rds in node: if rdtype == dns.rdatatype.ANY or \ @@ -584,7 +576,7 @@ class Zone(dns.transaction.TransactionManager): for rdata in rds: yield (name, rds.ttl, rdata) - def to_file(self, f, sorted=True, relativize=True, nl=None, + def to_file(self, f: Any, sorted=True, relativize=True, nl: Optional[str]=None, want_comments=False, want_origin=False): """Write a zone to a file. @@ -634,6 +626,7 @@ class Zone(dns.transaction.TransactionManager): nl = nl.decode() if want_origin: + assert self.origin is not None l = '$ORIGIN ' + self.origin.to_text() l_b = l.encode(file_enc) try: @@ -661,7 +654,7 @@ class Zone(dns.transaction.TransactionManager): f.write(l) f.write(nl) - def to_text(self, sorted=True, relativize=True, nl=None, + def to_text(self, sorted=True, relativize=True, nl: Optional[str]=None, want_comments=False, want_origin=False): """Return a zone's text as though it were written to a file. @@ -713,7 +706,7 @@ class Zone(dns.transaction.TransactionManager): if self.get_rdataset(name, dns.rdatatype.NS) is None: raise NoNS - def get_soa(self, txn=None): + def get_soa(self, txn: Optional[dns.transaction.Transaction]=None): """Get the zone SOA RR. Raises ``dns.zone.NoSOA`` if there is no SOA RRset. @@ -723,7 +716,12 @@ class Zone(dns.transaction.TransactionManager): if self.relativize: origin_name = dns.name.empty else: + if self.origin is None: + # get_soa() has been called very early, and there must not be + # an SOA if there is no origin. + raise NoSOA origin_name = self.origin + soa: Optional[dns.rdataset.Rdataset] if txn: soa = txn.get(origin_name, dns.rdatatype.SOA) else: @@ -732,7 +730,7 @@ class Zone(dns.transaction.TransactionManager): raise NoSOA return soa[0] - def _compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): + def _compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> bytes: hashinfo = _digest_hashers.get(hash_algorithm) if not hashinfo: raise UnsupportedDigestHashAlgorithm @@ -742,6 +740,7 @@ class Zone(dns.transaction.TransactionManager): if self.relativize: origin_name = dns.name.empty else: + assert self.origin is not None origin_name = self.origin hasher = hashinfo() for (name, node) in sorted(self.items()): @@ -760,11 +759,7 @@ class Zone(dns.transaction.TransactionManager): hasher.update(rrnamebuf + rrfixed + rrlen + rdata) return hasher.digest() - def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): - if self.relativize: - origin_name = dns.name.empty - else: - origin_name = self.origin + def compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: serial = self.get_soa().serial digest = self._compute_digest(hash_algorithm, scheme) return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass, @@ -772,13 +767,15 @@ class Zone(dns.transaction.TransactionManager): serial, scheme, hash_algorithm, digest) - def verify_digest(self, zonemd=None): + def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None): + digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]] if zonemd: digests = [zonemd] else: - digests = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) - if digests is None: + rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) + if rds is None: raise NoDigest + digests = rds for digest in digests: try: computed = self._compute_digest(digest.hash_algorithm, @@ -791,16 +788,17 @@ class Zone(dns.transaction.TransactionManager): # TransactionManager methods - def reader(self): + def reader(self) -> 'Transaction': return Transaction(self, False, Version(self, 1, self.nodes, self.origin)) - def writer(self, replacement=False): + def writer(self, replacement=False) -> 'Transaction': txn = Transaction(self, replacement) txn._setup_version() return txn - def origin_information(self): + def origin_information(self) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: + effective: Optional[dns.name.Name] if self.relativize: effective = dns.name.empty else: @@ -878,7 +876,9 @@ class ImmutableVersionedNode(VersionedNode): class Version: - def __init__(self, zone, id, nodes=None, origin=None): + def __init__(self, zone: Zone, id: int, + nodes: Optional[Dict[dns.name.Name, dns.node.Node]]=None, + origin: Optional[dns.name.Name]=None): self.zone = zone self.id = id if nodes is not None: @@ -887,7 +887,7 @@ class Version: self.nodes = {} self.origin = origin - def _validate_name(self, name): + def _validate_name(self, name: dns.name.Name): if name.is_absolute(): if self.origin is None: # This should probably never happen as other code (e.g. @@ -898,13 +898,19 @@ class Version: raise KeyError("name is not a subdomain of the zone origin") if self.zone.relativize: name = name.relativize(self.origin) + elif not self.zone.relativize: + # We have a relative name in a non-relative zone, so derelativize. + if self.origin is None: + raise KeyError('no zone origin is defined') + name = name.derelativize(self.origin) return name - def get_node(self, name): + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: name = self._validate_name(name) return self.nodes.get(name) - def get_rdataset(self, name, rdtype, covers): + def get_rdataset(self, name: dns.name.Name, rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType) -> Optional[dns.rdataset.Rdataset]: node = self.get_node(name) if node is None: return None @@ -915,7 +921,7 @@ class Version: class WritableVersion(Version): - def __init__(self, zone, replacement=False): + def __init__(self, zone: Zone, replacement=False): # The zone._versions_lock must be held by our caller in a versioned # zone. id = zone._get_next_version_id() @@ -929,9 +935,9 @@ class WritableVersion(Version): # We have to copy the zone origin as it may be None in the first # version, and we don't want to mutate the zone until we commit. self.origin = zone.origin - self.changed = set() + self.changed: Set[dns.name.Name] = set() - def _maybe_cow(self, name): + def _maybe_cow(self, name: dns.name.Name): name = self._validate_name(name) node = self.nodes.get(name) if node is None or name not in self.changed: @@ -941,7 +947,9 @@ class WritableVersion(Version): # code used new_node.id != self.id for the "do we need to CoW?" # test. Now we use the changed set as this works with both # regular zones and versioned zones. - new_node.id = self.id + # + # We ignore the mypy error as this is safe but it doesn't see it. + new_node.id = self.id # type: ignore if node is not None: # moo! copy on write! new_node.rdatasets.extend(node.rdatasets) @@ -951,17 +959,18 @@ class WritableVersion(Version): else: return node - def delete_node(self, name): + def delete_node(self, name: dns.name.Name): name = self._validate_name(name) if name in self.nodes: del self.nodes[name] self.changed.add(name) - def put_rdataset(self, name, rdataset): + def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset): node = self._maybe_cow(name) node.replace_rdataset(rdataset) - def delete_rdataset(self, name, rdtype, covers): + def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType): node = self._maybe_cow(name) node.delete_rdataset(self.zone.rdclass, rdtype, covers) if len(node) == 0: @@ -970,7 +979,7 @@ class WritableVersion(Version): @dns.immutable.immutable class ImmutableVersion(Version): - def __init__(self, version): + def __init__(self, version: WritableVersion): # We tell super() that it's a replacement as we don't want it # to copy the nodes, as we're about to do that with an # immutable Dict. @@ -985,7 +994,9 @@ class ImmutableVersion(Version): # it might not exist if we deleted it in the version if node: version.nodes[name] = ImmutableVersionedNode(node) - self.nodes = dns.immutable.Dict(version.nodes, True) + # We're changing the type of the nodes dictionary here on purpose, so + # we ignore the mypy error. + self.nodes = dns.immutable.Dict(version.nodes, True) # type: ignore class Transaction(dns.transaction.Transaction): @@ -1066,9 +1077,11 @@ class Transaction(dns.transaction.Transaction): return (absolute, relativize, effective) -def from_text(text, origin=None, rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename=None, - allow_include=False, check_origin=True, idna_codec=None): +def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, + rdclass=dns.rdataclass.IN, + relativize=True, zone_factory=Zone, filename: Optional[str]=None, + allow_include=False, check_origin=True, + idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone: """Build a zone object from a zone file format string. *text*, a ``str``, the zone file format input. @@ -1077,7 +1090,7 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, of the zone; if not specified, the first ``$ORIGIN`` statement in the zone file will determine the origin of the zone. - *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + *rdclass*, a ``dns.rdataclass.RdataClass``, the zone's rdata class; the default is class IN. *relativize*, a ``bool``, determine's whether domain names are relativized to the zone's origin. The default is ``True``. @@ -1132,9 +1145,10 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, return zone -def from_file(f, origin=None, rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename=None, - allow_include=True, check_origin=True): +def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None, + rdclass=dns.rdataclass.IN, + relativize=True, zone_factory=Zone, filename: Optional[str]=None, + allow_include=True, check_origin=True) -> Zone: """Read a zone file and build a zone object. *f*, a file or ``str``. If *f* is a string, it is treated @@ -1184,6 +1198,7 @@ def from_file(f, origin=None, rdclass=dns.rdataclass.IN, f = stack.enter_context(open(f)) return from_text(f, origin, rdclass, relativize, zone_factory, filename, allow_include, check_origin) + assert False # make mypy happy def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): diff --git a/dns/zone.pyi b/dns/zone.pyi deleted file mode 100644 index 272814fe..00000000 --- a/dns/zone.pyi +++ /dev/null @@ -1,55 +0,0 @@ -from typing import Generator, Optional, Union, Tuple, Iterable, Callable, Any, Iterator, TextIO, BinaryIO, Dict -from . import rdata, zone, rdataclass, name, rdataclass, message, rdatatype, exception, node, rdataset, rrset, rdatatype - -class BadZone(exception.DNSException): ... -class NoSOA(BadZone): ... -class NoNS(BadZone): ... -class UnknownOrigin(BadZone): ... - -class Zone: - def __getitem__(self, key : str) -> node.Node: - ... - def __init__(self, origin : Union[str,name.Name], rdclass : int = rdataclass.IN, relativize : bool = True) -> None: - self.nodes : Dict[str,node.Node] - self.origin = origin - def values(self): - return self.nodes.values() - def iterate_rdatas(self, rdtype : Union[int,str] = rdatatype.ANY, covers : Union[int,str] = None) -> Iterable[Tuple[name.Name, int, rdata.Rdata]]: - ... - def __iter__(self) -> Iterator[str]: - ... - def get_node(self, name : Union[name.Name,str], create=False) -> Optional[node.Node]: - ... - def find_rrset(self, name : Union[str,name.Name], rdtype : Union[int,str], covers=rdatatype.NONE) -> rrset.RRset: - ... - def find_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, - create=False) -> rdataset.Rdataset: - ... - def get_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE, create=False) -> Optional[rdataset.Rdataset]: - ... - def get_rrset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> Optional[rrset.RRset]: - ... - def replace_rdataset(self, name : Union[str,name.Name], replacement : rdataset.Rdataset) -> None: - ... - def delete_rdataset(self, name : Union[str,name.Name], rdtype : Union[str,int], covers=rdatatype.NONE) -> None: - ... - def iterate_rdatasets(self, rdtype : Union[str,int] =rdatatype.ANY, - covers : Union[str,int] =rdatatype.NONE): - ... - def to_file(self, f : Union[TextIO, BinaryIO, str], sorted=True, relativize=True, nl : Optional[bytes] = None): - ... - def to_text(self, sorted=True, relativize=True, nl : Optional[str] = None) -> str: - ... - -def from_xfr(xfr : Generator[Any,Any,message.Message], zone_factory : Callable[..., zone.Zone] = zone.Zone, relativize=True, check_origin=True): - ... - -def from_text(text : str, origin : Optional[Union[str,name.Name]] = None, rdclass : int = rdataclass.IN, - relativize=True, zone_factory : Callable[...,zone.Zone] = zone.Zone, filename : Optional[str] = None, - allow_include=False, check_origin=True) -> zone.Zone: - ... - -def from_file(f, origin : Optional[Union[str,name.Name]] = None, rdclass=rdataclass.IN, - relativize=True, zone_factory : Callable[..., zone.Zone] = Zone, filename : Optional[str] = None, - allow_include=True, check_origin=True) -> zone.Zone: - ... diff --git a/dns/zonefile.py b/dns/zonefile.py index 53b40880..605131dc 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -17,6 +17,8 @@ """DNS Zones.""" +from typing import Any, List, Optional, Tuple, Union + import re import sys @@ -61,14 +63,27 @@ def _check_cname_and_other_data(txn, name, rdataset): # adding the rdataset is ok +SavedStateType = Tuple[dns.tokenizer.Tokenizer, + Optional[dns.name.Name], # current_origin + Optional[dns.name.Name], # last_name + Optional[str], # current_file + int, # last_ttl + bool, # last_ttl_known + int, # default_ttl + bool] # default_ttl_known + + class Reader: """Read a DNS zone file into a transaction.""" - def __init__(self, tok, rdclass, txn, allow_include=False, - allow_directives=True, force_name=None, - force_ttl=None, force_rdclass=None, force_rdtype=None, - default_ttl=None): + def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass, + txn: dns.transaction.Transaction, allow_include=False, + allow_directives=True, force_name: Optional[dns.name.Name]=None, + force_ttl: Optional[int]=None, + force_rdclass: Optional[dns.rdataclass.RdataClass]=None, + force_rdtype: Optional[dns.rdatatype.RdataType]=None, + default_ttl: Optional[int]=None): self.tok = tok (self.zone_origin, self.relativize, _) = \ txn.manager.origin_information() @@ -86,7 +101,7 @@ class Reader: self.last_name = self.current_origin self.zone_rdclass = rdclass self.txn = txn - self.saved_state = [] + self.saved_state: List[SavedStateType] = [] self.current_file = None self.allow_include = allow_include self.allow_directives = allow_directives @@ -548,10 +563,16 @@ class RRSetsReaderManager(dns.transaction.TransactionManager): self.rrsets = rrsets -def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, - default_rdclass=dns.rdataclass.IN, - rdtype=None, default_ttl=None, idna_codec=None, - origin=dns.name.root, relativize=False): +def read_rrsets(text: Any, + name: Optional[Union[dns.name.Name, str]]=None, + ttl: Optional[int]=None, + rdclass: Optional[Union[dns.rdataclass.RdataClass, str]]=dns.rdataclass.IN, + default_rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + rdtype: Optional[Union[dns.rdatatype.RdataType, str]]=None, + default_ttl: Optional[Union[int, str]]=None, + idna_codec: Optional[dns.name.IDNACodec]=None, + origin: Optional[Union[dns.name.Name, str]]=dns.name.root, + relativize=False) -> List[dns.rrset.RRset]: """Read one or more rrsets from the specified text, possibly subject to restrictions. @@ -610,15 +631,19 @@ def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN, if isinstance(default_ttl, str): default_ttl = dns.ttl.from_text(default_ttl) if rdclass is not None: - rdclass = dns.rdataclass.RdataClass.make(rdclass) - default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + else: + the_rdclass = None + the_default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass) if rdtype is not None: - rdtype = dns.rdatatype.RdataType.make(rdtype) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + else: + the_rdtype = None manager = RRSetsReaderManager(origin, relativize, default_rdclass) with manager.writer(True) as txn: tok = dns.tokenizer.Tokenizer(text, '', idna_codec=idna_codec) - reader = Reader(tok, default_rdclass, txn, allow_directives=False, - force_name=name, force_ttl=ttl, force_rdclass=rdclass, - force_rdtype=rdtype, default_ttl=default_ttl) + reader = Reader(tok, the_default_rdclass, txn, allow_directives=False, + force_name=name, force_ttl=ttl, force_rdclass=the_rdclass, + force_rdtype=the_rdtype, default_ttl=default_ttl) reader.read() return manager.rrsets diff --git a/dns/zonetypes.py b/dns/zonetypes.py new file mode 100644 index 00000000..195ee2ec --- /dev/null +++ b/dns/zonetypes.py @@ -0,0 +1,37 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""Common zone-related types.""" + +# This is a separate file to avoid import circularity between dns.zone and +# the implementation of the ZONEMD type. + +import hashlib + +import dns.enum + + +class DigestScheme(dns.enum.IntEnum): + """ZONEMD Scheme""" + + SIMPLE = 1 + + @classmethod + def _maximum(cls): + return 255 + + +class DigestHashAlgorithm(dns.enum.IntEnum): + """ZONEMD Hash Algorithm""" + + SHA384 = 1 + SHA512 = 2 + + @classmethod + def _maximum(cls): + return 255 + + +_digest_hashers = { + DigestHashAlgorithm.SHA384: hashlib.sha384, + DigestHashAlgorithm.SHA512: hashlib.sha512, +} diff --git a/doc/manual.rst b/doc/manual.rst index 19107cb5..348deefb 100644 --- a/doc/manual.rst +++ b/doc/manual.rst @@ -16,6 +16,5 @@ Dnspython Manual async exceptions utilities - typing threads examples diff --git a/doc/typing.rst b/doc/typing.rst deleted file mode 100644 index 1325f10e..00000000 --- a/doc/typing.rst +++ /dev/null @@ -1,10 +0,0 @@ -.. _typing: - -A Note on Typing ----------------- - -Dnspython has partial support for type annotations in separate .pyi -files. Type information will not be integrated into the main files -until major LTS versions of various Linux distributions containing 3.6 -are beyond their support times. Improvements to the .pyi files are -welcome during this time. diff --git a/mypy.ini b/mypy.ini index a0ba7e30..de66885a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,3 +2,9 @@ [mypy-requests_toolbelt.*] ignore_missing_imports = True + +[mypy-curio] +ignore_missing_imports = True + +[mypy-trio] +ignore_missing_imports = True diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py index d4d76275..e8189925 100644 --- a/tests/test_dnssec.py +++ b/tests/test_dnssec.py @@ -15,6 +15,8 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +from typing import Any + import unittest import dns.dnssec @@ -22,6 +24,8 @@ import dns.name import dns.rdata import dns.rdataclass import dns.rdatatype +import dns.rdtypes.ANY.CDS +import dns.rdtypes.ANY.DS import dns.rrset # pylint: disable=line-too-long @@ -472,6 +476,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase): self.assertEqual(good_ds, good_ds_mnemonic) def testMakeExampleSHA1DS(self): # type: () -> None + algorithm: Any for algorithm in ('SHA1', 'sha1', dns.dnssec.DSDigest.SHA1): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha1) @@ -479,11 +484,13 @@ class DNSSECMakeDSTestCase(unittest.TestCase): self.assertEqual(ds, example_ds_sha1) def testMakeExampleSHA256DS(self): # type: () -> None + algorithm: Any for algorithm in ('SHA256', 'sha256', dns.dnssec.DSDigest.SHA256): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha256) def testMakeExampleSHA384DS(self): # type: () -> None + algorithm: Any for algorithm in ('SHA384', 'sha384', dns.dnssec.DSDigest.SHA384): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) self.assertEqual(ds, example_ds_sha384) @@ -493,6 +500,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase): self.assertEqual(ds, good_ds) def testInvalidAlgorithm(self): # type: () -> None + algorithm: Any for algorithm in (10, 'shax'): with self.assertRaises(dns.dnssec.UnsupportedAlgorithm): ds = dns.dnssec.make_ds(abs_example, example_sep_key, algorithm) @@ -508,6 +516,7 @@ class DNSSECMakeDSTestCase(unittest.TestCase): for rdtype in digest_types: rd = dns.rdata.from_text(dns.rdataclass.IN, rdtype, f'18673 3 5 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7') + assert isinstance(rd, dns.rdtypes.ANY.DS.DS) or isinstance(rd, dns.rdtypes.ANY.CDS.CDS) self.assertEqual(rd.digest_type, 5) self.assertEqual(rd.digest, bytes.fromhex('71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7')) diff --git a/tests/test_name.py b/tests/test_name.py index f91d7e65..45f83793 100644 --- a/tests/test_name.py +++ b/tests/test_name.py @@ -89,7 +89,7 @@ class NameTestCase(unittest.TestCase): try: dns.name.from_text(t) except Exception: - self.fail("good test '%s' raised an exception" % t) + self.fail("good test '%r' raised an exception" % t) for t in bad: caught = False try: @@ -97,7 +97,7 @@ class NameTestCase(unittest.TestCase): except Exception: caught = True if not caught: - self.fail("bad test '%s' did not raise an exception" % t) + self.fail("bad test '%r' did not raise an exception" % t) def testImmutable1(self): def bad(): @@ -106,7 +106,7 @@ class NameTestCase(unittest.TestCase): def testImmutable2(self): def bad(): - self.origin.labels[0] = 'foo' + self.origin.labels[0] = 'foo' # type: ignore self.assertRaises(TypeError, bad) def testAbs1(self): @@ -879,7 +879,7 @@ class NameTestCase(unittest.TestCase): def testReverseIPv6(self): e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.') - n = dns.reversename.from_address(b'::1') + n = dns.reversename.from_address('::1') self.assertEqual(e, n) def testReverseIPv6MappedIpv4(self): @@ -906,7 +906,7 @@ class NameTestCase(unittest.TestCase): def testReverseIPv6AlternateOrigin(self): e = dns.name.from_text('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.foo.bar.') origin = dns.name.from_text('foo.bar') - n = dns.reversename.from_address(b'::1', v6_origin=origin) + n = dns.reversename.from_address('::1', v6_origin=origin) self.assertEqual(e, n) def testForwardIPv4(self): @@ -980,12 +980,12 @@ class NameTestCase(unittest.TestCase): def testFromUnicodeNotString(self): def bad(): - dns.name.from_unicode(b'123') + dns.name.from_unicode(b'123') # type: ignore self.assertRaises(ValueError, bad) def testFromUnicodeBadOrigin(self): def bad(): - dns.name.from_unicode('example', 123) + dns.name.from_unicode('example', 123) # type: ignore self.assertRaises(ValueError, bad) def testFromUnicodeEmptyLabel(self): @@ -998,17 +998,17 @@ class NameTestCase(unittest.TestCase): def testFromTextNotString(self): def bad(): - dns.name.from_text(123) + dns.name.from_text(123) # type: ignore self.assertRaises(ValueError, bad) def testFromTextBadOrigin(self): def bad(): - dns.name.from_text('example', 123) + dns.name.from_text('example', 123) # type: ignore self.assertRaises(ValueError, bad) def testFromWireNotBytes(self): def bad(): - dns.name.from_wire(123, 0) + dns.name.from_wire(123, 0) # type: ignore self.assertRaises(ValueError, bad) def testBadPunycode(self): @@ -1035,7 +1035,7 @@ class NameTestCase(unittest.TestCase): c.encode('Königsgäßchen') with self.assertRaises(dns.name.NoIDNA2008): c = dns.name.IDNA2008Codec(strict_decode=True) - c.decode('xn--eckwd4c7c.xn--zckzah.') + c.decode(b'xn--eckwd4c7c.xn--zckzah.') dns.name.have_idna_2008 = True @unittest.skipUnless(dns.name.have_idna_2008, diff --git a/tests/test_processing_order.py b/tests/test_processing_order.py index 4be695a0..76754dde 100644 --- a/tests/test_processing_order.py +++ b/tests/test_processing_order.py @@ -1,6 +1,7 @@ import dns.rdata import dns.rdataset +import dns.rdtypes.IN.SRV def test_processing_order_shuffle(): @@ -42,6 +43,7 @@ def test_processing_order_priority_weighted(): for j in range(3): assert rds[j] in po assert rds[0] == po[0] + assert isinstance(po[1], dns.rdtypes.IN.SRV.SRV) if po[1].weight == 90: weight_90_count += 1 else: diff --git a/tests/test_zone.py b/tests/test_zone.py index 88b1e58e..473c7333 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -499,6 +499,13 @@ class ZoneTestCase(unittest.TestCase): rds = z.get_rdataset('@', 'loc') self.assertTrue(rds is None) + def testGetRdatasetWithRelativeNameFromAbsoluteZone(self): + z = dns.zone.from_text(example_text, 'example.', relativize=False) + rds = z.get_rdataset(dns.name.empty, 'soa') + self.assertIsNotNone(rds) + exrds = dns.rdataset.from_text('IN', 'SOA', 300, 'foo.example. bar.example. 1 2 3 4 5') + self.assertEqual(rds, exrds) + def testGetRRset1(self): z = dns.zone.from_text(example_text, 'example.', relativize=True) rrs = z.get_rrset('@', 'soa') @@ -1077,7 +1084,6 @@ class VersionedZoneTestCase(unittest.TestCase): self.assertTrue(soa.rdtype, dns.rdatatype.SOA) self.assertEqual(soa.serial, 1) - def testGetSoaEmptyZone(self): z = dns.zone.Zone('example.') with self.assertRaises(dns.zone.NoSOA):