]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Txns and txn managers have a single RdataClass
authorBob Halley <halley@dnspython.org>
Wed, 19 Aug 2020 23:46:30 +0000 (16:46 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 19 Aug 2020 23:46:30 +0000 (16:46 -0700)
dns/transaction.py
dns/versioned.py
dns/zone.py
tests/test_transaction.py

index db32e9dc7ee376216afc4f096466a78981099a24..c6c2f0f7908a09c47bf930f1ea643d3db4512798 100644 (file)
@@ -44,6 +44,11 @@ class TransactionManager:
         """
         raise NotImplementedError  # pragma: no cover
 
+    def get_class(self):
+        """The class of the transaction manager.
+        """
+        raise NotImplementedError  # pragma: no cover
+
 
 class DeleteNotExact(dns.exception.DNSException):
     """Existing data did not match data specified by an exact delete."""
@@ -69,18 +74,17 @@ class Transaction:
     # This is the high level API
     #
 
-    def get(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE):
-        """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
-        and *covers*, or `None` if not found.
+    def get(self, name, rdtype, covers=dns.rdatatype.NONE):
+        """Return the rdataset associated with *name*, *rdtype*, and *covers*,
+        or `None` if not found.
 
         Note that the returned rdataset is immutable.
         """
         self._check_ended()
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
-        rdclass = dns.rdataclass.RdataClass.make(rdclass)
         rdtype = dns.rdatatype.RdataType.make(rdtype)
-        rdataset = self._get_rdataset(name, rdclass, rdtype, covers)
+        rdataset = self._get_rdataset(name, rdtype, covers)
         if rdataset is not None and \
            not isinstance(rdataset, dns.rdataset.ImmutableRdataset):
             rdataset = dns.rdataset.ImmutableRdataset(rdataset)
@@ -178,8 +182,7 @@ class Transaction:
             name = dns.name.from_text(name, None)
         return self._name_exists(name)
 
-    def update_serial(self, value=1, relative=True, name=dns.name.empty,
-                      rdclass=dns.rdataclass.IN):
+    def update_serial(self, value=1, relative=True, name=dns.name.empty):
         """Update the serial number.
 
         *value*, an `int`, is an increment if *relative* is `True`, or the
@@ -196,7 +199,7 @@ class Transaction:
             raise ValueError('negative update_serial() value')
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
-        rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA,
+        rdataset = self._get_rdataset(name, dns.rdatatype.SOA,
                                       dns.rdatatype.NONE)
         if rdataset is None or len(rdataset) == 0:
             raise KeyError
@@ -311,10 +314,12 @@ class Transaction:
             else:
                 raise TypeError(f'{method} requires a name or RRset ' +
                                 'as the first argument')
+            if rdataset.rdclass != self.manager.get_class():
+                raise ValueError(f'{method} has objects of wrong RdataClass')
             self._raise_if_not_empty(method, args)
             if not replace:
-                existing = self._get_rdataset(name, rdataset.rdclass,
-                                              rdataset.rdtype, rdataset.covers)
+                existing = self._get_rdataset(name, rdataset.rdtype,
+                                              rdataset.covers)
                 if existing is not None:
                     if isinstance(existing, dns.rdataset.ImmutableRdataset):
                         trds = dns.rdataset.Rdataset(existing.rdclass,
@@ -341,20 +346,19 @@ class Transaction:
                 name = arg
                 if len(args) > 0 and (isinstance(args[0], int) or
                                       isinstance(args[0], str)):
-                    # deleting by type and class
-                    rdclass = dns.rdataclass.RdataClass.make(args.popleft())
+                    # deleting by type and (optionally) covers
                     rdtype = dns.rdatatype.RdataType.make(args.popleft())
                     if len(args) > 0:
                         covers = dns.rdatatype.RdataType.make(args.popleft())
                     else:
                         covers = dns.rdatatype.NONE
                     self._raise_if_not_empty(method, args)
-                    existing = self._get_rdataset(name, rdclass, rdtype, covers)
+                    existing = self._get_rdataset(name, rdtype, covers)
                     if existing is None:
                         if exact:
                             raise DeleteNotExact(f'{method}: missing rdataset')
                     else:
-                        self._delete_rdataset(name, rdclass, rdtype, covers)
+                        self._delete_rdataset(name, rdtype, covers)
                     return
                 else:
                     rdataset = self._rdataset_from_args(method, True, args)
@@ -366,8 +370,11 @@ class Transaction:
                                 'as the first argument')
             self._raise_if_not_empty(method, args)
             if rdataset:
-                existing = self._get_rdataset(name, rdataset.rdclass,
-                                              rdataset.rdtype, rdataset.covers)
+                if rdataset.rdclass != self.manager.get_class():
+                    raise ValueError(f'{method} has objects of wrong '
+                                     'RdataClass')
+                existing = self._get_rdataset(name, rdataset.rdtype,
+                                              rdataset.covers)
                 if existing is not None:
                     if exact:
                         intersection = existing.intersection(rdataset)
@@ -375,8 +382,8 @@ class Transaction:
                             raise DeleteNotExact(f'{method}: missing rdatas')
                     rdataset = existing.difference(rdataset)
                     if len(rdataset) == 0:
-                        self._delete_rdataset(name, rdataset.rdclass,
-                                              rdataset.rdtype, rdataset.covers)
+                        self._delete_rdataset(name, rdataset.rdtype,
+                                              rdataset.covers)
                     else:
                         self._put_rdataset(name, rdataset)
                 elif exact:
@@ -421,9 +428,10 @@ class Transaction:
     # of Transaction.
     #
 
-    def _get_rdataset(self, name, rdclass, rdtype, covers):
-        """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
-        and *covers*, or `None` if not found."""
+    def _get_rdataset(self, name, rdtype, covers):
+        """Return the rdataset associated with *name*, *rdtype*, and *covers*,
+        or `None` if not found.
+        """
         raise NotImplementedError  # pragma: no cover
 
     def _put_rdataset(self, name, rdataset):
@@ -437,9 +445,8 @@ class Transaction:
         """
         raise NotImplementedError  # pragma: no cover
 
-    def _delete_rdataset(self, name, rdclass, rdtype, covers):
-        """Delete all data associated with *name*, *rdclass*, *rdtype*, and
-        *covers*.
+    def _delete_rdataset(self, name, rdtype, covers):
+        """Delete all data associated with *name*, *rdtype*, and *covers*.
 
         It is not an error if the rdataset does not exist.
         """
index ff3d7020d3c3f6abcf43a6f4a757e03521441f21..9f0caa1c71f815a0f0babe2e8230afb5bd0d7e17 100644 (file)
@@ -416,24 +416,18 @@ class Transaction(dns.transaction.Transaction):
         assert self.version is None
         self.version = WritableVersion(self.zone, self.replacement)
 
-    def _get_rdataset(self, name, rdclass, rdtype, covers):
-        if rdclass != self.zone.rdclass:
-            raise ValueError(f'class {rdclass} != ' +
-                             f'zone class {self.zone.rdclass}')
+    def _get_rdataset(self, name, rdtype, covers):
         return self.version.get_rdataset(name, rdtype, covers)
 
     def _put_rdataset(self, name, rdataset):
         assert not self.read_only
-        if rdataset.rdclass != self.zone.rdclass:
-            raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
-                             f'zone class {self.zone.rdclass}')
         self.version.put_rdataset(name, rdataset)
 
     def _delete_name(self, name):
         assert not self.read_only
         self.version.delete_node(name)
 
-    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+    def _delete_rdataset(self, name, rdtype, covers):
         assert not self.read_only
         self.version.delete_rdataset(name, rdtype, covers)
 
index 39ab1a60346767aa7ef681dc29c1cdec4a7e695c..427184ebad6fda527a71fbb735357fd1ba8740c3 100644 (file)
@@ -651,6 +651,9 @@ class Zone(dns.transaction.TransactionManager):
     def origin_information(self):
         return (self.origin, self.relativize)
 
+    def get_class(self):
+        return self.rdclass
+
 
 class Transaction(dns.transaction.Transaction):
 
@@ -665,10 +668,7 @@ class Transaction(dns.transaction.Transaction):
     def zone(self):
         return self.manager
 
-    def _get_rdataset(self, name, rdclass, rdtype, covers):
-        if rdclass != self.zone.rdclass:
-            raise ValueError(f'class {rdclass} != ' +
-                             f'zone class {self.zone.rdclass}')
+    def _get_rdataset(self, name, rdtype, covers):
         rdataset = self.rdatasets.get((name, rdtype, covers))
         if rdataset is self._deleted_rdataset:
             return None
@@ -679,9 +679,6 @@ class Transaction(dns.transaction.Transaction):
     def _put_rdataset(self, name, rdataset):
         assert not self.read_only
         self.zone._validate_name(name)
-        if rdataset.rdclass != self.zone.rdclass:
-            raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
-                             f'zone class {self.zone.rdclass}')
         self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
 
     def _delete_name(self, name):
@@ -702,11 +699,8 @@ class Transaction(dns.transaction.Transaction):
                 self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
                     self._deleted_rdataset
 
-    def _delete_rdataset(self, name, rdclass, rdtype, covers):
+    def _delete_rdataset(self, name, rdtype, covers):
         assert not self.read_only
-        # The high-level code always does a _get_rdataset() before any
-        # situation where it would call _delete_rdataset(), so we don't
-        # need to check if rdclass != self.zone.rdclass.
         try:
             del self.rdatasets[(name, rdtype, covers)]
         except KeyError:
index bf7b130b0e480d2eb078fcfb514dd26c3680184e..7fb353cd31171e23695834d089cf019395e0d5ae 100644 (file)
@@ -26,22 +26,26 @@ class DB(dns.transaction.TransactionManager):
     def origin_information(self):
         return (None, True)
 
+    def get_class(self):
+        return dns.rdataclass.IN
+
 
 class Transaction(dns.transaction.Transaction):
     def __init__(self, db, replacement, read_only):
-        super().__init__(replacement)
-        self.db = db
+        super().__init__(db, replacement, read_only)
         self.rdatasets = {}
-        self.read_only = read_only
         if not replacement:
             self.rdatasets.update(db.rdatasets)
 
-    def _get_rdataset(self, name, rdclass, rdtype, covers):
-        return self.rdatasets.get((name, rdclass, rdtype, covers))
+    @property
+    def db(self):
+        return self.manager
+
+    def _get_rdataset(self, name, rdtype, covers):
+        return self.rdatasets.get((name, rdtype, covers))
 
     def _put_rdataset(self, name, rdataset):
-        self.rdatasets[(name, rdataset.rdclass, rdataset.rdtype,
-                        rdataset.covers)] = rdataset
+        self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
 
     def _delete_name(self, name):
         remove = []
@@ -52,8 +56,8 @@ class Transaction(dns.transaction.Transaction):
             for key in remove:
                 del self.rdatasets[key]
 
-    def _delete_rdataset(self, name, rdclass, rdtype, covers):
-        del self.rdatasets[(name, rdclass, rdtype, covers)]
+    def _delete_rdataset(self, name, rdtype, covers):
+        del self.rdatasets[(name, rdtype, covers)]
 
     def _name_exists(self, name):
         for key in self.rdatasets.keys():
@@ -78,7 +82,7 @@ class Transaction(dns.transaction.Transaction):
 def db():
     db = DB()
     rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
-    db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset
+    db.rdatasets[(rrset.name, rrset.rdtype, 0)] = rrset
     return db
 
 def test_basic(db):
@@ -88,7 +92,7 @@ def test_basic(db):
                                     '10.0.0.1', '10.0.0.2')
         txn.add(rrset)
         assert txn.name_exists(rrset.name)
-    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+    assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \
         rrset
     # rollback
     with pytest.raises(Exception):
@@ -97,17 +101,17 @@ def test_basic(db):
                                          '10.0.0.3', '10.0.0.4')
             txn.add(rrset2)
             raise Exception()
-    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+    assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \
         rrset
     with db.writer() as txn:
         txn.delete(rrset.name)
-    assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+    assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \
         is None
 
 def test_get(db):
     with db.writer() as txn:
         content = dns.name.from_text('content', None)
-        rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+        rdataset = txn.get(content, dns.rdatatype.TXT)
         assert rdataset is not None
         assert rdataset[0].strings == (b'content',)
         assert isinstance(rdataset, dns.rdataset.ImmutableRdataset)
@@ -123,7 +127,7 @@ def test_add(db):
     expected = dns.rrset.from_text('foo', 300, 'in', 'a',
                                    '10.0.0.1', '10.0.0.2',
                                    '10.0.0.3', '10.0.0.4')
-    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+    assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \
         expected
 
 def test_replacement(db):
@@ -134,7 +138,7 @@ def test_replacement(db):
         rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
                                      '10.0.0.3', '10.0.0.4')
         txn.replace(rrset2)
-    assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+    assert db.rdatasets[(rrset.name, rrset.rdtype, 0)] == \
         rrset2
 
 def test_delete(db):
@@ -144,11 +148,11 @@ def test_delete(db):
         content2 = dns.name.from_text('content2', None)
         txn.delete(content)
         assert not txn.name_exists(content)
-        txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT)
+        txn.delete(content2, dns.rdatatype.TXT)
         rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content')
         txn.add(rrset)
         assert txn.name_exists(content)
-        txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+        txn.delete(content, dns.rdatatype.TXT)
         assert not txn.name_exists(content)
         rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content')
         txn.delete(rrset)
@@ -166,10 +170,10 @@ def test_delete_exact(db):
         with pytest.raises(dns.transaction.DeleteNotExact):
             txn.delete_exact(rrset.name)
         with pytest.raises(dns.transaction.DeleteNotExact):
-            txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT)
+            txn.delete_exact(rrset.name, dns.rdatatype.TXT)
         rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
         txn.delete_exact(rrset)
-    assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+    assert db.rdatasets.get((rrset.name, rrset.rdtype, 0)) \
         is None
 
 def test_parameter_forms(db):
@@ -185,13 +189,13 @@ def test_parameter_forms(db):
     expected = dns.rrset.from_text('foo', 30, 'in', 'a',
                                    '10.0.0.1', '10.0.0.2',
                                    '10.0.0.3', '10.0.0.4')
-    assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \
+    assert db.rdatasets[(foo, rdataset.rdtype, 0)] == \
         expected
     with db.writer() as txn:
         txn.delete(foo, rdataset)
         txn.delete(foo, rdata1)
         txn.delete(foo, rdata2)
-    assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \
+    assert db.rdatasets.get((foo, rdataset.rdtype, 0)) \
         is None
 
 def test_bad_parameters(db):
@@ -329,7 +333,7 @@ def test_zone_changed(zone):
     # delete a nonexistent rdataset from an extant node
     with zone.writer() as txn:
         assert not txn.changed()
-        txn.delete(dns.name.from_text('bar.foo', None), 'in', 'txt')
+        txn.delete(dns.name.from_text('bar.foo', None), 'txt')
         assert not txn.changed()
     # add an rdataset to an extant Node
     with zone.writer() as txn:
@@ -345,8 +349,7 @@ def test_zone_changed(zone):
 def test_zone_base_layer(zone):
     with zone.writer() as txn:
         # Get a set from the zone layer
-        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
-                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE)
         expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
         assert rdataset == expected
 
@@ -357,8 +360,7 @@ def test_zone_transaction_layer(zone):
         txn.add(dns.name.empty, 3600, rd)
         # Get a set from the transaction layer
         expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3')
-        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
-                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE)
         assert rdataset == expected
         assert txn.name_exists(dns.name.empty)
         ns1 = dns.name.from_text('ns1', None)
@@ -373,14 +375,14 @@ def test_zone_add_and_delete(zone):
         a101 = dns.name.from_text('a101', None)
         rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
         txn.add(a99, rds)
-        txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A)
-        txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A)
+        txn.delete(a99, dns.rdatatype.A)
+        txn.delete(a100, dns.rdatatype.A)
         txn.delete(a101)
         assert not txn.name_exists(a99)
         assert not txn.name_exists(a100)
         assert not txn.name_exists(a101)
         ns1 = dns.name.from_text('ns1', None)
-        txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A)
+        txn.delete(ns1, dns.rdatatype.A)
         assert not txn.name_exists(ns1)
     with zone.writer() as txn:
         txn.add(a99, rds)
@@ -409,18 +411,15 @@ def test_zone_get_deleted(zone):
     with zone.writer() as txn:
         print(zone.to_text())
         ns1 = dns.name.from_text('ns1', None)
-        assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None
+        assert txn.get(ns1, dns.rdatatype.A) is not None
         txn.delete(ns1)
-        assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None
+        assert txn.get(ns1, dns.rdatatype.A) is None
         ns2 = dns.name.from_text('ns2', None)
-        txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A)
-        assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None
+        txn.delete(ns2, dns.rdatatype.A)
+        assert txn.get(ns2, dns.rdatatype.A) is None
 
 def test_zone_bad_class(zone):
     with zone.writer() as txn:
-        with pytest.raises(ValueError):
-            txn.get(dns.name.empty, dns.rdataclass.CH,
-                    dns.rdatatype.NS, dns.rdatatype.NONE)
         rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2')
         with pytest.raises(ValueError):
             txn.add(dns.name.empty, rds)
@@ -428,9 +427,6 @@ def test_zone_bad_class(zone):
             txn.replace(dns.name.empty, rds)
         with pytest.raises(ValueError):
             txn.delete(dns.name.empty, rds)
-        with pytest.raises(ValueError):
-            txn.delete(dns.name.empty, dns.rdataclass.CH,
-                       dns.rdatatype.NS, dns.rdatatype.NONE)
 
 def test_update_serial(zone):
     # basic
@@ -500,8 +496,7 @@ def vzone():
 
 def test_vzone_read_only(vzone):
     with vzone.reader() as txn:
-        rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
-                           dns.rdatatype.NS, dns.rdatatype.NONE)
+        rdataset = txn.get(dns.name.empty, dns.rdatatype.NS, dns.rdatatype.NONE)
         expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
         assert rdataset == expected
         with pytest.raises(dns.transaction.ReadOnly):
@@ -521,11 +516,11 @@ def test_vzone_multiple_versions(vzone):
     assert len(vzone._versions) == 4
     with vzone.reader(id=5) as txn:
         assert txn.version.id == 5
-        rdataset = txn.get('@', 'in', 'soa')
+        rdataset = txn.get('@', 'soa')
         assert rdataset[0].serial == 1000
     with vzone.reader(serial=1000) as txn:
         assert txn.version.id == 5
-        rdataset = txn.get('@', 'in', 'soa')
+        rdataset = txn.get('@', 'soa')
         assert rdataset[0].serial == 1000
     vzone.set_max_versions(2)
     assert len(vzone._versions) == 2
@@ -561,7 +556,7 @@ def test_vzone_open_txn_pins_versions(vzone):
     with vzone.reader(id=2) as txn:
         vzone.set_max_versions(1)
         with vzone.reader(id=3) as txn:
-            rdataset = txn.get('@', 'in', 'soa')
+            rdataset = txn.get('@', 'soa')
             assert rdataset[0].serial == 2
             assert len(vzone._versions) == 4
     assert len(vzone._versions) == 1