]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Typing pass number 2, featuring typing of bools, adding a return type
authorBob Halley <halley@dnspython.org>
Wed, 9 Mar 2022 15:53:53 +0000 (07:53 -0800)
committerBob Halley <halley@dnspython.org>
Thu, 10 Mar 2022 16:41:20 +0000 (08:41 -0800)
of "-> None" to procedures, and various fixes for omissions, errors,
and new issues discovered by type checking previously unchecked things.

29 files changed:
dns/asyncquery.py
dns/asyncresolver.py
dns/dnssec.py
dns/e164.py
dns/edns.py
dns/entropy.py
dns/immutable.py
dns/ipv6.py
dns/message.py
dns/name.py
dns/node.py
dns/query.py
dns/rcode.py
dns/rdata.py
dns/rdataclass.py
dns/rdataset.py
dns/rdatatype.py
dns/resolver.py
dns/reversename.py
dns/rrset.py
dns/serial.py
dns/tokenizer.py
dns/transaction.py
dns/update.py
dns/versioned.py
dns/wire.py
dns/xfr.py
dns/zone.py
dns/zonefile.py

index 950624a14d28b07c48a382d8a3412f361aea3fe0..c785764dac9c2c49fabd11fd0a6b7548fefe67f4 100644 (file)
@@ -97,9 +97,9 @@ async def send_udp(sock: dns.asyncbackend.DatagramSocket,
 
 async def receive_udp(sock: dns.asyncbackend.DatagramSocket,
                       destination: Optional[Any]=None, expiration: Optional[float]=None,
-                      ignore_unexpected=False, one_rr_per_rrset=False,
+                      ignore_unexpected: bool=False, one_rr_per_rrset: bool=False,
                       keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
-                      ignore_trailing=False, raise_on_truncation=False) -> Any:
+                      ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any:
     """Read a DNS message from a UDP socket.
 
     *sock*, a ``dns.asyncbackend.DatagramSocket``.
@@ -121,10 +121,10 @@ async def receive_udp(sock: dns.asyncbackend.DatagramSocket,
                               raise_on_truncation=raise_on_truncation)
     return (r, received_time, from_address)
 
-async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-              source: Optional[str]=None, source_port=0,
-              ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
-              raise_on_truncation=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None,
+async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+              source: Optional[str]=None, source_port: int=0,
+              ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
+              raise_on_truncation: bool=False, sock: Optional[dns.asyncbackend.DatagramSocket]=None,
               backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via UDP.
 
@@ -174,9 +174,9 @@ async def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
         if not sock and s:
             await s.close()
 
-async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-                            source: Optional[str]=None, source_port=0,
-                            ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+async def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+                            source: Optional[str]=None, source_port: int=0,
+                            ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
                             udp_sock: Optional[dns.asyncbackend.DatagramSocket]=None,
                             tcp_sock: Optional[dns.asyncbackend.StreamSocket]=None,
                             backend: Optional[dns.asyncbackend.Backend]=None) -> Tuple[dns.message.Message, bool]:
@@ -252,9 +252,9 @@ async def _read_exactly(sock, count, expiration):
 
 
 async def receive_tcp(sock: dns.asyncbackend.StreamSocket,
-                      expiration: Optional[float]=None, one_rr_per_rrset=False,
+                      expiration: Optional[float]=None, one_rr_per_rrset: bool=False,
                       keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
-                      request_mac=b'', ignore_trailing=False) -> Tuple[dns.message.Message, float]:
+                      request_mac=b'', ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]:
     """Read a DNS message from a TCP socket.
 
     *sock*, a ``dns.asyncbackend.StreamSocket``.
@@ -273,9 +273,9 @@ async def receive_tcp(sock: dns.asyncbackend.StreamSocket,
     return (r, received_time)
 
 
-async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-              source: Optional[str]=None, source_port=0,
-              one_rr_per_rrset=False, ignore_trailing=False,
+async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+              source: Optional[str]=None, source_port: int=0,
+              one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
               sock: Optional[dns.asyncbackend.StreamSocket]=None,
               backend: Optional[dns.asyncbackend.Backend]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TCP.
@@ -328,8 +328,8 @@ async def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None,
             await s.close()
 
 async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
-              port=853, source: Optional[str]=None, source_port=0,
-              one_rr_per_rrset=False, ignore_trailing=False,
+              port: int=853, source: Optional[str]=None, source_port: int=0,
+              one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
               sock: Optional[dns.asyncbackend.StreamSocket]=None,
               backend: Optional[dns.asyncbackend.Backend]=None,
               ssl_context: Optional[ssl.SSLContext]=None,
@@ -383,10 +383,10 @@ async def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
             await s.close()
 
 async def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
-                port=443, source: Optional[str]=None, source_port=0,
-                one_rr_per_rrset=False, ignore_trailing=False,
+                port: int=443, source: Optional[str]=None, source_port: int=0,
+                one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
                 client: Optional[httpx.AsyncClient]=None,
-                path='/dns-query', post=True, verify=True):
+                path: str='/dns-query', post: bool=True, verify: bool=True) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
     *client*, a ``httpx.AsyncClient``.  If provided, the client to use for
@@ -466,9 +466,9 @@ async def https(q: dns.message.Message, where: str, timeout: Optional[float]=Non
 
 async def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
                       query: Optional[dns.message.Message]=None,
-                      port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
-                      source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER,
-                      backend: Optional[dns.asyncbackend.Backend]=None):
+                      port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+                      source: Optional[str]=None, source_port: int=0, udp_mode=UDPMode.NEVER,
+                      backend: Optional[dns.asyncbackend.Backend]=None) -> None:
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
index 72ef0412c55bc84a7eaab1a6acaa51682d2d4b73..152b1a6387f0c5aeaec4055f2672067aaabb6db4 100644 (file)
@@ -26,6 +26,8 @@ import dns.asyncquery
 import dns.exception
 import dns.name
 import dns.query
+import dns.rdataclass
+import dns.rdatatype
 import dns.resolver  # lgtm[py/import-and-import-from]
 
 # import some resolver symbols for brevity
@@ -41,10 +43,10 @@ class Resolver(dns.resolver.BaseResolver):
     """Asynchronous DNS stub resolver."""
 
     async def resolve(self, qname: Union[dns.name.Name, str],
-                      rdtype=dns.rdatatype.A,
-                      rdclass=dns.rdataclass.IN,
-                      tcp=False, source: Optional[str]=None,
-                      raise_on_no_answer=True, source_port=0,
+                      rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+                      rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+                      tcp: bool=False, source: Optional[str]=None,
+                      raise_on_no_answer: bool=True, source_port: int=0,
                       lifetime: Optional[float]=None, search: Optional[bool]=None,
                       backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
         """Query nameservers asynchronously to find the answer to the question.
@@ -167,7 +169,7 @@ def get_default_resolver() -> Resolver:
     return default_resolver
 
 
-def reset_default_resolver():
+def reset_default_resolver() -> None:
     """Re-initialize default asynchronous resolver.
 
     Note that the resolver configuration (i.e. /etc/resolv.conf on UNIX
@@ -179,10 +181,10 @@ def reset_default_resolver():
 
 
 async def resolve(qname: Union[dns.name.Name, str],
-                  rdtype=dns.rdatatype.A,
-                  rdclass=dns.rdataclass.IN,
-                  tcp=False, source: Optional[str]=None,
-                  raise_on_no_answer=True, source_port=0,
+                  rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
+                  rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+                  tcp: bool=False, source: Optional[str]=None,
+                  raise_on_no_answer: bool=True, source_port: int=0,
                   lifetime: Optional[float]=None, search: Optional[bool]=None,
                   backend: Optional[dns.asyncbackend.Backend]=None) -> dns.resolver.Answer:
     """Query nameservers asynchronously to find the answer to the question.
@@ -218,8 +220,9 @@ async def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
 
     return await get_default_resolver().canonical_name(name)
 
-async def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
-                        tcp=False, resolver: Optional[Resolver]=None,
+async def zone_for_name(name: Union[dns.name.Name, str],
+                        rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN,
+                        tcp: bool=False, resolver: Optional[Resolver]=None,
                         backend: Optional[dns.asyncbackend.Backend]=None) -> dns.name.Name:
     """Find the name of the zone which contains the specified name.
 
index 810d12de8640aae37aa8e80a744ff88e019c3128..598734b9e7da3f35485b9849b5910fc62e177763 100644 (file)
@@ -221,7 +221,7 @@ def _bytes_to_long(b: bytes) -> int:
     return int.from_bytes(b, 'big')
 
 
-def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any):
+def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any) -> None:
     keyptr: bytes
     if _is_rsa(key.algorithm):
         # we ignore because mypy is confused and thinks key.key is a str for unknown reasons.
@@ -304,7 +304,7 @@ def _validate_signature(sig: bytes, data: bytes, key: DNSKEY, chosen_hash: Any):
 def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
                     rrsig: RRSIG,
                     keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
-                    origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
+                    origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None:
     """Validate an RRset against a single signature rdata, throwing an
     exception if validation is not successful.
 
@@ -416,7 +416,7 @@ def _validate_rrsig(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdata
 def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
               rrsigset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rdataset]],
               keys: Dict[dns.name.Name, Union[dns.node.Node, dns.rdataset.Rdataset]],
-              origin: Optional[dns.name.Name]=None, now: Optional[float]=None):
+              origin: Optional[dns.name.Name]=None, now: Optional[float]=None) -> None:
     """Validate an RRset against a signature RRset, throwing an exception
     if none of the signatures validate.
 
