rdataset = self.rdatasets.get((name, rdtype, covers))
if rdataset is self._deleted_rdataset:
return None
- elif rdataset is None:
+ elif rdataset is None and not self.replacement:
rdataset = self.zone.get_rdataset(name, rdtype, covers)
return rdataset
def _end_transaction(self, commit):
if commit and self._changed():
+ if self.replacement:
+ self.zone.nodes = {}
for (name, rdtype, covers), rdataset in \
self.rdatasets.items():
if rdataset is self._deleted_rdataset:
def _iterate_rdatasets(self):
# Expensive but simple! Use a versioned zone for efficient txn
# iteration.
- rdatasets = {}
- for (name, rdataset) in self.zone.iterate_rdatasets():
- rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
- rdatasets.update(self.rdatasets)
+ if self.replacement:
+ rdatasets = self.rdatasets
+ else:
+ rdatasets = {}
+ for (name, rdataset) in self.zone.iterate_rdatasets():
+ rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ rdatasets.update(self.rdatasets)
for (name, _, _), rdataset in rdatasets.items():
yield (name, rdataset)
actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
assert actual == expected
+def test_iteration_in_replacement_txn(zone):
+ rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8')
+ expected = {}
+ expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds
+ with zone.writer(True) as txn:
+ txn.replace(dns.name.empty, rds)
+ actual = {}
+ for (name, rdataset) in txn:
+ actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ assert actual == expected
+
+def test_replacement_commit(zone):
+ rds = dns.rdataset.from_text('in', 'a', 300, '1.2.3.4', '5.6.7.8')
+ expected = {}
+ expected[(dns.name.empty, rds.rdtype, rds.covers)] = rds
+ with zone.writer(True) as txn:
+ txn.replace(dns.name.empty, rds)
+ with zone.reader() as txn:
+ actual = {}
+ for (name, rdataset) in txn:
+ actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ assert actual == expected
+
+def test_replacement_get(zone):
+ with zone.writer(True) as txn:
+ rds = txn.get(dns.name.empty, 'soa')
+ assert rds is None
+
+
@pytest.fixture
def vzone():
return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone)