]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
hide versions
authorBob Halley <halley@dnspython.org>
Wed, 12 Aug 2020 01:06:34 +0000 (18:06 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 12 Aug 2020 01:06:34 +0000 (18:06 -0700)
dns/versioned.py
tests/test_transaction.py

index cc2f71480a95085fdb99c9c7da4db0985a105a0e..ae921f12d68487e9be07cca0f2087a2e8e6e5cb2 100644 (file)
@@ -54,8 +54,9 @@ class Version:
 
 class WritableVersion(Version):
     def __init__(self, zone, replacement=False):
-        if len(zone.versions) > 0:
-            id = zone.versions[-1].id + 1
+        # The zone._versions_lock must be held by our caller.
+        if len(zone._versions) > 0:
+            id = zone._versions[-1].id + 1
         else:
             id = 1
         super().__init__(zone, id)
@@ -167,8 +168,8 @@ class ImmutableNode(Node):
 
 class Zone(dns.zone.Zone):
 
-    __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event',
-                 '_pruning_policy']
+    __slots__ = ['_versions', '_versions_lock', '_write_txn',
+                 '_write_waiters', '_write_event', '_pruning_policy']
 
     node_factory = Node
 
@@ -190,8 +191,8 @@ class Zone(dns.zone.Zone):
         the default policy, which retains one version is used.
         """
         super().__init__(origin, rdclass, relativize)
-        self.versions = collections.deque()
-        self.version_lock = _threading.Lock()
+        self._versions = collections.deque()
+        self._version_lock = _threading.Lock()
         if pruning_policy is None:
             self._pruning_policy = self._default_pruning_policy
         else:
@@ -204,10 +205,10 @@ class Zone(dns.zone.Zone):
     def reader(self, id=None, serial=None):  # pylint: disable=arguments-differ
         if id is not None and serial is not None:
             raise ValueError('cannot specify both id and serial')
-        with self.version_lock:
+        with self._version_lock:
             if id is not None:
                 version = None
-                for v in reversed(self.versions):
+                for v in reversed(self._versions):
                     if v.id == id:
                         version = v
                         break
@@ -219,7 +220,7 @@ class Zone(dns.zone.Zone):
                 else:
                     oname = self.origin
                 version = None
-                for v in reversed(self.versions):
+                for v in reversed(self._versions):
                     n = v.nodes.get(oname)
                     if n:
                         rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA)
@@ -229,13 +230,13 @@ class Zone(dns.zone.Zone):
                 if version is None:
                     raise KeyError('serial not found')
             else:
-                version = self.versions[-1]
+                version = self._versions[-1]
             return Transaction(False, self, version)
 
     def writer(self, replacement=False):
         event = None
         while True:
-            with self.version_lock:
+            with self._version_lock:
                 # Checking event == self._write_event ensures that either
                 # no one was waiting before we got lucky and found no write
                 # txn, or we were the one who was waiting and got woken up.
@@ -270,7 +271,7 @@ class Zone(dns.zone.Zone):
             # try:
             #     event.wait()
             # except trio.Cancelled:
-            #     with self.version_lock:
+            #     with self._version_lock:
             #         self._maybe_wakeup_one_waiter_unlocked()
             #     raise
             #
@@ -290,9 +291,9 @@ class Zone(dns.zone.Zone):
     # pylint: enable=unused-argument
 
     def _prune_versions_unlocked(self):
-        while len(self.versions) > 1 and \
-              self._pruning_policy(self, self.versions[0]):
-            self.versions.popleft()
+        while len(self._versions) > 1 and \
+              self._pruning_policy(self, self._versions[0]):
+            self._versions.popleft()
 
     def set_max_versions(self, max_versions):
         """Set a pruning policy that retains up to the specified number
@@ -305,7 +306,7 @@ class Zone(dns.zone.Zone):
                 return False
         else:
             def policy(zone, _):
-                return len(zone.versions) > max_versions
+                return len(zone._versions) > max_versions
         self.set_pruning_policy(policy)
 
     def set_pruning_policy(self, policy):
@@ -322,12 +323,12 @@ class Zone(dns.zone.Zone):
         """
         if policy is None:
             policy = self._default_pruning_policy
-        with self.version_lock:
+        with self._version_lock:
             self._pruning_policy = policy
             self._prune_versions_unlocked()
 
     def _commit_version_unlocked(self, version, origin):
-        self.versions.append(version)
+        self._versions.append(version)
         self._prune_versions_unlocked()
         self.nodes = version.nodes
         if self.origin is None:
@@ -336,7 +337,7 @@ class Zone(dns.zone.Zone):
         self._maybe_wakeup_one_waiter_unlocked()
 
     def _commit_version(self, version, origin):
-        with self.version_lock:
+        with self._version_lock:
             self._commit_version_unlocked(version, origin)
 
     def find_node(self, name, create=False):
index 64705ed4352ea02c12f512a3b13c1ddba1c4acd1..888fbd59c4c22a4775b9692b4ca6f0b128c334a5 100644 (file)
@@ -391,7 +391,7 @@ def test_vzone_read_only(vzone):
             txn.replace(dns.name.empty, expected)
 
 def test_vzone_multiple_versions(vzone):
-    assert len(vzone.versions) == 1
+    assert len(vzone._versions) == 1
     vzone.set_max_versions(None)  # unlimited!
     with vzone.writer() as txn:
         txn.set_serial()
@@ -401,7 +401,7 @@ def test_vzone_multiple_versions(vzone):
         txn.set_serial(increment=0, value=1000)
     rdataset = vzone.find_rdataset('@', 'soa')
     assert rdataset[0].serial == 1000
-    assert len(vzone.versions) == 4
+    assert len(vzone._versions) == 4
     with vzone.reader(id=5) as txn:
         assert txn.version.id == 5
         rdataset = txn.get('@', 'in', 'soa')
@@ -411,13 +411,15 @@ def test_vzone_multiple_versions(vzone):
         rdataset = txn.get('@', 'in', 'soa')
         assert rdataset[0].serial == 1000
     vzone.set_max_versions(2)
-    assert len(vzone.versions) == 2
+    assert len(vzone._versions) == 2
     # The ones that survived should be 3 and 1000
-    rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+    rdataset = vzone._versions[0].get_rdataset(dns.name.empty,
+                                               dns.rdatatype.SOA,
                                               dns.rdatatype.NONE)
     assert rdataset[0].serial == 3
-    rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
-                                              dns.rdatatype.NONE)
+    rdataset = vzone._versions[1].get_rdataset(dns.name.empty,
+                                               dns.rdatatype.SOA,
+                                               dns.rdatatype.NONE)
     assert rdataset[0].serial == 1000
     with pytest.raises(ValueError):
         vzone.set_max_versions(0)