@@ -476,7 +476,8 @@ def _validate(rrset: Union[dns.rrset.RRset, Tuple[dns.name.Name, dns.rdataset.Rd
     raise ValidationFailure("no RRSIGs validated")
 
 
-def nsec3_hash(domain, salt, iterations, algorithm):
+def nsec3_hash(domain: Union[dns.name.Name, str], salt: Optional[Union[str, bytes]],
+               iterations: int, algorithm: Union[int, str]) -> str:
     """
     Calculate the NSEC3 hash, according to
     https://tools.ietf.org/html/rfc5155#section-5
@@ -507,7 +508,6 @@ def nsec3_hash(domain, salt, iterations, algorithm):
     if algorithm != NSEC3Hash.SHA1:
         raise ValueError("Wrong hash algorithm (only SHA1 is supported)")
 
-    salt_encoded = salt
     if salt is None:
         salt_encoded = b''
     elif isinstance(salt, str):
@@ -515,10 +515,13 @@ def nsec3_hash(domain, salt, iterations, algorithm):
             salt_encoded = bytes.fromhex(salt)
         else:
             raise ValueError("Invalid salt length")
+    else:
+        salt_encoded = salt
 
     if not isinstance(domain, dns.name.Name):
         domain = dns.name.from_text(domain)
     domain_encoded = domain.canonicalize().to_wire()
+    assert domain_encoded is not None
 
     digest = hashlib.sha1(domain_encoded + salt_encoded).digest()
     for _ in range(iterations):
index 8c9a3ac58bd3031fbbb438447916a749e5150360..6e34ae5db87e729e4958a882aa11b5a8b9cdc4c6 100644 (file)
@@ -48,7 +48,7 @@ def from_e164(text: str, origin: Optional[dns.name.Name]=public_enum_domain) ->
 
 
 def to_e164(name: dns.name.Name, origin: Optional[dns.name.Name]=public_enum_domain,
-            want_plus_prefix=True) -> str:
+            want_plus_prefix: bool=True) -> str:
     """Convert an ENUM domain name into an E.164 number.
 
     Note that dnspython does not have any information about preferred
index 15c646de1deffec41793c90bc1d35bc107fc7a16..b47b6d24cacaf0f79a37ffc09b951af57255883c 100644 (file)
@@ -228,7 +228,7 @@ class ECSOption(Option):  # lgtm[py/missing-equals]
                                            self.scopelen)
 
     @staticmethod
-    def from_text(text) -> Option:
+    def from_text(text: str) -> Option:
         """Convert a string into a `dns.edns.ECSOption`
 
         *text*, a `str`, the text form of the option.
@@ -264,25 +264,25 @@ class ECSOption(Option):  # lgtm[py/missing-equals]
             raise ValueError('could not parse ECS from "{}"'.format(text))
         n_slashes = ecs_text.count('/')
         if n_slashes == 1:
-            address, srclen = ecs_text.split('/')
-            scope = 0
+            address, tsrclen = ecs_text.split('/')
+            tscope = '0'
         elif n_slashes == 2:
-            address, srclen, scope = ecs_text.split('/')
+            address, tsrclen, tscope = ecs_text.split('/')
         else:
             raise ValueError('could not parse ECS from "{}"'.format(text))
         try:
-            scope = int(scope)
+            scope = int(tscope)
         except ValueError:
             raise ValueError('invalid scope ' +
-                             '"{}": scope must be an integer'.format(scope))
+                             '"{}": scope must be an integer'.format(tscope))
         try:
-            srclen = int(srclen)
+            srclen = int(tsrclen)
         except ValueError:
             raise ValueError('invalid srclen ' +
-                             '"{}": srclen must be an integer'.format(srclen))
+                             '"{}": srclen must be an integer'.format(tsrclen))
         return ECSOption(address, srclen, scope)
 
-    def to_wire(self, file=None) -> Optional[bytes]:
+    def to_wire(self, file: Optional[Any]=None) -> Optional[bytes]:
         value = (struct.pack('!HBB', self.family, self.srclen, self.scopelen) +
                  self.addrdata)
         if file:
@@ -442,7 +442,7 @@ def option_from_wire(otype: Union[OptionType, str], wire: bytes, current: int, o
     with parser.restrict_to(olen):
         return option_from_wire_parser(otype, parser)
 
-def register_type(implementation: Any, otype: OptionType):
+def register_type(implementation: Any, otype: OptionType) -> None:
     """Register the implementation of an option type.
 
     *implementation*, a ``class``, is a subclass of ``dns.edns.Option``.
index b5d34971ef4d0f726f036223f5760f4b6224ec85..7da2e04ad816a3eed89c87ac2352f855254bf9f9 100644 (file)
@@ -34,7 +34,7 @@ class EntropyPool:
     # leaving this code doesn't hurt anything as the library code
     # is used if present.
 
-    def __init__(self, seed=None):
+    def __init__(self, seed: Optional[bytes]=None):
         self.pool_index = 0
         self.digest: Optional[bytearray] = None
         self.next_byte = 0
@@ -43,14 +43,14 @@ class EntropyPool:
         self.hash_len = 20
         self.pool = bytearray(b'\0' * self.hash_len)
         if seed is not None:
-            self._stir(bytearray(seed))
+            self._stir(seed)
             self.seeded = True
             self.seed_pid = os.getpid()
         else:
             self.seeded = False
             self.seed_pid = 0
 
-    def _stir(self, entropy):
+    def _stir(self, entropy: bytes) -> None:
         for c in entropy:
             if self.pool_index == self.hash_len:
                 self.pool_index = 0
@@ -58,11 +58,11 @@ class EntropyPool:
             self.pool[self.pool_index] ^= b
             self.pool_index += 1
 
-    def stir(self, entropy):
+    def stir(self, entropy: bytes) -> None:
         with self.lock:
             self._stir(entropy)
 
-    def _maybe_seed(self):
+    def _maybe_seed(self) -> None:
         if not self.seeded or self.seed_pid != os.getpid():
             try:
                 seed = os.urandom(16)
index 20da7d902f5ada69574c72f668808cd288519fcc..8a4262109b52a1175ef23c435c41a111b1edf179 100644 (file)
@@ -1,5 +1,7 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
+from typing import Any
+
 import collections.abc
 
 from dns._immutable_ctx import immutable
@@ -7,7 +9,7 @@ from dns._immutable_ctx import immutable
 
 @immutable
 class Dict(collections.abc.Mapping):  # lgtm[py/missing-equals]
-    def __init__(self, dictionary, no_copy=False):
+    def __init__(self, dictionary: Any, no_copy: bool=False):
         """Make an immutable dictionary from the specified dictionary.
 
         If *no_copy* is `True`, then *dictionary* will be wrapped instead
@@ -39,7 +41,7 @@ class Dict(collections.abc.Mapping):  # lgtm[py/missing-equals]
         return iter(self._odict)
 
 
-def constify(o):
+def constify(o: Any) -> Any:
     """
     Convert mutable types to immutable types.
     """
index 1d5bffde96f528fdbd3c99593dc2e0d8cc83629a..9e6e8b6a3d12d045b7b0498e72a39ab93b6f9eae 100644 (file)
@@ -98,7 +98,7 @@ _v4_ending = re.compile(br'(.*):(\d+\.\d+\.\d+\.\d+)$')
 _colon_colon_start = re.compile(br'::.*')
 _colon_colon_end = re.compile(br'.*::$')
 
-def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes:
+def inet_aton(text: Union[str, bytes], ignore_scope: bool=False) -> bytes:
     """Convert an IPv6 address in text form to binary form.
 
     *text*, a ``str``, the IPv6 address in textual form.
@@ -190,7 +190,7 @@ def inet_aton(text: Union[str, bytes], ignore_scope=False) -> bytes:
 
 _mapped_prefix = b'\x00' * 10 + b'\xff\xff'
 
-def is_mapped(address):
+def is_mapped(address: bytes) -> bool:
     """Is the specified address a mapped IPv4 address?
 
     *address*, a ``bytes`` is an IPv6 address in binary form.
index 7c92cdaf8c34a1728bced497ce5a1927ff49692d..a375c7e9fbe71bd8eaa09d1cc063fdd17c9b090c 100644 (file)
@@ -195,8 +195,8 @@ class Message:
     def __str__(self):
         return self.to_text()
 
-    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True,
-                **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True,
+                **kw) -> str:
         """Convert the message to text.
 
         The *origin*, *relativize*, and any other keyword
@@ -327,8 +327,8 @@ class Message:
                    rdtype: dns.rdatatype.RdataType,
                    covers = dns.rdatatype.NONE,
                    deleting: Optional[dns.rdataclass.RdataClass]=None,
-                   create=False,
-                   force_unique=False) -> dns.rrset.RRset:
+                   create: bool=False,
+                   force_unique: bool=False) -> dns.rrset.RRset:
         """Find the RRset with the given attributes in the specified section.
 
         *section*, an ``int`` section number, or one of the section
@@ -394,10 +394,10 @@ class Message:
                   name: dns.name.Name,
                   rdclass: dns.rdataclass.RdataClass,
                   rdtype: dns.rdatatype.RdataType,
-                  covers = dns.rdatatype.NONE,
+                  covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
                   deleting: Optional[dns.rdataclass.RdataClass]=None,
-                  create=False,
-                  force_unique=False) -> Optional[dns.rrset.RRset]:
+                  create: bool=False,
+                  force_unique: bool=False) -> Optional[dns.rrset.RRset]:
         """Get the RRset with the given attributes in the specified section.
 
         If the RRset is not found, None is returned.
@@ -439,8 +439,8 @@ class Message:
             rrset = None
         return rrset
 
-    def to_wire(self, origin: Optional[dns.name.Name]=None, max_size=0,
-                multi=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes:
+    def to_wire(self, origin: Optional[dns.name.Name]=None, max_size: int=0,
+                multi: bool=False, tsig_ctx: Optional[Any]=None, **kw) -> bytes:
         """Return a string containing the message in DNS compressed wire
         format.
 
@@ -513,9 +513,10 @@ class Message:
                                          original_id, error, other)
         return dns.rrset.from_rdata(keyname, 0, tsig)
 
-    def use_tsig(self, keyring: Any, keyname: Optional[dns.name.Name]=None,
-                 fudge=300, original_id: Optional[int]=None, tsig_error=0,
-                 other_data=b'', algorithm=dns.tsig.default_algorithm):
+    def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None,
+                 fudge: int=300, original_id: Optional[int]=None, tsig_error: int=0,
+                 other_data: bytes=b'',
+                 algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None:
         """When sending, a TSIG signature using the specified key
         should be added.
 
@@ -549,7 +550,7 @@ class Message:
 
         *other_data*, a ``bytes``, the TSIG other data.
 
-        *algorithm*, a ``dns.name.Name``, the TSIG algorithm to use.  This is
+        *algorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.  This is
         only used if *keyring* is a ``dict``, and the key entry is a ``bytes``.
         """
 
@@ -610,9 +611,9 @@ class Message:
                                       options or ())
         return dns.rrset.from_rdata(dns.name.root, int(flags), opt)
 
-    def use_edns(self, edns=0, ednsflags=0, payload=DEFAULT_EDNS_PAYLOAD,
+    def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0, payload: int=DEFAULT_EDNS_PAYLOAD,
                  request_payload: Optional[int]=None,
-                 options: Optional[List[dns.edns.Option]]=None):
+                 options: Optional[List[dns.edns.Option]]=None) -> None:
         """Configure EDNS behavior.
 
         *edns*, an ``int``, is the EDNS level to use.  Specifying
@@ -687,7 +688,7 @@ class Message:
         else:
             return ()
 
-    def want_dnssec(self, wanted=True):
+    def want_dnssec(self, wanted: bool=True) -> None:
         """Enable or disable 'DNSSEC desired' flag in requests.
 
         *wanted*, a ``bool``.  If ``True``, then DNSSEC data is
@@ -708,7 +709,7 @@ class Message:
         """
         return dns.rcode.from_flags(int(self.flags), int(self.ednsflags))
 
-    def set_rcode(self, rcode: dns.rcode.Rcode):
+    def set_rcode(self, rcode: dns.rcode.Rcode) -> None:
         """Set the rcode.
 
         *rcode*, a ``dns.rcode.Rcode``, is the rcode to set.
@@ -726,7 +727,7 @@ class Message:
         """
         return dns.opcode.from_flags(int(self.flags))
 
-    def set_opcode(self, opcode: dns.opcode.Opcode):
+    def set_opcode(self, opcode: dns.opcode.Opcode) -> None:
         """Set the opcode.
 
         *opcode*, a ``dns.opcode.Opcode``, is the opcode to set.
@@ -1067,17 +1068,18 @@ class _WireReader:
         return self.message
 
 
-def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
-              tsig_ctx=None, multi=False,
-              question_only=False, one_rr_per_rrset=False,
-              ignore_trailing=False, raise_on_truncation=False,
-              continue_on_error=False) -> Message:
+def from_wire(wire, keyring: Optional[Any]=None, request_mac: Optional[bytes]=b'',
+              xfr: bool=False, origin: Optional[dns.name.Name]=None,
+              tsig_ctx: Optional[Union[dns.tsig.HMACTSig, dns.tsig.GSSTSig]]=None,
+              multi: bool=False, question_only: bool=False, one_rr_per_rrset: bool=False,
+              ignore_trailing: bool=False, raise_on_truncation: bool=False,
+              continue_on_error: bool=False) -> Message:
     """Convert a DNS wire format message into a message object.
 
     *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
     message is signed.
 
-    *request_mac*, a ``bytes``.  If the message is a response to a TSIG-signed
+    *request_mac*, a ``bytes`` or ``None``.  If the message is a response to a TSIG-signed
     request, *request_mac* should be set to the MAC of that request.
 
     *xfr*, a ``bool``, should be set to ``True`` if this message is part of a
@@ -1130,6 +1132,10 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
     Returns a ``dns.message.Message``.
     """
 
+    # We permit None for request_mac solely for backwards compatibility
+    if request_mac is None:
+        request_mac = b''
+
     def initialize_message(message):
         message.request_mac = request_mac
         message.xfr = xfr
@@ -1382,8 +1388,9 @@ class _TextReader:
         return self.message
 
 
-def from_text(text, idna_codec=None, one_rr_per_rrset=False,
-              origin=None, relativize=True, relativize_to=None) -> Message:
+def from_text(text, idna_codec: Optional[dns.name.IDNACodec]=None,
+              one_rr_per_rrset: bool=False, origin: Optional[dns.name.Name]=None,
+              relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Message:
     """Convert the text format message into a message object.
 
     The reader stops after reading the first blank line in the input to
@@ -1423,7 +1430,7 @@ def from_text(text, idna_codec=None, one_rr_per_rrset=False,
     return reader.read()
 
 
-def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message:
+def from_file(f, idna_codec: Optional[dns.name.IDNACodec]=None, one_rr_per_rrset: bool=False) -> Message:
     """Read the next text format message from the specified file.
 
     Message blocks are separated by a single blank line.
@@ -1452,8 +1459,11 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False) -> Message:
     assert False  # for mypy  lgtm[py/unreachable-statement]
 
 
-def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
-               want_dnssec=False, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+def make_query(qname: Union[dns.name.Name, str],
+               rdtype: Union[dns.rdatatype.RdataType, str],
+               rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+               use_edns: Optional[Union[int, bool]]=None,
+               want_dnssec: bool=False, ednsflags: Optional[int]=None, payload: Optional[int]=None,
                request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
                idna_codec: Optional[dns.name.IDNACodec]=None, id: Optional[int]=None,
                flags: int=dns.flags.RD) -> QueryMessage:
@@ -1509,11 +1519,11 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
 
     if isinstance(qname, str):
         qname = dns.name.from_text(qname, idna_codec=idna_codec)
-    rdtype = dns.rdatatype.RdataType.make(rdtype)
-    rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+    the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
     m = QueryMessage(id=id)
     m.flags = dns.flags.Flag(flags)
-    m.find_rrset(m.question, qname, rdclass, rdtype, create=True,
+    m.find_rrset(m.question, qname, the_rdclass, the_rdtype, create=True,
                  force_unique=True)
     # only pass keywords on to use_edns if they have been set to a
     # non-None value.  Setting a field will turn EDNS on if it hasn't
index 334f2b18688ae0b9d89aff4751927b3e38c39e2d..5fd10a29a0d12b2d1c01559f6ba3b1237b8869ea 100644 (file)
@@ -18,7 +18,7 @@
 """DNS Names.
 """
 
-from typing import Dict, Iterable, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, Optional, Tuple, Union
 
 import copy
 import struct
@@ -297,7 +297,7 @@ IDNA_2008_Strict = IDNA2008Codec(False, False, False, True)
 IDNA_2008_Transitional = IDNA2008Codec(True, True, False, False)
 IDNA_2008 = IDNA_2008_Practical
 
-def _validate_labels(labels: Tuple[bytes, ...]):
+def _validate_labels(labels: Tuple[bytes, ...]) -> None:
     """Check for empty labels in the middle of a label sequence,
     labels that are too long, and for too many labels.
 
@@ -555,7 +555,7 @@ class Name:
     def __str__(self):
         return self.to_text(False)
 
-    def to_text(self, omit_final_dot=False) -> str:
+    def to_text(self, omit_final_dot: bool=False) -> str:
         """Convert name to DNS text format.
 
         *omit_final_dot* is a ``bool``.  If True, don't emit the final
@@ -576,7 +576,7 @@ class Name:
         s = '.'.join(map(_escapify, l))
         return s
 
-    def to_unicode(self, omit_final_dot=False, idna_codec: Optional[IDNACodec]=None) -> str:
+    def to_unicode(self, omit_final_dot: bool=False, idna_codec: Optional[IDNACodec]=None) -> str:
         """Convert name to Unicode text format.
 
         IDN ACE labels are converted to Unicode.
@@ -627,8 +627,8 @@ class Name:
         assert digest is not None
         return digest
 
-    def to_wire(self, file=None, compress: Optional[CompressType]=None,
-                origin: Optional['Name']=None, canonicalize=False) -> Optional[bytes]:
+    def to_wire(self, file: Optional[Any]=None, compress: Optional[CompressType]=None,
+                origin: Optional['Name']=None, canonicalize: bool=False) -> Optional[bytes]:
         """Convert name to wire format, possibly compressing it.
 
         *file* is the file where the name is emitted (typically an
@@ -794,7 +794,7 @@ class Name:
         else:
             return self
 
-    def choose_relativity(self, origin: Optional['Name']=None, relativize=True) -> 'Name':
+    def choose_relativity(self, origin: Optional['Name']=None, relativize: bool=True) -> 'Name':
         """Return a name with the relativity desired by the caller.
 
         If *origin* is ``None``, then the name is returned.
index de017b432233ceed95df445cb2b0f5f2d73644f7..8727e42d9229a57787b5f45df94b328a4d532ef6 100644 (file)
@@ -164,7 +164,7 @@ class Node:
                       rdclass: dns.rdataclass.RdataClass,
                       rdtype: dns.rdatatype.RdataType,
                       covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
-                      create=False) -> dns.rdataset.Rdataset:
+                      create: bool=False) -> dns.rdataset.Rdataset:
         """Find an rdataset matching the specified properties in the
         current node.
 
@@ -203,7 +203,7 @@ class Node:
                      rdclass: dns.rdataclass.RdataClass,
                      rdtype: dns.rdatatype.RdataType,
                      covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
-                     create=False) -> Optional[dns.rdataset.Rdataset]:
+                     create: bool=False) -> Optional[dns.rdataset.Rdataset]:
         """Get an rdataset matching the specified properties in the
         current node.
 
