From 9224e4a0df26f83e44c9e102bedfaa8479c67cc4 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 6 Dec 2025 10:00:05 -0800 Subject: [PATCH] initial support for "ty" type checker --- Makefile | 9 ++++++++- dns/_asyncio_backend.py | 2 +- dns/_ddr.py | 2 +- dns/_trio_backend.py | 2 +- dns/message.py | 5 +++-- dns/query.py | 2 +- dns/rdataset.py | 2 +- dns/resolver.py | 25 +++++++++++++------------ dns/zone.py | 40 ++++++++++++++++++++++++++-------------- dns/zonefile.py | 3 ++- pyproject.toml | 1 + 11 files changed, 58 insertions(+), 35 deletions(-) diff --git a/Makefile b/Makefile index 98dc35b9..f2b9a6f6 100644 --- a/Makefile +++ b/Makefile @@ -35,12 +35,19 @@ test: check: test -type: +mypy: python -m mypy --disallow-incomplete-defs dns pyright: pyright dns +ty: + ty check dns + +type: + pyright dns + ty check dns + ruff: ruff check dns diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index d7ce1467..cc683985 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -204,7 +204,7 @@ if dns._features.have("doh"): resolver = dns.asyncresolver.Resolver() super().__init__(*args, **kwargs) - self._pool._network_backend = _NetworkBackend( + self._pool._network_backend = _NetworkBackend( # type: ignore resolver, local_port, bootstrap_address, family ) diff --git a/dns/_ddr.py b/dns/_ddr.py index bf5c11eb..ad4249b1 100644 --- a/dns/_ddr.py +++ b/dns/_ddr.py @@ -40,7 +40,7 @@ class _SVCBInfo: def make_tls_context(self): ssl = dns.query.ssl ctx = ssl.create_default_context() - ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 # type: ignore return ctx def ddr_tls_check_sync(self, lifetime): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index bde7e8ba..e057068a 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -181,7 +181,7 @@ if dns._features.have("doh"): resolver = dns.asyncresolver.Resolver() super().__init__(*args, **kwargs) - self._pool._network_backend = _NetworkBackend( + self._pool._network_backend = _NetworkBackend( # type: ignore resolver, local_port, bootstrap_address, family ) diff --git a/dns/message.py b/dns/message.py index 82d4d226..30a84233 100644 --- a/dns/message.py +++ b/dns/message.py @@ -1430,7 +1430,7 @@ class _TextReader: def __init__( self, - text: str, + text: Any, idna_codec: dns.name.IDNACodec | None, one_rr_per_rrset: bool = False, origin: dns.name.Name | None = None, @@ -1742,7 +1742,8 @@ def from_file( else: cm = contextlib.nullcontext(f) with cm as f: - return from_text(f, idna_codec, one_rr_per_rrset) + reader = _TextReader(f, idna_codec, one_rr_per_rrset) + return reader.read() assert False # for mypy lgtm[py/unreachable-statement] diff --git a/dns/query.py b/dns/query.py index 34864d99..9baa0ad8 100644 --- a/dns/query.py +++ b/dns/query.py @@ -143,7 +143,7 @@ if _have_httpx: resolver = dns.resolver.Resolver() super().__init__(*args, **kwargs) - self._pool._network_backend = _NetworkBackend( + self._pool._network_backend = _NetworkBackend( # type: ignore resolver, local_port, bootstrap_address, family ) diff --git a/dns/rdataset.py b/dns/rdataset.py index 1673aad7..19191fcc 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -369,7 +369,7 @@ class ImmutableRdataset(Rdataset): # lgtm[py/missing-equals] def update_ttl(self, ttl): raise TypeError("immutable") - def add(self, rd, ttl=None): + def add(self, rd, ttl=None): # type: ignore raise TypeError("immutable") def union_update(self, other): diff --git a/dns/resolver.py b/dns/resolver.py index f6d239ab..2b908a45 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -1005,6 +1005,7 @@ class BaseResolver: else: cm = contextlib.nullcontext(f) with cm as f: + assert f is not None for l in f: if len(l) == 0 or l[0] == "#" or l[0] == ";": continue @@ -2047,12 +2048,12 @@ def override_system_resolver(resolver: Resolver | None = None) -> None: resolver = get_default_resolver() global _resolver _resolver = resolver - socket.getaddrinfo = _getaddrinfo - socket.getnameinfo = _getnameinfo - socket.getfqdn = _getfqdn - socket.gethostbyname = _gethostbyname - socket.gethostbyname_ex = _gethostbyname_ex - socket.gethostbyaddr = _gethostbyaddr + socket.getaddrinfo = _getaddrinfo # type: ignore + socket.getnameinfo = _getnameinfo # type: ignore + socket.getfqdn = _getfqdn # type: ignore + socket.gethostbyname = _gethostbyname # type: ignore + socket.gethostbyname_ex = _gethostbyname_ex # type: ignore + socket.gethostbyaddr = _gethostbyaddr # type: ignore def restore_system_resolver() -> None: @@ -2060,9 +2061,9 @@ def restore_system_resolver() -> None: global _resolver _resolver = None - socket.getaddrinfo = _original_getaddrinfo - socket.getnameinfo = _original_getnameinfo - socket.getfqdn = _original_getfqdn - socket.gethostbyname = _original_gethostbyname - socket.gethostbyname_ex = _original_gethostbyname_ex - socket.gethostbyaddr = _original_gethostbyaddr + socket.getaddrinfo = _original_getaddrinfo # type: ignore + socket.getnameinfo = _original_getnameinfo # type: ignore + socket.getfqdn = _original_getfqdn # type: ignore + socket.gethostbyname = _original_gethostbyname # type: ignore + socket.gethostbyname_ex = _original_gethostbyname_ex # type: ignore + socket.gethostbyaddr = _original_gethostbyaddr # type: ignore diff --git a/dns/zone.py b/dns/zone.py index 02e1f209..f68e8e85 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -159,7 +159,7 @@ class Zone(dns.transaction.TransactionManager): raise ValueError("origin parameter must be convertible to a DNS name") if not origin.is_absolute(): raise ValueError("origin parameter must be an absolute name") - self.origin = origin + self.origin: dns.name.Name | None = origin self.rdclass = rdclass self.nodes: MutableMapping[dns.name.Name, dns.node.Node] = self.map_factory() self.relativize = relativize @@ -686,8 +686,8 @@ class Zone(dns.transaction.TransactionManager): f.write(l_b) f.write(nl_b) except TypeError: # textual mode - f.write(l) - f.write(nl) + f.write(l) # type: ignore + f.write(nl) # type: ignore if sorted: names = list(self.keys()) @@ -707,8 +707,8 @@ class Zone(dns.transaction.TransactionManager): f.write(l_b) f.write(nl_b) except TypeError: # textual mode - f.write(l) - f.write(nl) + f.write(l) # type: ignore + f.write(nl) # type: ignore def to_text( self, @@ -1104,15 +1104,21 @@ class ImmutableVersion(Version): class Transaction(dns.transaction.Transaction): - def __init__(self, zone, replacement, version=None, make_immutable=False): + def __init__( + self, + zone: Zone, + replacement: bool, + version: Version | None = None, + make_immutable: bool = False, + ): read_only = version is not None super().__init__(zone, replacement, read_only) self.version = version self.make_immutable = make_immutable @property - def zone(self): - return self.manager + def zone(self) -> Zone: + return cast(Zone, self.manager) def _setup_version(self): assert self.version is None @@ -1128,17 +1134,20 @@ class Transaction(dns.transaction.Transaction): def _put_rdataset(self, name, rdataset): assert not self.read_only assert self.version is not None - self.version.put_rdataset(name, rdataset) + version = cast(WritableVersion, self.version) + version.put_rdataset(name, rdataset) def _delete_name(self, name): assert not self.read_only assert self.version is not None - self.version.delete_node(name) + version = cast(WritableVersion, self.version) + version.delete_node(name) def _delete_rdataset(self, name, rdtype, covers): assert not self.read_only assert self.version is not None - self.version.delete_rdataset(name, rdtype, covers) + version = cast(WritableVersion, self.version) + version.delete_rdataset(name, rdtype, covers) def _name_exists(self, name): assert self.version is not None @@ -1149,14 +1158,15 @@ class Transaction(dns.transaction.Transaction): return False else: assert self.version is not None - return len(self.version.changed) > 0 + version = cast(WritableVersion, self.version) + return len(version.changed) > 0 def _end_transaction(self, commit): assert self.zone is not None assert self.version is not None if self.read_only: self.zone._end_read(self) # type: ignore - elif commit and len(self.version.changed) > 0: + elif commit and len(cast(WritableVersion, self.version).changed) > 0: if self.make_immutable: factory = self.manager.immutable_version_factory # type: ignore if factory is None: @@ -1191,7 +1201,9 @@ class Transaction(dns.transaction.Transaction): assert self.version is not None return self.version.get_node(name) - def _origin_information(self): + def _origin_information( + self, + ) -> Tuple[dns.name.Name | None, bool, dns.name.Name | None]: assert self.version is not None (absolute, relativize, effective) = self.manager.origin_information() if absolute is None and self.version.origin is not None: diff --git a/dns/zonefile.py b/dns/zonefile.py index 4654446c..782a935e 100644 --- a/dns/zonefile.py +++ b/dns/zonefile.py @@ -150,9 +150,10 @@ class Reader: raise dns.exception.SyntaxError return token - def _rr_line(self): + def _rr_line(self) -> None: """Process one line from a DNS zone file.""" token: dns.tokenizer.Token + name: dns.name.Name | None # Name if self.force_name is not None: name = self.force_name diff --git a/pyproject.toml b/pyproject.toml index 34ccdcc3..2ed8c3bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dev = [ "sphinx>=8.2.0 ; python_version >= '3.11'", "sphinx-rtd-theme>=3.0.0 ; python_full_version >= '3.11'", "twine>=6.1.0", + "ty>=0.0.1a32", "wheel>=0.45.0", ] dnssec = ["cryptography>=45"] -- 2.47.3