class Zone(dns.zone.Zone):
__slots__ = ['_versions', '_versions_lock', '_write_txn',
- '_write_waiters', '_write_event', '_pruning_policy']
+ '_write_waiters', '_write_event', '_pruning_policy',
+ '_readers']
node_factory = Node
self._write_txn = None
self._write_event = None
self._write_waiters = collections.deque()
- self._commit_version_unlocked(WritableVersion(self), origin)
+ self._readers = set()
+ self._commit_version_unlocked(None, WritableVersion(self), origin)
def reader(self, id=None, serial=None): # pylint: disable=arguments-differ
if id is not None and serial is not None:
raise KeyError('serial not found')
else:
version = self._versions[-1]
- return Transaction(False, self, version)
+ txn = Transaction(False, self, version)
+ self._readers.add(txn)
+ return txn
def writer(self, replacement=False):
event = None
# pylint: enable=unused-argument
def _prune_versions_unlocked(self):
- while len(self._versions) > 1 and \
+ assert len(self._versions) > 0
+ # Don't ever prune a version greater than or equal to one that
+ # a reader has open. This pins versions in memory while the
+ # reader is open, and importantly lets the reader open a txn on
+ # a successor version (e.g. if generating an IXFR).
+ #
+ # Note our definition of least_kept also ensures we do not try to
+ # delete the greatest version.
+ if len(self._readers) > 0:
+ least_kept = min(txn.version.id for txn in self._readers)
+ else:
+ least_kept = self._versions[-1].id
+ while self._versions[0].id < least_kept and \
self._pruning_policy(self, self._versions[0]):
self._versions.popleft()
self._pruning_policy = policy
self._prune_versions_unlocked()
- def _commit_version_unlocked(self, version, origin):
+ def _end_read(self, txn):
+ with self._version_lock:
+ self._readers.remove(txn)
+ self._prune_versions_unlocked()
+
+ def _end_write_unlocked(self, txn):
+ assert self._write_txn == txn
+ self._write_txn = None
+ self._maybe_wakeup_one_waiter_unlocked()
+
+ def _end_write(self, txn):
+ with self._version_lock:
+ self._end_write_unlocked(txn)
+
+ def _commit_version_unlocked(self, txn, version, origin):
self._versions.append(version)
self._prune_versions_unlocked()
self.nodes = version.nodes
if self.origin is None:
self.origin = origin
- self._write_txn = None
- self._maybe_wakeup_one_waiter_unlocked()
+ # txn can be None in __init__ when we make the empty version.
+ if txn is not None:
+ self._end_write_unlocked(txn)
- def _commit_version(self, version, origin):
+ def _commit_version(self, txn, version, origin):
with self._version_lock:
- self._commit_version_unlocked(version, origin)
+ self._commit_version_unlocked(txn, version, origin)
def find_node(self, name, create=False):
if create:
def _end_transaction(self, commit):
if self.read_only:
- return
- if commit and len(self.version.changed) > 0:
- self.zone._commit_version(ImmutableVersion(self.version),
+ self.zone._end_read(self)
+ elif commit and len(self.version.changed) > 0:
+ self.zone._commit_version(self, ImmutableVersion(self.version),
self.version.origin)
+ else:
+ # rollback
+ self.zone._end_write(self)
def _set_origin(self, origin):
if self.version.origin is None:
assert not txn.name_exists(a99)
assert txn.name_exists(a100)
+def test_write_after_rollback(zone):
+ with pytest.raises(ExpectedException):
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99', None)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add(a99, rds)
+ raise ExpectedException
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99', None)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.99.99.99')
+ txn.add(a99, rds)
+ assert zone.get_rdataset('a99', 'a') == rds
+
def test_zone_get_deleted(zone):
with zone.writer() as txn:
print(zone.to_text())
# The ones that survived should be 3 and 1000
rdataset = vzone._versions[0].get_rdataset(dns.name.empty,
dns.rdatatype.SOA,
- dns.rdatatype.NONE)
+ dns.rdatatype.NONE)
assert rdataset[0].serial == 3
rdataset = vzone._versions[1].get_rdataset(dns.name.empty,
dns.rdatatype.SOA,
with pytest.raises(ValueError):
vzone.set_max_versions(0)
+# for debugging if needed
+def _dump(zone):
+ for v in zone._versions:
+ print('VERSION', v.id)
+ for (name, n) in v.nodes.items():
+ for rdataset in n:
+ print(rdataset.to_text(name))
+
+def test_vzone_open_txn_pins_versions(vzone):
+ assert len(vzone._versions) == 1
+ vzone.set_max_versions(None) # unlimited!
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.reader(id=2) as txn:
+ vzone.set_max_versions(1)
+ with vzone.reader(id=3) as txn:
+ rdataset = txn.get('@', 'in', 'soa')
+ assert rdataset[0].serial == 2
+ assert len(vzone._versions) == 4
+ assert len(vzone._versions) == 1
+ rdataset = vzone.find_rdataset('@', 'soa')
+ assert vzone._versions[0].id == 5
+ assert rdataset[0].serial == 4
+
+
try:
import threading