@@ -237,7 +237,7 @@ class Node:
     def delete_rdataset(self,
                         rdclass: dns.rdataclass.RdataClass,
                         rdtype: dns.rdatatype.RdataType,
-                        covers: dns.rdatatype.RdataType=dns.rdatatype.NONE):
+                        covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None:
         """Delete the rdataset matching the specified properties in the
         current node.
 
@@ -254,7 +254,7 @@ class Node:
         if rds is not None:
             self.rdatasets.remove(rds)
 
-    def replace_rdataset(self, replacement: dns.rdataset.Rdataset):
+    def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
         """Replace an rdataset.
 
         It is not an error if there is no rdataset matching *replacement*.
@@ -312,22 +312,31 @@ class ImmutableNode(Node):
             [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
         )
 
-    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
+    def find_rdataset(self,
+                      rdclass: dns.rdataclass.RdataClass,
+                      rdtype: dns.rdatatype.RdataType,
+                      covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                      create: bool=False) -> dns.rdataset.Rdataset:
         if create:
             raise TypeError("immutable")
         return super().find_rdataset(rdclass, rdtype, covers, False)
 
-    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
+    def get_rdataset(self,
+                     rdclass: dns.rdataclass.RdataClass,
+                     rdtype: dns.rdatatype.RdataType,
+                     covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                     create: bool=False) -> Optional[dns.rdataset.Rdataset]:
         if create:
             raise TypeError("immutable")
         return super().get_rdataset(rdclass, rdtype, covers, False)
 
-    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+    def delete_rdataset(self,
+                        rdclass: dns.rdataclass.RdataClass,
+                        rdtype: dns.rdatatype.RdataType,
+                        covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None:
         raise TypeError("immutable")
 
-    def replace_rdataset(self, replacement):
+    def replace_rdataset(self, replacement) -> None:
         raise TypeError("immutable")
 
     def is_immutable(self) -> bool:
index 4757be8a26bde0a1fb83ef4fe3cff3687e51274a..1ba57790935862b0b58fcf4e9211b7b0e4259262 100644 (file)
@@ -258,10 +258,10 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
         raise
 
 def https(q: dns.message.Message, where: str, timeout: Optional[float]=None,
-          port=443, source: Optional[str]=None, source_port=0,
-          one_rr_per_rrset=False, ignore_trailing=False,
-          session: Optional[Any]=None, path='/dns-query', post=True,
-          bootstrap_address: Optional[str]=None, verify=True) -> dns.message.Message:
+          port: int=443, source: Optional[str]=None, source_port: int=0,
+          one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
+          session: Optional[Any]=None, path: str='/dns-query', post: bool=True,
+          bootstrap_address: Optional[str]=None, verify: bool=True) -> dns.message.Message:
     """Return the response obtained after sending a query via DNS-over-HTTPS.
 
     *q*, a ``dns.message.Message``, the query to send.
@@ -465,9 +465,9 @@ def send_udp(sock: Any, what: Union[dns.message.Message, bytes], destination: An
 
 
 def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional[float]=None,
-                ignore_unexpected=False, one_rr_per_rrset=False,
-                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
-                ignore_trailing=False, raise_on_truncation=False) -> Any:
+                ignore_unexpected: bool=False, one_rr_per_rrset: bool=False,
+                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac: Optional[bytes]=b'',
+                ignore_trailing: bool=False, raise_on_truncation: bool=False) -> Any:
     """Read a DNS message from a UDP socket.
 
     *sock*, a ``socket``.
@@ -489,7 +489,7 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional
 
     *keyring*, a ``dict``, the keyring to use for TSIG.
 
-    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+    *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
 
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
@@ -525,10 +525,10 @@ def receive_udp(sock: Any, destination: Optional[Any]=None, expiration: Optional
     else:
         return (r, received_time, from_address)
 
-def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-        source: Optional[str]=None, source_port=0,
-        ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
-        raise_on_truncation=False, sock: Optional[Any]=None) -> dns.message.Message:
+def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+        source: Optional[str]=None, source_port: int=0,
+        ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
+        raise_on_truncation: bool=False, sock: Optional[Any]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via UDP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -587,9 +587,9 @@ def udp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=
         return r
     assert False  # help mypy figure out we can't get here  lgtm[py/unreachable-statement]
 
-def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-                      source: Optional[str]=None, source_port=0,
-                      ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,
+def udp_with_fallback(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+                      source: Optional[str]=None, source_port: int=0,
+                      ignore_unexpected: bool=False, one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
                       udp_sock: Optional[Any]=None,
                       tcp_sock: Optional[Any]=None) -> Tuple[dns.message.Message, bool]:
     """Return the response to the query, trying UDP first and falling back
@@ -709,9 +709,10 @@ def send_tcp(sock: Any, what: Union[dns.message.Message, bytes],
     _net_write(sock, tcpmsg, expiration)
     return (len(tcpmsg), sent_time)
 
-def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=False,
-                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None, request_mac=b'',
-                ignore_trailing=False) -> Tuple[dns.message.Message, float]:
+def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset: bool=False,
+                keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
+                request_mac: Optional[bytes]=b'',
+                ignore_trailing: bool=False) -> Tuple[dns.message.Message, float]:
     """Read a DNS message from a TCP socket.
 
     *sock*, a ``socket``.
@@ -725,7 +726,7 @@ def receive_tcp(sock: Any, expiration: Optional[float]=None, one_rr_per_rrset=Fa
 
     *keyring*, a ``dict``, the keyring to use for TSIG.
 
-    *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+    *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).
 
     *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
     junk at end of the received message.
@@ -757,9 +758,10 @@ def _connect(s, address, expiration):
         raise OSError(err, os.strerror(err))
 
 
-def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port=53,
-        source: Optional[str]=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[Any]=None) -> dns.message.Message:
+def tcp(q: dns.message.Message, where: str, timeout: Optional[float]=None, port: int=53,
+        source: Optional[str]=None, source_port: int=0,
+        one_rr_per_rrset: bool=False, ignore_trailing: bool=False,
+        sock: Optional[Any]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TCP.
 
     *q*, a ``dns.message.Message``, the query to send
@@ -826,8 +828,8 @@ def _tls_handshake(s, expiration):
 
 
 def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
-        port=853, source: Optional[str]=None, source_port=0,
-        one_rr_per_rrset=False, ignore_trailing=False, sock: Optional[ssl.SSLSocket]=None,
+        port: int=853, source: Optional[str]=None, source_port: int=0,
+        one_rr_per_rrset: bool=False, ignore_trailing: bool=False, sock: Optional[ssl.SSLSocket]=None,
         ssl_context: Optional[ssl.SSLContext]=None,
         server_hostname: Optional[str]=None) -> dns.message.Message:
     """Return the response obtained after sending a query via TLS.
@@ -908,10 +910,15 @@ def tls(q: dns.message.Message, where: str, timeout: Optional[float]=None,
         return r
     assert False  # help mypy figure out we can't get here  lgtm[py/unreachable-statement]
 
-def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
-        timeout=None, port=53, keyring=None, keyname=None, relativize=True,
-        lifetime=None, source=None, source_port=0, serial=0,
-        use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
+def xfr(where: str, zone: Union[dns.name.Name, str],
+        rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.AXFR,
+        rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
+        timeout: Optional[float]=None, port: int=53,
+        keyring: Optional[Dict[dns.name.Name, dns.tsig.Key]]=None,
+        keyname: Optional[Union[dns.name.Name, str]]=None, relativize: bool=True,
+        lifetime: Optional[float]=None, source: Optional[str]=None, source_port: int=0,
+        serial: int=0, use_udp: bool=False,
+        keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> Any:
     """Return a generator for the responses to a zone transfer.
 
     *where*, a ``str`` containing an IPv4 or IPv6 address,  where
@@ -1089,8 +1096,8 @@ class UDPMode(enum.IntEnum):
 
 def inbound_xfr(where: str, txn_manager: dns.transaction.TransactionManager,
                 query: Optional[dns.message.Message]=None,
-                port=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
-                source: Optional[str]=None, source_port=0, udp_mode=UDPMode.NEVER):
+                port: int=53, timeout: Optional[float]=None, lifetime: Optional[float]=None,
+                source: Optional[str]=None, source_port: int=0, udp_mode=UDPMode.NEVER):
     """Conduct an inbound transfer and apply it via a transaction from the
     txn_manager.
 
index 49fee69503663636c60d3e60fb3636196c46592c..16e1ed4bff8c9d2c536b099bae4a709724b673bf 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Result Codes."""
 
+from typing import Tuple
+
 import dns.enum
 import dns.exception
 
@@ -77,20 +79,20 @@ class UnknownRcode(dns.exception.DNSException):
     """A DNS rcode is unknown."""
 
 
-def from_text(text):
+def from_text(text: str) -> Rcode:
     """Convert text into an rcode.
 
     *text*, a ``str``, the textual rcode or an integer in textual form.
 
     Raises ``dns.rcode.UnknownRcode`` if the rcode mnemonic is unknown.
 
-    Returns an ``int``.
+    Returns a ``dns.rcode.Rcode``.
     """
 
     return Rcode.from_text(text)
 
 
-def from_flags(flags, ednsflags):
+def from_flags(flags: int, ednsflags: int) -> Rcode:
     """Return the rcode value encoded by flags and ednsflags.
 
     *flags*, an ``int``, the DNS flags field.
@@ -99,17 +101,17 @@ def from_flags(flags, ednsflags):
 
     Raises ``ValueError`` if rcode is < 0 or > 4095
 
-    Returns an ``int``.
+    Returns a ``dns.rcode.Rcode``.
     """
 
     value = (flags & 0x000f) | ((ednsflags >> 20) & 0xff0)
