From: Bob Halley Date: Thu, 13 Aug 2020 14:31:57 +0000 (-0700) Subject: Allow explicit commit/rollback. Prevent multiple txn end. Add txn.changed(). X-Git-Tag: v2.1.0rc1~83 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=677c3c03c91d5d758e2f7a225a7455d9b2b5da24;p=thirdparty%2Fdnspython.git Allow explicit commit/rollback. Prevent multiple txn end. Add txn.changed(). --- diff --git a/dns/transaction.py b/dns/transaction.py index 630f6303..aec06626 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -36,11 +36,16 @@ class ReadOnly(dns.exception.DNSException): """Tried to write to a read-only transaction.""" +class AlreadyEnded(dns.exception.DNSException): + """Tried to use an already-ended transaction.""" + + class Transaction: def __init__(self, replacement=False, read_only=False): self.replacement = replacement self.read_only = read_only + self._ended = False # # This is the high level API @@ -52,6 +57,7 @@ class Transaction: Note that the returned rdataset is immutable. """ + self._check_ended() if isinstance(name, str): name = dns.name.from_text(name, None) rdclass = dns.rdataclass.RdataClass.make(rdclass) @@ -77,6 +83,7 @@ class Transaction: - name, ttl, rdata... """ + self._check_ended() self._check_read_only() return self._add(False, args) @@ -97,6 +104,7 @@ class Transaction: a delete of the name followed by one or more calls to add() or replace(). """ + self._check_ended() self._check_read_only() return self._add(True, args) @@ -118,6 +126,7 @@ class Transaction: - name, rdata... """ + self._check_ended() self._check_read_only() return self._delete(False, args) @@ -140,11 +149,13 @@ class Transaction: are not in the existing set. """ + self._check_ended() self._check_read_only() return self._delete(True, args) def name_exists(self, name): """Does the specified name exist?""" + self._check_ended() if isinstance(name, str): name = dns.name.from_text(name, None) return self._name_exists(name) @@ -162,6 +173,7 @@ class Transaction: so large that it would cause the new serial to be less than the prior value. """ + self._check_ended() if value < 0: raise ValueError('negative update_serial() value') if isinstance(name, str): @@ -182,8 +194,45 @@ class Transaction: self.replace(name, new_rdataset) def __iter__(self): + self._check_ended() return self._iterate_rdatasets() + def changed(self): + """Has this transaction changed anything? + + For read-only transactions, the result is always `False`. + + For writable transactions, the result is `True` if at some time + during the life of the transaction, the content was changed. + """ + self._check_ended() + return self._changed() + + def commit(self): + """Commit the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.Ended`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Raises an exception if the commit fails (in which case the transaction + is also rolled back. + """ + self._end(True) + + def rollback(self): + """Rollback the transaction. + + Normally transactions are used as context managers and commit + or rollback automatically, but it may be done explicitly if needed. + A ``dns.transaction.AlreadyEnded`` exception will be raised if you try + to use a transaction after it has been committed or rolled back. + + Rollback cannot otherwise fail. + """ + self._end(False) + # # Helper methods # @@ -272,7 +321,8 @@ class Transaction: arg = dns.name.from_text(arg, None) if isinstance(arg, dns.name.Name): name = arg - if len(args) > 0 and isinstance(args[0], int): + if len(args) > 0 and (isinstance(args[0], int) or + isinstance(args[0], str)): # deleting by type and class rdclass = dns.rdataclass.RdataClass.make(args.popleft()) rdtype = dns.rdatatype.RdataType.make(args.popleft()) @@ -320,6 +370,19 @@ class Transaction: except IndexError: raise TypeError(f'not enough parameters to {method}') + def _check_ended(self): + if self._ended: + raise AlreadyEnded + + def _end(self, commit): + self._check_ended() + if self._ended: + raise AlreadyEnded + try: + self._end_transaction(commit) + finally: + self._ended = True + # # Transactions are context managers. # @@ -328,10 +391,11 @@ class Transaction: return self def __exit__(self, exc_type, exc_val, exc_tb): - if exc_type is None: - self._end_transaction(True) - else: - self._end_transaction(False) + if not self._ended: + if exc_type is None: + self.commit() + else: + self.rollback() return False # @@ -370,13 +434,18 @@ class Transaction: """ raise NotImplementedError # pragma: no cover + def _changed(self): + """Has this transaction changed anything?""" + raise NotImplementedError # pragma: no cover + def _end_transaction(self, commit): """End the transaction. *commit*, a bool. If ``True``, commit the transaction, otherwise roll it back. - Raises an exception if committing failed. + If committing adn the commit fails, then roll back and raise an + exception. """ raise NotImplementedError # pragma: no cover diff --git a/dns/versioned.py b/dns/versioned.py index e7534386..e070c1a9 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -90,6 +90,7 @@ class WritableVersion(Version): name = self._validate_name(name) if name in self.nodes: del self.nodes[name] + self.changed.add(name) return True return False @@ -436,6 +437,12 @@ class Transaction(dns.transaction.Transaction): def _name_exists(self, name): return self.version.get_node(name) is not None + def _changed(self): + if self.read_only: + return False + else: + return len(self.version.changed) > 0 + def _end_transaction(self, commit): if self.read_only: self.zone._end_read(self) diff --git a/dns/zone.py b/dns/zone.py index 2ca9bc21..e85603b4 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -722,8 +722,14 @@ class Transaction(dns.transaction.Transaction): return True return False + def _changed(self): + if self.read_only: + return False + else: + return len(self.rdatasets) > 0 + def _end_transaction(self, commit): - if commit and not self.read_only: + if commit and self._changed(): for (name, rdtype, covers), rdataset in \ self.rdatasets.items(): if rdataset is self._deleted_rdataset: diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 93705f2e..c9b6f5ce 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -56,6 +56,12 @@ class Transaction(dns.transaction.Transaction): return True return False + def _changed(self): + if self.read_only: + return False + else: + return len(self.rdatasets) > 0 + def _end_transaction(self, commit): if commit: self.db.rdatasets = self.rdatasets @@ -244,6 +250,93 @@ def test_zone_basic(zone): output = zone.to_text() assert output == example_text_output +def test_explicit_rollback_and_commit(zone): + with zone.writer() as txn: + assert not txn.changed() + txn.delete(dns.name.from_text('bar.foo', None)) + txn.rollback() + assert zone.get_node('bar.foo') is not None + with zone.writer() as txn: + assert not txn.changed() + txn.delete(dns.name.from_text('bar.foo', None)) + txn.commit() + assert zone.get_node('bar.foo') is None + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.delete(dns.name.from_text('bar.foo', None)) + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.replace('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.reader() as txn: + txn.rollback() + txn.get('bar.foo', 'in', 'mx') + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.delete_exact('bar.foo') + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.name_exists('bar.foo') + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.update_serial() + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.changed() + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.rollback() + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + txn.commit() + with pytest.raises(dns.transaction.AlreadyEnded): + with zone.writer() as txn: + txn.rollback() + for rdataset in txn: + print(rdataset) + +def test_zone_changed(zone): + # Read-only is not changed! + with zone.reader() as txn: + assert not txn.changed() + # delete an existing name + with zone.writer() as txn: + assert not txn.changed() + txn.delete(dns.name.from_text('bar.foo', None)) + assert txn.changed() + # delete a nonexistent name + with zone.writer() as txn: + assert not txn.changed() + txn.delete(dns.name.from_text('unknown.bar.foo', None)) + assert not txn.changed() + # delete a nonexistent rdataset from an extant node + with zone.writer() as txn: + assert not txn.changed() + txn.delete(dns.name.from_text('bar.foo', None), 'in', 'txt') + assert not txn.changed() + # add an rdataset to an extant Node + with zone.writer() as txn: + assert not txn.changed() + txn.add('bar.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + assert txn.changed() + # add an rdataset to a nonexistent Node + with zone.writer() as txn: + assert not txn.changed() + txn.add('foo.foo', 300, dns.rdata.from_text('in', 'txt', 'hi')) + assert txn.changed() + def test_zone_base_layer(zone): with zone.writer() as txn: # Get a set from the zone layer