From: Bob Halley Date: Wed, 19 Aug 2020 23:46:30 +0000 (-0700) Subject: Txns and txn managers have a single RdataClass X-Git-Tag: v2.1.0rc1~62 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=264668a768904855e8c230860ebcb127c04e5256;p=thirdparty%2Fdnspython.git Txns and txn managers have a single RdataClass --- diff --git a/dns/transaction.py b/dns/transaction.py index db32e9dc..c6c2f0f7 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -44,6 +44,11 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover + def get_class(self): + """The class of the transaction manager. + """ + raise NotImplementedError # pragma: no cover + class DeleteNotExact(dns.exception.DNSException): """Existing data did not match data specified by an exact delete.""" @@ -69,18 +74,17 @@ class Transaction: # This is the high level API # - def get(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE): - """Return the rdataset associated with *name*, *rdclass*, *rdtype*, - and *covers*, or `None` if not found. + def get(self, name, rdtype, covers=dns.rdatatype.NONE): + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. Note that the returned rdataset is immutable. """ self._check_ended() if isinstance(name, str): name = dns.name.from_text(name, None) - rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) - rdataset = self._get_rdataset(name, rdclass, rdtype, covers) + rdataset = self._get_rdataset(name, rdtype, covers) if rdataset is not None and \ not isinstance(rdataset, dns.rdataset.ImmutableRdataset): rdataset = dns.rdataset.ImmutableRdataset(rdataset) @@ -178,8 +182,7 @@ class Transaction: name = dns.name.from_text(name, None) return self._name_exists(name) - def update_serial(self, value=1, relative=True, name=dns.name.empty, - rdclass=dns.rdataclass.IN): + def update_serial(self, value=1, relative=True, name=dns.name.empty): """Update the serial number. *value*, an `int`, is an increment if *relative* is `True`, or the @@ -196,7 +199,7 @@ class Transaction: raise ValueError('negative update_serial() value') if isinstance(name, str): name = dns.name.from_text(name, None) - rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA, + rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE) if rdataset is None or len(rdataset) == 0: raise KeyError @@ -311,10 +314,12 @@ class Transaction: else: raise TypeError(f'{method} requires a name or RRset ' + 'as the first argument') + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f'{method} has objects of wrong RdataClass') self._raise_if_not_empty(method, args) if not replace: - existing = self._get_rdataset(name, rdataset.rdclass, - rdataset.rdtype, rdataset.covers) + existing = self._get_rdataset(name, rdataset.rdtype, + rdataset.covers) if existing is not None: if isinstance(existing, dns.rdataset.ImmutableRdataset): trds = dns.rdataset.Rdataset(existing.rdclass, @@ -341,20 +346,19 @@ class Transaction: name = arg if len(args) > 0 and (isinstance(args[0], int) or isinstance(args[0], str)): - # deleting by type and class - rdclass = dns.rdataclass.RdataClass.make(args.popleft()) + # deleting by type and (optionally) covers rdtype = dns.rdatatype.RdataType.make(args.popleft()) if len(args) > 0: covers = dns.rdatatype.RdataType.make(args.popleft()) else: covers = dns.rdatatype.NONE self._raise_if_not_empty(method, args) - existing = self._get_rdataset(name, rdclass, rdtype, covers) + existing = self._get_rdataset(name, rdtype, covers) if existing is None: if exact: raise DeleteNotExact(f'{method}: missing rdataset') else: - self._delete_rdataset(name, rdclass, rdtype, covers) + self._delete_rdataset(name, rdtype, covers) return else: rdataset = self._rdataset_from_args(method, True, args) @@ -366,8 +370,11 @@ class Transaction: 'as the first argument') self._raise_if_not_empty(method, args) if rdataset: - existing = self._get_rdataset(name, rdataset.rdclass, - rdataset.rdtype, rdataset.covers) + if rdataset.rdclass != self.manager.get_class(): + raise ValueError(f'{method} has objects of wrong ' + 'RdataClass') + existing = self._get_rdataset(name, rdataset.rdtype, + rdataset.covers) if existing is not None: if exact: intersection = existing.intersection(rdataset) @@ -375,8 +382,8 @@ class Transaction: raise DeleteNotExact(f'{method}: missing rdatas') rdataset = existing.difference(rdataset) if len(rdataset) == 0: - self._delete_rdataset(name, rdataset.rdclass, - rdataset.rdtype, rdataset.covers) + self._delete_rdataset(name, rdataset.rdtype, + rdataset.covers) else: self._put_rdataset(name, rdataset) elif exact: @@ -421,9 +428,10 @@ class Transaction: # of Transaction. # - def _get_rdataset(self, name, rdclass, rdtype, covers): - """Return the rdataset associated with *name*, *rdclass*, *rdtype*, - and *covers*, or `None` if not found.""" + def _get_rdataset(self, name, rdtype, covers): + """Return the rdataset associated with *name*, *rdtype*, and *covers*, + or `None` if not found. + """ raise NotImplementedError # pragma: no cover def _put_rdataset(self, name, rdataset): @@ -437,9 +445,8 @@ class Transaction: """ raise NotImplementedError # pragma: no cover - def _delete_rdataset(self, name, rdclass, rdtype, covers): - """Delete all data associated with *name*, *rdclass*, *rdtype*, and - *covers*. + def _delete_rdataset(self, name, rdtype, covers): + """Delete all data associated with *name*, *rdtype*, and *covers*. It is not an error if the rdataset does not exist. """ diff --git a/dns/versioned.py b/dns/versioned.py index ff3d7020..9f0caa1c 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -416,24 +416,18 @@ class Transaction(dns.transaction.Transaction): assert self.version is None self.version = WritableVersion(self.zone, self.replacement) - def _get_rdataset(self, name, rdclass, rdtype, covers): - if rdclass != self.zone.rdclass: - raise ValueError(f'class {rdclass} != ' + - f'zone class {self.zone.rdclass}') + def _get_rdataset(self, name, rdtype, covers): return self.version.get_rdataset(name, rdtype, covers) def _put_rdataset(self, name, rdataset): assert not self.read_only - if rdataset.rdclass != self.zone.rdclass: - raise ValueError(f'rdataset class {rdataset.rdclass} != ' + - f'zone class {self.zone.rdclass}') self.version.put_rdataset(name, rdataset) def _delete_name(self, name): assert not self.read_only self.version.delete_node(name) - def _delete_rdataset(self, name, rdclass, rdtype, covers): + def _delete_rdataset(self, name, rdtype, covers): assert not self.read_only self.version.delete_rdataset(name, rdtype, covers) diff --git a/dns/zone.py b/dns/zone.py index 39ab1a60..427184eb 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -651,6 +651,9 @@ class Zone(dns.transaction.TransactionManager): def origin_information(self): return (self.origin, self.relativize) + def get_class(self): + return self.rdclass + class Transaction(dns.transaction.Transaction): @@ -665,10 +668,7 @@ class Transaction(dns.transaction.Transaction): def zone(self): return self.manager - def _get_rdataset(self, name, rdclass, rdtype, covers): - if rdclass != self.zone.rdclass: - raise ValueError(f'class {rdclass} != ' + - f'zone class {self.zone.rdclass}') + def _get_rdataset(self, name, rdtype, covers): rdataset = self.rdatasets.get((name, rdtype, covers)) if rdataset is self._deleted_rdataset: return None @@ -679,9 +679,6 @@ class Transaction(dns.transaction.Transaction): def _put_rdataset(self, name, rdataset): assert not self.read_only self.zone._validate_name(name) - if rdataset.rdclass != self.zone.rdclass: - raise ValueError(f'rdataset class {rdataset.rdclass} != ' + - f'zone class {self.zone.rdclass}') self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset def _delete_name(self, name): @@ -702,11 +699,8 @@ class Transaction(dns.transaction.Transaction): self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \ self._deleted_rdataset - def _delete_rdataset(self, name, rdclass, rdtype, covers): + def _delete_rdataset(self, name, rdtype, covers): assert not self.read_only - # The high-level code always does a _get_rdataset() before any - # situation where it would call _delete_rdataset(), so we don't - # need to check if rdclass != self.zone.rdclass. try: del self.rdatasets[(name, rdtype, covers)] except KeyError: diff --git a/tests/test_transaction.py b/tests/test_transaction.py index bf7b130b..7fb353cd 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -26,22 +26,26 @@ class DB(dns.transaction.TransactionManager): def origin_information(self): return (None, True) + def get_class(self): + return dns.rdataclass.IN + class Transaction(dns.transaction.Transaction): def __init__(self, db, replacement, read_only): - super().__init__(replacement) - self.db = db + super().__init__(db, replacement, read_only) self.rdatasets = {} - self.read_only = read_only if not replacement: self.rdatasets.update(db.rdatasets) - def _get_rdataset(self, name, rdclass, rdtype, covers): - return self.rdatasets.get((name, rdclass, rdtype, covers)) + @property + def db(self): + return self.manager + + def _get_rdataset(self, name, rdtype, covers): + return self.rdatasets.get((name, rdtype, covers)) def _put_rdataset(self, name, rdataset): - self.rdatasets[(name, rdataset.rdclass, rdataset.rdtype, - rdataset.covers)] = rdataset + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset def _delete_name(self, name): remove = [] @@ -52,8 +56,8 @@ class Transaction(dns.transaction.Transaction): for key in remove: del self.rdatasets[key] - def _delete_rdataset(self, name, rdclass, rdtype, covers): - del self.rdatasets[(name, rdclass, rdtype, covers)] + def _delete_rdataset(self, name, rdtype, covers): + del self.rdatasets[(name, rdtype, covers)] def _name_exists(self, name): for key in self.rdatasets.keys(): @@ -78,7 +82,7 @@ class Transaction(dns.transaction.Transaction): def db(): db = DB() rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') - db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset + db.rdatasets[(rrset.name, rrset.rdtype, 0)] = rrset return db def test_basic(db): @@ -88,7 +92,7 @@ def test_basic(db): '10.0.0.1', '10.0.0.2') txn.add(rrset) assert txn.name_exists(rrset.name) - assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ rrset # rollback with pytest.raises(Exception): @@ -97,17 +101,17 @@ def test_basic(db): '10.0.0.3', '10.0.0.4') txn.add(rrset2) raise Exception() - assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ rrset with db.writer() as txn: txn.delete(rrset.name) - assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \ + assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \ is None def test_get(db): with db.writer() as txn: content = dns.name.from_text('content', None) - rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT) + rdataset = txn.get(content, dns.rdatatype.TXT) assert rdataset is not None assert rdataset[0].strings == (b'content',) assert isinstance(rdataset, dns.rdataset.ImmutableRdataset) @@ -123,7 +127,7 @@ def test_add(db): expected = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.1', '10.0.0.2', '10.0.0.3', '10.0.0.4') - assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ expected def test_replacement(db): @@ -134,7 +138,7 @@ def test_replacement(db): rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', '10.0.0.3', '10.0.0.4') txn.replace(rrset2) - assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \ rrset2 def test_delete(db): @@ -144,11 +148,11 @@ def test_delete(db): content2 = dns.name.from_text('content2', None) txn.delete(content) assert not txn.name_exists(content) - txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT) + txn.delete(content2, dns.rdatatype.TXT) rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content') txn.add(rrset) assert txn.name_exists(content) - txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT) + txn.delete(content, dns.rdatatype.TXT) assert not txn.name_exists(content) rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content') txn.delete(rrset) @@ -166,10 +170,10 @@ def test_delete_exact(db): with pytest.raises(dns.transaction.DeleteNotExact): txn.delete_exact(rrset.name) with pytest.raises(dns.transaction.DeleteNotExact): - txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT) + txn.delete_exact(rrset.name, dns.rdatatype.TXT) rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') txn.delete_exact(rrset) - assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \ + assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \ is None def test_parameter_forms(db): @@ -185,13 +189,13 @@ def test_parameter_forms(db): expected = dns.rrset.from_text('foo', 30, 'in', 'a', '10.0.0.1', '10.0.0.2', '10.0.0.3', '10.0.0.4') - assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \ + assert db.rdatasets[(foo, rdataset.rdtype, 0)] == \ expected with db.writer() as txn: txn.delete(foo, rdataset) txn.delete(foo, rdata1) txn.delete(foo, rdata2) - assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \ + assert db.rdatasets.get((foo, rdataset.rdtype, 0)) \ is None def test_bad_parameters(db): @@ -329,7 +333,7 @@ def test_zone_changed(zone): # delete a nonexistent rdataset from an extant node with zone.writer() as txn: assert not txn.changed() - txn.delete(dns.name.from_text('bar.foo', None), 'in', 'txt') + txn.delete(dns.name.from_text('bar.foo', None), 'txt') assert not txn.changed() # add an rdataset to an extant Node with zone.writer() as txn: @@ -345,8 +349,7 @@ def test_zone_changed(zone): def test_zone_base_layer(zone): with zone.writer() as txn: # Get a set from the zone layer - rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, - dns.rdatatype.NS, dns.rdatatype.NONE) + rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') assert rdataset == expected @@ -357,8 +360,7 @@ def test_zone_transaction_layer(zone): txn.add(dns.name.empty, 3600, rd) # Get a set from the transaction layer expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3') - rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, - dns.rdatatype.NS, dns.rdatatype.NONE) + rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) assert rdataset == expected assert txn.name_exists(dns.name.empty) ns1 = dns.name.from_text('ns1', None) @@ -373,14 +375,14 @@ def test_zone_add_and_delete(zone): a101 = dns.name.from_text('a101', None) rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') txn.add(a99, rds) - txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A) - txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A) + txn.delete(a99, dns.rdatatype.A) + txn.delete(a100, dns.rdatatype.A) txn.delete(a101) assert not txn.name_exists(a99) assert not txn.name_exists(a100) assert not txn.name_exists(a101) ns1 = dns.name.from_text('ns1', None) - txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A) + txn.delete(ns1, dns.rdatatype.A) assert not txn.name_exists(ns1) with zone.writer() as txn: txn.add(a99, rds) @@ -409,18 +411,15 @@ def test_zone_get_deleted(zone): with zone.writer() as txn: print(zone.to_text()) ns1 = dns.name.from_text('ns1', None) - assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None + assert txn.get(ns1, dns.rdatatype.A) is not None txn.delete(ns1) - assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None + assert txn.get(ns1, dns.rdatatype.A) is None ns2 = dns.name.from_text('ns2', None) - txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A) - assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None + txn.delete(ns2, dns.rdatatype.A) + assert txn.get(ns2, dns.rdatatype.A) is None def test_zone_bad_class(zone): with zone.writer() as txn: - with pytest.raises(ValueError): - txn.get(dns.name.empty, dns.rdataclass.CH, - dns.rdatatype.NS, dns.rdatatype.NONE) rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2') with pytest.raises(ValueError): txn.add(dns.name.empty, rds) @@ -428,9 +427,6 @@ def test_zone_bad_class(zone): txn.replace(dns.name.empty, rds) with pytest.raises(ValueError): txn.delete(dns.name.empty, rds) - with pytest.raises(ValueError): - txn.delete(dns.name.empty, dns.rdataclass.CH, - dns.rdatatype.NS, dns.rdatatype.NONE) def test_update_serial(zone): # basic @@ -500,8 +496,7 @@ def vzone(): def test_vzone_read_only(vzone): with vzone.reader() as txn: - rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, - dns.rdatatype.NS, dns.rdatatype.NONE) + rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE) expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') assert rdataset == expected with pytest.raises(dns.transaction.ReadOnly): @@ -521,11 +516,11 @@ def test_vzone_multiple_versions(vzone): assert len(vzone._versions) == 4 with vzone.reader(id=5) as txn: assert txn.version.id == 5 - rdataset = txn.get('@', 'in', 'soa') + rdataset = txn.get('@', 'soa') assert rdataset[0].serial == 1000 with vzone.reader(serial=1000) as txn: assert txn.version.id == 5 - rdataset = txn.get('@', 'in', 'soa') + rdataset = txn.get('@', 'soa') assert rdataset[0].serial == 1000 vzone.set_max_versions(2) assert len(vzone._versions) == 2 @@ -561,7 +556,7 @@ def test_vzone_open_txn_pins_versions(vzone): with vzone.reader(id=2) as txn: vzone.set_max_versions(1) with vzone.reader(id=3) as txn: - rdataset = txn.get('@', 'in', 'soa') + rdataset = txn.get('@', 'soa') assert rdataset[0].serial == 2 assert len(vzone._versions) == 4 assert len(vzone._versions) == 1