]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
If we rollback a write, release the write txn and wake someone up.
authorBob Halley <halley@dnspython.org>
Wed, 12 Aug 2020 13:59:59 +0000 (06:59 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 12 Aug 2020 13:59:59 +0000 (06:59 -0700)
Don't allow pruning to prune any version >= the version of an active reader.
(This isn't a bug fix as the reader was safe before, but this ensures that
the reader can open a successor version if needed.)

dns/versioned.py
tests/test_transaction.py

index ae921f12d68487e9be07cca0f2087a2e8e6e5cb2..e7534386ba0c665c8bfc717a3cd6757073605c3f 100644 (file)
@@ -169,7 +169,8 @@ class ImmutableNode(Node):
 class Zone(dns.zone.Zone):
 
     __slots__ = ['_versions', '_versions_lock', '_write_txn',
-                 '_write_waiters', '_write_event', '_pruning_policy']
+                 '_write_waiters', '_write_event', '_pruning_policy',
+                 '_readers']
 
     node_factory = Node
 
@@ -200,7 +201,8 @@ class Zone(dns.zone.Zone):
         self._write_txn = None
         self._write_event = None
         self._write_waiters = collections.deque()
-        self._commit_version_unlocked(WritableVersion(self), origin)
+        self._readers = set()
+        self._commit_version_unlocked(None, WritableVersion(self), origin)
 
     def reader(self, id=None, serial=None):  # pylint: disable=arguments-differ
         if id is not None and serial is not None:
@@ -231,7 +233,9 @@ class Zone(dns.zone.Zone):
                     raise KeyError('serial not found')
             else:
                 version = self._versions[-1]
-            return Transaction(False, self, version)
+            txn = Transaction(False, self, version)
+            self._readers.add(txn)
+            return txn
 
     def writer(self, replacement=False):
         event = None
@@ -291,7 +295,19 @@ class Zone(dns.zone.Zone):
     # pylint: enable=unused-argument
 
     def _prune_versions_unlocked(self):
-        while len(self._versions) > 1 and \
+        assert len(self._versions) > 0
+        # Don't ever prune a version greater than or equal to one that
+        # a reader has open.  This pins versions in memory while the
+        # reader is open, and importantly lets the reader open a txn on
+        # a successor version (e.g. if generating an IXFR).
+        #
+        # Note our definition of least_kept also ensures we do not try to
+        # delete the greatest version.
+        if len(self._readers) > 0:
+            least_kept = min(txn.version.id for txn in self._readers)
+        else:
+            least_kept = self._versions[-1].id
+        while self._versions[0].id < least_kept and \
               self._pruning_policy(self, self._versions[0]):
             self._versions.popleft()
 
@@ -327,18 +343,33 @@ class Zone(dns.zone.Zone):
             self._pruning_policy = policy
             self._prune_versions_unlocked()
 
-    def _commit_version_unlocked(self, version, origin):
+    def _end_read(self, txn):
+        with self._version_lock:
+            self._readers.remove(txn)
+            self._prune_versions_unlocked()
+
+    def _end_write_unlocked(self, txn):
+        assert self._write_txn == txn
+        self._write_txn = None
+        self._maybe_wakeup_one_waiter_unlocked()
+
+    def _end_write(self, txn):
+        with self._version_lock:
+            self._end_write_unlocked(txn)
+
+    def _commit_version_unlocked(self, txn, version, origin):
         self._versions.append(version)
         self._prune_versions_unlocked()
         self.nodes = version.nodes
         if self.origin is None:
             self.origin = origin
-        self._write_txn = None
-        self._maybe_wakeup_one_waiter_unlocked()
+        # txn can be None in __init__ when we make the empty version.
+        if txn is not None:
+            self._end_write_unlocked(txn)
 
-    def _commit_version(self, version, origin):
+    def _commit_version(self, txn, version, origin):
         with self._version_lock:
-            self._commit_version_unlocked(version, origin)
+            self._commit_version_unlocked(txn, version, origin)
 
     def find_node(self, name, create=False):
         if create:
@@ -407,10 +438,13 @@ class Transaction(dns.transaction.Transaction):
 
     def _end_transaction(self, commit):
         if self.read_only:
-            return
-        if commit and len(self.version.changed) > 0:
-            self.zone._commit_version(ImmutableVersion(self.version),
+            self.zone._end_read(self)
+        elif commit and len(self.version.changed) > 0:
+            self.zone._commit_version(self, ImmutableVersion(self.version),
                                       self.version.origin)
+        else:
+            # rollback
+            self.zone._end_write(self)
 
     def _set_origin(self, origin):
         if self.version.origin is None:
index 888fbd59c4c22a4775b9692b4ca6f0b128c334a5..d782f21e948c9dba2ca16d21adab609e2fd10a72 100644 (file)
@@ -294,6 +294,19 @@ def test_zone_add_and_delete(zone):
         assert not txn.name_exists(a99)
         assert txn.name_exists(a100)
 
+def test_write_after_rollback(zone):
+    with pytest.raises(ExpectedException):
+        with zone.writer() as txn:
+            a99 = dns.name.from_text('a99', None)
+            rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+            txn.add(a99, rds)
+            raise ExpectedException
+    with zone.writer() as txn:
+        a99 = dns.name.from_text('a99', None)
+        rds = dns.rdataset.from_text('in', 'a', 300, '10.99.99.99')
+        txn.add(a99, rds)
+    assert zone.get_rdataset('a99', 'a') == rds
+
 def test_zone_get_deleted(zone):
     with zone.writer() as txn:
         print(zone.to_text())
@@ -415,7 +428,7 @@ def test_vzone_multiple_versions(vzone):
     # The ones that survived should be 3 and 1000
     rdataset = vzone._versions[0].get_rdataset(dns.name.empty,
                                                dns.rdatatype.SOA,
-                                              dns.rdatatype.NONE)
+                                               dns.rdatatype.NONE)
     assert rdataset[0].serial == 3
     rdataset = vzone._versions[1].get_rdataset(dns.name.empty,
                                                dns.rdatatype.SOA,
@@ -424,6 +437,35 @@ def test_vzone_multiple_versions(vzone):
     with pytest.raises(ValueError):
         vzone.set_max_versions(0)
 
+# for debugging if needed
+def _dump(zone):
+    for v in zone._versions:
+        print('VERSION', v.id)
+        for (name, n) in v.nodes.items():
+            for rdataset in n:
+                print(rdataset.to_text(name))
+
+def test_vzone_open_txn_pins_versions(vzone):
+    assert len(vzone._versions) == 1
+    vzone.set_max_versions(None)  # unlimited!
+    with vzone.writer() as txn:
+        txn.set_serial()
+    with vzone.writer() as txn:
+        txn.set_serial()
+    with vzone.writer() as txn:
+        txn.set_serial()
+    with vzone.reader(id=2) as txn:
+        vzone.set_max_versions(1)
+        with vzone.reader(id=3) as txn:
+            rdataset = txn.get('@', 'in', 'soa')
+            assert rdataset[0].serial == 2
+            assert len(vzone._versions) == 4
+    assert len(vzone._versions) == 1
+    rdataset = vzone.find_rdataset('@', 'soa')
+    assert vzone._versions[0].id == 5
+    assert rdataset[0].serial == 4
+
+
 try:
     import threading