From 8b9b6166821b193c3150d1bbffe11ae8ed8a2664 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Wed, 9 Mar 2022 07:53:53 -0800 Subject: [PATCH] Typing pass number 2, featuring typing of bools, adding a return type of "-> None" to procedures, and various fixes for omissions, errors, and new issues discovered by type checking previously unchecked things. --- dns/asyncquery.py | 44 +++++++++--------- dns/asyncresolver.py | 25 ++++++----- dns/dnssec.py | 13 +++--- dns/e164.py | 2 +- dns/edns.py | 20 ++++----- dns/entropy.py | 10 ++--- dns/immutable.py | 6 ++- dns/ipv6.py | 4 +- dns/message.py | 74 +++++++++++++++++------------- dns/name.py | 14 +++--- dns/node.py | 29 +++++++----- dns/query.py | 67 +++++++++++++++------------- dns/rcode.py | 20 +++++---- dns/rdata.py | 41 ++++++++--------- dns/rdataclass.py | 10 ++--- dns/rdataset.py | 12 ++--- dns/rdatatype.py | 16 +++---- dns/resolver.py | 103 ++++++++++++++++++++++-------------------- dns/reversename.py | 8 ++-- dns/rrset.py | 10 ++--- dns/serial.py | 2 +- dns/tokenizer.py | 31 ++++++------- dns/transaction.py | 34 +++++++------- dns/update.py | 26 ++++++----- dns/versioned.py | 38 ++++++++++------ dns/wire.py | 10 ++--- dns/xfr.py | 12 ++--- dns/zone.py | 104 ++++++++++++++++++++++++------------------- dns/zonefile.py | 18 ++++---- 29 files changed, 436 insertions(+), 367 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 950624a1..c785764d 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -97,9 +97,9 @@ async def send_udp(sock: dns.asyncbackend.DatagramSocket, async def receive_udp(sock: dns.asyncbackend.DatagramSocket, destination: Optional[Any]=None, expiration: Optional[float]=None, - ignore_unexpected=False, one_rr_per_rrset=False, + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', - ignore_trailing=False, raise_on_truncation=False) -> Any: + ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``dns.asyncbackend.DatagramSocket``. @@ -121,10 +121,10 @@ async def receive_udp(sock: dns.asyncbackend.DatagramSocket, raise_on_truncation=raise_on_truncation) return (r, received_time, from_address) -async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, - raise_on_truncation=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None, +async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, + raise_on_truncation: bool=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: """Return the response obtained after sending a query via UDP. @@ -174,9 +174,9 @@ async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, if not sock and s: await s.close() -async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, +async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None, tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]: @@ -252,9 +252,9 @@ async def _read_exactly(sock, count, expiration): async def receive_tcp(sock: dns.asyncbackend.StreamSocket, - expiration: Optional[float]=None, one_rr_per_rrset=False, + expiration: Optional[float]=None, one_rr_per_rrset: bool=False, keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, - request_mac=b'', ignore_trailing=False) -> Tuple[dns.message.Message, float]: + request_mac=b'', ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``dns.asyncbackend.StreamSocket``. @@ -273,9 +273,9 @@ async def receive_tcp(sock: dns.asyncbackend.StreamSocket, return (r, received_time) -async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, +async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[dns.asyncbackend.StreamSocket]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message: """Return the response obtained after sending a query via TCP. @@ -328,8 +328,8 @@ async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, await s.close() async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port=853, source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, + port: int=853, source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[dns.asyncbackend.StreamSocket]=None, backend: Optional[dns.asyncbackend.Backend]=None, ssl_context: Optional[ssl.SSLContext]=None, @@ -383,10 +383,10 @@ async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, await s.close() async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port=443, source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, + port: int=443, source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, client: Optional[httpx.AsyncClient]=None, - path='/dns-query', post=True, verify=True): + path: str='/dns-query', post: bool=True, verify: bool=True) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *client*, a ``httpx.AsyncClient``. If provided, the client to use for @@ -466,9 +466,9 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, query: Optional[dns.message.Message]=None, - port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, - source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER, - backend: Optional[dns.asyncbackend.Backend]=None): + port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, + source: Optional[str]=None, source_port: int=0, udp_mode=UDPMode.NEVER, + backend: Optional[dns.asyncbackend.Backend]=None) -> None: """Conduct an inbound transfer and apply it via a transaction from the txn_manager. diff --git a/dns/asyncresolver.py b/dns/asyncresolver.py index 72ef0412..152b1a63 100644 --- a/dns/asyncresolver.py +++ b/dns/asyncresolver.py @@ -26,6 +26,8 @@ import dns.asyncquery import dns.exception import dns.name import dns.query +import dns.rdataclass +import dns.rdatatype import dns.resolver # lgtm[py/import-and-import-from] # import some resolver symbols for brevity @@ -41,10 +43,10 @@ class Resolver(dns.resolver.BaseResolver): """Asynchronous DNS stub resolver.""" async def resolve(self, qname: Union[dns.name.Name, str], - rdtype=dns.rdatatype.A, - rdclass=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, - raise_on_no_answer=True, source_port=0, + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp: bool=False, source: Optional[str]=None, + raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None, search: Optional[bool]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. @@ -167,7 +169,7 @@ def get_default_resolver() -> Resolver: return default_resolver -def reset_default_resolver(): +def reset_default_resolver() -> None: """Re-initialize default asynchronous resolver. Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX @@ -179,10 +181,10 @@ def reset_default_resolver(): async def resolve(qname: Union[dns.name.Name, str], - rdtype=dns.rdatatype.A, - rdclass=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, - raise_on_no_answer=True, source_port=0, + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + tcp: bool=False, source: Optional[str]=None, + raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None, search: Optional[bool]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer: """Query nameservers asynchronously to find the answer to the question. @@ -218,8 +220,9 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: return await get_default_resolver().canonical_name(name) -async def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN, - tcp=False, resolver: Optional[Resolver]=None, +async def zone_for_name(name: Union[dns.name.Name, str], + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, + tcp: bool=False, resolver: Optional[Resolver]=None, backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name: """Find the name of the zone which contains the specified name. diff --git a/dns/dnssec.py b/dns/dnssec.py index 810d12de..598734b9 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -221,7 +221,7 @@ def _bytes_to_long(b: bytes) -> int: return int.from_bytes(b, 'big') -def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any): +def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None: keyptr: bytes if _is_rsa(key.algorithm): # we ignore because mypy is confused and thinks key.key is a str for unknown reasons. @@ -304,7 +304,7 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any): def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], rrsig: RRSIG, keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], - origin: Optional[dns.name.Name]=None, now: Optional[float]=None): + origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None: """Validate an RRset against a single signature rdata, throwing an exception if validation is not successful. @@ -416,7 +416,7 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]], keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]], - origin: Optional[dns.name.Name]=None, now: Optional[float]=None): + origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None: """Validate an RRset against a signature RRset, throwing an exception if none of the signatures validate. @@ -476,7 +476,8 @@ def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rd raise ValidationFailure("no RRSIGs validated") -def nsec3_hash(domain, salt, iterations, algorithm): +def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str, bytes]], + iterations: int, algorithm: Union[int, str]) -> str: """ Calculate the NSEC3 hash, according to https://tools.ietf.org/html/rfc5155#section-5 @@ -507,7 +508,6 @@ def nsec3_hash(domain, salt, iterations, algorithm): if algorithm != NSEC3Hash.SHA1: raise ValueError("Wrong hash algorithm (only SHA1 is supported)") - salt_encoded = salt if salt is None: salt_encoded = b'' elif isinstance(salt, str): @@ -515,10 +515,13 @@ def nsec3_hash(domain, salt, iterations, algorithm): salt_encoded = bytes.fromhex(salt) else: raise ValueError("Invalid salt length") + else: + salt_encoded = salt if not isinstance(domain, dns.name.Name): domain = dns.name.from_text(domain) domain_encoded = domain.canonicalize().to_wire() + assert domain_encoded is not None digest = hashlib.sha1(domain_encoded + salt_encoded).digest() for _ in range(iterations): diff --git a/dns/e164.py b/dns/e164.py index 8c9a3ac5..6e34ae5d 100644 --- a/dns/e164.py +++ b/dns/e164.py @@ -48,7 +48,7 @@ def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) -> def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain, - want_plus_prefix=True) -> str: + want_plus_prefix: bool=True) -> str: """Convert an ENUM domain name into an E.164 number. Note that dnspython does not have any information about preferred diff --git a/dns/edns.py b/dns/edns.py index 15c646de..b47b6d24 100644 --- a/dns/edns.py +++ b/dns/edns.py @@ -228,7 +228,7 @@ class ECSOption(Option): # lgtm[py/missing-equals] self.scopelen) @staticmethod - def from_text(text) -> Option: + def from_text(text: str) -> Option: """Convert a string into a `dns.edns.ECSOption` *text*, a `str`, the text form of the option. @@ -264,25 +264,25 @@ class ECSOption(Option): # lgtm[py/missing-equals] raise ValueError('could not parse ECS from "{}"'.format(text)) n_slashes = ecs_text.count('/') if n_slashes == 1: - address, srclen = ecs_text.split('/') - scope = 0 + address, tsrclen = ecs_text.split('/') + tscope = '0' elif n_slashes == 2: - address, srclen, scope = ecs_text.split('/') + address, tsrclen, tscope = ecs_text.split('/') else: raise ValueError('could not parse ECS from "{}"'.format(text)) try: - scope = int(scope) + scope = int(tscope) except ValueError: raise ValueError('invalid scope ' + - '"{}": scope must be an integer'.format(scope)) + '"{}": scope must be an integer'.format(tscope)) try: - srclen = int(srclen) + srclen = int(tsrclen) except ValueError: raise ValueError('invalid srclen ' + - '"{}": srclen must be an integer'.format(srclen)) + '"{}": srclen must be an integer'.format(tsrclen)) return ECSOption(address, srclen, scope) - def to_wire(self, file=None) -> Optional[bytes]: + def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]: value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) + self.addrdata) if file: @@ -442,7 +442,7 @@ def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, o with parser.restrict_to(olen): return option_from_wire_parser(otype, parser) -def register_type(implementation: Any, otype: OptionType): +def register_type(implementation: Any, otype: OptionType) -> None: """Register the implementation of an option type. *implementation*, a ``class``, is a subclass of ``dns.edns.Option``. diff --git a/dns/entropy.py b/dns/entropy.py index b5d34971..7da2e04a 100644 --- a/dns/entropy.py +++ b/dns/entropy.py @@ -34,7 +34,7 @@ class EntropyPool: # leaving this code doesn't hurt anything as the library code # is used if present. - def __init__(self, seed=None): + def __init__(self, seed: Optional[bytes]=None): self.pool_index = 0 self.digest: Optional[bytearray] = None self.next_byte = 0 @@ -43,14 +43,14 @@ class EntropyPool: self.hash_len = 20 self.pool = bytearray(b'\0' * self.hash_len) if seed is not None: - self._stir(bytearray(seed)) + self._stir(seed) self.seeded = True self.seed_pid = os.getpid() else: self.seeded = False self.seed_pid = 0 - def _stir(self, entropy): + def _stir(self, entropy: bytes) -> None: for c in entropy: if self.pool_index == self.hash_len: self.pool_index = 0 @@ -58,11 +58,11 @@ class EntropyPool: self.pool[self.pool_index] ^= b self.pool_index += 1 - def stir(self, entropy): + def stir(self, entropy: bytes) -> None: with self.lock: self._stir(entropy) - def _maybe_seed(self): + def _maybe_seed(self) -> None: if not self.seeded or self.seed_pid != os.getpid(): try: seed = os.urandom(16) diff --git a/dns/immutable.py b/dns/immutable.py index 20da7d90..8a426210 100644 --- a/dns/immutable.py +++ b/dns/immutable.py @@ -1,5 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license +from typing import Any + import collections.abc from dns._immutable_ctx import immutable @@ -7,7 +9,7 @@ from dns._immutable_ctx import immutable @immutable class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] - def __init__(self, dictionary, no_copy=False): + def __init__(self, dictionary: Any, no_copy: bool=False): """Make an immutable dictionary from the specified dictionary. If *no_copy* is `True`, then *dictionary* will be wrapped instead @@ -39,7 +41,7 @@ class Dict(collections.abc.Mapping): # lgtm[py/missing-equals] return iter(self._odict) -def constify(o): +def constify(o: Any) -> Any: """ Convert mutable types to immutable types. """ diff --git a/dns/ipv6.py b/dns/ipv6.py index 1d5bffde..9e6e8b6a 100644 --- a/dns/ipv6.py +++ b/dns/ipv6.py @@ -98,7 +98,7 @@ _v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$') _colon_colon_start = re.compile(br'::.*') _colon_colon_end = re.compile(br'.*::$') -def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes: +def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes: """Convert an IPv6 address in text form to binary form. *text*, a ``str``, the IPv6 address in textual form. @@ -190,7 +190,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes: _mapped_prefix = b'\x00' * 10 + b'\xff\xff' -def is_mapped(address): +def is_mapped(address: bytes) -> bool: """Is the specified address a mapped IPv4 address? *address*, a ``bytes`` is an IPv6 address in binary form. diff --git a/dns/message.py b/dns/message.py index 7c92cdaf..a375c7e9 100644 --- a/dns/message.py +++ b/dns/message.py @@ -195,8 +195,8 @@ class Message: def __str__(self): return self.to_text() - def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, - **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, + **kw) -> str: """Convert the message to text. The *origin*, *relativize*, and any other keyword @@ -327,8 +327,8 @@ class Message: rdtype: dns.rdatatype.RdataType, covers = dns.rdatatype.NONE, deleting: Optional[dns.rdataclass.RdataClass]=None, - create=False, - force_unique=False) -> dns.rrset.RRset: + create: bool=False, + force_unique: bool=False) -> dns.rrset.RRset: """Find the RRset with the given attributes in the specified section. *section*, an ``int`` section number, or one of the section @@ -394,10 +394,10 @@ class Message: name: dns.name.Name, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers = dns.rdatatype.NONE, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, deleting: Optional[dns.rdataclass.RdataClass]=None, - create=False, - force_unique=False) -> Optional[dns.rrset.RRset]: + create: bool=False, + force_unique: bool=False) -> Optional[dns.rrset.RRset]: """Get the RRset with the given attributes in the specified section. If the RRset is not found, None is returned. @@ -439,8 +439,8 @@ class Message: rrset = None return rrset - def to_wire(self, origin: Optional[dns.name.Name]=None, max_size=0, - multi=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes: + def to_wire(self, origin: Optional[dns.name.Name]=None, max_size: int=0, + multi: bool=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes: """Return a string containing the message in DNS compressed wire format. @@ -513,9 +513,10 @@ class Message: original_id, error, other) return dns.rrset.from_rdata(keyname, 0, tsig) - def use_tsig(self, keyring: Any, keyname: Optional[dns.name.Name]=None, - fudge=300, original_id: Optional[int]=None, tsig_error=0, - other_data=b'', algorithm=dns.tsig.default_algorithm): + def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None, + fudge: int=300, original_id: Optional[int]=None, tsig_error: int=0, + other_data: bytes=b'', + algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None: """When sending, a TSIG signature using the specified key should be added. @@ -549,7 +550,7 @@ class Message: *other_data*, a ``bytes``, the TSIG other data. - *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use. This is + *algorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. This is only used if *keyring* is a ``dict``, and the key entry is a ``bytes``. """ @@ -610,9 +611,9 @@ class Message: options or ()) return dns.rrset.from_rdata(dns.name.root, int(flags), opt) - def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD, + def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0, payload: int=DEFAULT_EDNS_PAYLOAD, request_payload: Optional[int]=None, - options: Optional[List[dns.edns.Option]]=None): + options: Optional[List[dns.edns.Option]]=None) -> None: """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -687,7 +688,7 @@ class Message: else: return () - def want_dnssec(self, wanted=True): + def want_dnssec(self, wanted: bool=True) -> None: """Enable or disable 'DNSSEC desired' flag in requests. *wanted*, a ``bool``. If ``True``, then DNSSEC data is @@ -708,7 +709,7 @@ class Message: """ return dns.rcode.from_flags(int(self.flags), int(self.ednsflags)) - def set_rcode(self, rcode: dns.rcode.Rcode): + def set_rcode(self, rcode: dns.rcode.Rcode) -> None: """Set the rcode. *rcode*, a ``dns.rcode.Rcode``, is the rcode to set. @@ -726,7 +727,7 @@ class Message: """ return dns.opcode.from_flags(int(self.flags)) - def set_opcode(self, opcode: dns.opcode.Opcode): + def set_opcode(self, opcode: dns.opcode.Opcode) -> None: """Set the opcode. *opcode*, a ``dns.opcode.Opcode``, is the opcode to set. @@ -1067,17 +1068,18 @@ class _WireReader: return self.message -def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, - tsig_ctx=None, multi=False, - question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False, - continue_on_error=False) -> Message: +def from_wire(wire, keyring: Optional[Any]=None, request_mac: Optional[bytes]=b'', + xfr: bool=False, origin: Optional[dns.name.Name]=None, + tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]]=None, + multi: bool=False, question_only: bool=False, one_rr_per_rrset: bool=False, + ignore_trailing: bool=False, raise_on_truncation: bool=False, + continue_on_error: bool=False) -> Message: """Convert a DNS wire format message into a message object. *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the message is signed. - *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed + *request_mac*, a ``bytes`` or ``None``. If the message is a response to a TSIG-signed request, *request_mac* should be set to the MAC of that request. *xfr*, a ``bool``, should be set to ``True`` if this message is part of a @@ -1130,6 +1132,10 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, Returns a ``dns.message.Message``. """ + # We permit None for request_mac solely for backwards compatibility + if request_mac is None: + request_mac = b'' + def initialize_message(message): message.request_mac = request_mac message.xfr = xfr @@ -1382,8 +1388,9 @@ class _TextReader: return self.message -def from_text(text, idna_codec=None, one_rr_per_rrset=False, - origin=None, relativize=True, relativize_to=None) -> Message: +def from_text(text, idna_codec: Optional[dns.name.IDNACodec]=None, + one_rr_per_rrset: bool=False, origin: Optional[dns.name.Name]=None, + relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Message: """Convert the text format message into a message object. The reader stops after reading the first blank line in the input to @@ -1423,7 +1430,7 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False, return reader.read() -def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message: +def from_file(f, idna_codec: Optional[dns.name.IDNACodec]=None, one_rr_per_rrset: bool=False) -> Message: """Read the next text format message from the specified file. Message blocks are separated by a single blank line. @@ -1452,8 +1459,11 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message: assert False # for mypy lgtm[py/unreachable-statement] -def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, - want_dnssec=False, ednsflags: Optional[int]=None, payload: Optional[int]=None, +def make_query(qname: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + use_edns: Optional[Union[int, bool]]=None, + want_dnssec: bool=False, ednsflags: Optional[int]=None, payload: Optional[int]=None, request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None, flags: int=dns.flags.RD) -> QueryMessage: @@ -1509,11 +1519,11 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, if isinstance(qname, str): qname = dns.name.from_text(qname, idna_codec=idna_codec) - rdtype = dns.rdatatype.RdataType.make(rdtype) - rdclass = dns.rdataclass.RdataClass.make(rdclass) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) m = QueryMessage(id=id) m.flags = dns.flags.Flag(flags) - m.find_rrset(m.question, qname, rdclass, rdtype, create=True, + m.find_rrset(m.question, qname, the_rdclass, the_rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a # non-None value. Setting a field will turn EDNS on if it hasn't diff --git a/dns/name.py b/dns/name.py index 334f2b18..5fd10a29 100644 --- a/dns/name.py +++ b/dns/name.py @@ -18,7 +18,7 @@ """DNS Names. """ -from typing import Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Tuple, Union import copy import struct @@ -297,7 +297,7 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True) IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False) IDNA_2008 = IDNA_2008_Practical -def _validate_labels(labels: Tuple[bytes, ...]): +def _validate_labels(labels: Tuple[bytes, ...]) -> None: """Check for empty labels in the middle of a label sequence, labels that are too long, and for too many labels. @@ -555,7 +555,7 @@ class Name: def __str__(self): return self.to_text(False) - def to_text(self, omit_final_dot=False) -> str: + def to_text(self, omit_final_dot: bool=False) -> str: """Convert name to DNS text format. *omit_final_dot* is a ``bool``. If True, don't emit the final @@ -576,7 +576,7 @@ class Name: s = '.'.join(map(_escapify, l)) return s - def to_unicode(self, omit_final_dot=False, idna_codec: Optional[IDNACodec]=None) -> str: + def to_unicode(self, omit_final_dot: bool=False, idna_codec: Optional[IDNACodec]=None) -> str: """Convert name to Unicode text format. IDN ACE labels are converted to Unicode. @@ -627,8 +627,8 @@ class Name: assert digest is not None return digest - def to_wire(self, file=None, compress: Optional[CompressType]=None, - origin: Optional['Name']=None, canonicalize=False) -> Optional[bytes]: + def to_wire(self, file: Optional[Any]=None, compress: Optional[CompressType]=None, + origin: Optional['Name']=None, canonicalize: bool=False) -> Optional[bytes]: """Convert name to wire format, possibly compressing it. *file* is the file where the name is emitted (typically an @@ -794,7 +794,7 @@ class Name: else: return self - def choose_relativity(self, origin: Optional['Name']=None, relativize=True) -> 'Name': + def choose_relativity(self, origin: Optional['Name']=None, relativize: bool=True) -> 'Name': """Return a name with the relativity desired by the caller. If *origin* is ``None``, then the name is returned. diff --git a/dns/node.py b/dns/node.py index de017b43..8727e42d 100644 --- a/dns/node.py +++ b/dns/node.py @@ -164,7 +164,7 @@ class Node: rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create=False) -> dns.rdataset.Rdataset: + create: bool=False) -> dns.rdataset.Rdataset: """Find an rdataset matching the specified properties in the current node. @@ -203,7 +203,7 @@ class Node: rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, - create=False) -> Optional[dns.rdataset.Rdataset]: + create: bool=False) -> Optional[dns.rdataset.Rdataset]: """Get an rdataset matching the specified properties in the current node. @@ -237,7 +237,7 @@ class Node: def delete_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType=dns.rdatatype.NONE): + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: """Delete the rdataset matching the specified properties in the current node. @@ -254,7 +254,7 @@ class Node: if rds is not None: self.rdatasets.remove(rds) - def replace_rdataset(self, replacement: dns.rdataset.Rdataset): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: """Replace an rdataset. It is not an error if there is no rdataset matching *replacement*. @@ -312,22 +312,31 @@ class ImmutableNode(Node): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create: bool=False) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create: bool=False) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset(self, + rdclass: dns.rdataclass.RdataClass, + rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: raise TypeError("immutable") - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement) -> None: raise TypeError("immutable") def is_immutable(self) -> bool: diff --git a/dns/query.py b/dns/query.py index 4757be8a..1ba57790 100644 --- a/dns/query.py +++ b/dns/query.py @@ -258,10 +258,10 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None): raise def https(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port=443, source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, - session: Optional[Any]=None, path='/dns-query', post=True, - bootstrap_address: Optional[str]=None, verify=True) -> dns.message.Message: + port: int=443, source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, + session: Optional[Any]=None, path: str='/dns-query', post: bool=True, + bootstrap_address: Optional[str]=None, verify: bool=True) -> dns.message.Message: """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. @@ -465,9 +465,9 @@ def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: An def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None, - ignore_unexpected=False, one_rr_per_rrset=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', - ignore_trailing=False, raise_on_truncation=False) -> Any: + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'', + ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any: """Read a DNS message from a UDP socket. *sock*, a ``socket``. @@ -489,7 +489,7 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional *keyring*, a ``dict``, the keyring to use for TSIG. - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. @@ -525,10 +525,10 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional else: return (r, received_time, from_address) -def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, - raise_on_truncation=False, sock: Optional[Any]=None) -> dns.message.Message: +def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, + raise_on_truncation: bool=False, sock: Optional[Any]=None) -> dns.message.Message: """Return the response obtained after sending a query via UDP. *q*, a ``dns.message.Message``, the query to send @@ -587,9 +587,9 @@ def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port= return r assert False # help mypy figure out we can't get here lgtm[py/unreachable-statement] -def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, +def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False, udp_sock: Optional[Any]=None, tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]: """Return the response to the query, trying UDP first and falling back @@ -709,9 +709,10 @@ def send_tcp(sock: Any, what: Union[dns.message.Message, bytes], _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) -def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=False, - keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'', - ignore_trailing=False) -> Tuple[dns.message.Message, float]: +def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: bool=False, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, + request_mac: Optional[bytes]=b'', + ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]: """Read a DNS message from a TCP socket. *sock*, a ``socket``. @@ -725,7 +726,7 @@ def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=Fa *keyring*, a ``dict``, the keyring to use for TSIG. - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG). *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. @@ -757,9 +758,10 @@ def _connect(s, address, expiration): raise OSError(err, os.strerror(err)) -def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53, - source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[Any]=None) -> dns.message.Message: +def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53, + source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, + sock: Optional[Any]=None) -> dns.message.Message: """Return the response obtained after sending a query via TCP. *q*, a ``dns.message.Message``, the query to send @@ -826,8 +828,8 @@ def _tls_handshake(s, expiration): def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, - port=853, source: Optional[str]=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[ssl.SSLSocket]=None, + port: int=853, source: Optional[str]=None, source_port: int=0, + one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[ssl.SSLSocket]=None, ssl_context: Optional[ssl.SSLContext]=None, server_hostname: Optional[str]=None) -> dns.message.Message: """Return the response obtained after sending a query via TLS. @@ -908,10 +910,15 @@ def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None, return r assert False # help mypy figure out we can't get here lgtm[py/unreachable-statement] -def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, - timeout=None, port=53, keyring=None, keyname=None, relativize=True, - lifetime=None, source=None, source_port=0, serial=0, - use_udp=False, keyalgorithm=dns.tsig.default_algorithm): +def xfr(where: str, zone: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.AXFR, + rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, + timeout: Optional[float]=None, port: int=53, + keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, + keyname: Optional[Union[dns.name.Name, str]]=None, relativize: bool=True, + lifetime: Optional[float]=None, source: Optional[str]=None, source_port: int=0, + serial: int=0, use_udp: bool=False, + keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> Any: """Return a generator for the responses to a zone transfer. *where*, a ``str`` containing an IPv4 or IPv6 address, where @@ -1089,8 +1096,8 @@ class UDPMode(enum.IntEnum): def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager, query: Optional[dns.message.Message]=None, - port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, - source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER): + port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None, + source: Optional[str]=None, source_port: int=0, udp_mode=UDPMode.NEVER): """Conduct an inbound transfer and apply it via a transaction from the txn_manager. diff --git a/dns/rcode.py b/dns/rcode.py index 49fee695..16e1ed4b 100644 --- a/dns/rcode.py +++ b/dns/rcode.py @@ -17,6 +17,8 @@ """DNS Result Codes.""" +from typing import Tuple + import dns.enum import dns.exception @@ -77,20 +79,20 @@ class UnknownRcode(dns.exception.DNSException): """A DNS rcode is unknown.""" -def from_text(text): +def from_text(text: str) -> Rcode: """Convert text into an rcode. *text*, a ``str``, the textual rcode or an integer in textual form. Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown. - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ return Rcode.from_text(text) -def from_flags(flags, ednsflags): +def from_flags(flags: int, ednsflags: int) -> Rcode: """Return the rcode value encoded by flags and ednsflags. *flags*, an ``int``, the DNS flags field. @@ -99,17 +101,17 @@ def from_flags(flags, ednsflags): Raises ``ValueError`` if rcode is < 0 or > 4095 - Returns an ``int``. + Returns a ``dns.rcode.Rcode``. """ value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0) - return value + return Rcode.make(value) -def to_flags(value): +def to_flags(value: Rcode) -> Tuple[int, int]: """Return a (flags, ednsflags) tuple which encodes the rcode. - *value*, an ``int``, the rcode. + *value*, a ``dns.rcode.Rcode``, the rcode. Raises ``ValueError`` if rcode is < 0 or > 4095. @@ -123,10 +125,10 @@ def to_flags(value): return (v, ev) -def to_text(value, tsig=False): +def to_text(value: Rcode, tsig: bool=False) -> str: """Convert rcode into text. - *value*, an ``int``, the rcode. + *value*, a ``dns.rcode.Rcode``, the rcode. Raises ``ValueError`` if rcode is < 0 or > 4095. diff --git a/dns/rdata.py b/dns/rdata.py index 1e1992be..24e5fde5 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -191,7 +191,7 @@ class Rdata: return self.covers() << 16 | self.rdtype - def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw): + def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw): """Convert an rdata to text format. Returns a ``str``. @@ -199,12 +199,12 @@ class Rdata: raise NotImplementedError # pragma: no cover - def _to_wire(self, file, compress: Optional[dns.name.CompressType]=None, - origin: Optional[dns.name.Name]=None, canonicalize=False): + def _to_wire(self, file: Optional[Any], compress: Optional[dns.name.CompressType]=None, + origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes: raise NotImplementedError # pragma: no cover - def to_wire(self, file=None, compress=None, origin=None, - canonicalize=False) -> bytes: + def to_wire(self, file: Optional[Any]=None, compress: Optional[dns.name.CompressType]=None, + origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes: """Convert an rdata to wire format. Returns a ``bytes`` or ``None``. @@ -353,17 +353,17 @@ class Rdata: @classmethod def from_text(cls, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize=True, + tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize: bool=True, relativize_to: Optional[dns.name.Name]=None): raise NotImplementedError # pragma: no cover @classmethod def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None): + parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> 'Rdata': raise NotImplementedError # pragma: no cover - def replace(self, **kwargs): + def replace(self, **kwargs) -> 'Rdata': """ Create a new Rdata instance based on the instance replace was invoked on. It is possible to pass different parameters to @@ -376,7 +376,7 @@ class Rdata: """ # Get the constructor parameters. - parameters = inspect.signature(self.__init__).parameters + parameters = inspect.signature(self.__init__).parameters # type: ignore # Ensure that all of the arguments correspond to valid fields. # Don't allow rdclass or rdtype to be changed, though. @@ -615,7 +615,7 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str], rdtype: Union[dns.rdatatype.RdataType, str], tok: Union[dns.tokenizer.Tokenizer, str], origin: Optional[dns.name.Name]=None, - relativize=True, relativize_to: Optional[dns.name.Name]=None, + relativize: bool=True, relativize_to: Optional[dns.name.Name]=None, idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata: """Build an rdata object from text format. @@ -769,8 +769,8 @@ class RdatatypeExists(dns.exception.DNSException): "already exists." -def register_type(implementation, rdtype, rdtype_text, is_singleton=False, - rdclass=dns.rdataclass.IN): +def register_type(implementation: Any, rdtype: int, rdtype_text: str, is_singleton: bool=False, + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN): """Dynamically register a module to handle an rdatatype. *implementation*, a module implementing the type in the usual dnspython @@ -787,14 +787,15 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False, it applies to all classes. """ - existing_cls = get_rdata_class(rdclass, rdtype) - if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): - raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + existing_cls = get_rdata_class(rdclass, the_rdtype) + if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype): + raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) try: - if dns.rdatatype.RdataType(rdtype).name != rdtype_text: - raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) + if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text: + raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype) except ValueError: pass - _rdata_classes[(rdclass, rdtype)] = getattr(implementation, - rdtype_text.replace('-', '_')) - dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) + _rdata_classes[(rdclass, the_rdtype)] = getattr(implementation, + rdtype_text.replace('-', '_')) + dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton) diff --git a/dns/rdataclass.py b/dns/rdataclass.py index 41bba693..28670548 100644 --- a/dns/rdataclass.py +++ b/dns/rdataclass.py @@ -56,7 +56,7 @@ class UnknownRdataclass(dns.exception.DNSException): """A DNS class is unknown.""" -def from_text(text): +def from_text(text: str) -> RdataClass: """Convert text into a DNS rdata class value. The input text can be a defined DNS RR class mnemonic or @@ -68,13 +68,13 @@ def from_text(text): Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535. - Returns an ``int``. + Returns a ``dns.rdataclass.RdataClass``. """ return RdataClass.from_text(text) -def to_text(value): +def to_text(value: RdataClass) -> str: """Convert a DNS rdata class value to text. If the value has a known mnemonic, it will be used, otherwise the @@ -88,12 +88,12 @@ def to_text(value): return RdataClass.to_text(value) -def is_metaclass(rdclass): +def is_metaclass(rdclass: RdataClass) -> bool: """True if the specified class is a metaclass. The currently defined metaclasses are ANY and NONE. - *rdclass* is an ``int``. + *rdclass* is a ``dns.rdataclass.RdataClass``. """ if rdclass in _metaclasses: diff --git a/dns/rdataset.py b/dns/rdataset.py index 33bee2f1..b47057fd 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -53,7 +53,7 @@ class Rdataset(dns.set.Set): def __init__(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, - covers=dns.rdatatype.NONE, ttl=0): + covers=dns.rdatatype.NONE, ttl: int=0): """Create a new rdataset of the specified class and type. *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass. @@ -79,7 +79,7 @@ class Rdataset(dns.set.Set): obj.ttl = self.ttl return obj - def update_ttl(self, ttl: int): + def update_ttl(self, ttl: int) -> None: """Perform TTL minimization. Set the TTL of the rdataset to be the lesser of the set's current @@ -94,7 +94,7 @@ class Rdataset(dns.set.Set): elif ttl < self.ttl: self.ttl = ttl - def add(self, rd, ttl: Optional[int]=None): # pylint: disable=arguments-differ + def add(self, rd: dns.rdata.Rdata, ttl: Optional[int]=None) -> None: # pylint: disable=arguments-differ """Add the specified rdata to the rdataset. If the optional *ttl* parameter is supplied, then @@ -184,7 +184,7 @@ class Rdataset(dns.set.Set): def to_text(self, name: Optional[dns.name.Name]=None, origin: Optional[dns.name.Name]=None, - relativize=True, + relativize: bool=True, override_rdclass: Optional[dns.rdataclass.RdataClass]=None, want_comments=False, **kw) -> str: """Convert the rdataset into DNS zone file format. @@ -254,7 +254,7 @@ class Rdataset(dns.set.Set): compress: Optional[dns.name.CompressType]=None, origin: Optional[dns.name.Name]=None, override_rdclass: Optional[dns.rdataclass.RdataClass]=None, - want_shuffle=True) -> int: + want_shuffle: bool=True) -> int: """Convert the rdataset to wire format. *name*, a ``dns.name.Name`` is the owner name to use. @@ -414,7 +414,7 @@ def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str], ttl: int, text_rdatas: Collection[str], idna_codec: Optional[dns.name.IDNACodec]=None, origin: Optional[dns.name.Name]=None, - relativize=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset: + relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset: """Create an rdataset with the specified class, type, and TTL, and with the specified list of rdatas in text format. diff --git a/dns/rdatatype.py b/dns/rdatatype.py index 80f8acaf..18185bca 100644 --- a/dns/rdatatype.py +++ b/dns/rdatatype.py @@ -135,7 +135,7 @@ class UnknownRdatatype(dns.exception.DNSException): """DNS resource record type is unknown.""" -def from_text(text): +def from_text(text: str) -> RdataType: """Convert text into a DNS rdata type value. The input text can be a defined DNS RR type mnemonic or @@ -147,7 +147,7 @@ def from_text(text): Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535. - Returns an ``int``. + Returns a ``dns.rdatatype.RdataType``. """ text = text.upper().replace('-', '_') @@ -160,7 +160,7 @@ def from_text(text): raise -def to_text(value): +def to_text(value: RdataType) -> str: """Convert a DNS rdata type value to text. If the value has a known mnemonic, it will be used, otherwise the @@ -179,10 +179,10 @@ def to_text(value): return text.replace('_', '-') -def is_metatype(rdtype): +def is_metatype(rdtype: RdataType) -> bool: """True if the specified type is a metatype. - *rdtype* is an ``int``. + *rdtype* is a ``dns.rdatatype.RdataType``. The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA, MAILB, ANY, and OPT. @@ -193,7 +193,7 @@ def is_metatype(rdtype): return (256 > rdtype >= 128) or rdtype in _metatypes -def is_singleton(rdtype): +def is_singleton(rdtype: RdataType) -> bool: """Is the specified type a singleton type? Singleton types can only have a single rdata in an rdataset, or a single @@ -212,10 +212,10 @@ def is_singleton(rdtype): return False # pylint: disable=redefined-outer-name -def register_type(rdtype, rdtype_text, is_singleton=False): +def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False): """Dynamically register an rdatatype. - *rdtype*, an ``int``, the rdatatype to register. + *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register. *rdtype_text*, a ``str``, the textual form of the rdatatype. diff --git a/dns/resolver.py b/dns/resolver.py index 5f0f3628..28769f4e 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -17,7 +17,7 @@ """DNS stub resolver.""" -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import contextlib @@ -32,6 +32,7 @@ except ImportError: # pragma: no cover import dummy_threading as _threading # type: ignore import dns.exception +import dns.edns import dns.flags import dns.inet import dns.ipv4 @@ -139,7 +140,7 @@ class YXDOMAIN(dns.exception.DNSException): """The DNS query name is too long after DNAME substitution.""" -ErrorTuple = Tuple[str, bool, int, Exception, dns.message.Message] +ErrorTuple = Tuple[Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]] def _errors_to_text(errors: List[ErrorTuple]) -> List[str]: @@ -312,17 +313,17 @@ class CacheBase: self.lock = _threading.Lock() self.statistics = CacheStatistics() - def reset_statistics(self): + def reset_statistics(self) -> None: """Reset all statistics to zero.""" with self.lock: self.statistics.reset() - def hits(self): + def hits(self) -> int: """How many hits has the cache had?""" with self.lock: return self.statistics.hits - def misses(self): + def misses(self) -> int: """How many misses has the cache had?""" with self.lock: return self.statistics.misses @@ -344,17 +345,17 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" - def __init__(self, cleaning_interval=300.0): + def __init__(self, cleaning_interval: float=300.0): """*cleaning_interval*, a ``float`` is the number of seconds between periodic cleanings. """ super().__init__() - self.data = {} + self.data: Dict[CacheKey, Answer] = {} self.cleaning_interval = cleaning_interval - self.next_cleaning = time.time() + self.cleaning_interval + self.next_cleaning: float = time.time() + self.cleaning_interval - def _maybe_clean(self): + def _maybe_clean(self) -> None: """Clean the cache if it's time to do so.""" now = time.time() @@ -388,7 +389,7 @@ class Cache(CacheBase): self.statistics.hits += 1 return v - def put(self, key: CacheKey, value: Answer): + def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the @@ -401,7 +402,7 @@ class Cache(CacheBase): self._maybe_clean() self.data[key] = value - def flush(self, key: Optional[CacheKey]=None): + def flush(self, key: Optional[CacheKey]=None) -> None: """Flush the cache. If *key* is not ``None``, only that item is flushed. Otherwise @@ -451,19 +452,19 @@ class LRUCache(CacheBase): for a new one. """ - def __init__(self, max_size=100000): + def __init__(self, max_size: int=100000): """*max_size*, an ``int``, is the maximum number of nodes to cache; it must be greater than 0. """ super().__init__() - self.data = {} + self.data: Dict[CacheKey, LRUCacheNode] = {} self.set_max_size(max_size) - self.sentinel = LRUCacheNode(None, None) + self.sentinel: LRUCacheNode = LRUCacheNode(None, None) self.sentinel.prev = self.sentinel self.sentinel.next = self.sentinel - def set_max_size(self, max_size): + def set_max_size(self, max_size: int) -> None: if max_size < 1: max_size = 1 self.max_size = max_size @@ -505,7 +506,7 @@ class LRUCache(CacheBase): else: return node.hits - def put(self, key: CacheKey, value: Answer): + def put(self, key: CacheKey, value: Answer) -> None: """Associate key and value in the cache. *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the @@ -520,14 +521,14 @@ class LRUCache(CacheBase): node.unlink() del self.data[node.key] while len(self.data) >= self.max_size: - node = self.sentinel.prev - node.unlink() - del self.data[node.key] + gnode = self.sentinel.prev + gnode.unlink() + del self.data[gnode.key] node = LRUCacheNode(key, value) node.link_after(self.sentinel) self.data[key] = node - def flush(self, key: Optional[CacheKey]=None): + def flush(self, key: Optional[CacheKey]=None) -> None: """Flush the cache. If *key* is not ``None``, only that item is flushed. Otherwise @@ -544,11 +545,11 @@ class LRUCache(CacheBase): node.unlink() del self.data[node.key] else: - node = self.sentinel.next - while node != self.sentinel: - next = node.next - node.unlink() - node = next + gnode = self.sentinel.next + while gnode != self.sentinel: + next = gnode.next + gnode.unlink() + gnode = next self.data = {} class _Resolution: @@ -569,20 +570,20 @@ class _Resolution: tcp: bool, raise_on_no_answer: bool, search: Optional[bool]): if isinstance(qname, str): qname = dns.name.from_text(qname, None) - rdtype = dns.rdatatype.RdataType.make(rdtype) - if dns.rdatatype.is_metatype(rdtype): + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + if dns.rdatatype.is_metatype(the_rdtype): raise NoMetaqueries - rdclass = dns.rdataclass.RdataClass.make(rdclass) - if dns.rdataclass.is_metaclass(rdclass): + the_rdclass = dns.rdataclass.RdataClass.make(rdclass) + if dns.rdataclass.is_metaclass(the_rdclass): raise NoMetaqueries self.resolver = resolver self.qnames_to_try = resolver._get_qnames_to_try(qname, search) self.qnames = self.qnames_to_try[:] - self.rdtype = rdtype - self.rdclass = rdclass + self.rdtype = the_rdtype + self.rdclass = the_rdclass self.tcp = tcp self.raise_on_no_answer = raise_on_no_answer - self.nxdomain_responses: Dict[dns.name.Name, Answer] = {} + self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {} # Initialize other things to help analysis tools self.qname = dns.name.empty self.nameservers: List[str] = [] @@ -660,14 +661,14 @@ class _Resolution: raise NXDOMAIN(qnames=self.qnames_to_try, responses=self.nxdomain_responses) - def next_nameserver(self): + def next_nameserver(self) -> Tuple[str, int, bool, float]: if self.retry_with_tcp: assert self.nameserver is not None self.tcp_attempt = True self.retry_with_tcp = False return (self.nameserver, self.port, True, 0) - backoff = 0 + backoff = 0.0 if not self.current_nameservers: if len(self.nameservers) == 0: # Out of things to try! @@ -682,10 +683,12 @@ class _Resolution: self.tcp_attempt = self.tcp return (self.nameserver, self.port, self.tcp_attempt, backoff) - def query_result(self, response, ex): + def query_result(self, response: Optional[dns.message.Message], + ex: Optional[Exception]) -> Tuple[Optional[Answer], bool]: # # returns an (answer: Answer, end_loop: bool) tuple. # + assert self.nameserver is not None if ex: # Exception during I/O or from_wire() assert response is None @@ -706,6 +709,7 @@ class _Resolution: return (None, False) # We got an answer! assert response is not None + assert isinstance(response, dns.message.QueryMessage) rcode = response.rcode() if rcode == dns.rcode.NOERROR: try: @@ -767,7 +771,7 @@ class BaseResolver: # # pylint: disable=attribute-defined-outside-init - def __init__(self, filename='/etc/resolv.conf', configure=True): + def __init__(self, filename: str='/etc/resolv.conf', configure: bool=True): """*filename*, a ``str`` or file object, specifying a file in standard /etc/resolv.conf format. This parameter is meaningful only when *configure* is true and the platform is POSIX. @@ -813,7 +817,7 @@ class BaseResolver: self.rotate = False self.ndots: Optional[int] = None - def read_resolv_conf(self, f): + def read_resolv_conf(self, f: Any) -> None: """Process *f* as a file in the /etc/resolv.conf format. If f is a ``str``, it is used as the name of the file to open; otherwise it is treated as the file itself. @@ -879,10 +883,10 @@ class BaseResolver: if len(self.nameservers) == 0: raise NoResolverConfiguration('no nameservers') - def read_registry(self): + def read_registry(self) -> None: """Extract resolver configuration from the Windows registry.""" try: - info = dns.win32util.get_dns_info() + info = dns.win32util.get_dns_info() # type: ignore if info.domain is not None: self.domain = info.domain self.nameservers = info.nameservers @@ -949,8 +953,8 @@ class BaseResolver: qnames_to_try.append(abs_qname) return qnames_to_try - def use_tsig(self, keyring, keyname=None, - algorithm=dns.tsig.default_algorithm): + def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None, + algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None: """Add a TSIG signature to each query. The parameters are passed to ``dns.message.Message.use_tsig()``; @@ -961,8 +965,9 @@ class BaseResolver: self.keyname = keyname self.keyalgorithm = algorithm - def use_edns(self, edns=0, ednsflags=0, - payload=dns.message.DEFAULT_EDNS_PAYLOAD, options=None): + def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0, + payload: int=dns.message.DEFAULT_EDNS_PAYLOAD, + options: Optional[List[dns.edns.Option]]=None) -> None: """Configure EDNS behavior. *edns*, an ``int``, is the EDNS level to use. Specifying @@ -989,7 +994,7 @@ class BaseResolver: self.payload = payload self.ednsoptions = options - def set_flags(self, flags: int): + def set_flags(self, flags: int) -> None: """Overrides the default flags with your own. *flags*, an ``int``, the message flags to use. @@ -1030,7 +1035,7 @@ class Resolver(BaseResolver): def resolve(self, qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pylint: disable=arguments-differ """Query nameservers to find the answer to the question. @@ -1136,7 +1141,7 @@ class Resolver(BaseResolver): def query(self, qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. @@ -1226,7 +1231,7 @@ def reset_default_resolver(): def resolve(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. @@ -1245,7 +1250,7 @@ def resolve(qname: Union[dns.name.Name, str], def query(qname: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A, rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN, - tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0, + tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0, lifetime: Optional[float]=None) -> Answer: # pragma: no cover """Query nameservers to find the answer to the question. @@ -1282,7 +1287,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name: def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN, - tcp=False, resolver: Optional[Resolver]=None, + tcp: bool=False, resolver: Optional[Resolver]=None, lifetime: Optional[float]=None) -> dns.name.Name: """Find the name of the zone which contains the specified name. diff --git a/dns/reversename.py b/dns/reversename.py index 4b70cf64..c25e77df 100644 --- a/dns/reversename.py +++ b/dns/reversename.py @@ -27,8 +27,8 @@ ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.') ipv6_reverse_domain = dns.name.from_text('ip6.arpa.') -def from_address(text: str, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain) -> dns.name.Name: +def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain, + v6_origin: dns.name.Name=ipv6_reverse_domain) -> dns.name.Name: """Convert an IPv4 or IPv6 address in textual form into a Name object whose value is the reverse-map domain name of the address. @@ -63,8 +63,8 @@ def from_address(text: str, v4_origin=ipv4_reverse_domain, return dns.name.from_text('.'.join(reversed(parts)), origin=origin) -def to_address(name: dns.name.Name, v4_origin=ipv4_reverse_domain, - v6_origin=ipv6_reverse_domain) -> str: +def to_address(name: dns.name.Name, v4_origin: dns.name.Name=ipv4_reverse_domain, + v6_origin: dns.name.Name=ipv6_reverse_domain) -> str: """Convert a reverse map domain name into textual address form. *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name diff --git a/dns/rrset.py b/dns/rrset.py index 37458571..e14433ee 100644 --- a/dns/rrset.py +++ b/dns/rrset.py @@ -17,7 +17,7 @@ """DNS RRsets (an RRset is a named rdataset)""" -from typing import cast, Collection, Optional, Union +from typing import Any, cast, Collection, Optional, Union import dns.name import dns.rdataset @@ -110,7 +110,7 @@ class RRset(dns.rdataset.Rdataset): # pylint: disable=arguments-differ - def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw) -> str: # type: ignore + def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw) -> str: # type: ignore """Convert the RRset into DNS zone file format. See ``dns.name.Name.choose_relativity`` for more information @@ -130,7 +130,7 @@ class RRset(dns.rdataset.Rdataset): return super().to_text(self.name, origin, relativize, self.deleting, **kw) - def to_wire(self, file, compress: Optional[dns.name.CompressType]=None, # type: ignore + def to_wire(self, file: Any, compress: Optional[dns.name.CompressType]=None, # type: ignore origin: Optional[dns.name.Name]=None, **kw) -> int: """Convert the RRset to wire format. @@ -158,7 +158,7 @@ def from_text_list(name: Union[dns.name.Name, str], ttl: int, rdtype: Union[dns.rdatatype.RdataType, str], text_rdatas: Collection[str], idna_codec: Optional[dns.name.IDNACodec]=None, - origin: Optional[dns.name.Name]=None, relativize=True, + origin: Optional[dns.name.Name]=None, relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> RRset: """Create an RRset with the specified name, TTL, class, and type, and with the specified list of rdatas in text format. @@ -205,7 +205,7 @@ def from_text(name: Union[dns.name.Name, str], ttl: int, cast(Collection[str], text_rdatas)) -def from_rdata_list(name: Union[dns.name.Name, str], ttl:int, +def from_rdata_list(name: Union[dns.name.Name, str], ttl: int, rdatas: Collection[dns.rdata.Rdata], idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset: """Create an RRset with the specified name and TTL, and with diff --git a/dns/serial.py b/dns/serial.py index 138ffbf9..b4d264cb 100644 --- a/dns/serial.py +++ b/dns/serial.py @@ -3,7 +3,7 @@ """Serial Number Arthimetic from RFC 1982""" class Serial: - def __init__(self, value:int , bits=32): + def __init__(self, value: int, bits: int=32): self.value = value % 2 ** bits self.bits = bits diff --git a/dns/tokenizer.py b/dns/tokenizer.py index bb94ce94..331bee3c 100644 --- a/dns/tokenizer.py +++ b/dns/tokenizer.py @@ -17,7 +17,7 @@ """Tokenize DNS zone file format""" -from typing import Optional, List, Tuple +from typing import Any, Optional, List, Tuple, Union import io import sys @@ -50,7 +50,8 @@ class Token: has_escape: Does the token value contain escapes? """ - def __init__(self, ttype: int, value='', has_escape=False, comment: Optional[str]=None): + def __init__(self, ttype: int, value: Any='', has_escape: bool=False, + comment: Optional[str]=None): """Initialize a token instance.""" self.ttype = ttype @@ -225,7 +226,7 @@ class Tokenizer: encoder/decoder is used. """ - def __init__(self, f=sys.stdin, filename: Optional[str]=None, + def __init__(self, f: Any=sys.stdin, filename: Optional[str]=None, idna_codec: Optional[dns.name.IDNACodec]=None): """Initialize a tokenizer instance. @@ -297,7 +298,7 @@ class Tokenizer: return (self.filename, self.line_number) - def _unget_char(self, c): + def _unget_char(self, c: str) -> None: """Unget a character. The unget buffer for characters is only one character large; it is @@ -313,7 +314,7 @@ class Tokenizer: raise UngetBufferFull # pragma: no cover self.ungotten_char = c - def skip_whitespace(self): + def skip_whitespace(self) -> int: """Consume input until a non-whitespace character is encountered. The non-whitespace character is then ungotten, and the number of @@ -333,7 +334,7 @@ class Tokenizer: return skipped skipped += 1 - def get(self, want_leading=False, want_comment=False) -> Token: + def get(self, want_leading: bool=False, want_comment: bool=False) -> Token: """Get the next token. want_leading: If True, return a WHITESPACE token if the @@ -477,7 +478,7 @@ class Tokenizer: # Helpers - def get_int(self, base=10): + def get_int(self, base: int=10) -> int: """Read the next token and interpret it as an unsigned integer. Raises dns.exception.SyntaxError if not an unsigned integer. @@ -507,7 +508,7 @@ class Tokenizer: '%d is not an unsigned 8-bit integer' % value) return value - def get_uint16(self, base=10) -> int: + def get_uint16(self, base: int=10) -> int: """Read the next token and interpret it as a 16-bit unsigned integer. @@ -526,7 +527,7 @@ class Tokenizer: '%d is not an unsigned 16-bit integer' % value) return value - def get_uint32(self, base=10) -> int: + def get_uint32(self, base: int=10) -> int: """Read the next token and interpret it as a 32-bit unsigned integer. @@ -541,7 +542,7 @@ class Tokenizer: '%d is not an unsigned 32-bit integer' % value) return value - def get_uint48(self, base=10) -> int: + def get_uint48(self, base: int=10) -> int: """Read the next token and interpret it as a 48-bit unsigned integer. @@ -556,7 +557,7 @@ class Tokenizer: '%d is not an unsigned 48-bit integer' % value) return value - def get_string(self, max_length=None) -> str: + def get_string(self, max_length: Optional[int]=None) -> str: """Read the next token and interpret it as a string. Raises dns.exception.SyntaxError if not a string. @@ -586,7 +587,7 @@ class Tokenizer: raise dns.exception.SyntaxError('expecting an identifier') return token.value - def get_remaining(self, max_tokens=None) -> List[Token]: + def get_remaining(self, max_tokens: Optional[int]=None) -> List[Token]: """Return the remaining tokens on the line, until an EOL or EOF is seen. max_tokens: If not None, stop after this number of tokens. @@ -605,7 +606,7 @@ class Tokenizer: break return tokens - def concatenate_remaining_identifiers(self, allow_empty=False) -> str: + def concatenate_remaining_identifiers(self, allow_empty: bool=False) -> str: """Read the remaining tokens on the line, which should be identifiers. Raises dns.exception.SyntaxError if there are no remaining tokens, @@ -631,7 +632,7 @@ class Tokenizer: return s def as_name(self, token: Token, origin: Optional[dns.name.Name]=None, - relativize=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: + relativize: bool=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: """Try to interpret the token as a DNS name. Raises dns.exception.SyntaxError if not a name. @@ -643,7 +644,7 @@ class Tokenizer: name = dns.name.from_text(token.value, origin, self.idna_codec) return name.choose_relativity(relativize_to or origin, relativize) - def get_name(self, origin: Optional[dns.name.Name]=None, relativize=False, + def get_name(self, origin: Optional[dns.name.Name]=None, relativize: bool=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name: """Read the next token and interpret it as a DNS name. diff --git a/dns/transaction.py b/dns/transaction.py index ccb557ce..f48d83ef 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -20,7 +20,7 @@ class TransactionManager: """Begin a read-only transaction.""" raise NotImplementedError # pragma: no cover - def writer(self, replacement=False) -> 'Transaction': + def writer(self, replacement: bool=False) -> 'Transaction': """Begin a writable transaction. *replacement*, a ``bool``. If `True`, the content of the @@ -101,7 +101,7 @@ CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None] class Transaction: - def __init__(self, manager: TransactionManager, replacement=False, read_only=False): + def __init__(self, manager: TransactionManager, replacement: bool=False, read_only: bool=False): self.manager = manager self.replacement = replacement self.read_only = read_only @@ -133,18 +133,18 @@ class Transaction: rdataset = self._get_rdataset(name, rdtype, covers) return _ensure_immutable_rdataset(rdataset) - def get_node(self, name) -> dns.node.Node: + def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: """Return the node at *name*, if any. Returns an immutable node or ``None``. """ return _ensure_immutable_node(self._get_node(name)) - def _check_read_only(self): + def _check_read_only(self) -> None: if self.read_only: raise ReadOnly - def add(self, *args): + def add(self, *args) -> None: """Add records. The arguments may be: @@ -157,9 +157,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._add(False, args) + self._add(False, args) - def replace(self, *args): + def replace(self, *args) -> None: """Replace the existing rdataset at the name with the specified rdataset, or add the specified rdataset if there was no existing rdataset. @@ -178,9 +178,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._add(True, args) + self._add(True, args) - def delete(self, *args): + def delete(self, *args) -> None: """Delete records. It is not an error if some of the records are not in the existing @@ -200,9 +200,9 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._delete(False, args) + self._delete(False, args) - def delete_exact(self, *args): + def delete_exact(self, *args) -> None: """Delete records. The arguments may be: @@ -223,7 +223,7 @@ class Transaction: """ self._check_ended() self._check_read_only() - return self._delete(True, args) + self._delete(True, args) def name_exists(self, name: Union[dns.name.Name, str]) -> bool: """Does the specified name exist?""" @@ -232,7 +232,7 @@ class Transaction: name = dns.name.from_text(name, None) return self._name_exists(name) - def update_serial(self, value=1, relative=True, name=dns.name.empty): + def update_serial(self, value: int=1, relative: bool=True, name: dns.name.Name=dns.name.empty) -> None: """Update the serial number. *value*, an `int`, is an increment if *relative* is `True`, or the @@ -279,7 +279,7 @@ class Transaction: self._check_ended() return self._changed() - def commit(self): + def commit(self) -> None: """Commit the transaction. Normally transactions are used as context managers and commit @@ -292,7 +292,7 @@ class Transaction: """ self._end(True) - def rollback(self): + def rollback(self) -> None: """Rollback the transaction. Normally transactions are used as context managers and commit @@ -304,7 +304,7 @@ class Transaction: """ self._end(False) - def check_put_rdataset(self, check: CheckPutRdatasetType): + def check_put_rdataset(self, check: CheckPutRdatasetType) -> None: """Call *check* before putting (storing) an rdataset. The function is called with the transaction, the name, and the rdataset. @@ -316,7 +316,7 @@ class Transaction: """ self._check_put_rdataset.append(check) - def check_delete_rdataset(self, check: CheckDeleteRdatasetType): + def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None: """Call *check* before deleting an rdataset. The function is called with the transaction, the name, the rdatatype, diff --git a/dns/update.py b/dns/update.py index 5df0cc78..9e9b113b 100644 --- a/dns/update.py +++ b/dns/update.py @@ -17,13 +17,14 @@ """DNS Dynamic Update Support""" -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import dns.message import dns.name import dns.opcode import dns.rdata import dns.rdataclass +import dns.rdatatype import dns.rdataset import dns.tsig @@ -48,7 +49,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None, rdclass=dns.rdataclass.IN, keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None, - keyalgorithm=dns.tsig.default_algorithm, + keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm, id: Optional[int]=None): """Initialize a new DNS Update object. @@ -79,7 +80,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.use_tsig(keyring, keyname, algorithm=keyalgorithm) @property - def zone(self): + def zone(self) -> List[dns.rrset.RRset]: """The zone section.""" return self.sections[0] @@ -88,7 +89,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.sections[0] = v @property - def prerequisite(self): + def prerequisite(self) -> List[dns.rrset.RRset]: """The prerequisite section.""" return self.sections[1] @@ -97,7 +98,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.sections[1] = v @property - def update(self): + def update(self) -> List[dns.rrset.RRset]: """The update section.""" return self.sections[2] @@ -156,7 +157,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.origin) self._add_rr(name, ttl, rd, section=section) - def add(self, name: Union[dns.name.Name, str], *args): + def add(self, name: Union[dns.name.Name, str], *args) -> None: """Add records. The first argument is always a name. The other @@ -171,7 +172,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self._add(False, self.update, name, *args) - def delete(self, name: Union[dns.name.Name, str], *args): + def delete(self, name: Union[dns.name.Name, str], *args) -> None: """Delete records. The first argument is always a name. The other @@ -215,7 +216,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self.origin) self._add_rr(name, 0, rd, dns.rdataclass.NONE) - def replace(self, name: Union[dns.name.Name, str], *args): + def replace(self, name: Union[dns.name.Name, str], *args) -> None: """Replace records. The first argument is always a name. The other @@ -233,7 +234,7 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] self._add(True, self.update, name, *args) - def present(self, name: Union[dns.name.Name, str], *args): + def present(self, name: Union[dns.name.Name, str], *args) -> None: """Require that an owner name (and optionally an rdata type, or specific rdataset) exists as a prerequisite to the execution of the update. @@ -272,7 +273,8 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] dns.rdatatype.NONE, None, True, True) - def absent(self, name: Union[dns.name.Name, str], rdtype=None): + def absent(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str]=None) -> None: """Require that an owner name (and optionally an rdata type) does not exist as a prerequisite to the execution of the update.""" @@ -284,9 +286,9 @@ class UpdateMessage(dns.message.Message): # lgtm[py/missing-equals] dns.rdatatype.NONE, None, True, True) else: - rdtype = dns.rdatatype.RdataType.make(rdtype) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) self.find_rrset(self.prerequisite, name, - dns.rdataclass.NONE, rdtype, + dns.rdataclass.NONE, the_rdtype, dns.rdatatype.NONE, None, True, True) diff --git a/dns/versioned.py b/dns/versioned.py index 02316c82..9ed9cef6 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -13,7 +13,9 @@ except ImportError: # pragma: no cover import dns.exception import dns.immutable import dns.name +import dns.node import dns.rdataclass +import dns.rdataset import dns.rdatatype import dns.rdtypes.ANY.SOA import dns.zone @@ -40,7 +42,8 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] node_factory = Node - def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass=dns.rdataclass.IN, relativize=True, + def __init__(self, origin: Optional[Union[dns.name.Name, str]], + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True, pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None): """Initialize a versioned zone object. @@ -106,7 +109,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] self._readers.add(txn) return txn - def writer(self, replacement=False) -> Transaction: + def writer(self, replacement: bool=False) -> Transaction: event = None while True: with self._version_lock: @@ -181,7 +184,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] self._pruning_policy(self, self._versions[0]): self._versions.popleft() - def set_max_versions(self, max_versions: Optional[int]): + def set_max_versions(self, max_versions: Optional[int]) -> None: """Set a pruning policy that retains up to the specified number of versions """ @@ -195,7 +198,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] return len(zone._versions) > max_versions self.set_pruning_policy(policy) - def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]): + def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]) -> None: """Set the pruning policy for the zone. The *policy* function takes a `Version` and returns `True` if @@ -248,30 +251,39 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] id = 1 return id - def find_node(self, name, create=False): + def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node: if create: raise UseTransaction return super().find_node(name) - def delete_node(self, name): + def delete_node(self, name: Union[dns.name.Name, str]) -> None: raise UseTransaction - def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, + create: bool=False) -> dns.rdataset.Rdataset: if create: raise UseTransaction rdataset = super().find_rdataset(name, rdtype, covers) return dns.rdataset.ImmutableRdataset(rdataset) - def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, + create: bool=False) -> Optional[dns.rdataset.Rdataset]: if create: raise UseTransaction rdataset = super().get_rdataset(name, rdtype, covers) - return dns.rdataset.ImmutableRdataset(rdataset) + if rdataset is not None: + return dns.rdataset.ImmutableRdataset(rdataset) + else: + return None - def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None: raise UseTransaction - def replace_rdataset(self, name, replacement): + def replace_rdataset(self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset) -> None: raise UseTransaction diff --git a/dns/wire.py b/dns/wire.py index d3317a59..87814eea 100644 --- a/dns/wire.py +++ b/dns/wire.py @@ -9,7 +9,7 @@ import dns.exception import dns.name class Parser: - def __init__(self, wire: bytes, current=0): + def __init__(self, wire: bytes, current: int=0): self.wire = wire self.current = 0 self.end = len(self.wire) @@ -17,10 +17,10 @@ class Parser: self.seek(current) self.furthest = current - def remaining(self): + def remaining(self) -> int: return self.end - self.current - def get_bytes(self, size=int) -> bytes: + def get_bytes(self, size: int) -> bytes: assert size >= 0 if size > self.remaining(): raise dns.exception.FormError @@ -29,7 +29,7 @@ class Parser: self.furthest = max(self.furthest, self.current) return output - def get_counted_bytes(self, length_size=1) -> bytes: + def get_counted_bytes(self, length_size: int=1) -> bytes: length = int.from_bytes(self.get_bytes(length_size), 'big') return self.get_bytes(length) @@ -57,7 +57,7 @@ class Parser: name = name.relativize(origin) return name - def seek(self, where: int): + def seek(self, where: int) -> None: # Note that seeking to the end is OK! (If you try to read # after such a seek, you'll get an exception as expected.) if where < 0 or where > self.end: diff --git a/dns/xfr.py b/dns/xfr.py index 618eac2f..a360deba 100644 --- a/dns/xfr.py +++ b/dns/xfr.py @@ -15,7 +15,7 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import dns.exception import dns.message @@ -51,8 +51,9 @@ class Inbound: State machine for zone transfers. """ - def __init__(self, txn_manager: dns.transaction.TransactionManager, rdtype=dns.rdatatype.AXFR, - serial: Optional[int]=None, is_udp=False): + def __init__(self, txn_manager: dns.transaction.TransactionManager, + rdtype: dns.rdatatype.RdataType=dns.rdatatype.AXFR, + serial: Optional[int]=None, is_udp: bool=False): """Initialize an inbound zone transfer. *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. @@ -245,10 +246,11 @@ class Inbound: def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0, - use_edns=None, ednsflags: Optional[int]=None, payload: Optional[int]=None, + use_edns: Optional[Union[int, bool]]=None, ednsflags: Optional[int]=None, payload: Optional[int]=None, request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None, keyring: Any=None, keyname: Optional[dns.name.Name]=None, - keyalgorithm=dns.tsig.default_algorithm) -> Tuple[dns.message.QueryMessage, Optional[int]]: + keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) \ + -> Tuple[dns.message.QueryMessage, Optional[int]]: """Make an AXFR or IXFR query. *txn_manager* is a ``dns.transaction.TransactionManager``, typically a diff --git a/dns/zone.py b/dns/zone.py index a1fe07a9..91fb6970 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -100,7 +100,7 @@ class Zone(dns.transaction.TransactionManager): __slots__ = ['rdclass', 'origin', 'nodes', 'relativize'] def __init__(self, origin: Optional[Union[dns.name.Name, str]], - rdclass=dns.rdataclass.IN, relativize=True): + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True): """Initialize a zone object. *origin* is the origin of the zone. It may be a ``dns.name.Name``, @@ -204,7 +204,7 @@ class Zone(dns.transaction.TransactionManager): key = self._validate_name(key) return key in self.nodes - def find_node(self, name: Union[dns.name.Name, str], create=False): + def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node: """Find a node in the zone, possibly creating it. *name*: the name of the node to find. @@ -230,7 +230,7 @@ class Zone(dns.transaction.TransactionManager): self.nodes[name] = node return node - def get_node(self, name: Union[dns.name.Name, str], create=False): + def get_node(self, name: Union[dns.name.Name, str], create: bool=False) -> Optional[dns.node.Node]: """Get a node in the zone, possibly creating it. This method is like ``find_node()``, except it returns None instead @@ -257,7 +257,7 @@ class Zone(dns.transaction.TransactionManager): node = None return node - def delete_node(self, name: Union[dns.name.Name, str]): + def delete_node(self, name: Union[dns.name.Name, str]) -> None: """Delete the specified node if it exists. *name*: the name of the node to find. @@ -275,7 +275,7 @@ class Zone(dns.transaction.TransactionManager): def find_rdataset(self, name: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str], covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, - create=False) -> dns.rdataset.Rdataset: + create: bool=False) -> dns.rdataset.Rdataset: """Look for an rdataset with the specified name and type in the zone, and return an rdataset encapsulating it. @@ -310,14 +310,16 @@ class Zone(dns.transaction.TransactionManager): Returns a ``dns.rdataset.Rdataset``. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - covers = dns.rdatatype.RdataType.make(covers) - node = self.find_node(name, create) - return node.find_rdataset(self.rdclass, rdtype, covers, create) + the_name = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + node = self.find_node(the_name, create) + return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create) - def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, - create=False) -> Optional[dns.rdataset.Rdataset]: + def get_rdataset(self, name: Union[dns.name.Name, str], + rdtype: Union[dns.rdatatype.RdataType, str], + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE, + create: bool=False) -> Optional[dns.rdataset.Rdataset]: """Look for an rdataset with the specified name and type in the zone. This method is like ``find_rdataset()``, except it returns None instead @@ -361,7 +363,7 @@ class Zone(dns.transaction.TransactionManager): def delete_rdataset(self, name: Union[dns.name.Name, str], rdtype: Union[dns.rdatatype.RdataType, str], - covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE): + covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None: """Delete the rdataset matching *rdtype* and *covers*, if it exists at the node specified by *name*. @@ -389,17 +391,17 @@ class Zone(dns.transaction.TransactionManager): RRSIG rdataset. """ - name = self._validate_name(name) - rdtype = dns.rdatatype.RdataType.make(rdtype) - covers = dns.rdatatype.RdataType.make(covers) - node = self.get_node(name) + the_name = self._validate_name(name) + the_rdtype = dns.rdatatype.RdataType.make(rdtype) + the_covers = dns.rdatatype.RdataType.make(covers) + node = self.get_node(the_name) if node is not None: - node.delete_rdataset(self.rdclass, rdtype, covers) + node.delete_rdataset(self.rdclass, the_rdtype, the_covers) if len(node) == 0: - self.delete_node(name) + self.delete_node(the_name) def replace_rdataset(self, name: Union[dns.name.Name, str], - replacement: dns.rdataset.Rdataset): + replacement: dns.rdataset.Rdataset) -> None: """Replace an rdataset at name. It is not an error if there is no rdataset matching I{replacement}. @@ -575,8 +577,8 @@ class Zone(dns.transaction.TransactionManager): for rdata in rds: yield (name, rds.ttl, rdata) - def to_file(self, f: Any, sorted=True, relativize=True, nl: Optional[str]=None, - want_comments=False, want_origin=False): + def to_file(self, f: Any, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None, + want_comments: bool=False, want_origin: bool=False): """Write a zone to a file. *f*, a file or `str`. If *f* is a string, it is treated @@ -653,8 +655,8 @@ class Zone(dns.transaction.TransactionManager): f.write(l) f.write(nl) - def to_text(self, sorted=True, relativize=True, nl: Optional[str]=None, - want_comments=False, want_origin=False): + def to_text(self, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None, + want_comments: bool=False, want_origin: bool=False): """Return a zone's text as though it were written to a file. *sorted*, a ``bool``. If True, the default, then the file @@ -687,7 +689,7 @@ class Zone(dns.transaction.TransactionManager): temp_buffer.close() return return_value - def check_origin(self): + def check_origin(self) -> None: """Do some simple checking of the zone's origin. Raises ``dns.zone.NoSOA`` if there is no SOA RRset. @@ -699,6 +701,7 @@ class Zone(dns.transaction.TransactionManager): if self.relativize: name = dns.name.empty else: + assert self.origin is not None name = self.origin if self.get_rdataset(name, dns.rdatatype.SOA) is None: raise NoSOA @@ -758,7 +761,8 @@ class Zone(dns.transaction.TransactionManager): hasher.update(rrnamebuf + rrfixed + rrlen + rdata) return hasher.digest() - def compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: + def compute_digest(self, hash_algorithm: DigestHashAlgorithm, + scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD: serial = self.get_soa().serial digest = self._compute_digest(hash_algorithm, scheme) return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass, @@ -766,11 +770,12 @@ class Zone(dns.transaction.TransactionManager): serial, scheme, hash_algorithm, digest) - def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None): + def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None) -> None: digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]] if zonemd: digests = [zonemd] else: + assert self.origin is not None rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD) if rds is None: raise NoDigest @@ -791,7 +796,7 @@ class Zone(dns.transaction.TransactionManager): return Transaction(self, False, Version(self, 1, self.nodes, self.origin)) - def writer(self, replacement=False) -> 'Transaction': + def writer(self, replacement: bool=False) -> 'Transaction': txn = Transaction(self, replacement) txn._setup_version() return txn @@ -852,25 +857,28 @@ class ImmutableVersionedNode(VersionedNode): [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] ) - def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def find_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create: bool=False) -> dns.rdataset.Rdataset: if create: raise TypeError("immutable") return super().find_rdataset(rdclass, rdtype, covers, False) - def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, - create=False): + def get_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE, + create: bool=False) -> Optional[dns.rdataset.Rdataset]: if create: raise TypeError("immutable") return super().get_rdataset(rdclass, rdtype, covers, False) - def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + def delete_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType, + covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None: raise TypeError("immutable") - def replace_rdataset(self, replacement): + def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None: raise TypeError("immutable") - def is_immutable(self): + def is_immutable(self) -> bool: return True @@ -920,7 +928,7 @@ class Version: class WritableVersion(Version): - def __init__(self, zone: Zone, replacement=False): + def __init__(self, zone: Zone, replacement: bool=False): # The zone._versions_lock must be held by our caller in a versioned # zone. id = zone._get_next_version_id() @@ -958,18 +966,18 @@ class WritableVersion(Version): else: return node - def delete_node(self, name: dns.name.Name): + def delete_node(self, name: dns.name.Name) -> None: name = self._validate_name(name) if name in self.nodes: del self.nodes[name] self.changed.add(name) - def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset): + def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset) -> None: node = self._maybe_cow(name) node.replace_rdataset(rdataset) def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType, - covers: dns.rdatatype.RdataType): + covers: dns.rdatatype.RdataType) -> None: node = self._maybe_cow(name) node.delete_rdataset(self.zone.rdclass, rdtype, covers) if len(node) == 0: @@ -1077,9 +1085,9 @@ class Transaction(dns.transaction.Transaction): def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, - rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename: Optional[str]=None, - allow_include=False, check_origin=True, + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, + relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None, + allow_include: bool=False, check_origin: bool=True, idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone: """Build a zone object from a zone file format string. @@ -1145,9 +1153,9 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None, def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None, - rdclass=dns.rdataclass.IN, - relativize=True, zone_factory=Zone, filename: Optional[str]=None, - allow_include=True, check_origin=True) -> Zone: + rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, + relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None, + allow_include: bool=True, check_origin: bool=True) -> Zone: """Read a zone file and build a zone object. *f*, a file or ``str``. If *f* is a string, it is treated @@ -1200,7 +1208,7 @@ def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None, assert False # make mypy happy lgtm[py/unreachable-statement] -def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): +def from_xfr(xfr: Any, zone_factory=Zone, relativize: bool=True, check_origin: bool=True) -> Zone: """Convert the output of a zone transfer generator into a zone object. *xfr*, a generator of ``dns.message.Message`` objects, typically @@ -1221,6 +1229,8 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): Raises ``KeyError`` if there is no origin node. + Raises ``ValueError`` if no messages are yielded by the generator. + Returns a subclass of ``dns.zone.Zone``. """ @@ -1243,6 +1253,8 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True): zrds.update_ttl(rrset.ttl) for rd in rrset: zrds.add(rd) + if z is None: + raise ValueError('empty transfer') if check_origin: z.check_origin() return z diff --git a/dns/zonefile.py b/dns/zonefile.py index 605131dc..479f0d63 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -66,7 +66,7 @@ def _check_cname_and_other_data(txn, name, rdataset): SavedStateType = Tuple[dns.tokenizer.Tokenizer, Optional[dns.name.Name], # current_origin Optional[dns.name.Name], # last_name - Optional[str], # current_file + Optional[Any], # current_file int, # last_ttl bool, # last_ttl_known int, # default_ttl @@ -78,8 +78,8 @@ class Reader: """Read a DNS zone file into a transaction.""" def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass, - txn: dns.transaction.Transaction, allow_include=False, - allow_directives=True, force_name: Optional[dns.name.Name]=None, + txn: dns.transaction.Transaction, allow_include: bool=False, + allow_directives: bool=True, force_name: Optional[dns.name.Name]=None, force_ttl: Optional[int]=None, force_rdclass: Optional[dns.rdataclass.RdataClass]=None, force_rdtype: Optional[dns.rdatatype.RdataType]=None, @@ -102,7 +102,7 @@ class Reader: self.zone_rdclass = rdclass self.txn = txn self.saved_state: List[SavedStateType] = [] - self.current_file = None + self.current_file: Optional[Any] = None self.allow_include = allow_include self.allow_directives = allow_directives self.force_name = force_name @@ -385,7 +385,7 @@ class Reader: self.txn.add(name, ttl, rd) - def read(self): + def read(self) -> None: """Read a DNS zone file and build a zone object. @raises dns.zone.NoSOA: No SOA RR was found at the zone origin @@ -433,11 +433,9 @@ class Reader: token = self.tok.get() filename = token.value token = self.tok.get() + new_origin: Optional[dns.name.Name] if token.is_identifier(): - new_origin =\ - dns.name.from_text(token.value, - self.current_origin, - self.tok.idna_codec) + new_origin = dns.name.from_text(token.value, self.current_origin, self.tok.idna_codec) self.tok.get_eol() elif not token.is_eol_or_eof(): raise dns.exception.SyntaxError( @@ -572,7 +570,7 @@ def read_rrsets(text: Any, default_ttl: Optional[Union[int, str]]=None, idna_codec: Optional[dns.name.IDNACodec]=None, origin: Optional[Union[dns.name.Name, str]]=dns.name.root, - relativize=False) -> List[dns.rrset.RRset]: + relativize: bool=False) -> List[dns.rrset.RRset]: """Read one or more rrsets from the specified text, possibly subject to restrictions. -- 2.47.3