-    return value
+    return Rcode.make(value)
 
 
-def to_flags(value):
+def to_flags(value: Rcode) -> Tuple[int, int]:
     """Return a (flags, ednsflags) tuple which encodes the rcode.
 
-    *value*, an ``int``, the rcode.
+    *value*, a ``dns.rcode.Rcode``, the rcode.
 
     Raises ``ValueError`` if rcode is < 0 or > 4095.
 
@@ -123,10 +125,10 @@ def to_flags(value):
     return (v, ev)
 
 
-def to_text(value, tsig=False):
+def to_text(value: Rcode, tsig: bool=False) -> str:
     """Convert rcode into text.
 
-    *value*, an ``int``, the rcode.
+    *value*, a ``dns.rcode.Rcode``, the rcode.
 
     Raises ``ValueError`` if rcode is < 0 or > 4095.
 
index 1e1992be8303e61a61d79f88f97229c6c282c9ea..24e5fde58b344e6d37b8abd82688830137141604 100644 (file)
@@ -191,7 +191,7 @@ class Rdata:
 
         return self.covers() << 16 | self.rdtype
 
-    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw):
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw):
         """Convert an rdata to text format.
 
         Returns a ``str``.
@@ -199,12 +199,12 @@ class Rdata:
 
         raise NotImplementedError  # pragma: no cover
 
-    def _to_wire(self, file, compress: Optional[dns.name.CompressType]=None,
-                 origin: Optional[dns.name.Name]=None, canonicalize=False):
+    def _to_wire(self, file: Optional[Any], compress: Optional[dns.name.CompressType]=None,
+                 origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes:
         raise NotImplementedError  # pragma: no cover
 
-    def to_wire(self, file=None, compress=None, origin=None,
-                canonicalize=False) -> bytes:
+    def to_wire(self, file: Optional[Any]=None, compress: Optional[dns.name.CompressType]=None,
+                origin: Optional[dns.name.Name]=None, canonicalize: bool=False) -> bytes:
         """Convert an rdata to wire format.
 
         Returns a ``bytes`` or ``None``.
@@ -353,17 +353,17 @@ class Rdata:
     @classmethod
     def from_text(cls, rdclass: dns.rdataclass.RdataClass,
                   rdtype: dns.rdatatype.RdataType,
-                  tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize=True,
+                  tok: dns.tokenizer.Tokenizer, origin: Optional[dns.name.Name]=None, relativize: bool=True,
                   relativize_to: Optional[dns.name.Name]=None):
         raise NotImplementedError  # pragma: no cover
 
     @classmethod
     def from_wire_parser(cls, rdclass: dns.rdataclass.RdataClass,
                          rdtype: dns.rdatatype.RdataType,
-                         parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None):
+                         parser: dns.wire.Parser, origin: Optional[dns.name.Name]=None) -> 'Rdata':
         raise NotImplementedError  # pragma: no cover
 
-    def replace(self, **kwargs):
+    def replace(self, **kwargs) -> 'Rdata':
         """
         Create a new Rdata instance based on the instance replace was
         invoked on. It is possible to pass different parameters to
@@ -376,7 +376,7 @@ class Rdata:
         """
 
         # Get the constructor parameters.
-        parameters = inspect.signature(self.__init__).parameters
+        parameters = inspect.signature(self.__init__).parameters  # type: ignore
 
         # Ensure that all of the arguments correspond to valid fields.
         # Don't allow rdclass or rdtype to be changed, though.
@@ -615,7 +615,7 @@ def from_text(rdclass: Union[dns.rdataclass.RdataClass, str],
               rdtype: Union[dns.rdatatype.RdataType, str],
               tok: Union[dns.tokenizer.Tokenizer, str],
               origin: Optional[dns.name.Name]=None,
-              relativize=True, relativize_to: Optional[dns.name.Name]=None,
+              relativize: bool=True, relativize_to: Optional[dns.name.Name]=None,
               idna_codec: Optional[dns.name.IDNACodec]=None) -> Rdata:
     """Build an rdata object from text format.
 
@@ -769,8 +769,8 @@ class RdatatypeExists(dns.exception.DNSException):
         "already exists."
 
 
-def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
-                  rdclass=dns.rdataclass.IN):
+def register_type(implementation: Any, rdtype: int, rdtype_text: str, is_singleton: bool=False,
+                  rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN):
     """Dynamically register a module to handle an rdatatype.
 
     *implementation*, a module implementing the type in the usual dnspython
@@ -787,14 +787,15 @@ def register_type(implementation, rdtype, rdtype_text, is_singleton=False,
     it applies to all classes.
     """
 
-    existing_cls = get_rdata_class(rdclass, rdtype)
-    if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype):
-        raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
+    the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+    existing_cls = get_rdata_class(rdclass, the_rdtype)
+    if existing_cls != GenericRdata or dns.rdatatype.is_metatype(the_rdtype):
+        raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
     try:
-        if dns.rdatatype.RdataType(rdtype).name != rdtype_text:
-            raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype)
+        if dns.rdatatype.RdataType(the_rdtype).name != rdtype_text:
+            raise RdatatypeExists(rdclass=rdclass, rdtype=the_rdtype)
     except ValueError:
         pass
-    _rdata_classes[(rdclass, rdtype)] = getattr(implementation,
-                                                rdtype_text.replace('-', '_'))
-    dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton)
+    _rdata_classes[(rdclass, the_rdtype)] = getattr(implementation,
+                                                    rdtype_text.replace('-', '_'))
+    dns.rdatatype.register_type(the_rdtype, rdtype_text, is_singleton)
index 41bba693b79934776709b1d526da38d6830798e7..286705486bebbd97581726949461a7f0d69c36d6 100644 (file)
@@ -56,7 +56,7 @@ class UnknownRdataclass(dns.exception.DNSException):
     """A DNS class is unknown."""
 
 
-def from_text(text):
+def from_text(text: str) -> RdataClass:
     """Convert text into a DNS rdata class value.
 
     The input text can be a defined DNS RR class mnemonic or
@@ -68,13 +68,13 @@ def from_text(text):
 
     Raises ``ValueError`` if the rdata class value is not >= 0 and <= 65535.
 
-    Returns an ``int``.
+    Returns a ``dns.rdataclass.RdataClass``.
     """
 
     return RdataClass.from_text(text)
 
 
-def to_text(value):
+def to_text(value: RdataClass) -> str:
     """Convert a DNS rdata class value to text.
 
     If the value has a known mnemonic, it will be used, otherwise the
@@ -88,12 +88,12 @@ def to_text(value):
     return RdataClass.to_text(value)
 
 
-def is_metaclass(rdclass):
+def is_metaclass(rdclass: RdataClass) -> bool:
     """True if the specified class is a metaclass.
 
     The currently defined metaclasses are ANY and NONE.
 
-    *rdclass* is an ``int``.
+    *rdclass* is a ``dns.rdataclass.RdataClass``.
     """
 
     if rdclass in _metaclasses:
index 33bee2f196b9d68ce2edc05d0b1fede71f793a94..b47057fdc482b0030acc1e959057c2339f7f0be3 100644 (file)
@@ -53,7 +53,7 @@ class Rdataset(dns.set.Set):
 
     def __init__(self, rdclass: dns.rdataclass.RdataClass,
                  rdtype: dns.rdatatype.RdataType,
-                 covers=dns.rdatatype.NONE, ttl=0):
+                 covers=dns.rdatatype.NONE, ttl: int=0):
         """Create a new rdataset of the specified class and type.
 
         *rdclass*, a ``dns.rdataclass.RdataClass``, the rdataclass.
@@ -79,7 +79,7 @@ class Rdataset(dns.set.Set):
         obj.ttl = self.ttl
         return obj
 
-    def update_ttl(self, ttl: int):
+    def update_ttl(self, ttl: int) -> None:
         """Perform TTL minimization.
 
         Set the TTL of the rdataset to be the lesser of the set's current
@@ -94,7 +94,7 @@ class Rdataset(dns.set.Set):
         elif ttl < self.ttl:
             self.ttl = ttl
 
-    def add(self, rd, ttl: Optional[int]=None):  # pylint: disable=arguments-differ
+    def add(self, rd: dns.rdata.Rdata, ttl: Optional[int]=None) -> None:  # pylint: disable=arguments-differ
         """Add the specified rdata to the rdataset.
 
         If the optional *ttl* parameter is supplied, then
@@ -184,7 +184,7 @@ class Rdataset(dns.set.Set):
 
     def to_text(self, name: Optional[dns.name.Name]=None,
                 origin: Optional[dns.name.Name]=None,
-                relativize=True,
+                relativize: bool=True,
                 override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
                 want_comments=False, **kw) -> str:
         """Convert the rdataset into DNS zone file format.
@@ -254,7 +254,7 @@ class Rdataset(dns.set.Set):
                 compress: Optional[dns.name.CompressType]=None,
                 origin: Optional[dns.name.Name]=None,
                 override_rdclass: Optional[dns.rdataclass.RdataClass]=None,
-                want_shuffle=True) -> int:
+                want_shuffle: bool=True) -> int:
         """Convert the rdataset to wire format.
 
         *name*, a ``dns.name.Name`` is the owner name to use.
@@ -414,7 +414,7 @@ def from_text_list(rdclass: Union[dns.rdataclass.RdataClass, str],
                    ttl: int, text_rdatas: Collection[str],
                    idna_codec: Optional[dns.name.IDNACodec]=None,
                    origin: Optional[dns.name.Name]=None,
-                   relativize=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset:
+                   relativize: bool=True, relativize_to: Optional[dns.name.Name]=None) -> Rdataset:
     """Create an rdataset with the specified class, type, and TTL, and with
     the specified list of rdatas in text format.
 
index 80f8acaf16fa6e9d9ddafa58ef9a3e97775627d8..18185bca06f9a80d6a8736c8181637d5e7f4034b 100644 (file)
@@ -135,7 +135,7 @@ class UnknownRdatatype(dns.exception.DNSException):
     """DNS resource record type is unknown."""
 
 
-def from_text(text):
+def from_text(text: str) -> RdataType:
     """Convert text into a DNS rdata type value.
 
     The input text can be a defined DNS RR type mnemonic or
@@ -147,7 +147,7 @@ def from_text(text):
 
     Raises ``ValueError`` if the rdata type value is not >= 0 and <= 65535.
 
-    Returns an ``int``.
+    Returns a ``dns.rdatatype.RdataType``.
     """
 
     text = text.upper().replace('-', '_')
@@ -160,7 +160,7 @@ def from_text(text):
         raise
 
 
-def to_text(value):
+def to_text(value: RdataType) -> str:
     """Convert a DNS rdata type value to text.
 
     If the value has a known mnemonic, it will be used, otherwise the
@@ -179,10 +179,10 @@ def to_text(value):
     return text.replace('_', '-')
 
 
-def is_metatype(rdtype):
+def is_metatype(rdtype: RdataType) -> bool:
     """True if the specified type is a metatype.
 
-    *rdtype* is an ``int``.
+    *rdtype* is a ``dns.rdatatype.RdataType``.
 
     The currently defined metatypes are TKEY, TSIG, IXFR, AXFR, MAILA,
     MAILB, ANY, and OPT.
@@ -193,7 +193,7 @@ def is_metatype(rdtype):
     return (256 > rdtype >= 128) or rdtype in _metatypes
 
 
-def is_singleton(rdtype):
+def is_singleton(rdtype: RdataType) -> bool:
     """Is the specified type a singleton type?
 
     Singleton types can only have a single rdata in an rdataset, or a single
@@ -212,10 +212,10 @@ def is_singleton(rdtype):
     return False
 
 # pylint: disable=redefined-outer-name
-def register_type(rdtype, rdtype_text, is_singleton=False):
+def register_type(rdtype: RdataType, rdtype_text: str, is_singleton: bool=False):
     """Dynamically register an rdatatype.
 
-    *rdtype*, an ``int``, the rdatatype to register.
+    *rdtype*, a ``dns.rdatatype.RdataType``, the rdatatype to register.
 
     *rdtype_text*, a ``str``, the textual form of the rdatatype.
 
index 5f0f3628f6f850847757052f3c49d24f30686193..28769f4e98f973ac79a2dbe4f22438efc42bdbbc 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS stub resolver."""
 
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from urllib.parse import urlparse
 import contextlib
@@ -32,6 +32,7 @@ except ImportError:  # pragma: no cover
     import dummy_threading as _threading    # type: ignore
 
 import dns.exception
+import dns.edns
 import dns.flags
 import dns.inet
 import dns.ipv4
@@ -139,7 +140,7 @@ class YXDOMAIN(dns.exception.DNSException):
     """The DNS query name is too long after DNAME substitution."""
 
 
-ErrorTuple = Tuple[str, bool, int, Exception, dns.message.Message]
+ErrorTuple = Tuple[Optional[str], bool, int, Union[Exception, str], Optional[dns.message.Message]]
 
 
 def _errors_to_text(errors: List[ErrorTuple]) -> List[str]:
@@ -312,17 +313,17 @@ class CacheBase:
         self.lock = _threading.Lock()
         self.statistics = CacheStatistics()
 
-    def reset_statistics(self):
+    def reset_statistics(self) -> None:
         """Reset all statistics to zero."""
         with self.lock:
             self.statistics.reset()
 
-    def hits(self):
+    def hits(self) -> int:
         """How many hits has the cache had?"""
         with self.lock:
             return self.statistics.hits
 
-    def misses(self):
+    def misses(self) -> int:
         """How many misses has the cache had?"""
         with self.lock:
             return self.statistics.misses
@@ -344,17 +345,17 @@ CacheKey = Tuple[dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataCla
 class Cache(CacheBase):
     """Simple thread-safe DNS answer cache."""
 
-    def __init__(self, cleaning_interval=300.0):
+    def __init__(self, cleaning_interval: float=300.0):
         """*cleaning_interval*, a ``float`` is the number of seconds between
         periodic cleanings.
         """
 
         super().__init__()
-        self.data = {}
+        self.data: Dict[CacheKey, Answer] = {}
         self.cleaning_interval = cleaning_interval
-        self.next_cleaning = time.time() + self.cleaning_interval
+        self.next_cleaning: float = time.time() + self.cleaning_interval
 
-    def _maybe_clean(self):
+    def _maybe_clean(self) -> None:
         """Clean the cache if it's time to do so."""
 
         now = time.time()
