]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Fix replacement txn bugs in non-versioned zones [#732].
authorBob Halley <halley@dnspython.org>
Sun, 28 Nov 2021 22:26:21 +0000 (14:26 -0800)
committerBob Halley <halley@dnspython.org>
Sun, 28 Nov 2021 22:26:21 +0000 (14:26 -0800)
dns/zone.py
tests/test_transaction.py

index d15492870b886e7f16549ca94087b9e8ff678842..2f99b1b7d508525fbcfc3ded68750363e9f4b946 100644 (file)
@@ -805,7 +805,7 @@ class Transaction(dns.transaction.Transaction):
         rdataset = self.rdatasets.get((name, rdtype, covers))
         if rdataset is self._deleted_rdataset:
             return None
-        elif rdataset is None:
+        elif rdataset is None and not self.replacement:
             rdataset = self.zone.get_rdataset(name, rdtype, covers)
         return rdataset
 
@@ -863,6 +863,8 @@ class Transaction(dns.transaction.Transaction):
 
     def _end_transaction(self, commit):
         if commit and self._changed():
+            if self.replacement:
+                self.zone.nodes = {}
             for (name, rdtype, covers), rdataset in \
                 self.rdatasets.items():
                 if rdataset is self._deleted_rdataset:
@@ -877,10 +879,13 @@ class Transaction(dns.transaction.Transaction):
     def _iterate_rdatasets(self):
         # Expensive but simple!  Use a versioned zone for efficient txn
         # iteration.
-        rdatasets = {}
-        for (name, rdataset) in self.zone.iterate_rdatasets():
-            rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
-        rdatasets.update(self.rdatasets)
+        if self.replacement:
+            rdatasets = self.rdatasets
+        else:
+            rdatasets = {}
+            for (name, rdataset) in self.zone.iterate_rdatasets():
+                rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+            rdatasets.update(self.rdatasets)
         for (name, _, _), rdataset in rdatasets.items():
             yield (name, rdataset)
 
index 85aa9868bfee6be4ffc7a08633ecb0f664beceb0..ce533c51061dfc62519613d753795e0c54946778 100644 (file)
@@ -497,6 +497,35 @@ def test_zone_iteration(zone):
             actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
     assert actual == expected
 
+def test_iteration_in_replacement_txn(zone):
+    rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8')
+    expected = {}
+    expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds
+    with zone.writer(True) as txn:
+        txn.replace(dns.name.empty, rds)
+        actual = {}
+        for (name, rdataset) in txn:
+            actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+    assert actual == expected
+
+def test_replacement_commit(zone):
+    rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8')
+    expected = {}
+    expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds
+    with zone.writer(True) as txn:
+        txn.replace(dns.name.empty, rds)
+    with zone.reader() as txn:
+        actual = {}
+        for (name, rdataset) in txn:
+            actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+    assert actual == expected
+
+def test_replacement_get(zone):
+    with zone.writer(True) as txn:
+        rds = txn.get(dns.name.empty, 'soa')
+        assert rds is None
+
+
 @pytest.fixture
 def vzone():
     return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone)