From 9c91b29cb4505650b02cc45cb7389e232ecb9af1 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Wed, 12 Aug 2020 06:59:59 -0700 Subject: [PATCH] If we rollback a write, release the write txn and wake someone up. Don't allow pruning to prune any version >= the version of an active reader. (This isn't a bug fix as the reader was safe before, but this ensures that the reader can open a successor version if needed.) --- dns/versioned.py | 58 +++++++++++++++++++++++++++++++-------- tests/test_transaction.py | 44 ++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 13 deletions(-) diff --git a/dns/versioned.py b/dns/versioned.py index ae921f12..e7534386 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -169,7 +169,8 @@ class ImmutableNode(Node): class Zone(dns.zone.Zone): __slots__ = ['_versions', '_versions_lock', '_write_txn', - '_write_waiters', '_write_event', '_pruning_policy'] + '_write_waiters', '_write_event', '_pruning_policy', + '_readers'] node_factory = Node @@ -200,7 +201,8 @@ class Zone(dns.zone.Zone): self._write_txn = None self._write_event = None self._write_waiters = collections.deque() - self._commit_version_unlocked(WritableVersion(self), origin) + self._readers = set() + self._commit_version_unlocked(None, WritableVersion(self), origin) def reader(self, id=None, serial=None): # pylint: disable=arguments-differ if id is not None and serial is not None: @@ -231,7 +233,9 @@ class Zone(dns.zone.Zone): raise KeyError('serial not found') else: version = self._versions[-1] - return Transaction(False, self, version) + txn = Transaction(False, self, version) + self._readers.add(txn) + return txn def writer(self, replacement=False): event = None @@ -291,7 +295,19 @@ class Zone(dns.zone.Zone): # pylint: enable=unused-argument def _prune_versions_unlocked(self): - while len(self._versions) > 1 and \ + assert len(self._versions) > 0 + # Don't ever prune a version greater than or equal to one that + # a reader has open. This pins versions in memory while the + # reader is open, and importantly lets the reader open a txn on + # a successor version (e.g. if generating an IXFR). + # + # Note our definition of least_kept also ensures we do not try to + # delete the greatest version. + if len(self._readers) > 0: + least_kept = min(txn.version.id for txn in self._readers) + else: + least_kept = self._versions[-1].id + while self._versions[0].id < least_kept and \ self._pruning_policy(self, self._versions[0]): self._versions.popleft() @@ -327,18 +343,33 @@ class Zone(dns.zone.Zone): self._pruning_policy = policy self._prune_versions_unlocked() - def _commit_version_unlocked(self, version, origin): + def _end_read(self, txn): + with self._version_lock: + self._readers.remove(txn) + self._prune_versions_unlocked() + + def _end_write_unlocked(self, txn): + assert self._write_txn == txn + self._write_txn = None + self._maybe_wakeup_one_waiter_unlocked() + + def _end_write(self, txn): + with self._version_lock: + self._end_write_unlocked(txn) + + def _commit_version_unlocked(self, txn, version, origin): self._versions.append(version) self._prune_versions_unlocked() self.nodes = version.nodes if self.origin is None: self.origin = origin - self._write_txn = None - self._maybe_wakeup_one_waiter_unlocked() + # txn can be None in __init__ when we make the empty version. + if txn is not None: + self._end_write_unlocked(txn) - def _commit_version(self, version, origin): + def _commit_version(self, txn, version, origin): with self._version_lock: - self._commit_version_unlocked(version, origin) + self._commit_version_unlocked(txn, version, origin) def find_node(self, name, create=False): if create: @@ -407,10 +438,13 @@ class Transaction(dns.transaction.Transaction): def _end_transaction(self, commit): if self.read_only: - return - if commit and len(self.version.changed) > 0: - self.zone._commit_version(ImmutableVersion(self.version), + self.zone._end_read(self) + elif commit and len(self.version.changed) > 0: + self.zone._commit_version(self, ImmutableVersion(self.version), self.version.origin) + else: + # rollback + self.zone._end_write(self) def _set_origin(self, origin): if self.version.origin is None: diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 888fbd59..d782f21e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -294,6 +294,19 @@ def test_zone_add_and_delete(zone): assert not txn.name_exists(a99) assert txn.name_exists(a100) +def test_write_after_rollback(zone): + with pytest.raises(ExpectedException): + with zone.writer() as txn: + a99 = dns.name.from_text('a99', None) + rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + txn.add(a99, rds) + raise ExpectedException + with zone.writer() as txn: + a99 = dns.name.from_text('a99', None) + rds = dns.rdataset.from_text('in', 'a', 300, '10.99.99.99') + txn.add(a99, rds) + assert zone.get_rdataset('a99', 'a') == rds + def test_zone_get_deleted(zone): with zone.writer() as txn: print(zone.to_text()) @@ -415,7 +428,7 @@ def test_vzone_multiple_versions(vzone): # The ones that survived should be 3 and 1000 rdataset = vzone._versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA, - dns.rdatatype.NONE) + dns.rdatatype.NONE) assert rdataset[0].serial == 3 rdataset = vzone._versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA, @@ -424,6 +437,35 @@ def test_vzone_multiple_versions(vzone): with pytest.raises(ValueError): vzone.set_max_versions(0) +# for debugging if needed +def _dump(zone): + for v in zone._versions: + print('VERSION', v.id) + for (name, n) in v.nodes.items(): + for rdataset in n: + print(rdataset.to_text(name)) + +def test_vzone_open_txn_pins_versions(vzone): + assert len(vzone._versions) == 1 + vzone.set_max_versions(None) # unlimited! + with vzone.writer() as txn: + txn.set_serial() + with vzone.writer() as txn: + txn.set_serial() + with vzone.writer() as txn: + txn.set_serial() + with vzone.reader(id=2) as txn: + vzone.set_max_versions(1) + with vzone.reader(id=3) as txn: + rdataset = txn.get('@', 'in', 'soa') + assert rdataset[0].serial == 2 + assert len(vzone._versions) == 4 + assert len(vzone._versions) == 1 + rdataset = vzone.find_rdataset('@', 'soa') + assert vzone._versions[0].id == 5 + assert rdataset[0].serial == 4 + + try: import threading -- 2.47.3