]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
cache statistics
authorBob Halley <halley@dnspython.org>
Sun, 19 Jul 2020 14:20:41 +0000 (07:20 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 19 Jul 2020 14:20:41 +0000 (07:20 -0700)
dns/resolver.py
tests/test_resolver.py

index 513841e6881402ab8d3d48399f86468ca3a8d072..4d2b72b0358c69464cbca1d0d55b15d78e4076e4 100644 (file)
@@ -283,7 +283,54 @@ class Answer:
         del self.rrset[i]
 
 
-class Cache:
+class CacheStats:
+    """Cache Statistics
+    """
+
+    def __init__(self, hits=0, misses=0):
+        self.hits = hits
+        self.misses = misses
+
+    def reset(self):
+        self.hits = 0
+        self.misses = 0
+
+    def clone(self):
+        return CacheStats(self.hits, self.misses)
+
+
+class CacheBase:
+    def __init__(self):
+        self.lock = _threading.Lock()
+        self.statistics = CacheStats()
+
+    def reset_statistics(self):
+        """Reset all statistics to zero."""
+        with self.lock:
+            self.statistics.reset()
+
+    def hits(self):
+        """How many hits has the cache had?"""
+        with self.lock:
+            return self.statistics.hits
+
+    def misses(self):
+        """How many misses has the cache had?"""
+        with self.lock:
+            return self.statistics.misses
+
+    def get_statistics_snapshot(self):
+        """Return a consistent snapshot of all the statistics.
+
+        If running with multiple threads, it's better to take a
+        snapshot than to call statistics methods such as hits() and
+        misses() individually.
+        """
+        with self.lock:
+            return self.statistics.clone()
+
+
+class Cache(CacheBase):
     """Simple thread-safe DNS answer cache."""
 
     def __init__(self, cleaning_interval=300.0):
@@ -291,10 +338,10 @@ class Cache:
         periodic cleanings.
         """
 
+        super().__init__()
         self.data = {}
         self.cleaning_interval = cleaning_interval
         self.next_cleaning = time.time() + self.cleaning_interval
-        self.lock = _threading.Lock()
 
     def _maybe_clean(self):
         """Clean the cache if it's time to do so."""
@@ -325,7 +372,9 @@ class Cache:
             self._maybe_clean()
             v = self.data.get(key)
             if v is None or v.expiration <= time.time():
+                self.statistics.misses += 1
                 return None
+            self.statistics.hits += 1
             return v
 
     def put(self, key, value):
@@ -366,6 +415,7 @@ class LRUCacheNode:
     def __init__(self, key, value):
         self.key = key
         self.value = value
+        self.hits = 0
         self.prev = self
         self.next = self
 
@@ -380,7 +430,7 @@ class LRUCacheNode:
         self.prev.next = self.next
 
 
-class LRUCache:
+class LRUCache(CacheBase):
     """Thread-safe, bounded, least-recently-used DNS answer cache.
 
     This cache is better than the simple cache (above) if you're
@@ -395,12 +445,12 @@ class LRUCache:
         it must be greater than 0.
         """
 
+        super().__init__()
         self.data = {}
         self.set_max_size(max_size)
         self.sentinel = LRUCacheNode(None, None)
         self.sentinel.prev = self.sentinel
         self.sentinel.next = self.sentinel
-        self.lock = _threading.Lock()
 
     def set_max_size(self, max_size):
         if max_size < 1:
@@ -421,16 +471,29 @@ class LRUCache:
         with self.lock:
             node = self.data.get(key)
             if node is None:
+                self.statistics.misses += 1
                 return None
             # Unlink because we're either going to move the node to the front
             # of the LRU list or we're going to free it.
             node.unlink()
             if node.value.expiration <= time.time():
                 del self.data[node.key]
+                self.statistics.misses += 1
                 return None
             node.link_after(self.sentinel)
+            self.statistics.hits += 1
+            node.hits += 1
             return node.value
 
+    def get_hits_for_key(self, key):
+        """Return the number of cache hits associated with the specified key."""
+        with self.lock:
+            node = self.data.get(key)
+            if node is None or node.value.expiration <= time.time():
+                return 0
+            else:
+                return node.hits
+
     def put(self, key, value):
         """Associate key and value in the cache.
 
index cadf2245c1f0592baf5ee983641b3e3dd7769607..b63ec1963ccb5c81f4feaa531f52902cc5435208 100644 (file)
@@ -347,6 +347,41 @@ class BaseResolverTests(unittest.TestCase):
         self.assertFalse(on_lru_list(cache, key, answer1))
         self.assertTrue(on_lru_list(cache, key, answer2))
 
+    def test_cache_stats(self):
+        caches = [dns.resolver.Cache(), dns.resolver.LRUCache(4)]
+        key1 = (dns.name.from_text('key1.'), dns.rdatatype.A, dns.rdataclass.IN)
+        key2 = (dns.name.from_text('key2.'), dns.rdatatype.A, dns.rdataclass.IN)
+        for cache in caches:
+            answer1 = FakeAnswer(time.time() + 10)
+            answer2 = FakeAnswer(10)  # expired!
+            a = cache.get(key1)
+            self.assertIsNone(a)
+            self.assertEqual(cache.hits(), 0)
+            self.assertEqual(cache.misses(), 1)
+            if isinstance(cache, dns.resolver.LRUCache):
+                self.assertEqual(cache.get_hits_for_key(key1), 0)
+            cache.put(key1, answer1)
+            a = cache.get(key1)
+            self.assertIs(a, answer1)
+            self.assertEqual(cache.hits(), 1)
+            self.assertEqual(cache.misses(), 1)
+            if isinstance(cache, dns.resolver.LRUCache):
+                self.assertEqual(cache.get_hits_for_key(key1), 1)
+            cache.put(key2, answer2)
+            a = cache.get(key2)
+            self.assertIsNone(a)
+            self.assertEqual(cache.hits(), 1)
+            self.assertEqual(cache.misses(), 2)
+            if isinstance(cache, dns.resolver.LRUCache):
+                self.assertEqual(cache.get_hits_for_key(key2), 0)
+            stats = cache.get_statistics_snapshot()
+            self.assertEqual(stats.hits, 1)
+            self.assertEqual(stats.misses, 2)
+            cache.reset_statistics()
+            stats = cache.get_statistics_snapshot()
+            self.assertEqual(stats.hits, 0)
+            self.assertEqual(stats.misses, 0)
+
     def testEmptyAnswerSection(self):
         # TODO: dangling_cname_0_message_text was the only sample message
         #       with an empty answer section. Other than that it doesn't