]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Allow explicit commit/rollback. Prevent multiple txn end. Add txn.changed().
authorBob Halley <halley@dnspython.org>
Thu, 13 Aug 2020 14:31:57 +0000 (07:31 -0700)
committerBob Halley <halley@dnspython.org>
Thu, 13 Aug 2020 14:31:57 +0000 (07:31 -0700)
dns/transaction.py
dns/versioned.py
dns/zone.py
tests/test_transaction.py

index 630f6303fd559f9fd46a102747917b20d5f51004..aec06626221e8866b6921a4fc16857b1fcbf6986 100644 (file)
@@ -36,11 +36,16 @@ class ReadOnly(dns.exception.DNSException):
     """Tried to write to a read-only transaction."""
 
 
+class AlreadyEnded(dns.exception.DNSException):
+    """Tried to use an already-ended transaction."""
+
+
 class Transaction:
 
     def __init__(self, replacement=False, read_only=False):
         self.replacement = replacement
         self.read_only = read_only
+        self._ended = False
 
     #
     # This is the high level API
@@ -52,6 +57,7 @@ class Transaction:
 
         Note that the returned rdataset is immutable.
         """
+        self._check_ended()
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
         rdclass = dns.rdataclass.RdataClass.make(rdclass)
@@ -77,6 +83,7 @@ class Transaction:
 
             - name, ttl, rdata...
         """
+        self._check_ended()
         self._check_read_only()
         return self._add(False, args)
 
@@ -97,6 +104,7 @@ class Transaction:
         a delete of the name followed by one or more calls to add() or
         replace().
         """
+        self._check_ended()
         self._check_read_only()
         return self._add(True, args)
 
@@ -118,6 +126,7 @@ class Transaction:
 
             - name, rdata...
         """
+        self._check_ended()
         self._check_read_only()
         return self._delete(False, args)
 
@@ -140,11 +149,13 @@ class Transaction:
         are not in the existing set.
 
         """
+        self._check_ended()
         self._check_read_only()
         return self._delete(True, args)
 
     def name_exists(self, name):
         """Does the specified name exist?"""
+        self._check_ended()
         if isinstance(name, str):
             name = dns.name.from_text(name, None)
         return self._name_exists(name)
@@ -162,6 +173,7 @@ class Transaction:
         so large that it would cause the new serial to be less than the
         prior value.
         """
+        self._check_ended()
         if value < 0:
             raise ValueError('negative update_serial() value')
         if isinstance(name, str):
@@ -182,8 +194,45 @@ class Transaction:
         self.replace(name, new_rdataset)
 
     def __iter__(self):
+        self._check_ended()
         return self._iterate_rdatasets()
 
+    def changed(self):
+        """Has this transaction changed anything?
+
+        For read-only transactions, the result is always `False`.
+
+        For writable transactions, the result is `True` if at some time
+        during the life of the transaction, the content was changed.
+        """
+        self._check_ended()
+        return self._changed()
+
+    def commit(self):
+        """Commit the transaction.
+
+        Normally transactions are used as context managers and commit
+        or rollback automatically, but it may be done explicitly if needed.
+        A ``dns.transaction.Ended`` exception will be raised if you try
+        to use a transaction after it has been committed or rolled back.
+
+        Raises an exception if the commit fails (in which case the transaction
+        is also rolled back.
+        """
+        self._end(True)
+
+    def rollback(self):
+        """Rollback the transaction.
+
+        Normally transactions are used as context managers and commit
+        or rollback automatically, but it may be done explicitly if needed.
+        A ``dns.transaction.AlreadyEnded`` exception will be raised if you try
+        to use a transaction after it has been committed or rolled back.
+
+        Rollback cannot otherwise fail.
+        """
+        self._end(False)
+
     #
     # Helper methods
     #
@@ -272,7 +321,8 @@ class Transaction:
                 arg = dns.name.from_text(arg, None)
             if isinstance(arg, dns.name.Name):
                 name = arg
-                if len(args) > 0 and isinstance(args[0], int):
+                if len(args) > 0 and (isinstance(args[0], int) or
+                                      isinstance(args[0], str)):
                     # deleting by type and class
                     rdclass = dns.rdataclass.RdataClass.make(args.popleft())
                     rdtype = dns.rdatatype.RdataType.make(args.popleft())
@@ -320,6 +370,19 @@ class Transaction:
         except IndexError:
             raise TypeError(f'not enough parameters to {method}')
 
+    def _check_ended(self):
+        if self._ended:
+            raise AlreadyEnded
+
+    def _end(self, commit):
+        self._check_ended()
+        if self._ended:
+            raise AlreadyEnded
+        try:
+            self._end_transaction(commit)
+        finally:
+            self._ended = True
+
     #
     # Transactions are context managers.
     #
