From: Bob Halley Date: Tue, 29 Jul 2025 13:25:39 +0000 (-0700) Subject: typing fixes X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9331e3d79dded11fa49d5c402793bf0dd14f590a;p=thirdparty%2Fdnspython.git typing fixes --- diff --git a/dns/btreezone.py b/dns/btreezone.py index 53c1d9f0..cd8d4b65 100644 --- a/dns/btreezone.py +++ b/dns/btreezone.py @@ -34,8 +34,13 @@ class NodeFlags(enum.IntFlag): class Node(dns.node.Node): __slots__ = ["flags", "id"] - def __init__(self, flags: NodeFlags = 0): # type: ignore + def __init__(self, flags: Optional[NodeFlags] = None): super().__init__() + if flags is None: + # We allow optional flags rather than a default + # as pyright doesn't like assigning a literal 0 + # to flags. + flags = NodeFlags(0) self.flags = flags self.id = 0 @@ -226,8 +231,12 @@ class WritableVersion(dns.zone.WritableVersion): @dns.immutable.immutable class ImmutableVersion(dns.zone.Version): - def __init__(self, version: dns.zone.WritableVersion): - assert isinstance(version, WritableVersion) + def __init__(self, version: dns.zone.Version): + if not isinstance(version, WritableVersion): + raise ValueError( + "a dns.btreezone.ImmutableVersion requires a " + "dns.btreezone.WritableVersion" + ) super().__init__(version.zone, True) self.id = version.id self.origin = version.origin @@ -243,13 +252,14 @@ class ImmutableVersion(dns.zone.Version): class Zone(dns.versioned.Zone): - node_factory: Callable[[], dns.node.Node] = Node # type: ignore - map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = ( - dns.btree.BTreeDict[dns.name.Name, Node] # type: ignore + node_factory: Callable[[], dns.node.Node] = Node + map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast( + Callable[[], MutableMapping[dns.name.Name, dns.node.Node]], + dns.btree.BTreeDict[dns.name.Name, Node], ) writable_version_factory: Optional[ - Callable[[], dns.zone.WritableVersion] - ] = WritableVersion # type: ignore + Callable[[dns.zone.Zone, bool], dns.zone.Version] + ] = WritableVersion immutable_version_factory: Optional[ - Callable[[], dns.zone.ImmutableVersion] - ] = ImmutableVersion # type: ignore + Callable[[dns.zone.Version], dns.zone.Version] + ] = ImmutableVersion diff --git a/dns/versioned.py b/dns/versioned.py index 6479ae47..260eea1b 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -41,7 +41,7 @@ class Zone(dns.zone.Zone): # lgtm[py/missing-equals] "_readers", ] - node_factory = Node + node_factory: Callable[[], dns.node.Node] = Node def __init__( self, diff --git a/dns/zone.py b/dns/zone.py index b1a52f63..cfb89ec4 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -131,8 +131,10 @@ class Zone(dns.transaction.TransactionManager): node_factory: Callable[[], dns.node.Node] = dns.node.Node map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = dict - writable_version_factory: Optional[Callable[[], "WritableVersion"]] = None - immutable_version_factory: Optional[Callable[[], "ImmutableVersion"]] = None + # We only require the version types as "Version" to allow for flexibility, as + # only the version protocol matters + writable_version_factory: Optional[Callable[["Zone", bool], "Version"]] = None + immutable_version_factory: Optional[Callable[["Version"], "Version"]] = None __slots__ = ["rdclass", "origin", "nodes", "relativize"] @@ -1074,7 +1076,11 @@ class WritableVersion(Version): @dns.immutable.immutable class ImmutableVersion(Version): - def __init__(self, version: WritableVersion): + def __init__(self, version: Version): + if not isinstance(version, WritableVersion): + raise ValueError( + "a dns.zone.ImmutableVersion requires a dns.zone.WritableVersion" + ) # We tell super() that it's a replacement as we don't want it # to copy the nodes, as we're about to do that with an # immutable Dict.