]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add names iteration to transactions via iterate_names(). (#907)
authorBob Halley <halley@dnspython.org>
Tue, 14 Mar 2023 20:26:51 +0000 (13:26 -0700)
committerGitHub <noreply@github.com>
Tue, 14 Mar 2023 20:26:51 +0000 (13:26 -0700)
Also make rdataset iteration more obvious by adding an
explicit iterate_rdatasets() API.

dns/transaction.py
dns/zone.py
dns/zonefile.py
tests/test_transaction.py

index c4a9e1f6289324fa68544849ce0fe9145d5b7237..91ed7329c03e56a39e4e5a76955497101e059267 100644 (file)
@@ -1,6 +1,6 @@
 # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
 
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
 
 import collections
 
@@ -357,6 +357,27 @@ class Transaction:
         """
         self._check_delete_name.append(check)
 
+    def iterate_rdatasets(
+        self,
+    ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]:
+        """Iterate all the rdatasets in the transaction, returning
+        (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples.
+
+        Note that as is usual with python iterators, adding or removing items
+        while iterating will invalidate the iterator and may raise `RuntimeError`
+        or fail to iterate over all entries."""
+        self._check_ended()
+        return self._iterate_rdatasets()
+
+    def iterate_names(self) -> Iterator[dns.name.Name]:
+        """Iterate all the names in the transaction.
+
+        Note that as is usual with python iterators, adding or removing names
+        while iterating will invalidate the iterator and may raise `RuntimeError`
+        or fail to iterate over all entries."""
+        self._check_ended()
+        return self._iterate_names()
+
     #
     # Helper methods
     #
@@ -610,6 +631,10 @@ class Transaction:
         """Return an iterator that yields (name, rdataset) tuples."""
         raise NotImplementedError  # pragma: no cover
 
+    def _iterate_names(self):
+        """Return an iterator that yields a name."""
+        raise NotImplementedError  # pragma: no cover
+
     def _get_node(self, name):
         """Return the node at *name*, if any.
 
index cc8268da346fe084daa774e6d38c605a7c4ef4aa..35724d7783431405a431ca0616ba4616196993dc 100644 (file)
@@ -565,7 +565,7 @@ class Zone(dns.transaction.TransactionManager):
 
         rdtype = dns.rdatatype.RdataType.make(rdtype)
         covers = dns.rdatatype.RdataType.make(covers)
-        for (name, node) in self.items():
+        for name, node in self.items():
             for rds in node:
                 if rdtype == dns.rdatatype.ANY or (
                     rds.rdtype == rdtype and rds.covers == covers
@@ -597,7 +597,7 @@ class Zone(dns.transaction.TransactionManager):
 
         rdtype = dns.rdatatype.RdataType.make(rdtype)
         covers = dns.rdatatype.RdataType.make(covers)
-        for (name, node) in self.items():
+        for name, node in self.items():
             for rds in node:
                 if rdtype == dns.rdatatype.ANY or (
                     rds.rdtype == rdtype and rds.covers == covers
@@ -795,7 +795,7 @@ class Zone(dns.transaction.TransactionManager):
             assert self.origin is not None
             origin_name = self.origin
         hasher = hashinfo()
-        for (name, node) in sorted(self.items()):
+        for name, node in sorted(self.items()):
             rrnamebuf = name.to_digestable(self.origin)
             for rdataset in sorted(node, key=lambda rds: (rds.rdtype, rds.covers)):
                 if name == origin_name and dns.rdatatype.ZONEMD in (
@@ -997,6 +997,9 @@ class Version:
             return None
         return node.get_rdataset(self.zone.rdclass, rdtype, covers)
 
+    def keys(self):
+        return self.nodes.keys()
+
     def items(self):
         return self.nodes.items()
 
@@ -1143,10 +1146,13 @@ class Transaction(dns.transaction.Transaction):
             self.version.origin = origin
 
     def _iterate_rdatasets(self):
-        for (name, node) in self.version.items():
+        for name, node in self.version.items():
             for rdataset in node:
                 yield (name, rdataset)
 
+    def _iterate_names(self):
+        return self.version.keys()
+
     def _get_node(self, name):
         return self.version.get_node(name)
 
index 1a53f5bcfca59906d7eb96b6dc1ff2cd50b4f596..fad78c3e6a1dc88add01e9588fdb45c62de871a6 100644 (file)
@@ -581,7 +581,7 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
             pass
 
     def _name_exists(self, name):
-        for (n, _, _) in self.rdatasets:
+        for n, _, _ in self.rdatasets:
             if n == name:
                 return True
         return False
@@ -606,6 +606,9 @@ class RRsetsReaderTransaction(dns.transaction.Transaction):
     def _iterate_rdatasets(self):
         raise NotImplementedError  # pragma: no cover
 
+    def _iterate_names(self):
+        raise NotImplementedError  # pragma: no cover
+
 
 class RRSetsReaderManager(dns.transaction.TransactionManager):
     def __init__(
index 8e2744abbabef01988321b766c266efe503087ac..80559bd6509f4c5b83675f947e40817fa91d9f5f 100644 (file)
@@ -499,12 +499,23 @@ def test_zone_ooz_name(zone):
 
 def test_zone_iteration(zone):
     expected = {}
-    for (name, rdataset) in zone.iterate_rdatasets():
+    for name, rdataset in zone.iterate_rdatasets():
         expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset
     with zone.writer() as txn:
-        actual = {}
-        for (name, rdataset) in txn:
-            actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+        actual1 = {}
+        for name, rdataset in txn:
+            actual1[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+        actual2 = {}
+        for name, rdataset in txn.iterate_rdatasets():
+            actual2[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+    assert actual1 == expected
+    assert actual2 == expected
+
+
+def test_zone_name_iteration(zone):
+    expected = list(zone.keys())
+    with zone.writer() as txn:
+        actual = list(txn.iterate_names())
     assert actual == expected
 
 
@@ -515,7 +526,7 @@ def test_iteration_in_replacement_txn(zone):
     with zone.writer(True) as txn:
         txn.replace(dns.name.empty, rds)
         actual = {}
-        for (name, rdataset) in txn:
+        for name, rdataset in txn:
             actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
     assert actual == expected
 
@@ -528,7 +539,7 @@ def test_replacement_commit(zone):
         txn.replace(dns.name.empty, rds)
     with zone.reader() as txn:
         actual = {}
-        for (name, rdataset) in txn:
+        for name, rdataset in txn:
             actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
     assert actual == expected
 
@@ -592,7 +603,7 @@ def test_vzone_multiple_versions(vzone):
 def _dump(zone):
     for v in zone._versions:
         print("VERSION", v.id)
-        for (name, n) in v.nodes.items():
+        for name, n in v.nodes.items():
             for rdataset in n:
                 print(rdataset.to_text(name))