@@ -388,7 +389,7 @@ class Cache(CacheBase):
             self.statistics.hits += 1
             return v
 
-    def put(self, key: CacheKey, value: Answer):
+    def put(self, key: CacheKey, value: Answer) -> None:
         """Associate key and value in the cache.
 
         *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
@@ -401,7 +402,7 @@ class Cache(CacheBase):
             self._maybe_clean()
             self.data[key] = value
 
-    def flush(self, key: Optional[CacheKey]=None):
+    def flush(self, key: Optional[CacheKey]=None) -> None:
         """Flush the cache.
 
         If *key* is not ``None``, only that item is flushed.  Otherwise
@@ -451,19 +452,19 @@ class LRUCache(CacheBase):
     for a new one.
     """
 
-    def __init__(self, max_size=100000):
+    def __init__(self, max_size: int=100000):
         """*max_size*, an ``int``, is the maximum number of nodes to cache;
         it must be greater than 0.
         """
 
         super().__init__()
-        self.data = {}
+        self.data: Dict[CacheKey, LRUCacheNode] = {}
         self.set_max_size(max_size)
-        self.sentinel = LRUCacheNode(None, None)
+        self.sentinel: LRUCacheNode = LRUCacheNode(None, None)
         self.sentinel.prev = self.sentinel
         self.sentinel.next = self.sentinel
 
-    def set_max_size(self, max_size):
+    def set_max_size(self, max_size: int) -> None:
         if max_size < 1:
             max_size = 1
         self.max_size = max_size
@@ -505,7 +506,7 @@ class LRUCache(CacheBase):
             else:
                 return node.hits
 
-    def put(self, key: CacheKey, value: Answer):
+    def put(self, key: CacheKey, value: Answer) -> None:
         """Associate key and value in the cache.
 
         *key*, a ``(dns.name.Name, dns.rdatatype.RdataType, dns.rdataclass.RdataClass)`` tuple whose values are the
@@ -520,14 +521,14 @@ class LRUCache(CacheBase):
                 node.unlink()
                 del self.data[node.key]
             while len(self.data) >= self.max_size:
-                node = self.sentinel.prev
-                node.unlink()
-                del self.data[node.key]
+                gnode = self.sentinel.prev
+                gnode.unlink()
+                del self.data[gnode.key]
             node = LRUCacheNode(key, value)
             node.link_after(self.sentinel)
             self.data[key] = node
 
-    def flush(self, key: Optional[CacheKey]=None):
+    def flush(self, key: Optional[CacheKey]=None) -> None:
         """Flush the cache.
 
         If *key* is not ``None``, only that item is flushed.  Otherwise
@@ -544,11 +545,11 @@ class LRUCache(CacheBase):
                     node.unlink()
                     del self.data[node.key]
             else:
-                node = self.sentinel.next
-                while node != self.sentinel:
-                    next = node.next
-                    node.unlink()
-                    node = next
+                gnode = self.sentinel.next
+                while gnode != self.sentinel:
+                    next = gnode.next
+                    gnode.unlink()
+                    gnode = next
                 self.data = {}
 
 class _Resolution:
@@ -569,20 +570,20 @@ class _Resolution:
                  tcp: bool, raise_on_no_answer: bool, search: Optional[bool]):
         if isinstance(qname, str):
             qname = dns.name.from_text(qname, None)
-        rdtype = dns.rdatatype.RdataType.make(rdtype)
-        if dns.rdatatype.is_metatype(rdtype):
+        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+        if dns.rdatatype.is_metatype(the_rdtype):
             raise NoMetaqueries
-        rdclass = dns.rdataclass.RdataClass.make(rdclass)
-        if dns.rdataclass.is_metaclass(rdclass):
+        the_rdclass = dns.rdataclass.RdataClass.make(rdclass)
+        if dns.rdataclass.is_metaclass(the_rdclass):
             raise NoMetaqueries
         self.resolver = resolver
         self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
         self.qnames = self.qnames_to_try[:]
-        self.rdtype = rdtype
-        self.rdclass = rdclass
+        self.rdtype = the_rdtype
+        self.rdclass = the_rdclass
         self.tcp = tcp
         self.raise_on_no_answer = raise_on_no_answer
-        self.nxdomain_responses: Dict[dns.name.Name, Answer] = {}
+        self.nxdomain_responses: Dict[dns.name.Name, dns.message.QueryMessage] = {}
         # Initialize other things to help analysis tools
         self.qname = dns.name.empty
         self.nameservers: List[str] = []
@@ -660,14 +661,14 @@ class _Resolution:
         raise NXDOMAIN(qnames=self.qnames_to_try,
                        responses=self.nxdomain_responses)
 
-    def next_nameserver(self):
+    def next_nameserver(self) -> Tuple[str, int, bool, float]:
         if self.retry_with_tcp:
             assert self.nameserver is not None
             self.tcp_attempt = True
             self.retry_with_tcp = False
             return (self.nameserver, self.port, True, 0)
 
-        backoff = 0
+        backoff = 0.0
         if not self.current_nameservers:
             if len(self.nameservers) == 0:
                 # Out of things to try!
@@ -682,10 +683,12 @@ class _Resolution:
         self.tcp_attempt = self.tcp
         return (self.nameserver, self.port, self.tcp_attempt, backoff)
 
-    def query_result(self, response, ex):
+    def query_result(self, response: Optional[dns.message.Message],
+                     ex: Optional[Exception]) -> Tuple[Optional[Answer], bool]:
         #
         # returns an (answer: Answer, end_loop: bool) tuple.
         #
+        assert self.nameserver is not None
         if ex:
             # Exception during I/O or from_wire()
             assert response is None
@@ -706,6 +709,7 @@ class _Resolution:
             return (None, False)
         # We got an answer!
         assert response is not None
+        assert isinstance(response, dns.message.QueryMessage)
         rcode = response.rcode()
         if rcode == dns.rcode.NOERROR:
             try:
@@ -767,7 +771,7 @@ class BaseResolver:
     #
     # pylint: disable=attribute-defined-outside-init
 
-    def __init__(self, filename='/etc/resolv.conf', configure=True):
+    def __init__(self, filename: str='/etc/resolv.conf', configure: bool=True):
         """*filename*, a ``str`` or file object, specifying a file
         in standard /etc/resolv.conf format.  This parameter is meaningful
         only when *configure* is true and the platform is POSIX.
@@ -813,7 +817,7 @@ class BaseResolver:
         self.rotate = False
         self.ndots: Optional[int] = None
 
-    def read_resolv_conf(self, f):
+    def read_resolv_conf(self, f: Any) -> None:
         """Process *f* as a file in the /etc/resolv.conf format.  If f is
         a ``str``, it is used as the name of the file to open; otherwise it
         is treated as the file itself.
@@ -879,10 +883,10 @@ class BaseResolver:
         if len(self.nameservers) == 0:
             raise NoResolverConfiguration('no nameservers')
 
-    def read_registry(self):
+    def read_registry(self) -> None:
         """Extract resolver configuration from the Windows registry."""
         try:
-            info = dns.win32util.get_dns_info()
+            info = dns.win32util.get_dns_info()  # type: ignore
             if info.domain is not None:
                 self.domain = info.domain
             self.nameservers = info.nameservers
@@ -949,8 +953,8 @@ class BaseResolver:
                 qnames_to_try.append(abs_qname)
         return qnames_to_try
 
-    def use_tsig(self, keyring, keyname=None,
-                 algorithm=dns.tsig.default_algorithm):
+    def use_tsig(self, keyring: Any, keyname: Optional[Union[dns.name.Name, str]]=None,
+                 algorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) -> None:
         """Add a TSIG signature to each query.
 
         The parameters are passed to ``dns.message.Message.use_tsig()``;
@@ -961,8 +965,9 @@ class BaseResolver:
         self.keyname = keyname
         self.keyalgorithm = algorithm
 
-    def use_edns(self, edns=0, ednsflags=0,
-                 payload=dns.message.DEFAULT_EDNS_PAYLOAD, options=None):
+    def use_edns(self, edns: Optional[Union[int, bool]]=0, ednsflags: int=0,
+                 payload: int=dns.message.DEFAULT_EDNS_PAYLOAD,
+                 options: Optional[List[dns.edns.Option]]=None) -> None:
         """Configure EDNS behavior.
 
         *edns*, an ``int``, is the EDNS level to use.  Specifying
@@ -989,7 +994,7 @@ class BaseResolver:
         self.payload = payload
         self.ednsoptions = options
 
-    def set_flags(self, flags: int):
+    def set_flags(self, flags: int) -> None:
         """Overrides the default flags with your own.
 
         *flags*, an ``int``, the message flags to use.
