]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Pyright lint (#1147)
authorBob Halley <halley@dnspython.org>
Tue, 15 Oct 2024 19:53:10 +0000 (12:53 -0700)
committerGitHub <noreply@github.com>
Tue, 15 Oct 2024 19:53:10 +0000 (12:53 -0700)
Start of pyright linting (more to come in future work).

40 files changed:
Makefile
dns/_immutable_ctx.py
dns/asyncquery.py
dns/asyncresolver.py
dns/dnssec.py
dns/e164.py
dns/edns.py
dns/entropy.py
dns/enum.py
dns/grange.py
dns/ipv4.py
dns/ipv6.py
dns/message.py
dns/name.py
dns/opcode.py
dns/query.py
dns/rcode.py
dns/rdata.py
dns/rdataset.py
dns/rdtypes/dnskeybase.py
dns/rdtypes/dsbase.py
dns/rdtypes/euibase.py
dns/rdtypes/svcbbase.py
dns/rdtypes/tlsabase.py
dns/rdtypes/txtbase.py
dns/rdtypes/util.py
dns/renderer.py
dns/resolver.py
dns/reversename.py
dns/rrset.py
dns/transaction.py
dns/tsig.py
dns/tsigkeyring.py
dns/ttl.py
dns/update.py
dns/versioned.py
dns/xfr.py
dns/zone.py
dns/zonefile.py
pyproject.toml

index 11edb489a5be7137cccb366ddde1fa98f2d245ac..ecb0115e5a688b314aa2c249e6c62da5e0567bb2 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -38,6 +38,9 @@ check: test
 type:
        python -m mypy --install-types --non-interactive --disallow-incomplete-defs dns
 
+pyright:
+       pyright dns
+
 lint:
        pylint dns
 
index ae7a33bf3a5f92252a5191b23086fd62e431e785..b3d72deef2a60f0ffcfd66688388816a08893558 100644 (file)
@@ -41,7 +41,7 @@ def _immutable_init(f):
         finally:
             _in__init__.reset(previous)
 
-    nf.__signature__ = inspect.signature(f)
+    nf.__signature__ = inspect.signature(f)  # pyright: ignore
     return nf
 
 
index efad0fd7594ad4bdb47b87696931db98a174d4ea..883e8afc06ef10f576a1d5faa5054f1a97e641df 100644 (file)
@@ -36,6 +36,8 @@ import dns.rcode
 import dns.rdataclass
 import dns.rdatatype
 import dns.transaction
+import dns.tsig
+import dns.xfr
 from dns._asyncbackend import NullContext
 from dns.query import (
     BadResponse,
@@ -219,9 +221,9 @@ async def udp(
             dtuple = None
         cm = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, dtuple)
     async with cm as s:
-        await send_udp(s, wire, destination, expiration)
+        await send_udp(s, wire, destination, expiration)  # pyright: ignore
         (r, received_time, _) = await receive_udp(
-            s,
+            s,  # pyright: ignore
             destination,
             expiration,
             ignore_unexpected,
@@ -424,9 +426,14 @@ async def tcp(
             af, socket.SOCK_STREAM, 0, stuple, dtuple, timeout
         )
     async with cm as s:
-        await send_tcp(s, wire, expiration)
+        await send_tcp(s, wire, expiration)  # pyright: ignore
         (r, received_time) = await receive_tcp(
-            s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
+            s,  # pyright: ignore
+            expiration,
+            one_rr_per_rrset,
+            q.keyring,
+            q.mac,
+            ignore_trailing,
         )
         r.time = received_time - begin_time
         if not q.is_response(r):
@@ -469,7 +476,9 @@ async def tls(
         cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
     else:
         if ssl_context is None:
-            ssl_context = _make_dot_ssl_context(server_hostname, verify)
+            ssl_context = _make_dot_ssl_context(
+                server_hostname, verify
+            )  # pyright: ignore
         af = dns.inet.af_for_address(where)
         stuple = _source_tuple(af, source, source_port)
         dtuple = (where, port)
@@ -505,8 +514,8 @@ async def tls(
 
 
 def _maybe_get_resolver(
-    resolver: Optional["dns.asyncresolver.Resolver"],
-) -> "dns.asyncresolver.Resolver":
+    resolver: Optional["dns.asyncresolver.Resolver"],  # pyright: ignore
+) -> "dns.asyncresolver.Resolver":  # pyright: ignore
     # We need a separate method for this to avoid overriding the global
     # variable "dns" with the as-yet undefined local variable "dns"
     # in https().
@@ -532,7 +541,7 @@ async def https(
     post: bool = True,
     verify: Union[bool, str] = True,
     bootstrap_address: Optional[str] = None,
-    resolver: Optional["dns.asyncresolver.Resolver"] = None,
+    resolver: Optional["dns.asyncresolver.Resolver"] = None,  # pyright: ignore
     family: int = socket.AF_UNSPEC,
     http_version: HTTPVersion = HTTPVersion.DEFAULT,
 ) -> dns.message.Message:
@@ -552,13 +561,13 @@ async def https(
         af = dns.inet.af_for_address(where)
     except ValueError:
         af = None
+    # we bind url and then override as pyright can't figure out all paths bind.
+    url = where
     if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = f"https://{where}:{port}{path}"
         elif af == socket.AF_INET6:
             url = f"https://[{where}]:{port}{path}"
-    else:
-        url = where
 
     extensions = {}
     if bootstrap_address is None:
@@ -577,8 +586,10 @@ async def https(
     ):
         if bootstrap_address is None:
             resolver = _maybe_get_resolver(resolver)
-            assert parsed.hostname is not None  # for mypy
-            answers = await resolver.resolve_name(parsed.hostname, family)
+            assert parsed.hostname is not None  # pyright: ignore
+            answers = await resolver.resolve_name(  # pyright: ignore
+                parsed.hostname, family  # pyright: ignore
+            )
             bootstrap_address = random.choice(list(answers.addresses()))
         return await _http3(
             q,
@@ -597,7 +608,7 @@ async def https(
     if not have_doh:
         raise NoDOH  # pragma: no cover
     # pylint: disable=possibly-used-before-assignment
-    if client and not isinstance(client, httpx.AsyncClient):
+    if client and not isinstance(client, httpx.AsyncClient):  # pyright: ignore
         raise ValueError("session parameter must be an httpx.AsyncClient")
     # pylint: enable=possibly-used-before-assignment
 
@@ -630,7 +641,9 @@ async def https(
             family=family,
         )
 
-        cm = httpx.AsyncClient(http1=h1, http2=h2, verify=verify, transport=transport)
+        cm = httpx.AsyncClient(  # pyright: ignore
+            http1=h1, http2=h2, verify=verify, transport=transport
+        )
 
     async with cm as the_client:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
@@ -643,7 +656,7 @@ async def https(
                 }
             )
             response = await backend.wait_for(
-                the_client.post(
+                the_client.post(  # pyright: ignore
                     url,
                     headers=headers,
                     content=wire,
@@ -655,7 +668,7 @@ async def https(
             wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
             twire = wire.decode()  # httpx does a repr() if we give it bytes
             response = await backend.wait_for(
-                the_client.get(
+                the_client.get(  # pyright: ignore
                     url,
                     headers=headers,
                     params={"dns": twire},
@@ -785,9 +798,11 @@ async def quic(
             server_name=server_hostname,
         ) as the_manager:
             if not connection:
-                the_connection = the_manager.connect(where, port, source, source_port)
+                the_connection = the_manager.connect(  # pyright: ignore
+                    where, port, source, source_port
+                )
             (start, expiration) = _compute_times(timeout)
-            stream = await the_connection.make_stream(timeout)
+            stream = await the_connection.make_stream(timeout)  # pyright: ignore
             async with stream:
                 await stream.send(wire, True)
                 wire = await stream.receive(_remaining(expiration))
@@ -829,6 +844,7 @@ async def _inbound_xfr(
     with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
         done = False
         tsig_ctx = None
+        r: Optional[dns.message.Message] = None
         while not done:
             (_, mexpiration) = _compute_times(timeout)
             if mexpiration is None or (
@@ -837,11 +853,11 @@ async def _inbound_xfr(
                 mexpiration = expiration
             if is_udp:
                 timeout = _timeout(mexpiration)
-                (rwire, _) = await udp_sock.recvfrom(65535, timeout)
+                (rwire, _) = await udp_sock.recvfrom(65535, timeout)  # pyright: ignore
             else:
-                ldata = await _read_exactly(tcp_sock, 2, mexpiration)
+                ldata = await _read_exactly(tcp_sock, 2, mexpiration)  # pyright: ignore
                 (l,) = struct.unpack("!H", ldata)
-                rwire = await _read_exactly(tcp_sock, l, mexpiration)
+                rwire = await _read_exactly(tcp_sock, l, mexpiration)  # pyright: ignore
             r = dns.message.from_wire(
                 rwire,
                 keyring=query.keyring,
@@ -855,7 +871,7 @@ async def _inbound_xfr(
             done = inbound.process_message(r)
             yield r
             tsig_ctx = r.tsig_ctx
-        if query.keyring and not r.had_tsig:
+        if query.keyring and r is not None and not r.had_tsig:
             raise dns.exception.FormError("missing TSIG")
 
 
@@ -896,8 +912,13 @@ async def inbound_xfr(
         )
         async with s:
             try:
-                async for _ in _inbound_xfr(
-                    txn_manager, s, query, serial, timeout, expiration
+                async for _ in _inbound_xfr(  # pyright: ignore
+                    txn_manager,
+                    s,
+                    query,
+                    serial,
+                    timeout,
+                    expiration,  # pyright: ignore
                 ):
                     pass
                 return
@@ -909,5 +930,7 @@ async def inbound_xfr(
         af, socket.SOCK_STREAM, 0, stuple, dtuple, _timeout(expiration)
     )
     async with s:
-        async for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
+        async for _ in _inbound_xfr(  # pyright: ignore
+            txn_manager, s, query, serial, timeout, expiration  # pyright: ignore
+        ):
             pass
index 8f5e062a9ee5c1bf19acf363da7344b8d393e32a..1df89e6ca4d83836bb7a1ecbf0f3ac0d24ad9dea 100644 (file)
@@ -25,11 +25,14 @@ import dns._ddr
 import dns.asyncbackend
 import dns.asyncquery
 import dns.exception
+import dns.inet
 import dns.name
+import dns.nameserver
 import dns.query
 import dns.rdataclass
 import dns.rdatatype
 import dns.resolver  # lgtm[py/import-and-import-from]
+import dns.reversename
 
 # import some resolver symbols for brevity
 from dns.resolver import NXDOMAIN, NoAnswer, NoRootSOA, NotAbsolute
@@ -426,7 +429,7 @@ async def make_resolver_at(
         answers = await resolver.resolve_name(where, family)
         for address in answers.addresses():
             nameservers.append(dns.nameserver.Do53Nameserver(address, port))
-    res = dns.asyncresolver.Resolver(configure=False)
+    res = Resolver(configure=False)
     res.nameservers = nameservers
     return res
 
index b69d0a1262ee28e4325c017e48d5188fb8961bbe..76d728a5f1e08eb3e9bc647327792f9aa32d6466 100644 (file)
@@ -135,16 +135,16 @@ class Policy:
     def __init__(self):
         pass
 
-    def ok_to_sign(self, _: DNSKEY) -> bool:  # pragma: no cover
+    def ok_to_sign(self, key: DNSKEY) -> bool:  # pragma: no cover
         return False
 
-    def ok_to_validate(self, _: DNSKEY) -> bool:  # pragma: no cover
+    def ok_to_validate(self, key: DNSKEY) -> bool:  # pragma: no cover
         return False
 
-    def ok_to_create_ds(self, _: DSDigest) -> bool:  # pragma: no cover
+    def ok_to_create_ds(self, algorithm: DSDigest) -> bool:  # pragma: no cover
         return False
 
-    def ok_to_validate_ds(self, _: DSDigest) -> bool:  # pragma: no cover
+    def ok_to_validate_ds(self, algorithm: DSDigest) -> bool:  # pragma: no cover
         return False
 
 
@@ -587,7 +587,7 @@ def _sign(
         signature=b"",
     )
 
-    data = dns.dnssec._make_rrsig_signature_data(rrset, rrsig_template, origin)
+    data = _make_rrsig_signature_data(rrset, rrsig_template, origin)
 
     # pylint: disable=possibly-used-before-assignment
     if isinstance(private_key, GenericPrivateKey):
@@ -979,7 +979,7 @@ def default_rrset_signer(
         keys = zsks
 
     for private_key, dnskey in keys:
-        rrsig = dns.dnssec.sign(
+        rrsig = sign(
             rrset=rrset,
             private_key=private_key,
             dnskey=dnskey,
index 453736d40806838131569785f5eb2c65b8a2c310..dd9aebc8a5dc474b3524aa80373af2eb5de43dfc 100644 (file)
@@ -108,7 +108,7 @@ def query(
     for domain in domains:
         if isinstance(domain, str):
             domain = dns.name.from_text(domain)
-        qname = dns.e164.from_e164(number, domain)
+        qname = from_e164(number, domain)
         try:
             return resolver.resolve(qname, "NAPTR")
         except dns.resolver.NXDOMAIN as e:
index c36036864c5a2850b9fd5dcde8792a211c5acca6..8db1d2e0f83755a9d053798623a58dfd87264f48 100644 (file)
@@ -25,6 +25,9 @@ from typing import Any, Dict, Optional, Union
 
 import dns.enum
 import dns.inet
+import dns.ipv4
+import dns.ipv6
+import dns.name
 import dns.rdata
 import dns.wire
 
@@ -81,14 +84,14 @@ class Option:
     def to_text(self) -> str:
         raise NotImplementedError  # pragma: no cover
 
-    def to_generic(self) -> "dns.edns.GenericOption":
+    def to_generic(self) -> "GenericOption":
         """Creates a dns.edns.GenericOption equivalent of this rdata.
 
         Returns a ``dns.edns.GenericOption``.
         """
         wire = self.to_wire()
         assert wire is not None  # for mypy
-        return dns.edns.GenericOption(self.otype, wire)
+        return GenericOption(self.otype, wire)
 
     @classmethod
     def from_wire_parser(cls, otype: OptionType, parser: "dns.wire.Parser") -> "Option":
@@ -175,7 +178,7 @@ class GenericOption(Option):  # lgtm[py/missing-equals]
     def to_text(self) -> str:
         return "Generic %d" % self.otype
 
-    def to_generic(self) -> "dns.edns.GenericOption":
+    def to_generic(self) -> "GenericOption":
         return self
 
     @classmethod
@@ -444,7 +447,7 @@ class NSIDOption(Option):
 
 class CookieOption(Option):
     def __init__(self, client: bytes, server: bytes):
-        super().__init__(dns.edns.OptionType.COOKIE)
+        super().__init__(OptionType.COOKIE)
         self.client = client
         self.server = server
         if len(client) != 8:
index 4dcdc6272ca3a670b1616f4c95f2a18b1803bc82..45e79e3dde8fd4a0ce4244ae5773c9fd1099dc55 100644 (file)
@@ -20,7 +20,7 @@ import os
 import random
 import threading
 import time
-from typing import Any, Optional
+from typing import Any, Optional, Union
 
 
 class EntropyPool:
@@ -45,7 +45,7 @@ class EntropyPool:
             self.seeded = False
             self.seed_pid = 0
 
-    def _stir(self, entropy: bytes) -> None:
+    def _stir(self, entropy: Union[bytes, bytearray]) -> None:
         for c in entropy:
             if self.pool_index == self.hash_len:
                 self.pool_index = 0
@@ -53,7 +53,7 @@ class EntropyPool:
             self.pool[self.pool_index] ^= b
             self.pool_index += 1
 
-    def stir(self, entropy: bytes) -> None:
+    def stir(self, entropy: Union[bytes, bytearray]) -> None:
         with self.lock:
             self._stir(entropy)
 
index 71461f1776f3990311f656cb37f6aab68e0b9f71..d7f2618702435d6127eb9ec32060e87ca0fcb6b0 100644 (file)
@@ -16,7 +16,9 @@
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
 import enum
-from typing import Type, TypeVar, Union
+from typing import Any, Optional, Type, TypeVar, Union
+
+import dns.exception
 
 TIntEnum = TypeVar("TIntEnum", bound="IntEnum")
 
@@ -25,9 +27,9 @@ class IntEnum(enum.IntEnum):
     @classmethod
     def _missing_(cls, value):
         cls._check_value(value)
-        val = int.__new__(cls, value)
+        val = int.__new__(cls, value)  # pyright: ignore
         val._name_ = cls._extra_to_text(value, None) or f"{cls._prefix()}{value}"
-        val._value_ = value
+        val._value_ = value  # pyright: ignore
         return val
 
     @classmethod
@@ -53,10 +55,7 @@ class IntEnum(enum.IntEnum):
         if text.startswith(prefix) and text[len(prefix) :].isdigit():
             value = int(text[len(prefix) :])
             cls._check_value(value)
-            try:
-                return cls(value)
-            except ValueError:
-                return value
+            return cls(value)
         raise cls._unknown_exception_class()
 
     @classmethod
@@ -100,11 +99,11 @@ class IntEnum(enum.IntEnum):
         return cls.__name__.lower()
 
     @classmethod
-    def _prefix(cls):
+    def _prefix(cls) -> str:
         return ""
 
     @classmethod
-    def _extra_from_text(cls, text):  # pylint: disable=W0613
+    def _extra_from_text(cls, text: str) -> Optional[Any]:  # pylint: disable=W0613
         return None
 
     @classmethod
@@ -112,5 +111,5 @@ class IntEnum(enum.IntEnum):
         return current_text
 
     @classmethod
-    def _unknown_exception_class(cls):
+    def _unknown_exception_class(cls) -> Type[Exception]:
         return ValueError
index a967ca41c63ac99d237619d884da4f1b5d0bc21e..8d366dc8d6fa53a8cb5f5deac78d1ac56390c69a 100644 (file)
@@ -19,7 +19,7 @@
 
 from typing import Tuple
 
-import dns
+import dns.exception
 
 
 def from_text(text: str) -> Tuple[int, int, int]:
index 65ee69c0d7a4f6ce949edc13d2d2d866889ec5ab..21f529614fe90922e5bfcd0acf428834ae7ddd5b 100644 (file)
@@ -74,4 +74,4 @@ def canonicalize(text: Union[str, bytes]) -> str:
     """
     # Note that inet_aton() only accepts canonial form, but we still run through
     # inet_ntoa() to ensure the output is a str.
-    return dns.ipv4.inet_ntoa(dns.ipv4.inet_aton(text))
+    return inet_ntoa(inet_aton(text))
index 4dd1d1cade2ddc126234825a75eef66a7c574896..4f27b415a8f33cb7349192db286c519b7d76ccb7 100644 (file)
@@ -214,4 +214,4 @@ def canonicalize(text: Union[str, bytes]) -> str:
 
     Raises ``dns.exception.SyntaxError`` if the text is not valid.
     """
-    return dns.ipv6.inet_ntoa(dns.ipv6.inet_aton(text))
+    return inet_ntoa(inet_aton(text))
index e978a0a2e1d8ce681b4d353dc9d4dcb74c97006d..fc2a0e72102881d7bc33e295c9a916567d51a2cc 100644 (file)
@@ -35,9 +35,11 @@ import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
 import dns.rdtypes.ANY.OPT
+import dns.rdtypes.ANY.SOA
 import dns.rdtypes.ANY.TSIG
 import dns.renderer
 import dns.rrset
+import dns.tokenizer
 import dns.tsig
 import dns.ttl
 import dns.wire
@@ -529,7 +531,8 @@ class Message:
         # worry about that for now.  We also don't worry if there is an existing padding
         # option, as it is unlikely and probably harmless, as the worst case is that we
         # may add another, and this seems to be legal.
-        for option in self.opt[0].options:
+        opt_rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+        for option in opt_rdata.options:
             wire = option.to_wire()
             # We add 4 here to account for the option type and length
             size += len(wire) + 4
@@ -753,21 +756,24 @@ class Message:
     @property
     def keyalgorithm(self) -> Optional[dns.name.Name]:
         if self.tsig:
-            return self.tsig[0].algorithm
+            rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+            return rdata.algorithm
         else:
             return None
 
     @property
     def mac(self) -> Optional[bytes]:
         if self.tsig:
-            return self.tsig[0].mac
+            rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+            return rdata.mac
         else:
             return None
 
     @property
     def tsig_error(self) -> Optional[int]:
         if self.tsig:
-            return self.tsig[0].error
+            rdata = cast(dns.rdtypes.ANY.TSIG.TSIG, self.tsig[0])
+            return rdata.error
         else:
             return None
 
@@ -857,14 +863,16 @@ class Message:
     @property
     def payload(self) -> int:
         if self.opt:
-            return self.opt[0].payload
+            rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+            return rdata.payload
         else:
             return 0
 
     @property
     def options(self) -> Tuple:
         if self.opt:
-            return self.opt[0].options
+            rdata = cast(dns.rdtypes.ANY.OPT.OPT, self.opt[0])
+            return rdata.options
         else:
             return ()
 
@@ -1051,7 +1059,8 @@ class QueryMessage(Message):
                     srrset = self.find_rrset(
                         self.authority, auname, question.rdclass, dns.rdatatype.SOA
                     )
-                    min_ttl = min(min_ttl, srrset.ttl, srrset[0].minimum)
+                    srdata = cast(dns.rdtypes.ANY.SOA.SOA, srrset[0])
+                    min_ttl = min(min_ttl, srrset.ttl, srdata.minimum)
                     break
                 except KeyError:
                     try:
@@ -1091,7 +1100,7 @@ def _message_factory_from_opcode(opcode):
         return QueryMessage
     elif opcode == dns.opcode.UPDATE:
         _maybe_import_update()
-        return dns.update.UpdateMessage
+        return dns.update.UpdateMessage  # pyright: ignore
     else:
         return Message
 
@@ -1195,7 +1204,10 @@ class _WireReader:
                 else:
                     with self.parser.restrict_to(rdlen):
                         rd = dns.rdata.from_wire_parser(
-                            rdclass, rdtype, self.parser, self.message.origin
+                            rdclass,  # pyright: ignore
+                            rdtype,
+                            self.parser,
+                            self.message.origin,
                         )
                     covers = rd.covers()
                 if self.message.xfr and rdtype == dns.rdatatype.SOA:
@@ -1203,12 +1215,13 @@ class _WireReader:
                 if rdtype == dns.rdatatype.OPT:
                     self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
                 elif rdtype == dns.rdatatype.TSIG:
+                    trd = cast(dns.rdtypes.ANY.TSIG.TSIG, rd)
                     if self.keyring is None or self.keyring is True:
                         raise UnknownTSIGKey("got signed message without keyring")
                     elif isinstance(self.keyring, dict):
                         key = self.keyring.get(absolute_name)
                         if isinstance(key, bytes):
-                            key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+                            key = dns.tsig.Key(absolute_name, key, trd.algorithm)
                     elif callable(self.keyring):
                         key = self.keyring(self.message, absolute_name)
                     else:
@@ -1233,7 +1246,7 @@ class _WireReader:
                     rrset = self.message.find_rrset(
                         section,
                         name,
-                        rdclass,
+                        rdclass,  # pyright: ignore
                         rdtype,
                         covers,
                         deleting,
@@ -1414,14 +1427,14 @@ class _TextReader:
 
     def __init__(
         self,
-        text,
-        idna_codec,
-        one_rr_per_rrset=False,
-        origin=None,
-        relativize=True,
-        relativize_to=None,
+        text: str,
+        idna_codec: Optional[dns.name.IDNACodec],
+        one_rr_per_rrset: bool = False,
+        origin: Optional[dns.name.Name] = None,
+        relativize: bool = True,
+        relativize_to: Optional[dns.name.Name] = None,
     ):
-        self.message = None
+        self.message: Optional[Message] = None  # mypy: ignore
         self.tok = dns.tokenizer.Tokenizer(text, idna_codec=idna_codec)
         self.last_name = None
         self.one_rr_per_rrset = one_rr_per_rrset
@@ -1480,6 +1493,7 @@ class _TextReader:
     def _question_line(self, section_number):
         """Process one line from the text format question section."""
 
+        assert self.message is not None
         section = self.message.sections[section_number]
         token = self.tok.get(want_leading=True)
         if not token.is_whitespace():
@@ -1517,6 +1531,7 @@ class _TextReader:
         additional data sections.
         """
 
+        assert self.message is not None
         section = self.message.sections[section_number]
         # Name
         token = self.tok.get(want_leading=True)
@@ -1910,6 +1925,8 @@ def make_response(
                     pad = 468
         response.use_edns(0, 0, our_payload, query.payload, pad=pad)
     if query.had_tsig:
+        assert query.mac is not None
+        assert query.keyalgorithm is not None
         response.use_tsig(
             query.keyring,
             query.keyname,
index f79f0d0f6f16f95228ae111b85d6a9e368b98345..4861e11cf413671b9687083f96844d719d1aa629 100644 (file)
@@ -30,6 +30,11 @@ import dns.exception
 import dns.immutable
 import dns.wire
 
+# Dnspython will never access idna if the import fails, but pyright can't figure
+# that out, so...
+#
+# pyright: reportAttributeAccessIssue = false, reportPossiblyUnboundVariable = false
+
 if dns._features.have("idna"):
     import idna  # type: ignore
 
@@ -37,6 +42,7 @@ if dns._features.have("idna"):
 else:  # pragma: no cover
     have_idna_2008 = False
 
+
 CompressType = Dict["Name", int]
 
 
index 78b43d2cbd1404b57f683b2bfad7f726e99caffd..3fa610d040e83b29204130b7ae155b2f30ac0dcc 100644 (file)
@@ -17,6 +17,8 @@
 
 """DNS Opcodes."""
 
+from typing import Type
+
 import dns.enum
 import dns.exception
 
@@ -38,7 +40,7 @@ class Opcode(dns.enum.IntEnum):
         return 15
 
     @classmethod
-    def _unknown_exception_class(cls):
+    def _unknown_exception_class(cls) -> Type[Exception]:
         return UnknownOpcode
 
 
index 0d8a977abdd2351addfad50f379173f02fe7b80e..068729db3c0bb8357c1b3d4dde101af775ad1678 100644 (file)
@@ -38,6 +38,7 @@ import dns.message
 import dns.name
 import dns.quic
 import dns.rcode
+import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
 import dns.serial
@@ -78,7 +79,7 @@ if _have_httpx:
             self._family = family
 
         def connect_tcp(
-            self, host, port, timeout, local_address, socket_options=None
+            self, host, port, timeout=None, local_address=None, socket_options=None
         ):  # pylint: disable=signature-differs
             addresses = []
             _, expiration = _compute_times(timeout)
@@ -98,6 +99,8 @@ if _have_httpx:
             for address in addresses:
                 af = dns.inet.af_for_address(address)
                 if local_address is not None or self._local_port != 0:
+                    if local_address is None:
+                        local_address = "0.0.0.0"
                     source = dns.inet.low_level_address_tuple(
                         (local_address, self._local_port), af
                     )
@@ -117,11 +120,11 @@ if _have_httpx:
             raise httpcore.ConnectError
 
         def connect_unix_socket(
-            self, path, timeout, socket_options=None
+            self, path, timeout=None, socket_options=None
         ):  # pylint: disable=signature-differs
             raise NotImplementedError
 
-    class _HTTPTransport(httpx.HTTPTransport):
+    class _HTTPTransport(httpx.HTTPTransport):  # pyright: ignore
         def __init__(
             self,
             *args,
@@ -144,6 +147,17 @@ if _have_httpx:
 else:
 
     class _HTTPTransport:  # type: ignore
+        def __init__(
+            self,
+            *args,
+            local_port=0,
+            bootstrap_address=None,
+            resolver=None,
+            family=socket.AF_UNSPEC,
+            **kwargs,
+        ):
+            pass
+
         def connect_tcp(self, host, port, timeout, local_address):
             raise NotImplementedError
 
@@ -151,7 +165,7 @@ else:
 have_doh = _have_httpx
 
 try:
-    import ssl
+    import ssl  # pyright: ignore
 except ImportError:  # pragma: no cover
 
     class ssl:  # type: ignore
@@ -163,11 +177,18 @@ except ImportError:  # pragma: no cover
         class WantWriteException(Exception):
             pass
 
+        class SSLWantReadError(Exception):
+            pass
+
+        class SSLWantWriteError(Exception):
+            pass
+
         class SSLContext:
             pass
 
         class SSLSocket:
-            pass
+            def pending(self) -> bool:
+                return False
 
         @classmethod
         def create_default_context(cls, *args, **kwargs):
@@ -226,7 +247,7 @@ def _wait_for(fd, readable, writable, _, expiration):
     if writable:
         events |= selectors.EVENT_WRITE
     if events:
-        sel.register(fd, events)
+        sel.register(fd, events)  # pyright: ignore
     if expiration is None:
         timeout = None
     else:
@@ -338,8 +359,8 @@ def _make_socket(af, type, source, ssl_context=None, server_hostname=None):
 
 
 def _maybe_get_resolver(
-    resolver: Optional["dns.resolver.Resolver"],
-) -> "dns.resolver.Resolver":
+    resolver: Optional["dns.resolver.Resolver"],  # pyright: ignore
+) -> "dns.resolver.Resolver":  # pyright: ignore
     # We need a separate method for this to avoid overriding the global
     # variable "dns" with the as-yet undefined local variable "dns"
     # in https().
@@ -381,7 +402,7 @@ def https(
     post: bool = True,
     bootstrap_address: Optional[str] = None,
     verify: Union[bool, str] = True,
-    resolver: Optional["dns.resolver.Resolver"] = None,
+    resolver: Optional["dns.resolver.Resolver"] = None,  # pyright: ignore
     family: int = socket.AF_UNSPEC,
     http_version: HTTPVersion = HTTPVersion.DEFAULT,
 ) -> dns.message.Message:
@@ -441,13 +462,13 @@ def https(
     (af, _, the_source) = _destination_and_source(
         where, port, source, source_port, False
     )
+    # we bind url and then override as pyright can't figure out all paths bind.
+    url = where
     if af is not None and dns.inet.is_address(where):
         if af == socket.AF_INET:
             url = f"https://{where}:{port}{path}"
         elif af == socket.AF_INET6:
             url = f"https://[{where}]:{port}{path}"
-    else:
-        url = where
 
     extensions = {}
     if bootstrap_address is None:
@@ -466,13 +487,13 @@ def https(
     ):
         if bootstrap_address is None:
             resolver = _maybe_get_resolver(resolver)
-            assert parsed.hostname is not None  # for mypy
-            answers = resolver.resolve_name(parsed.hostname, family)
+            assert parsed.hostname is not None  # pyright: ignore
+            answers = resolver.resolve_name(parsed.hostname, family)  # pyright: ignore
             bootstrap_address = random.choice(list(answers.addresses()))
         return _http3(
             q,
             bootstrap_address,
-            url,
+            url,  # pyright: ignore
             timeout,
             port,
             source,
@@ -485,7 +506,7 @@ def https(
 
     if not have_doh:
         raise NoDOH  # pragma: no cover
-    if session and not isinstance(session, httpx.Client):
+    if session and not isinstance(session, httpx.Client):  # pyright: ignore
         raise ValueError("session parameter must be an httpx.Client")
 
     wire = q.to_wire()
@@ -514,10 +535,12 @@ def https(
             local_port=local_port,
             bootstrap_address=bootstrap_address,
             resolver=resolver,
-            family=family,
+            family=family,  # pyright: ignore
         )
 
-        cm = httpx.Client(http1=h1, http2=h2, verify=verify, transport=transport)
+        cm = httpx.Client(  # pyright: ignore
+            http1=h1, http2=h2, verify=verify, transport=transport  # pyright: ignore
+        )
     with cm as session:
         # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
         # GET and POST examples
@@ -617,7 +640,7 @@ def _http3(
     q.id = 0
     wire = q.to_wire()
     manager = dns.quic.SyncQuicManager(
-        verify_mode=verify, server_name=hostname, h3=True
+        verify_mode=verify, server_name=hostname, h3=True  # pyright: ignore
     )
 
     with manager:
@@ -1162,7 +1185,7 @@ def tcp(
     with cm as s:
         if not sock:
             # pylint: disable=possibly-used-before-assignment
-            _connect(s, destination, expiration)
+            _connect(s, destination, expiration)  # pyright: ignore
         send_tcp(s, wire, expiration)
         (r, received_time) = receive_tcp(
             s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
@@ -1385,14 +1408,18 @@ def quic(
         manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
         the_connection = connection
     else:
-        manager = dns.quic.SyncQuicManager(verify_mode=verify, server_name=hostname)
+        manager = dns.quic.SyncQuicManager(
+            verify_mode=verify, server_name=hostname  # pyright: ignore
+        )
         the_manager = manager  # for type checking happiness
 
     with manager:
         if not connection:
-            the_connection = the_manager.connect(where, port, source, source_port)
+            the_connection = the_manager.connect(  # pyright: ignore
+                where, port, source, source_port
+            )
         (start, expiration) = _compute_times(timeout)
-        with the_connection.make_stream(timeout) as stream:
+        with the_connection.make_stream(timeout) as stream:  # pyright: ignore
             stream.send(wire, True)
             wire = stream.receive(_remaining(expiration))
         finish = time.time()
@@ -1428,7 +1455,7 @@ def _inbound_xfr(
     query: dns.message.Message,
     serial: Optional[int],
     timeout: Optional[float],
-    expiration: float,
+    expiration: Optional[float],
 ) -> Any:
     """Given a socket, does the zone transfer."""
     rdtype = query.question[0].rdtype
@@ -1444,6 +1471,7 @@ def _inbound_xfr(
     with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
         done = False
         tsig_ctx = None
+        r: Optional[dns.message.Message] = None
         while not done:
             (_, mexpiration) = _compute_times(timeout)
             if mexpiration is None or (
@@ -1469,7 +1497,7 @@ def _inbound_xfr(
             done = inbound.process_message(r)
             yield r
             tsig_ctx = r.tsig_ctx
-        if query.keyring and not r.had_tsig:
+        if query.keyring and r is not None and not r.had_tsig:
             raise dns.exception.FormError("missing TSIG")
 
 
index 8e6386f828019b379bbe97a3950ce604c4778f7f..7bb8467e26ec50189ac668ae6d51e2792107888f 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS Result Codes."""
 
-from typing import Tuple
+from typing import Tuple, Type
 
 import dns.enum
 import dns.exception
@@ -72,7 +72,7 @@ class Rcode(dns.enum.IntEnum):
         return 4095
 
     @classmethod
-    def _unknown_exception_class(cls):
+    def _unknown_exception_class(cls) -> Type[Exception]:
         return UnknownRcode
 
 
index 0189f240966d47e4a8c31fd6673ecaa434b0641a..bcdac094e87f92e286b8c814c0136fc3f82d055d 100644 (file)
@@ -210,7 +210,7 @@ class Rdata:
 
     def _to_wire(
         self,
-        file: Optional[Any],
+        file: Any,
         compress: Optional[dns.name.CompressType] = None,
         origin: Optional[dns.name.Name] = None,
         canonicalize: bool = False,
@@ -241,16 +241,12 @@ class Rdata:
             self._to_wire(f, compress, origin, canonicalize)
             return f.getvalue()
 
-    def to_generic(
-        self, origin: Optional[dns.name.Name] = None
-    ) -> "dns.rdata.GenericRdata":
+    def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata":
         """Creates a dns.rdata.GenericRdata equivalent of this rdata.
 
         Returns a ``dns.rdata.GenericRdata``.
         """
-        return dns.rdata.GenericRdata(
-            self.rdclass, self.rdtype, self.to_wire(origin=origin)
-        )
+        return GenericRdata(self.rdclass, self.rdtype, self.to_wire(origin=origin))
 
     def to_digestable(self, origin: Optional[dns.name.Name] = None) -> bytes:
         """Convert rdata to a format suitable for digesting in hashes.  This
@@ -298,6 +294,9 @@ class Rdata:
             In the future, all ordering comparisons for rdata with
             relative names will be disallowed.
         """
+        # the next two lines are for type checkers, so they are bound
+        our = b""
+        their = b""
         try:
             our = self.to_digestable()
             our_relative = False
@@ -620,7 +619,7 @@ class GenericRdata(Rdata):
         relativize: bool = True,
         **kw: Dict[str, Any],
     ) -> str:
-        return r"\# %d " % len(self.data) + _hexify(self.data, **kw)
+        return r"\# %d " % len(self.data) + _hexify(self.data, **kw)  # pyright: ignore
 
     @classmethod
     def from_text(
@@ -639,9 +638,7 @@ class GenericRdata(Rdata):
     def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
         file.write(self.data)
 
-    def to_generic(
-        self, origin: Optional[dns.name.Name] = None
-    ) -> "dns.rdata.GenericRdata":
+    def to_generic(self, origin: Optional[dns.name.Name] = None) -> "GenericRdata":
         return self
 
     @classmethod
@@ -659,7 +656,7 @@ _dynamic_load_allowed = True
 def get_rdata_class(rdclass, rdtype, use_generic=True):
     cls = _rdata_classes.get((rdclass, rdtype))
     if not cls:
-        cls = _rdata_classes.get((dns.rdatatype.ANY, rdtype))
+        cls = _rdata_classes.get((dns.rdataclass.ANY, rdtype))
         if not cls and _dynamic_load_allowed:
             rdclass_text = dns.rdataclass.to_text(rdclass)
             rdtype_text = dns.rdatatype.to_text(rdtype)
@@ -758,6 +755,7 @@ def from_text(
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     cls = get_rdata_class(rdclass, rdtype)
+    assert cls is not None  # for type checkers
     with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
         rdata = None
         if cls != GenericRdata:
@@ -830,6 +828,7 @@ def from_wire_parser(
     rdclass = dns.rdataclass.RdataClass.make(rdclass)
     rdtype = dns.rdatatype.RdataType.make(rdtype)
     cls = get_rdata_class(rdclass, rdtype)
+    assert cls is not None  # for type checkers
     with dns.exception.ExceptionWrapper(dns.exception.FormError):
         return cls.from_wire_parser(rdclass, rdtype, parser, origin)
 
index 39cab2365aee8757615cfe0113c53857cf29d184..4b4bd757603474d0ee743574b3e6d62ae6b68504 100644 (file)
@@ -75,7 +75,7 @@ class Rdataset(dns.set.Set):
         self.ttl = ttl
 
     def _clone(self):
-        obj = super()._clone()
+        obj = cast(Rdataset, super()._clone())
         obj.rdclass = self.rdclass
         obj.rdtype = self.rdtype
         obj.covers = self.covers
@@ -97,7 +97,8 @@ class Rdataset(dns.set.Set):
         elif ttl < self.ttl:
             self.ttl = ttl
 
-    def add(  # pylint: disable=arguments-differ,arguments-renamed
+    # pylint: disable=arguments-differ,arguments-renamed
+    def add(  # pyright: ignore
         self, rd: dns.rdata.Rdata, ttl: Optional[int] = None
     ) -> None:
         """Add the specified rdata to the rdataset.
@@ -355,7 +356,7 @@ class Rdataset(dns.set.Set):
         if len(self) == 0:
             return []
         else:
-            return self[0]._processing_order(iter(self))
+            return self[0]._processing_order(iter(self))  # pyright: ignore
 
 
 @dns.immutable.immutable
@@ -410,22 +411,22 @@ class ImmutableRdataset(Rdataset):  # lgtm[py/missing-equals]
         raise TypeError("immutable")
 
     def __copy__(self):
-        return ImmutableRdataset(super().copy())
+        return ImmutableRdataset(super().copy())  # pyright: ignore
 
     def copy(self):
-        return ImmutableRdataset(super().copy())
+        return ImmutableRdataset(super().copy())  # pyright: ignore
 
     def union(self, other):
-        return ImmutableRdataset(super().union(other))
+        return ImmutableRdataset(super().union(other))  # pyright: ignore
 
     def intersection(self, other):
-        return ImmutableRdataset(super().intersection(other))
+        return ImmutableRdataset(super().intersection(other))  # pyright: ignore
 
     def difference(self, other):
-        return ImmutableRdataset(super().difference(other))
+        return ImmutableRdataset(super().difference(other))  # pyright: ignore
 
     def symmetric_difference(self, other):
-        return ImmutableRdataset(super().symmetric_difference(other))
+        return ImmutableRdataset(super().symmetric_difference(other))  # pyright: ignore
 
 
 def from_text_list(
index db300f8b15ad074f63062341e37fadb506953ba5..381fe770dd5c260088fd4232c743c554665a0216 100644 (file)
@@ -52,7 +52,7 @@ class DNSKEYBase(dns.rdata.Rdata):
             self.flags,
             self.protocol,
             self.algorithm,
-            dns.rdata._base64ify(self.key, **kw),
+            dns.rdata._base64ify(self.key, **kw),  # pyright: ignore
         )
 
     @classmethod
index cd21f026dc41b3f76f262058f3861856bf4499b3..a9269d22cf6356efe20e4946eda2345c74c433ed 100644 (file)
@@ -59,7 +59,9 @@ class DSBase(dns.rdata.Rdata):
             self.key_tag,
             self.algorithm,
             self.digest_type,
-            dns.rdata._hexify(self.digest, chunksize=chunksize, **kw),
+            dns.rdata._hexify(
+                self.digest, chunksize=chunksize, **kw  # pyright: ignore
+            ),
         )
 
     @classmethod
index a39c166b98fe2973fc64835c1209c12417535079..4eb82eb5e842d3b00dd880a10fbe16f40ca47c72 100644 (file)
@@ -16,6 +16,7 @@
 
 import binascii
 
+import dns.exception
 import dns.immutable
 import dns.rdata
 
@@ -27,7 +28,9 @@ class EUIBase(dns.rdata.Rdata):
     # see: rfc7043.txt
 
     __slots__ = ["eui"]
-    # define these in subclasses
+    # redefine these in subclasses
+    byte_len = 0
+    text_len = 0
     # byte_len = 6  # 0123456789ab (in hex)
     # text_len = byte_len * 3 - 1  # 01-23-45-67-89-ab
 
index a2b15b922abed495f739a6f9bafbba88f6d8bdab..bcde5cbbb94849cf155dd46cdf75aa966f907146 100644 (file)
@@ -3,6 +3,7 @@
 import base64
 import enum
 import struct
+from typing import Any, Dict
 
 import dns.enum
 import dns.exception
@@ -97,9 +98,9 @@ def _escapify(qstring):
     return text
 
 
-def _unescape(value):
+def _unescape(value: str) -> bytes:
     if value == "":
-        return value
+        return b""
     unescaped = b""
     l = len(value)
     i = 0
@@ -159,7 +160,7 @@ class Param:
     """Abstract base class for SVCB parameters"""
 
     @classmethod
-    def emptiness(cls):
+    def emptiness(cls) -> Emptiness:
         return Emptiness.NEVER
 
 
@@ -427,7 +428,7 @@ class OHTTPParam(Param):
         raise NotImplementedError  # pragma: no cover
 
 
-_class_for_key = {
+_class_for_key: Dict[ParamKey, Any] = {
     ParamKey.MANDATORY: MandatoryParam,
     ParamKey.ALPN: ALPNParam,
     ParamKey.NO_DEFAULT_ALPN: NoDefaultALPNParam,
@@ -571,10 +572,11 @@ class SVCBBase(dns.rdata.Rdata):
                 raise dns.exception.FormError("keys not in order")
             prior_key = key
             vlen = parser.get_uint16()
-            pcls = _class_for_key.get(key, GenericParam)
+            pkey = ParamKey.make(key)
+            pcls = _class_for_key.get(pkey, GenericParam)
             with parser.restrict_to(vlen):
                 value = pcls.from_wire_parser(parser, origin)
-            params[key] = value
+            params[pkey] = value
         return cls(rdclass, rdtype, priority, target, params)
 
     def _processing_priority(self):
index a059d2c4a40dfd96486bf24047931bcddcd1713d..44d8cc24acbcbad449eca0f6b31708a5c1fecad2 100644 (file)
@@ -45,7 +45,7 @@ class TLSABase(dns.rdata.Rdata):
             self.usage,
             self.selector,
             self.mtype,
-            dns.rdata._hexify(self.cert, chunksize=chunksize, **kw),
+            dns.rdata._hexify(self.cert, chunksize=chunksize, **kw),  # pyright: ignore
         )
 
     @classmethod
index 73db6d9e25dcadab2dacb57001b490ac91d40dba..6ecdd35fc4bb0c4e228d28834194652389aea120 100644 (file)
@@ -21,7 +21,10 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
 
 import dns.exception
 import dns.immutable
+import dns.name
 import dns.rdata
+import dns.rdataclass
+import dns.rdatatype
 import dns.renderer
 import dns.tokenizer
 
index 653a0bf2e7eb03829905927a1c21d584c250793a..ee6e8acadaf4ccbaa32ad1c6950ca54bff6ed94c 100644 (file)
 import collections
 import random
 import struct
-from typing import Any, List
+from typing import Any, Iterable, List, Optional, Tuple, Union
 
 import dns.exception
 import dns.ipv4
 import dns.ipv6
 import dns.name
 import dns.rdata
+import dns.rdatatype
+import dns.tokenizer
+import dns.wire
 
 
 class Gateway:
@@ -32,7 +35,7 @@ class Gateway:
 
     name = ""
 
-    def __init__(self, type, gateway=None):
+    def __init__(self, type: Any, gateway: Optional[Union[str, dns.name.Name]] = None):
         self.type = dns.rdata.Rdata._as_uint8(type)
         self.gateway = gateway
         self._check()
@@ -48,9 +51,11 @@ class Gateway:
             self.gateway = None
         elif self.type == 1:
             # check that it's OK
+            assert isinstance(self.gateway, str)
             dns.ipv4.inet_aton(self.gateway)
         elif self.type == 2:
             # check that it's OK
+            assert isinstance(self.gateway, str)
             dns.ipv6.inet_aton(self.gateway)
         elif self.type == 3:
             if not isinstance(self.gateway, dns.name.Name):
@@ -64,6 +69,7 @@ class Gateway:
         elif self.type in (1, 2):
             return self.gateway
         elif self.type == 3:
+            assert isinstance(self.gateway, dns.name.Name)
             return str(self.gateway.choose_relativity(origin, relativize))
         else:
             raise ValueError(self._invalid_type(self.type))  # pragma: no cover
@@ -87,10 +93,13 @@ class Gateway:
         if self.type == 0:
             pass
         elif self.type == 1:
+            assert isinstance(self.gateway, str)
             file.write(dns.ipv4.inet_aton(self.gateway))
         elif self.type == 2:
+            assert isinstance(self.gateway, str)
             file.write(dns.ipv6.inet_aton(self.gateway))
         elif self.type == 3:
+            assert isinstance(self.gateway, dns.name.Name)
             self.gateway.to_wire(file, None, origin, False)
         else:
             raise ValueError(self._invalid_type(self.type))  # pragma: no cover
@@ -117,8 +126,10 @@ class Bitmap:
 
     type_name = ""
 
-    def __init__(self, windows=None):
+    def __init__(self, windows: Optional[Iterable[Tuple[int, bytes]]] = None):
         last_window = -1
+        if windows is None:
+            windows = []
         self.windows = windows
         for window, bitmap in self.windows:
             if not isinstance(window, int):
@@ -140,7 +151,7 @@ class Bitmap:
             for i, byte in enumerate(bitmap):
                 for j in range(0, 8):
                     if byte & (0x80 >> j):
-                        rdtype = window * 256 + i * 8 + j
+                        rdtype = dns.rdatatype.RdataType.make(window * 256 + i * 8 + j)
                         bits.append(dns.rdatatype.to_text(rdtype))
             text += " " + " ".join(bits)
         return text
@@ -236,9 +247,10 @@ def weighted_processing_order(iterable):
                 if weight > r:
                     break
                 r -= weight
-            total -= weight
-            ordered.append(rdata)  # pylint: disable=undefined-loop-variable
-            del rdatas[n]  # pylint: disable=undefined-loop-variable
+            total -= weight  # pyright: ignore[reportPossiblyUnboundVariable]
+            # pylint: disable=undefined-loop-variable
+            ordered.append(rdata)  # pyright: ignore[reportPossiblyUnboundVariable]
+            del rdatas[n]  # pyright: ignore[reportPossiblyUnboundVariable]
         ordered.append(rdatas[0])
     return ordered
 
index a77481f67c8cdab5205ce8453550eb6f02ec566c..cc912b29d9531579508e818b0c3978bc37c89754 100644 (file)
@@ -23,9 +23,14 @@ import random
 import struct
 import time
 
+import dns.edns
 import dns.exception
+import dns.rdataclass
+import dns.rdatatype
 import dns.tsig
 
+# Note we can't import dns.message for cicularity reasons
+
 QUESTION = 0
 ANSWER = 1
 AUTHORITY = 2
@@ -214,7 +219,9 @@ class Renderer:
                 pad = b""
             options = list(opt_rdata.options)
             options.append(dns.edns.GenericOption(dns.edns.OptionType.PADDING, pad))
-            opt = dns.message.Message._make_opt(ttl, opt_rdata.rdclass, options)
+            opt = dns.message.Message._make_opt(  # pyright: ignore
+                ttl, opt_rdata.rdclass, options
+            )
             self.was_padded = True
         self.add_rrset(ADDITIONAL, opt)
 
@@ -224,7 +231,9 @@ class Renderer:
         # make sure the EDNS version in ednsflags agrees with edns
         ednsflags &= 0xFF00FFFF
         ednsflags |= edns << 16
-        opt = dns.message.Message._make_opt(ednsflags, payload, options)
+        opt = dns.message.Message._make_opt(  # pyright: ignore
+            ednsflags, payload, options
+        )
         self.add_opt(opt)
 
     def add_tsig(
@@ -246,7 +255,7 @@ class Renderer:
             key = secret
         else:
             key = dns.tsig.Key(keyname, secret, algorithm)
-        tsig = dns.message.Message._make_tsig(
+        tsig = dns.message.Message._make_tsig(  # pyright: ignore
             keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
         )
         (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()), request_mac)
@@ -278,7 +287,7 @@ class Renderer:
             key = secret
         else:
             key = dns.tsig.Key(keyname, secret, algorithm)
-        tsig = dns.message.Message._make_tsig(
+        tsig = dns.message.Message._make_tsig(  # pyright: ignore
             keyname, algorithm, 0, fudge, b"", id, tsig_error, other_data
         )
         (tsig, ctx) = dns.tsig.sign(
index af90dd8f1fcd3c9d77cb217d63310656bffd1e4f..1e23c0c86bd18edc24f95a808f19f56767c1338f 100644 (file)
@@ -24,7 +24,7 @@ import sys
 import threading
 import time
 import warnings
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
 from urllib.parse import urlparse
 
 import dns._ddr
@@ -42,6 +42,7 @@ import dns.rcode
 import dns.rdata
 import dns.rdataclass
 import dns.rdatatype
+import dns.rdtypes.ANY.PTR
 import dns.rdtypes.svcbbase
 import dns.reversename
 import dns.tsig
@@ -63,7 +64,7 @@ class NXDOMAIN(dns.exception.DNSException):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def _check_kwargs(self, qnames, responses=None):
+    def _check_kwargs(self, qnames, responses=None):  # pyright: ignore
         if not isinstance(qnames, (list, tuple, set)):
             raise AttributeError("qnames must be a list, tuple or set")
         if len(qnames) == 0:
@@ -282,24 +283,25 @@ class Answer:
         self.expiration = time.time() + self.chaining_result.minimum_ttl
 
     def __getattr__(self, attr):  # pragma: no cover
-        if attr == "name":
-            return self.rrset.name
-        elif attr == "ttl":
-            return self.rrset.ttl
-        elif attr == "covers":
-            return self.rrset.covers
-        elif attr == "rdclass":
-            return self.rrset.rdclass
-        elif attr == "rdtype":
-            return self.rrset.rdtype
+        if self.rrset is not None:
+            if attr == "name":
+                return self.rrset.name
+            elif attr == "ttl":
+                return self.rrset.ttl
+            elif attr == "covers":
+                return self.rrset.covers
+            elif attr == "rdclass":
+                return self.rrset.rdclass
+            elif attr == "rdtype":
+                return self.rrset.rdtype
         else:
             raise AttributeError(attr)
 
     def __len__(self) -> int:
-        return self.rrset and len(self.rrset) or 0
+        return self.rrset is not None and len(self.rrset) or 0
 
     def __iter__(self) -> Iterator[Any]:
-        return self.rrset and iter(self.rrset) or iter(tuple())
+        return self.rrset is not None and iter(self.rrset) or iter(tuple())
 
     def __getitem__(self, i):
         if self.rrset is None:
@@ -1480,7 +1482,7 @@ class Resolver(BaseResolver):
         try:
             answer = self.resolve(name, raise_on_no_answer=False)
             canonical_name = answer.canonical_name
-        except dns.resolver.NXDOMAIN as e:
+        except NXDOMAIN as e:
             canonical_name = e.canonical_name
         return canonical_name
 
@@ -1655,7 +1657,7 @@ def zone_for_name(
     tcp: bool = False,
     resolver: Optional[Resolver] = None,
     lifetime: Optional[float] = None,
-) -> dns.name.Name:
+) -> dns.name.Name:  # pyright: ignore[reportReturnType]
     """Find the name of the zone which contains the specified name.
 
     *name*, an absolute ``dns.name.Name`` or ``str``, the query name.
@@ -1709,8 +1711,8 @@ def zone_for_name(
             if answer.rrset.name == name:
                 return name
             # otherwise we were CNAMEd or DNAMEd and need to look higher
-        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer) as e:
-            if isinstance(e, dns.resolver.NXDOMAIN):
+        except (NXDOMAIN, NoAnswer) as e:
+            if isinstance(e, NXDOMAIN):
                 response = e.responses().get(name)
             else:
                 response = e.response()  # pylint: disable=no-value-for-parameter
@@ -1765,7 +1767,7 @@ def make_resolver_at(
     else:
         for address in resolver.resolve_name(where, family).addresses():
             nameservers.append(dns.nameserver.Do53Nameserver(address, port))
-    res = dns.resolver.Resolver(configure=False)
+    res = Resolver(configure=False)
     res.nameservers = nameservers
     return res
 
@@ -1816,12 +1818,12 @@ def resolve_at(
 # running process.
 #
 
-_protocols_for_socktype = {
+_protocols_for_socktype: Dict[Any, List[Any]] = {
     socket.SOCK_DGRAM: [socket.SOL_UDP],
     socket.SOCK_STREAM: [socket.SOL_TCP],
 }
 
-_resolver = None
+_resolver: Optional[Resolver] = None
 _original_getaddrinfo = socket.getaddrinfo
 _original_getnameinfo = socket.getnameinfo
 _original_getfqdn = socket.getfqdn
@@ -1870,10 +1872,11 @@ def _getaddrinfo(
         pass
     # Something needs resolution!
     try:
+        assert _resolver is not None
         answers = _resolver.resolve_name(host, family)
         addrs = answers.addresses_and_families()
         canonical_name = answers.canonical_name().to_text(True)
-    except dns.resolver.NXDOMAIN:
+    except NXDOMAIN:
         raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
     except Exception:
         # We raise EAI_AGAIN here as the failure may be temporary
@@ -1890,7 +1893,7 @@ def _getaddrinfo(
     except Exception:
         if flags & socket.AI_NUMERICSERV == 0:
             try:
-                port = socket.getservbyname(service)
+                port = socket.getservbyname(service)  # pyright: ignore
             except Exception:
                 pass
     if port is None:
@@ -1906,7 +1909,8 @@ def _getaddrinfo(
         cname = ""
     for addr, af in addrs:
         for socktype in socktypes:
-            for proto in _protocols_for_socktype[socktype]:
+            for sockproto in _protocols_for_socktype[socktype]:
+                proto = int(sockproto)
                 addr_tuple = dns.inet.low_level_address_tuple((addr, port), af)
                 tuples.append((af, socktype, proto, cname, addr_tuple))
     if len(tuples) == 0:
@@ -1934,9 +1938,12 @@ def _getnameinfo(sockaddr, flags=0):
     qname = dns.reversename.from_address(addr)
     if flags & socket.NI_NUMERICHOST == 0:
         try:
+            assert _resolver is not None
             answer = _resolver.resolve(qname, "PTR")
-            hostname = answer.rrset[0].target.to_text(True)
-        except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+            assert answer.rrset is not None
+            rdata = cast(dns.rdtypes.ANY.PTR.PTR, answer.rrset[0])
+            hostname = rdata.target.to_text(True)
+        except (NXDOMAIN, NoAnswer):
             if flags & socket.NI_NAMEREQD:
                 raise socket.gaierror(socket.EAI_NONAME, "Name or service not known")
             hostname = addr
index 8236c711f16f1e3b514f182a8254cb0e0ce45a68..dc5f33e354bc72427434961ca8f335f84107540e 100644 (file)
@@ -19,6 +19,7 @@
 
 import binascii
 
+import dns.exception
 import dns.ipv4
 import dns.ipv6
 import dns.name
index 6f39b108db9f3ea4a8955a38e326511394092363..2b0effaaa7c67bcc0557742982927d8ecd6b6b64 100644 (file)
 from typing import Any, Collection, Dict, Optional, Union, cast
 
 import dns.name
+import dns.rdata
 import dns.rdataclass
 import dns.rdataset
+import dns.rdatatype
 import dns.renderer
 
 
@@ -52,7 +54,7 @@ class RRset(dns.rdataset.Rdataset):
         self.deleting = deleting
 
     def _clone(self):
-        obj = super()._clone()
+        obj = cast(RRset, super()._clone())
         obj.name = self.name
         obj.deleting = self.deleting
         return obj
index aa2e1160336b6450525f1748bdc05c05e08acca4..bcdda9e0d906ab78aa64045681635191e870ca2f 100644 (file)
@@ -6,6 +6,7 @@ from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
 import dns.exception
 import dns.name
 import dns.node
+import dns.rdata
 import dns.rdataclass
 import dns.rdataset
 import dns.rdatatype
@@ -416,12 +417,12 @@ class Transaction:
                 raise TypeError(f"{method}: expected more arguments")
 
     def _add(self, replace, args):
+        if replace:
+            method = "replace()"
+        else:
+            method = "add()"
         try:
             args = collections.deque(args)
-            if replace:
-                method = "replace()"
-            else:
-                method = "add()"
             arg = args.popleft()
             if isinstance(arg, str):
                 arg = dns.name.from_text(arg, None)
@@ -438,6 +439,7 @@ class Transaction:
                 raise TypeError(
                     f"{method} requires a name or RRset as the first argument"
                 )
+            assert rdataset is not None  # for type checkers
             if rdataset.rdclass != self.manager.get_class():
                 raise ValueError(f"{method} has objects of wrong RdataClass")
             if rdataset.rdtype == dns.rdatatype.SOA:
@@ -460,12 +462,12 @@ class Transaction:
             raise TypeError(f"not enough parameters to {method}")
 
     def _delete(self, exact, args):
+        if exact:
+            method = "delete_exact()"
+        else:
+            method = "delete()"
         try:
             args = collections.deque(args)
-            if exact:
-                method = "delete_exact()"
-            else:
-                method = "delete()"
             arg = args.popleft()
             if isinstance(arg, str):
                 arg = dns.name.from_text(arg, None)
index 780852e8e35028f57e2e1e9cd2ebff8f877e0c22..18640a8171ee3edfa48f820f0249d6cbf2eb8de0 100644 (file)
@@ -21,11 +21,13 @@ import base64
 import hashlib
 import hmac
 import struct
+from typing import Union
 
 import dns.exception
 import dns.name
 import dns.rcode
 import dns.rdataclass
+import dns.rdatatype
 
 
 class BadTime(dns.exception.DNSException):
@@ -221,6 +223,7 @@ def _digest(wire, key, rdata, time=None, request_mac=None, ctx=None, multi=None)
         if request_mac:
             ctx.update(struct.pack("!H", len(request_mac)))
             ctx.update(request_mac)
+    assert ctx is not None  # for type checkers
     ctx.update(struct.pack("!H", rdata.original_id))
     ctx.update(wire[2:])
     if first:
@@ -325,7 +328,12 @@ def get_context(key):
 
 
 class Key:
-    def __init__(self, name, secret, algorithm=default_algorithm):
+    def __init__(
+        self,
+        name: Union[dns.name.Name, str],
+        secret: Union[bytes, str],
+        algorithm: Union[dns.name.Name, str] = default_algorithm,
+    ):
         if isinstance(name, str):
             name = dns.name.from_text(name)
         self.name = name
index 1010a79f8f3c1856b765fa11e01cb5b6e2f6ea64..5996295a2270d7818863d1ff03a7decf00070688 100644 (file)
@@ -24,14 +24,14 @@ import dns.name
 import dns.tsig
 
 
-def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, dns.tsig.Key]:
+def from_text(textring: Dict[str, Any]) -> Dict[dns.name.Name, Any]:
     """Convert a dictionary containing (textual DNS name, base64 secret)
     pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
     a dictionary containing (textual DNS name, (algorithm, base64 secret))
     pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs.
     @rtype: dict"""
 
-    keyring = {}
+    keyring: Dict[dns.name.Name, Any] = {}
     for name, value in textring.items():
         kname = dns.name.from_text(name)
         if isinstance(value, str):
index b9a99fe3c2246ba13e9a9d27b7f0c79cf2c1afce..06c11eeff0f0abd4fed4921a12f35c817a2a2109 100644 (file)
@@ -87,6 +87,6 @@ def make(value: Union[int, str]) -> int:
     if isinstance(value, int):
         return value
     elif isinstance(value, str):
-        return dns.ttl.from_text(value)
+        return from_text(value)
     else:
         raise ValueError("cannot convert value to TTL")
index bf1157acdfe7f4262afec600fd9a30691aa0f78d..cbf2079779eb3ce57a69644de69d89a0a9a1c628 100644 (file)
@@ -19,6 +19,8 @@
 
 from typing import Any, List, Optional, Union
 
+import dns.enum
+import dns.exception
 import dns.message
 import dns.name
 import dns.opcode
@@ -26,6 +28,7 @@ import dns.rdata
 import dns.rdataclass
 import dns.rdataset
 import dns.rdatatype
+import dns.rrset
 import dns.tsig
 
 
@@ -351,7 +354,7 @@ class UpdateMessage(dns.message.Message):  # lgtm[py/missing-equals]
         # Updates are always one_rr_per_rrset
         return True
 
-    def _parse_rr_header(self, section, name, rdclass, rdtype):
+    def _parse_rr_header(self, section, name, rdclass, rdtype):  # pyright: ignore
         deleting = None
         empty = False
         if section == UpdateSection.ZONE:
index fd78e674e6edbb0dc2dcab6bbc9515b4b2103520..6479ae47e0b137bcce48b055db6049c38172dc0f 100644 (file)
@@ -4,7 +4,7 @@
 
 import collections
 import threading
-from typing import Callable, Deque, Optional, Set, Union
+from typing import Callable, Deque, Optional, Set, Union, cast
 
 import dns.exception
 import dns.immutable
@@ -105,7 +105,10 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
                     n = v.nodes.get(oname)
                     if n:
                         rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
-                        if rds and rds[0].serial == serial:
+                        if rds is None:
+                            continue
+                        soa = cast(dns.rdtypes.ANY.SOA.SOA, rds[0])
+                        if rds and soa.serial == serial:
                             version = v
                             break
                 if version is None:
@@ -186,7 +189,7 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
         # Note our definition of least_kept also ensures we do not try to
         # delete the greatest version.
         if len(self._readers) > 0:
-            least_kept = min(txn.version.id for txn in self._readers)
+            least_kept = min(txn.version.id for txn in self._readers)  # pyright: ignore
         else:
             least_kept = self._versions[-1].id
         while self._versions[0].id < least_kept and self._pruning_policy(
@@ -201,8 +204,8 @@ class Zone(dns.zone.Zone):  # lgtm[py/missing-equals]
         if max_versions is not None and max_versions < 1:
             raise ValueError("max versions must be at least 1")
         if max_versions is None:
-
-            def policy(zone, _):  # pylint: disable=unused-argument
+            # pylint: disable=unused-argument
+            def policy(zone, _):  # pyright: ignore
                 return False
 
         else:
index 520aa32ddc32ea50b090ff4b08b9450709a5fb93..f1b875934b9353f05cc2a503038190c03a72c718 100644 (file)
 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, List, Optional, Tuple, Union, cast
 
+import dns.edns
 import dns.exception
 import dns.message
 import dns.name
 import dns.rcode
+import dns.rdata
 import dns.rdataset
 import dns.rdatatype
+import dns.rdtypes
+import dns.rdtypes.ANY
+import dns.rdtypes.ANY.SMIMEA
+import dns.rdtypes.ANY.SOA
+import dns.rdtypes.svcbbase
 import dns.serial
 import dns.transaction
 import dns.tsig
@@ -123,14 +130,16 @@ class Inbound:
             if rdataset.rdtype != dns.rdatatype.SOA:
                 raise dns.exception.FormError("first RRset is not an SOA")
             answer_index = 1
-            self.soa_rdataset = rdataset.copy()
+            self.soa_rdataset = rdataset.copy()  # pyright: ignore
             if self.rdtype == dns.rdatatype.IXFR:
-                if self.soa_rdataset[0].serial == self.serial:
+                assert self.soa_rdataset is not None
+                soa = cast(dns.rdtypes.ANY.SOA.SOA, self.soa_rdataset[0])
+                if soa.serial == self.serial:
                     #
                     # We're already up-to-date.
                     #
                     self.done = True
-                elif dns.serial.Serial(self.soa_rdataset[0].serial) < self.serial:
+                elif dns.serial.Serial(soa.serial) < self.serial:
                     # It went backwards!
                     raise SerialWentBackwards
                 else:
@@ -174,13 +183,11 @@ class Inbound:
                     #
                     # This is the final SOA
                     #
+                    soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
                     if self.expecting_SOA:
                         # We got an empty IXFR sequence!
                         raise dns.exception.FormError("empty IXFR sequence")
-                    if (
-                        self.rdtype == dns.rdatatype.IXFR
-                        and self.serial != rdataset[0].serial
-                    ):
+                    if self.rdtype == dns.rdatatype.IXFR and self.serial != soa.serial:
                         raise dns.exception.FormError("unexpected end of IXFR sequence")
                     self.txn.replace(name, rdataset)
                     self.txn.commit()
@@ -191,16 +198,17 @@ class Inbound:
                     # This is not the final SOA
                     #
                     self.expecting_SOA = False
+                    soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
                     if self.rdtype == dns.rdatatype.IXFR:
                         if self.delete_mode:
                             # This is the start of an IXFR deletion set
-                            if rdataset[0].serial != self.serial:
+                            if soa.serial != self.serial:
                                 raise dns.exception.FormError(
                                     "IXFR base serial mismatch"
                                 )
                         else:
                             # This is the start of an IXFR addition set
-                            self.serial = rdataset[0].serial
+                            self.serial = soa.serial
                             self.txn.replace(name, rdataset)
                     else:
                         # We saw a non-final SOA for the origin in an AXFR.
@@ -289,7 +297,8 @@ def make_query(
         with txn_manager.reader() as txn:
             rdataset = txn.get(origin, "SOA")
             if rdataset:
-                serial = rdataset[0].serial
+                soa = cast(dns.rdtypes.ANY.SOA.SOA, rdataset[0])
+                serial = soa.serial
                 rdtype = dns.rdatatype.IXFR
             else:
                 serial = None
@@ -337,7 +346,8 @@ def extract_serial_from_query(query: dns.message.Message) -> Optional[int]:
         return None
     elif question.rdtype != dns.rdatatype.IXFR:
         raise ValueError("query is not an AXFR or IXFR")
-    soa = query.find_rrset(
+    soa_rrset = query.find_rrset(
         query.authority, question.name, question.rdclass, dns.rdatatype.SOA
     )
-    return soa[0].serial
+    soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rrset[0])
+    return soa.serial
index 844919e41f1162d44c276a329223df0e05e4dabc..7cba657d48293576fd75e9ad73cc30b989f16ef1 100644 (file)
@@ -32,6 +32,7 @@ from typing import (
     Set,
     Tuple,
     Union,
+    cast,
 )
 
 import dns.exception
@@ -698,9 +699,9 @@ class Zone(dns.transaction.TransactionManager):
             for n in names:
                 l = self[n].to_text(
                     n,
-                    origin=self.origin,
-                    relativize=relativize,
-                    want_comments=want_comments,
+                    origin=self.origin,  # pyright: ignore
+                    relativize=relativize,  # pyright: ignore
+                    want_comments=want_comments,  # pyright: ignore
                 )
                 l_b = l.encode(file_enc)
 
@@ -786,14 +787,16 @@ class Zone(dns.transaction.TransactionManager):
                 # an SOA if there is no origin.
                 raise NoSOA
             origin_name = self.origin
-        soa: Optional[dns.rdataset.Rdataset]
+        soa_rds: Optional[dns.rdataset.Rdataset]
         if txn:
-            soa = txn.get(origin_name, dns.rdatatype.SOA)
+            soa_rds = txn.get(origin_name, dns.rdatatype.SOA)
         else:
-            soa = self.get_rdataset(origin_name, dns.rdatatype.SOA)
-        if soa is None:
+            soa_rds = self.get_rdataset(origin_name, dns.rdatatype.SOA)
+        if soa_rds is None:
             raise NoSOA
-        return soa[0]
+        else:
+            soa = cast(dns.rdtypes.ANY.SOA.SOA, soa_rds[0])
+            return soa
 
     def _compute_digest(
         self,
@@ -892,12 +895,12 @@ class Zone(dns.transaction.TransactionManager):
     def _end_write(self, txn):
         pass
 
-    def _commit_version(self, _, version, origin):
+    def _commit_version(self, txn, version, origin):
         self.nodes = version.nodes
         if self.origin is None:
             self.origin = origin
 
-    def _get_next_version_id(self):
+    def _get_next_version_id(self) -> int:
         # Versions are ephemeral and all have id 1
         return 1
 
@@ -1106,67 +1109,83 @@ class Transaction(dns.transaction.Transaction):
 
     def _setup_version(self):
         assert self.version is None
-        factory = self.manager.writable_version_factory
+        factory = self.manager.writable_version_factory  # pyright: ignore
         if factory is None:
             factory = WritableVersion
-        self.version = factory(self.zone, self.replacement)
+        self.version = factory(self.zone, self.replacement)  # pyright: ignore
 
     def _get_rdataset(self, name, rdtype, covers):
+        assert self.version is not None
         return self.version.get_rdataset(name, rdtype, covers)
 
     def _put_rdataset(self, name, rdataset):
         assert not self.read_only
+        assert self.version is not None
         self.version.put_rdataset(name, rdataset)
 
     def _delete_name(self, name):
         assert not self.read_only
+        assert self.version is not None
         self.version.delete_node(name)
 
     def _delete_rdataset(self, name, rdtype, covers):
         assert not self.read_only
+        assert self.version is not None
         self.version.delete_rdataset(name, rdtype, covers)
 
     def _name_exists(self, name):
+        assert self.version is not None
         return self.version.get_node(name) is not None
 
     def _changed(self):
         if self.read_only:
             return False
         else:
+            assert self.version is not None
             return len(self.version.changed) > 0
 
     def _end_transaction(self, commit):
+        assert self.zone is not None
+        assert self.version is not None
         if self.read_only:
-            self.zone._end_read(self)
+            self.zone._end_read(self)  # pyright: ignore
         elif commit and len(self.version.changed) > 0:
             if self.make_immutable:
-                factory = self.manager.immutable_version_factory
+                factory = self.manager.immutable_version_factory  # pyright: ignore
                 if factory is None:
                     factory = ImmutableVersion
                 version = factory(self.version)
             else:
                 version = self.version
-            self.zone._commit_version(self, version, self.version.origin)
+            self.zone._commit_version(  # pyright: ignore
+                self, version, self.version.origin
+            )
+
         else:
             # rollback
-            self.zone._end_write(self)
+            self.zone._end_write(self)  # pyright: ignore
 
     def _set_origin(self, origin):
+        assert self.version is not None
         if self.version.origin is None:
             self.version.origin = origin
 
     def _iterate_rdatasets(self):
+        assert self.version is not None
         for name, node in self.version.items():
             for rdataset in node:
                 yield (name, rdataset)
 
     def _iterate_names(self):
+        assert self.version is not None
         return self.version.keys()
 
     def _get_node(self, name):
+        assert self.version is not None
         return self.version.get_node(name)
 
     def _origin_information(self):
+        assert self.version is not None
         (absolute, relativize, effective) = self.manager.origin_information()
         if absolute is None and self.version.origin is not None:
             # No origin has been committed yet, but we've learned one as part of
@@ -1214,7 +1233,7 @@ def _from_text(
             reader.read()
         except dns.zonefile.UnknownOrigin:
             # for backwards compatibility
-            raise dns.zone.UnknownOrigin
+            raise UnknownOrigin
     # Now that we're done reading, do some basic checking of the zone.
     if check_origin:
         zone.check_origin()
index d74510b29f008c0ab03e16a5ac2bd211c81fde5f..af4778512a849b19efe893796e933c4704392a53 100644 (file)
@@ -19,7 +19,7 @@
 
 import re
 import sys
-from typing import Any, Iterable, List, Optional, Set, Tuple, Union
+from typing import Any, Iterable, List, Optional, Set, Tuple, Union, cast
 
 import dns.exception
 import dns.grange
@@ -169,6 +169,9 @@ class Reader:
                     return
                 self.tok.unget(token)
             name = self.last_name
+            if name is None:
+                raise dns.exception.SyntaxError("the last used name is undefined")
+            assert self.zone_origin is not None
             if not name.is_subdomain(self.zone_origin):
                 self._eat_line()
                 return
@@ -257,11 +260,12 @@ class Reader:
             # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
             # TTL from the SOA minttl if no $TTL statement is present before the
             # SOA is parsed.
-            self.default_ttl = rd.minimum
+            soa_rd = cast(dns.rdtypes.ANY.SOA.SOA, rd)
+            self.default_ttl = soa_rd.minimum
             self.default_ttl_known = True
             if ttl is None:
                 # if we didn't have a TTL on the SOA, set it!
-                ttl = rd.minimum
+                ttl = soa_rd.minimum
 
         # TTL check.  We had to wait until now to do this as the SOA RR's
         # own TTL can be inferred from its minimum.
@@ -356,6 +360,12 @@ class Reader:
                 ttl = self.default_ttl
             elif self.last_ttl_known:
                 ttl = self.last_ttl
+            else:
+                # We don't go to the extra "look at the SOA" level of effort for
+                # $GENERATE, because the user really ought to have defined a TTL
+                # somehow!
+                raise dns.exception.SyntaxError("Missing default TTL value")
+
         # Class
         try:
             rdclass = dns.rdataclass.from_text(token.value)
@@ -417,6 +427,7 @@ class Reader:
                 name, self.current_origin, self.tok.idna_codec
             )
             name = self.last_name
+            assert self.zone_origin is not None
             if not name.is_subdomain(self.zone_origin):
                 self._eat_line()
                 return
@@ -606,7 +617,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
                 )
                 rrset.update(rdataset)
                 rrsets.append(rrset)
-            self.manager.set_rrsets(rrsets)
+            self.manager.set_rrsets(rrsets)  # pyright: ignore
 
     def _set_origin(self, origin):
         pass
@@ -620,12 +631,15 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
 
 class RRSetsReaderManager(dns.transaction.TransactionManager):
     def __init__(
-        self, origin=dns.name.root, relativize=False, rdclass=dns.rdataclass.IN
+        self,
+        origin: Optional[dns.name.Name] = dns.name.root,
+        relativize: bool = False,
+        rdclass: dns.rdataclass.RdataClass = dns.rdataclass.IN,
     ):
         self.origin = origin
         self.relativize = relativize
         self.rdclass = rdclass
-        self.rrsets = []
+        self.rrsets: List[dns.rrset.RRset] = []
 
     def reader(self):  # pragma: no cover
         raise NotImplementedError
@@ -644,7 +658,7 @@ class RRSetsReaderManager(dns.transaction.TransactionManager):
             effective = self.origin
         return (self.origin, self.relativize, effective)
 
-    def set_rrsets(self, rrsets):
+    def set_rrsets(self, rrsets: List[dns.rrset.RRset]) -> None:
         self.rrsets = rrsets
 
 
index 2a4d045c684a52360a8b20c93150164f8cc57f99..75530baaaca826f4fa6c7c6f5f221edc78016a53 100644 (file)
@@ -116,3 +116,16 @@ ignore_missing_imports = true
 [[tool.mypy.overrides]]
 module = "wmi"
 ignore_missing_imports = true
+
+[tool.pyright]
+reportUnsupportedDunderAll = false
+exclude = [
+    "dns/_*_backend.py",
+    "dns/dnssecalgs/*.py",
+    "dns/quic/*.py",
+    "dns/rdtypes/ANY/*.py",
+    "dns/rdtypes/CH/*.py",
+    "dns/rdtypes/IN/*.py",
+    "examples/*.py",
+    "tests/*.py",
+] # (mostly) temporary!