@@ -328,10 +391,11 @@ class Transaction:
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        if exc_type is None:
-            self._end_transaction(True)
-        else:
-            self._end_transaction(False)
+        if not self._ended:
+            if exc_type is None:
+                self.commit()
+            else:
+                self.rollback()
         return False
 
     #
@@ -370,13 +434,18 @@ class Transaction:
         """
         raise NotImplementedError  # pragma: no cover
 
+    def _changed(self):
+        """Has this transaction changed anything?"""
+        raise NotImplementedError  # pragma: no cover
+
     def _end_transaction(self, commit):
         """End the transaction.
 
         *commit*, a bool.  If ``True``, commit the transaction, otherwise
         roll it back.
 
-        Raises an exception if committing failed.
+        If committing adn the commit fails, then roll back and raise an
+        exception.
         """
         raise NotImplementedError  # pragma: no cover
 
index e7534386ba0c665c8bfc717a3cd6757073605c3f..e070c1a926663e733fc851cad4f2e0dea7111f90 100644 (file)
@@ -90,6 +90,7 @@ class WritableVersion(Version):
         name = self._validate_name(name)
         if name in self.nodes:
             del self.nodes[name]
+            self.changed.add(name)
             return True
         return False
 
@@ -436,6 +437,12 @@ class Transaction(dns.transaction.Transaction):
     def _name_exists(self, name):
         return self.version.get_node(name) is not None
 
+    def _changed(self):
+        if self.read_only:
+            return False
+        else:
+            return len(self.version.changed) > 0
+
     def _end_transaction(self, commit):
         if self.read_only:
             self.zone._end_read(self)
index 2ca9bc212ffe6462de5ef098155333b8410df226..e85603b4959ec7db642ae3ace58ac84e19a766d1 100644 (file)
@@ -722,8 +722,14 @@ class Transaction(dns.transaction.Transaction):
             return True
         return False
 
+    def _changed(self):
+        if self.read_only:
+            return False
+        else:
+            return len(self.rdatasets) > 0
+
     def _end_transaction(self, commit):
-        if commit and not self.read_only:
+        if commit and self._changed():
             for (name, rdtype, covers), rdataset in \
                 self.rdatasets.items():
                 if rdataset is self._deleted_rdataset:
index 93705f2eadf5dfe6f1ba05914a0f2798ce38ba22..c9b6f5ce2af1c8be218e2532ec1fc346b0b9f764 100644 (file)
@@ -56,6 +56,12 @@ class Transaction(dns.transaction.Transaction):
                 return True
         return False
 
+    def _changed(self):
+        if self.read_only:
+            return False
+        else:
+            return len(self.rdatasets) > 0
+
     def _end_transaction(self, commit):
         if commit:
             self.db.rdatasets = self.rdatasets
@@ -244,6 +250,93 @@ def test_zone_basic(zone):
     output = zone.to_text()
     assert output == example_text_output
 
+def test_explicit_rollback_and_commit(zone):
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.delete(dns.name.from_text('bar.foo', None))
+        txn.rollback()
+    assert zone.get_node('bar.foo') is not None
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.delete(dns.name.from_text('bar.foo', None))
+        txn.commit()
+    assert zone.get_node('bar.foo') is None
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.delete(dns.name.from_text('bar.foo', None))
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi'))
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.replace('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi'))
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.reader() as txn:
+            txn.rollback()
+            txn.get('bar.foo', 'in', 'mx')
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.delete_exact('bar.foo')
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.name_exists('bar.foo')
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.update_serial()
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.changed()
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.rollback()
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            txn.commit()
+    with pytest.raises(dns.transaction.AlreadyEnded):
+        with zone.writer() as txn:
+            txn.rollback()
+            for rdataset in txn:
+                print(rdataset)
+
+def test_zone_changed(zone):
+    # Read-only is not changed!
+    with zone.reader() as txn:
+        assert not txn.changed()
+    # delete an existing name
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.delete(dns.name.from_text('bar.foo', None))
+        assert txn.changed()
+    # delete a nonexistent name
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.delete(dns.name.from_text('unknown.bar.foo', None))
+        assert not txn.changed()
+    # delete a nonexistent rdataset from an extant node
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.delete(dns.name.from_text('bar.foo', None), 'in', 'txt')
+        assert not txn.changed()
+    # add an rdataset to an extant Node
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi'))
+        assert txn.changed()
+    # add an rdataset to a nonexistent Node
+    with zone.writer() as txn:
+        assert not txn.changed()
+        txn.add('foo.foo', 300, dns.rdata.from_text('in', 'txt', 'hi'))
+        assert txn.changed()
+
 def test_zone_base_layer(zone):
     with zone.writer() as txn:
         # Get a set from the zone layer