@@ -1030,7 +1035,7 @@ class Resolver(BaseResolver):
     def resolve(self, qname: Union[dns.name.Name, str],
                 rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
                 rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
-                tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+                tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0,
                 lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer: # pylint: disable=arguments-differ
         """Query nameservers to find the answer to the question.
 
@@ -1136,7 +1141,7 @@ class Resolver(BaseResolver):
     def query(self, qname: Union[dns.name.Name, str],
               rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
               rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
-              tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+              tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0,
               lifetime: Optional[float]=None) -> Answer:  # pragma: no cover
         """Query nameservers to find the answer to the question.
 
@@ -1226,7 +1231,7 @@ def reset_default_resolver():
 def resolve(qname: Union[dns.name.Name, str],
             rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
             rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
-            tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+            tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0,
             lifetime: Optional[float]=None, search: Optional[bool]=None) -> Answer:  # pragma: no cover
 
     """Query nameservers to find the answer to the question.
@@ -1245,7 +1250,7 @@ def resolve(qname: Union[dns.name.Name, str],
 def query(qname: Union[dns.name.Name, str],
           rdtype: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.A,
           rdclass: Union[dns.rdataclass.RdataClass, str]=dns.rdataclass.IN,
-          tcp=False, source: Optional[str]=None, raise_on_no_answer=True, source_port=0,
+          tcp: bool=False, source: Optional[str]=None, raise_on_no_answer: bool=True, source_port: int=0,
           lifetime: Optional[float]=None) -> Answer:  # pragma: no cover
     """Query nameservers to find the answer to the question.
 
@@ -1282,7 +1287,7 @@ def canonical_name(name: Union[dns.name.Name, str]) -> dns.name.Name:
 
 
 def zone_for_name(name: Union[dns.name.Name, str], rdclass=dns.rdataclass.IN,
-                  tcp=False, resolver: Optional[Resolver]=None,
+                  tcp: bool=False, resolver: Optional[Resolver]=None,
                   lifetime: Optional[float]=None) -> dns.name.Name:
     """Find the name of the zone which contains the specified name.
 
index 4b70cf6495ad677c78c6ca5bf2bafe541f77b4a9..c25e77df65673cfa9ad4b69d72196a5769d38805 100644 (file)
@@ -27,8 +27,8 @@ ipv4_reverse_domain = dns.name.from_text('in-addr.arpa.')
 ipv6_reverse_domain = dns.name.from_text('ip6.arpa.')
 
 
-def from_address(text: str, v4_origin=ipv4_reverse_domain,
-                 v6_origin=ipv6_reverse_domain) -> dns.name.Name:
+def from_address(text: str, v4_origin: dns.name.Name=ipv4_reverse_domain,
+                 v6_origin: dns.name.Name=ipv6_reverse_domain) -> dns.name.Name:
     """Convert an IPv4 or IPv6 address in textual form into a Name object whose
     value is the reverse-map domain name of the address.
 
@@ -63,8 +63,8 @@ def from_address(text: str, v4_origin=ipv4_reverse_domain,
     return dns.name.from_text('.'.join(reversed(parts)), origin=origin)
 
 
-def to_address(name: dns.name.Name, v4_origin=ipv4_reverse_domain,
-               v6_origin=ipv6_reverse_domain) -> str:
+def to_address(name: dns.name.Name, v4_origin: dns.name.Name=ipv4_reverse_domain,
+               v6_origin: dns.name.Name=ipv6_reverse_domain) -> str:
     """Convert a reverse map domain name into textual address form.
 
     *name*, a ``dns.name.Name``, an IPv4 or IPv6 address in reverse-map name
index 3745857145cbfdb91862c8089ff40eb5c25afd1e..e14433eedbf55138de9e6bfce2f3e8fc1ddea61c 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS RRsets (an RRset is a named rdataset)"""
 
-from typing import cast, Collection, Optional, Union
+from typing import Any, cast, Collection, Optional, Union
 
 import dns.name
 import dns.rdataset
@@ -110,7 +110,7 @@ class RRset(dns.rdataset.Rdataset):
 
     # pylint: disable=arguments-differ
 
-    def to_text(self, origin: Optional[dns.name.Name]=None, relativize=True, **kw) -> str:  # type: ignore
+    def to_text(self, origin: Optional[dns.name.Name]=None, relativize: bool=True, **kw) -> str:  # type: ignore
         """Convert the RRset into DNS zone file format.
 
         See ``dns.name.Name.choose_relativity`` for more information
@@ -130,7 +130,7 @@ class RRset(dns.rdataset.Rdataset):
         return super().to_text(self.name, origin, relativize,
                                self.deleting, **kw)
 
-    def to_wire(self, file, compress: Optional[dns.name.CompressType]=None,  # type: ignore
+    def to_wire(self, file: Any, compress: Optional[dns.name.CompressType]=None,  # type: ignore
                 origin: Optional[dns.name.Name]=None, **kw) -> int:
         """Convert the RRset to wire format.
 
@@ -158,7 +158,7 @@ def from_text_list(name: Union[dns.name.Name, str], ttl: int,
                    rdtype: Union[dns.rdatatype.RdataType, str],
                    text_rdatas: Collection[str],
                    idna_codec: Optional[dns.name.IDNACodec]=None,
-                   origin: Optional[dns.name.Name]=None, relativize=True,
+                   origin: Optional[dns.name.Name]=None, relativize: bool=True,
                    relativize_to: Optional[dns.name.Name]=None) -> RRset:
     """Create an RRset with the specified name, TTL, class, and type, and with
     the specified list of rdatas in text format.
@@ -205,7 +205,7 @@ def from_text(name: Union[dns.name.Name, str], ttl: int,
                           cast(Collection[str], text_rdatas))
 
 
-def from_rdata_list(name: Union[dns.name.Name, str], ttl:int,
+def from_rdata_list(name: Union[dns.name.Name, str], ttl: int,
                     rdatas: Collection[dns.rdata.Rdata],
                     idna_codec: Optional[dns.name.IDNACodec]=None) -> RRset:
     """Create an RRset with the specified name and TTL, and with
index 138ffbf966929f4103da27af5399127b1308eb98..b4d264cbecc9e9133468b7d8860c1d241dbb5e89 100644 (file)
@@ -3,7 +3,7 @@
 """Serial Number Arthimetic from RFC 1982"""
 
 class Serial:
-    def __init__(self, value:int , bits=32):
+    def __init__(self, value: int, bits: int=32):
         self.value = value % 2 ** bits
         self.bits = bits
 
index bb94ce94c3666199ce59237a13edc45f01da4738..331bee3ce779508a7d93b0d62ecf121630d1abc5 100644 (file)
@@ -17,7 +17,7 @@
 
 """Tokenize DNS zone file format"""
 
-from typing import Optional, List, Tuple
+from typing import Any, Optional, List, Tuple, Union
 
 import io
 import sys
@@ -50,7 +50,8 @@ class Token:
     has_escape: Does the token value contain escapes?
     """
 
-    def __init__(self, ttype: int, value='', has_escape=False, comment: Optional[str]=None):
+    def __init__(self, ttype: int, value: Any='', has_escape: bool=False,
+                 comment: Optional[str]=None):
         """Initialize a token instance."""
 
         self.ttype = ttype
@@ -225,7 +226,7 @@ class Tokenizer:
     encoder/decoder is used.
     """
 
-    def __init__(self, f=sys.stdin, filename: Optional[str]=None,
+    def __init__(self, f: Any=sys.stdin, filename: Optional[str]=None,
                  idna_codec: Optional[dns.name.IDNACodec]=None):
         """Initialize a tokenizer instance.
 
@@ -297,7 +298,7 @@ class Tokenizer:
 
         return (self.filename, self.line_number)
 
-    def _unget_char(self, c):
+    def _unget_char(self, c: str) -> None:
         """Unget a character.
 
         The unget buffer for characters is only one character large; it is
@@ -313,7 +314,7 @@ class Tokenizer:
             raise UngetBufferFull  # pragma: no cover
         self.ungotten_char = c
 
-    def skip_whitespace(self):
+    def skip_whitespace(self) -> int:
         """Consume input until a non-whitespace character is encountered.
 
         The non-whitespace character is then ungotten, and the number of
@@ -333,7 +334,7 @@ class Tokenizer:
                     return skipped
             skipped += 1
 
-    def get(self, want_leading=False, want_comment=False) -> Token:
+    def get(self, want_leading: bool=False, want_comment: bool=False) -> Token:
         """Get the next token.
 
         want_leading: If True, return a WHITESPACE token if the
@@ -477,7 +478,7 @@ class Tokenizer:
 
     # Helpers
 
-    def get_int(self, base=10):
+    def get_int(self, base: int=10) -> int:
         """Read the next token and interpret it as an unsigned integer.
 
         Raises dns.exception.SyntaxError if not an unsigned integer.
@@ -507,7 +508,7 @@ class Tokenizer:
                 '%d is not an unsigned 8-bit integer' % value)
         return value
 
-    def get_uint16(self, base=10) -> int:
+    def get_uint16(self, base: int=10) -> int:
         """Read the next token and interpret it as a 16-bit unsigned
         integer.
 
@@ -526,7 +527,7 @@ class Tokenizer:
                     '%d is not an unsigned 16-bit integer' % value)
         return value
 
-    def get_uint32(self, base=10) -> int:
+    def get_uint32(self, base: int=10) -> int:
         """Read the next token and interpret it as a 32-bit unsigned
         integer.
 
@@ -541,7 +542,7 @@ class Tokenizer:
                 '%d is not an unsigned 32-bit integer' % value)
         return value
 
-    def get_uint48(self, base=10) -> int:
+    def get_uint48(self, base: int=10) -> int:
         """Read the next token and interpret it as a 48-bit unsigned
         integer.
 
@@ -556,7 +557,7 @@ class Tokenizer:
                 '%d is not an unsigned 48-bit integer' % value)
         return value
 
-    def get_string(self, max_length=None) -> str:
+    def get_string(self, max_length: Optional[int]=None) -> str:
         """Read the next token and interpret it as a string.
 
         Raises dns.exception.SyntaxError if not a string.
@@ -586,7 +587,7 @@ class Tokenizer:
             raise dns.exception.SyntaxError('expecting an identifier')
         return token.value
 
-    def get_remaining(self, max_tokens=None) -> List[Token]:
+    def get_remaining(self, max_tokens: Optional[int]=None) -> List[Token]:
         """Return the remaining tokens on the line, until an EOL or EOF is seen.
 
         max_tokens: If not None, stop after this number of tokens.
@@ -605,7 +606,7 @@ class Tokenizer:
                 break
         return tokens
 
-    def concatenate_remaining_identifiers(self, allow_empty=False) -> str:
+    def concatenate_remaining_identifiers(self, allow_empty: bool=False) -> str:
         """Read the remaining tokens on the line, which should be identifiers.
 
         Raises dns.exception.SyntaxError if there are no remaining tokens,
@@ -631,7 +632,7 @@ class Tokenizer:
         return s
 
     def as_name(self, token: Token, origin: Optional[dns.name.Name]=None,
-                relativize=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
+                relativize: bool=False, relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
         """Try to interpret the token as a DNS name.
 
         Raises dns.exception.SyntaxError if not a name.
@@ -643,7 +644,7 @@ class Tokenizer:
         name = dns.name.from_text(token.value, origin, self.idna_codec)
         return name.choose_relativity(relativize_to or origin, relativize)
 
-    def get_name(self, origin: Optional[dns.name.Name]=None, relativize=False,
+    def get_name(self, origin: Optional[dns.name.Name]=None, relativize: bool=False,
                  relativize_to: Optional[dns.name.Name]=None) -> dns.name.Name:
         """Read the next token and interpret it as a DNS name.
 
index ccb557cec3f3eeaa4bc02c557f4182d508bc9591..f48d83efcfa7d174e520cb827c6d02be9ce5b349 100644 (file)
@@ -20,7 +20,7 @@ class TransactionManager:
         """Begin a read-only transaction."""
         raise NotImplementedError  # pragma: no cover
 
-    def writer(self, replacement=False) -> 'Transaction':
+    def writer(self, replacement: bool=False) -> 'Transaction':
         """Begin a writable transaction.
 
         *replacement*, a ``bool``.  If `True`, the content of the
@@ -101,7 +101,7 @@ CheckDeleteNameType = Callable[['Transaction', dns.name.Name], None]
 
 class Transaction:
 
-    def __init__(self, manager: TransactionManager, replacement=False, read_only=False):
+    def __init__(self, manager: TransactionManager, replacement: bool=False, read_only: bool=False):
         self.manager = manager
         self.replacement = replacement
         self.read_only = read_only
@@ -133,18 +133,18 @@ class Transaction:
         rdataset = self._get_rdataset(name, rdtype, covers)
         return _ensure_immutable_rdataset(rdataset)
 
-    def get_node(self, name) -> dns.node.Node:
+    def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]:
         """Return the node at *name*, if any.
 
         Returns an immutable node or ``None``.
         """
         return _ensure_immutable_node(self._get_node(name))
 
-    def _check_read_only(self):
+    def _check_read_only(self) -> None:
         if self.read_only:
             raise ReadOnly
 
-    def add(self, *args):
+    def add(self, *args) -> None:
         """Add records.
 
         The arguments may be:
@@ -157,9 +157,9 @@ class Transaction:
         """
         self._check_ended()
         self._check_read_only()
-        return self._add(False, args)
+        self._add(False, args)
 
-    def replace(self, *args):
+    def replace(self, *args) -> None:
         """Replace the existing rdataset at the name with the specified
         rdataset, or add the specified rdataset if there was no existing
         rdataset.
@@ -178,9 +178,9 @@ class Transaction:
         """
         self._check_ended()
         self._check_read_only()
-        return self._add(True, args)
+        self._add(True, args)
 
-    def delete(self, *args):
+    def delete(self, *args) -> None:
         """Delete records.
 
         It is not an error if some of the records are not in the existing
@@ -200,9 +200,9 @@ class Transaction:
         """
         self._check_ended()
         self._check_read_only()
-        return self._delete(False, args)
+        self._delete(False, args)
 
-    def delete_exact(self, *args):
+    def delete_exact(self, *args) -> None:
         """Delete records.
 
         The arguments may be:
@@ -223,7 +223,7 @@ class Transaction:
         """
         self._check_ended()
         self._check_read_only()
-        return self._delete(True, args)
+        self._delete(True, args)
 
     def name_exists(self, name: Union[dns.name.Name, str]) -> bool:
         """Does the specified name exist?"""
@@ -232,7 +232,7 @@ class Transaction:
             name = dns.name.from_text(name, None)
         return self._name_exists(name)
 
-    def update_serial(self, value=1, relative=True, name=dns.name.empty):
+    def update_serial(self, value: int=1, relative: bool=True, name: dns.name.Name=dns.name.empty) -> None:
         """Update the serial number.
 
         *value*, an `int`, is an increment if *relative* is `True`, or the
@@ -279,7 +279,7 @@ class Transaction:
         self._check_ended()
         return self._changed()
 
-    def commit(self):
+    def commit(self) -> None:
         """Commit the transaction.
 
         Normally transactions are used as context managers and commit
@@ -292,7 +292,7 @@ class Transaction:
         """
         self._end(True)
 
-    def rollback(self):
+    def rollback(self) -> None:
         """Rollback the transaction.
 
         Normally transactions are used as context managers and commit
@@ -304,7 +304,7 @@ class Transaction:
         """
         self._end(False)
 
-    def check_put_rdataset(self, check: CheckPutRdatasetType):
+    def check_put_rdataset(self, check: CheckPutRdatasetType) -> None:
         """Call *check* before putting (storing) an rdataset.
 
         The function is called with the transaction, the name, and the rdataset.
@@ -316,7 +316,7 @@ class Transaction:
         """
         self._check_put_rdataset.append(check)
 
-    def check_delete_rdataset(self, check: CheckDeleteRdatasetType):
+    def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None:
         """Call *check* before deleting an rdataset.
 
         The function is called with the transaction, the name, the rdatatype,
index 5df0cc783a0a8250108a7d0cc675e10fb1154221..9e9b113b96f33e0843afddfca74ef272334392f5 100644 (file)
 
 """DNS Dynamic Update Support"""
 
-from typing import Any, Optional, Union
+from typing import Any, List, Optional, Union
 
 import dns.message
 import dns.name
 import dns.opcode
 import dns.rdata
 import dns.rdataclass
+import dns.rdatatype
 import dns.rdataset
 import dns.tsig
 
@@ -48,7 +49,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
     def __init__(self, zone: Optional[Union[dns.name.Name, str]]=None,
                  rdclass=dns.rdataclass.IN,
                  keyring: Optional[Any]=None, keyname: Optional[dns.name.Name]=None,
-                 keyalgorithm=dns.tsig.default_algorithm,
+                 keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm,
                  id: Optional[int]=None):
         """Initialize a new DNS Update object.
 
@@ -79,7 +80,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
             self.use_tsig(keyring, keyname, algorithm=keyalgorithm)
 
     @property
-    def zone(self):
+    def zone(self) -> List[dns.rrset.RRset]:
         """The zone section."""
         return self.sections[0]
 
@@ -88,7 +89,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
         self.sections[0] = v
 
     @property
-    def prerequisite(self):
+    def prerequisite(self) -> List[dns.rrset.RRset]:
         """The prerequisite section."""
         return self.sections[1]
 
@@ -97,7 +98,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
         self.sections[1] = v
 
     @property
-    def update(self):
+    def update(self) -> List[dns.rrset.RRset]:
         """The update section."""
         return self.sections[2]
 
@@ -156,7 +157,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                                              self.origin)
                     self._add_rr(name, ttl, rd, section=section)
 
-    def add(self, name: Union[dns.name.Name, str], *args):
+    def add(self, name: Union[dns.name.Name, str], *args) -> None:
         """Add records.
 
         The first argument is always a name.  The other
@@ -171,7 +172,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
 
         self._add(False, self.update, name, *args)
 
-    def delete(self, name: Union[dns.name.Name, str], *args):
+    def delete(self, name: Union[dns.name.Name, str], *args) -> None:
         """Delete records.
 
         The first argument is always a name.  The other
@@ -215,7 +216,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                                                  self.origin)
                         self._add_rr(name, 0, rd, dns.rdataclass.NONE)
 
-    def replace(self, name: Union[dns.name.Name, str], *args):
+    def replace(self, name: Union[dns.name.Name, str], *args) -> None:
         """Replace records.
 
         The first argument is always a name.  The other
@@ -233,7 +234,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
 
         self._add(True, self.update, name, *args)
 
-    def present(self, name: Union[dns.name.Name, str], *args):
+    def present(self, name: Union[dns.name.Name, str], *args) -> None:
         """Require that an owner name (and optionally an rdata type,
         or specific rdataset) exists as a prerequisite to the
         execution of the update.
@@ -272,7 +273,8 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                             dns.rdatatype.NONE, None,
                             True, True)
 
-    def absent(self, name: Union[dns.name.Name, str], rdtype=None):
+    def absent(self, name: Union[dns.name.Name, str],
+               rdtype: Union[dns.rdatatype.RdataType, str]=None) -> None:
         """Require that an owner name (and optionally an rdata type) does
         not exist as a prerequisite to the execution of the update."""
 
@@ -284,9 +286,9 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
                             dns.rdatatype.NONE, None,
                             True, True)
         else:
-            rdtype = dns.rdatatype.RdataType.make(rdtype)
+            the_rdtype = dns.rdatatype.RdataType.make(rdtype)
             self.find_rrset(self.prerequisite, name,
-                            dns.rdataclass.NONE, rdtype,
+                            dns.rdataclass.NONE, the_rdtype,
                             dns.rdatatype.NONE, None,
                             True, True)
 
index 02316c822a799e90de89b8f014f3fcfe26653ce1..9ed9cef6a25a5ad7e4e9503f0ca42b90fb3cc444 100644 (file)
@@ -13,7 +13,9 @@ except ImportError:  # pragma: no cover
 import dns.exception
 import dns.immutable
 import dns.name
+import dns.node
 import dns.rdataclass
+import dns.rdataset
 import dns.rdatatype
 import dns.rdtypes.ANY.SOA
 import dns.zone
@@ -40,7 +42,8 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
 
     node_factory = Node
 
-    def __init__(self, origin: Optional[Union[dns.name.Name, str]], rdclass=dns.rdataclass.IN, relativize=True,
+    def __init__(self, origin: Optional[Union[dns.name.Name, str]],
+                 rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True,
                  pruning_policy: Optional[Callable[['Zone', Version], Optional[bool]]]=None):
         """Initialize a versioned zone object.
 
@@ -106,7 +109,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
             self._readers.add(txn)
             return txn
 
-    def writer(self, replacement=False) -> Transaction:
+    def writer(self, replacement: bool=False) -> Transaction:
         event = None
         while True:
             with self._version_lock:
@@ -181,7 +184,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
               self._pruning_policy(self, self._versions[0]):
             self._versions.popleft()
 
-    def set_max_versions(self, max_versions: Optional[int]):
+    def set_max_versions(self, max_versions: Optional[int]) -> None:
         """Set a pruning policy that retains up to the specified number
         of versions
         """
@@ -195,7 +198,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
                 return len(zone._versions) > max_versions
         self.set_pruning_policy(policy)
 
-    def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]):
+    def set_pruning_policy(self, policy: Optional[Callable[['Zone', Version], Optional[bool]]]) -> None:
         """Set the pruning policy for the zone.
 
         The *policy* function takes a `Version` and returns `True` if
@@ -248,30 +251,39 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
             id = 1
         return id
 
-    def find_node(self, name, create=False):
+    def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node:
         if create:
             raise UseTransaction
         return super().find_node(name)
 
-    def delete_node(self, name):
+    def delete_node(self, name: Union[dns.name.Name, str]) -> None:
         raise UseTransaction
 
-    def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
+    def find_rdataset(self, name: Union[dns.name.Name, str],
+                      rdtype: Union[dns.rdatatype.RdataType, str],
+                      covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
+                      create: bool=False) -> dns.rdataset.Rdataset:
         if create:
             raise UseTransaction
         rdataset = super().find_rdataset(name, rdtype, covers)
         return dns.rdataset.ImmutableRdataset(rdataset)
 
-    def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
+    def get_rdataset(self, name: Union[dns.name.Name, str],
+                     rdtype: Union[dns.rdatatype.RdataType, str],
+                     covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
+                     create: bool=False) -> Optional[dns.rdataset.Rdataset]:
         if create:
             raise UseTransaction
         rdataset = super().get_rdataset(name, rdtype, covers)
-        return dns.rdataset.ImmutableRdataset(rdataset)
+        if rdataset is not None:
+            return dns.rdataset.ImmutableRdataset(rdataset)
+        else:
+            return None
 
-    def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+    def delete_rdataset(self, name: Union[dns.name.Name, str],
+                        rdtype: Union[dns.rdatatype.RdataType, str],
+                        covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None:
         raise UseTransaction
 
-    def replace_rdataset(self, name, replacement):
+    def replace_rdataset(self, name: Union[dns.name.Name, str], replacement: dns.rdataset.Rdataset) -> None:
         raise UseTransaction
index d3317a59349bb35c16b5eaa2e4e060fc72d0575a..87814eea6a2446ae8f1505054ff0108aef8d16e7 100644 (file)
@@ -9,7 +9,7 @@ import dns.exception
 import dns.name
 
 class Parser:
-    def __init__(self, wire: bytes, current=0):
+    def __init__(self, wire: bytes, current: int=0):
         self.wire = wire
         self.current = 0
         self.end = len(self.wire)
@@ -17,10 +17,10 @@ class Parser:
             self.seek(current)
         self.furthest = current
 
-    def remaining(self):
+    def remaining(self) -> int:
         return self.end - self.current
 
-    def get_bytes(self, size=int) -> bytes:
+    def get_bytes(self, sizeint) -> bytes:
         assert size >= 0
         if size > self.remaining():
             raise dns.exception.FormError
@@ -29,7 +29,7 @@ class Parser:
         self.furthest = max(self.furthest, self.current)
         return output
 
-    def get_counted_bytes(self, length_size=1) -> bytes:
+    def get_counted_bytes(self, length_size: int=1) -> bytes:
         length = int.from_bytes(self.get_bytes(length_size), 'big')
         return self.get_bytes(length)
 
@@ -57,7 +57,7 @@ class Parser:
             name = name.relativize(origin)
         return name
 
-    def seek(self, where: int):
+    def seek(self, where: int) -> None:
         # Note that seeking to the end is OK!  (If you try to read
         # after such a seek, you'll get an exception as expected.)
         if where < 0 or where > self.end:
index 618eac2f2d6b8b4459c8c6fae6f17bf2fd7cd3b8..a360deba65fd9f94750df01f7a0f46a7544a0ab9 100644 (file)
@@ -15,7 +15,7 @@
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple, Union
 
 import dns.exception
 import dns.message
@@ -51,8 +51,9 @@ class Inbound:
     State machine for zone transfers.
     """
 
-    def __init__(self, txn_manager: dns.transaction.TransactionManager, rdtype=dns.rdatatype.AXFR,
-                 serial: Optional[int]=None, is_udp=False):
+    def __init__(self, txn_manager: dns.transaction.TransactionManager,
+                 rdtype: dns.rdatatype.RdataType=dns.rdatatype.AXFR,
+                 serial: Optional[int]=None, is_udp: bool=False):
         """Initialize an inbound zone transfer.
 
         *txn_manager* is a :py:class:`dns.transaction.TransactionManager`.
@@ -245,10 +246,11 @@ class Inbound:
 
 
 def make_query(txn_manager: dns.transaction.TransactionManager, serial: Optional[int]=0,
-               use_edns=None, ednsflags: Optional[int]=None, payload: Optional[int]=None,
+               use_edns: Optional[Union[int, bool]]=None, ednsflags: Optional[int]=None, payload: Optional[int]=None,
                request_payload: Optional[int]=None, options: Optional[List[dns.edns.Option]]=None,
                keyring: Any=None, keyname: Optional[dns.name.Name]=None,
-               keyalgorithm=dns.tsig.default_algorithm) -> Tuple[dns.message.QueryMessage, Optional[int]]:
+               keyalgorithm: Union[dns.name.Name, str]=dns.tsig.default_algorithm) \
+                   -> Tuple[dns.message.QueryMessage, Optional[int]]:
     """Make an AXFR or IXFR query.
 
     *txn_manager* is a ``dns.transaction.TransactionManager``, typically a
index a1fe07a90f60e767991878edabdc8347704e5d10..91fb697014e2315d7b02c8b910be7353feb8dfde 100644 (file)
@@ -100,7 +100,7 @@ class Zone(dns.transaction.TransactionManager):
     __slots__ = ['rdclass', 'origin', 'nodes', 'relativize']
 
     def __init__(self, origin: Optional[Union[dns.name.Name, str]],
-                 rdclass=dns.rdataclass.IN, relativize=True):
+                 rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN, relativize: bool=True):
         """Initialize a zone object.
 
         *origin* is the origin of the zone.  It may be a ``dns.name.Name``,
@@ -204,7 +204,7 @@ class Zone(dns.transaction.TransactionManager):
         key = self._validate_name(key)
         return key in self.nodes
 
-    def find_node(self, name: Union[dns.name.Name, str], create=False):
+    def find_node(self, name: Union[dns.name.Name, str], create: bool=False) -> dns.node.Node:
         """Find a node in the zone, possibly creating it.
 
         *name*: the name of the node to find.
@@ -230,7 +230,7 @@ class Zone(dns.transaction.TransactionManager):
             self.nodes[name] = node
         return node
 
-    def get_node(self, name: Union[dns.name.Name, str], create=False):
+    def get_node(self, name: Union[dns.name.Name, str], create: bool=False) -> Optional[dns.node.Node]:
         """Get a node in the zone, possibly creating it.
 
         This method is like ``find_node()``, except it returns None instead
@@ -257,7 +257,7 @@ class Zone(dns.transaction.TransactionManager):
             node = None
         return node
 
-    def delete_node(self, name: Union[dns.name.Name, str]):
+    def delete_node(self, name: Union[dns.name.Name, str]) -> None:
         """Delete the specified node if it exists.
 
         *name*: the name of the node to find.
@@ -275,7 +275,7 @@ class Zone(dns.transaction.TransactionManager):
     def find_rdataset(self, name: Union[dns.name.Name, str],
                       rdtype: Union[dns.rdatatype.RdataType, str],
                       covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
-                      create=False) -> dns.rdataset.Rdataset:
+                      create: bool=False) -> dns.rdataset.Rdataset:
         """Look for an rdataset with the specified name and type in the zone,
         and return an rdataset encapsulating it.
 
@@ -310,14 +310,16 @@ class Zone(dns.transaction.TransactionManager):
         Returns a ``dns.rdataset.Rdataset``.
         """
 
-        name = self._validate_name(name)
-        rdtype = dns.rdatatype.RdataType.make(rdtype)
-        covers = dns.rdatatype.RdataType.make(covers)
-        node = self.find_node(name, create)
-        return node.find_rdataset(self.rdclass, rdtype, covers, create)
+        the_name = self._validate_name(name)
+        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+        the_covers = dns.rdatatype.RdataType.make(covers)
+        node = self.find_node(the_name, create)
+        return node.find_rdataset(self.rdclass, the_rdtype, the_covers, create)
 
-    def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
-                     create=False) -> Optional[dns.rdataset.Rdataset]:
+    def get_rdataset(self, name: Union[dns.name.Name, str],
+                     rdtype: Union[dns.rdatatype.RdataType, str],
+                     covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE,
+                     create: bool=False) -> Optional[dns.rdataset.Rdataset]:
         """Look for an rdataset with the specified name and type in the zone.
 
         This method is like ``find_rdataset()``, except it returns None instead
@@ -361,7 +363,7 @@ class Zone(dns.transaction.TransactionManager):
 
     def delete_rdataset(self, name: Union[dns.name.Name, str],
                         rdtype: Union[dns.rdatatype.RdataType, str],
-                        covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE):
+                        covers: Union[dns.rdatatype.RdataType, str]=dns.rdatatype.NONE) -> None:
         """Delete the rdataset matching *rdtype* and *covers*, if it
         exists at the node specified by *name*.
 
@@ -389,17 +391,17 @@ class Zone(dns.transaction.TransactionManager):
         RRSIG rdataset.
         """
 
-        name = self._validate_name(name)
-        rdtype = dns.rdatatype.RdataType.make(rdtype)
-        covers = dns.rdatatype.RdataType.make(covers)
-        node = self.get_node(name)
+        the_name = self._validate_name(name)
+        the_rdtype = dns.rdatatype.RdataType.make(rdtype)
+        the_covers = dns.rdatatype.RdataType.make(covers)
+        node = self.get_node(the_name)
         if node is not None:
-            node.delete_rdataset(self.rdclass, rdtype, covers)
+            node.delete_rdataset(self.rdclass, the_rdtype, the_covers)
             if len(node) == 0:
-                self.delete_node(name)
+                self.delete_node(the_name)
 
     def replace_rdataset(self, name: Union[dns.name.Name, str],
-                         replacement: dns.rdataset.Rdataset):
+                         replacement: dns.rdataset.Rdataset) -> None:
         """Replace an rdataset at name.
 
         It is not an error if there is no rdataset matching I{replacement}.
@@ -575,8 +577,8 @@ class Zone(dns.transaction.TransactionManager):
                     for rdata in rds:
                         yield (name, rds.ttl, rdata)
 
-    def to_file(self, f: Any, sorted=True, relativize=True, nl: Optional[str]=None,
-                want_comments=False, want_origin=False):
+    def to_file(self, f: Any, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None,
+                want_comments: bool=False, want_origin: bool=False):
         """Write a zone to a file.
 
         *f*, a file or `str`.  If *f* is a string, it is treated
@@ -653,8 +655,8 @@ class Zone(dns.transaction.TransactionManager):
                     f.write(l)
                     f.write(nl)
 
-    def to_text(self, sorted=True, relativize=True, nl: Optional[str]=None,
-                want_comments=False, want_origin=False):
+    def to_text(self, sorted: bool=True, relativize: bool=True, nl: Optional[str]=None,
+                want_comments: bool=False, want_origin: bool=False):
         """Return a zone's text as though it were written to a file.
 
         *sorted*, a ``bool``.  If True, the default, then the file
@@ -687,7 +689,7 @@ class Zone(dns.transaction.TransactionManager):
         temp_buffer.close()
         return return_value
 
-    def check_origin(self):
+    def check_origin(self) -> None:
         """Do some simple checking of the zone's origin.
 
         Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
@@ -699,6 +701,7 @@ class Zone(dns.transaction.TransactionManager):
         if self.relativize:
             name = dns.name.empty
         else:
+            assert self.origin is not None
             name = self.origin
         if self.get_rdataset(name, dns.rdatatype.SOA) is None:
             raise NoSOA
@@ -758,7 +761,8 @@ class Zone(dns.transaction.TransactionManager):
                     hasher.update(rrnamebuf + rrfixed + rrlen + rdata)
         return hasher.digest()
 
-    def compute_digest(self, hash_algorithm: DigestHashAlgorithm, scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD:
+    def compute_digest(self, hash_algorithm: DigestHashAlgorithm,
+                       scheme=DigestScheme.SIMPLE) -> dns.rdtypes.ANY.ZONEMD.ZONEMD:
         serial = self.get_soa().serial
         digest = self._compute_digest(hash_algorithm, scheme)
         return dns.rdtypes.ANY.ZONEMD.ZONEMD(self.rdclass,
@@ -766,11 +770,12 @@ class Zone(dns.transaction.TransactionManager):
                                              serial, scheme, hash_algorithm,
                                              digest)
 
-    def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None):
+    def verify_digest(self, zonemd: Optional[dns.rdtypes.ANY.ZONEMD.ZONEMD]=None) -> None:
         digests: Union[dns.rdataset.Rdataset, List[dns.rdtypes.ANY.ZONEMD.ZONEMD]]
         if zonemd:
             digests = [zonemd]
         else:
+            assert self.origin is not None
             rds = self.get_rdataset(self.origin, dns.rdatatype.ZONEMD)
             if rds is None:
                 raise NoDigest
@@ -791,7 +796,7 @@ class Zone(dns.transaction.TransactionManager):
         return Transaction(self, False,
                            Version(self, 1, self.nodes, self.origin))
 
-    def writer(self, replacement=False) -> 'Transaction':
+    def writer(self, replacement: bool=False) -> 'Transaction':
         txn = Transaction(self, replacement)
         txn._setup_version()
         return txn
@@ -852,25 +857,28 @@ class ImmutableVersionedNode(VersionedNode):
             [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
         )
 
-    def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                      create=False):
+    def find_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType,
+                      covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                      create: bool=False) -> dns.rdataset.Rdataset:
         if create:
             raise TypeError("immutable")
         return super().find_rdataset(rdclass, rdtype, covers, False)
 
-    def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
-                     create=False):
+    def get_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType,
+                     covers: dns.rdatatype.RdataType=dns.rdatatype.NONE,
+                     create: bool=False) -> Optional[dns.rdataset.Rdataset]:
         if create:
             raise TypeError("immutable")
         return super().get_rdataset(rdclass, rdtype, covers, False)
 
-    def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+    def delete_rdataset(self, rdclass: dns.rdataclass.RdataClass, rdtype: dns.rdatatype.RdataType,
+                        covers: dns.rdatatype.RdataType=dns.rdatatype.NONE) -> None:
         raise TypeError("immutable")
 
-    def replace_rdataset(self, replacement):
+    def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
         raise TypeError("immutable")
 
-    def is_immutable(self):
+    def is_immutable(self) -> bool:
         return True
 
 
@@ -920,7 +928,7 @@ class Version:
 
 
 class WritableVersion(Version):
-    def __init__(self, zone: Zone, replacement=False):
+    def __init__(self, zone: Zone, replacement: bool=False):
         # The zone._versions_lock must be held by our caller in a versioned
         # zone.
         id = zone._get_next_version_id()
@@ -958,18 +966,18 @@ class WritableVersion(Version):
         else:
             return node
 
-    def delete_node(self, name: dns.name.Name):
+    def delete_node(self, name: dns.name.Name) -> None:
         name = self._validate_name(name)
         if name in self.nodes:
             del self.nodes[name]
             self.changed.add(name)
 
-    def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset):
+    def put_rdataset(self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset) -> None:
         node = self._maybe_cow(name)
         node.replace_rdataset(rdataset)
 
     def delete_rdataset(self, name: dns.name.Name, rdtype:dns.rdatatype.RdataType,
-                        covers: dns.rdatatype.RdataType):
+                        covers: dns.rdatatype.RdataType) -> None:
         node = self._maybe_cow(name)
         node.delete_rdataset(self.zone.rdclass, rdtype, covers)
         if len(node) == 0:
@@ -1077,9 +1085,9 @@ class Transaction(dns.transaction.Transaction):
 
 
 def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
-              rdclass=dns.rdataclass.IN,
-              relativize=True, zone_factory=Zone, filename: Optional[str]=None,
-              allow_include=False, check_origin=True,
+              rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN,
+              relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None,
+              allow_include: bool=False, check_origin: bool=True,
               idna_codec: Optional[dns.name.IDNACodec]=None) -> Zone:
     """Build a zone object from a zone file format string.
 
@@ -1145,9 +1153,9 @@ def from_text(text: str, origin: Optional[Union[dns.name.Name, str]]=None,
 
 
 def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
-              rdclass=dns.rdataclass.IN,
-              relativize=True, zone_factory=Zone, filename: Optional[str]=None,
-              allow_include=True, check_origin=True) -> Zone:
+              rdclass: dns.rdataclass.RdataClass=dns.rdataclass.IN,
+              relativize: bool=True, zone_factory: Any=Zone, filename: Optional[str]=None,
+              allow_include: bool=True, check_origin: bool=True) -> Zone:
     """Read a zone file and build a zone object.
 
     *f*, a file or ``str``.  If *f* is a string, it is treated
@@ -1200,7 +1208,7 @@ def from_file(f: Any, origin: Optional[Union[dns.name.Name, str]]=None,
     assert False  # make mypy happy  lgtm[py/unreachable-statement]
 
 
-def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
+def from_xfr(xfr: Any, zone_factory=Zone, relativize: bool=True, check_origin: bool=True) -> Zone:
     """Convert the output of a zone transfer generator into a zone object.
 
     *xfr*, a generator of ``dns.message.Message`` objects, typically
@@ -1221,6 +1229,8 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
 
     Raises ``KeyError`` if there is no origin node.
 
+    Raises ``ValueError`` if no messages are yielded by the generator.
+
     Returns a subclass of ``dns.zone.Zone``.
     """
 
@@ -1243,6 +1253,8 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
             zrds.update_ttl(rrset.ttl)
             for rd in rrset:
                 zrds.add(rd)
+    if z is None:
+        raise ValueError('empty transfer')
     if check_origin:
         z.check_origin()
     return z
index 605131dcce6615fac6b5f87b2c829d0cdad27db1..479f0d63ae20b6c8388eaff850a35052581615df 100644 (file)
@@ -66,7 +66,7 @@ def _check_cname_and_other_data(txn, name, rdataset):
 SavedStateType = Tuple[dns.tokenizer.Tokenizer,
                        Optional[dns.name.Name],   # current_origin
                        Optional[dns.name.Name],   # last_name
-                       Optional[str],             # current_file
+                       Optional[Any],             # current_file
                        int,                       # last_ttl
                        bool,                      # last_ttl_known
                        int,                       # default_ttl
@@ -78,8 +78,8 @@ class Reader:
     """Read a DNS zone file into a transaction."""
 
     def __init__(self, tok: dns.tokenizer.Tokenizer, rdclass: dns.rdataclass.RdataClass,
-                 txn: dns.transaction.Transaction, allow_include=False,
-                 allow_directives=True, force_name: Optional[dns.name.Name]=None,
+                 txn: dns.transaction.Transaction, allow_include: bool=False,
+                 allow_directives: bool=True, force_name: Optional[dns.name.Name]=None,
                  force_ttl: Optional[int]=None,
                  force_rdclass: Optional[dns.rdataclass.RdataClass]=None,
                  force_rdtype: Optional[dns.rdatatype.RdataType]=None,
@@ -102,7 +102,7 @@ class Reader:
         self.zone_rdclass = rdclass
         self.txn = txn
         self.saved_state: List[SavedStateType] = []
-        self.current_file = None
+        self.current_file: Optional[Any] = None
         self.allow_include = allow_include
         self.allow_directives = allow_directives
         self.force_name = force_name
@@ -385,7 +385,7 @@ class Reader:
 
             self.txn.add(name, ttl, rd)
 
-    def read(self):
+    def read(self) -> None:
         """Read a DNS zone file and build a zone object.
 
         @raises dns.zone.NoSOA: No SOA RR was found at the zone origin
@@ -433,11 +433,9 @@ class Reader:
                         token = self.tok.get()
                         filename = token.value
                         token = self.tok.get()
+                        new_origin: Optional[dns.name.Name]
                         if token.is_identifier():
-                            new_origin =\
-                                dns.name.from_text(token.value,
-                                                   self.current_origin,
-                                                   self.tok.idna_codec)
+                            new_origin = dns.name.from_text(token.value, self.current_origin, self.tok.idna_codec)
                             self.tok.get_eol()
                         elif not token.is_eol_or_eof():
                             raise dns.exception.SyntaxError(
@@ -572,7 +570,7 @@ def read_rrsets(text: Any,
                 default_ttl: Optional[Union[int, str]]=None,
                 idna_codec: Optional[dns.name.IDNACodec]=None,
                 origin: Optional[Union[dns.name.Name, str]]=dns.name.root,
-                relativize=False) -> List[dns.rrset.RRset]:
+                relativize: bool=False) -> List[dns.rrset.RRset]:
     """Read one or more rrsets from the specified text, possibly subject
     to restrictions.