From 20e329d623706dcaf3ab1088043e8b6f1568c0b1 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 19 Jul 2020 07:20:41 -0700 Subject: [PATCH] cache statistics --- dns/resolver.py | 71 +++++++++++++++++++++++++++++++++++++++--- tests/test_resolver.py | 35 +++++++++++++++++++++ 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/dns/resolver.py b/dns/resolver.py index 513841e6..4d2b72b0 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -283,7 +283,54 @@ class Answer: del self.rrset[i] -class Cache: +class CacheStats: + """Cache Statistics + """ + + def __init__(self, hits=0, misses=0): + self.hits = hits + self.misses = misses + + def reset(self): + self.hits = 0 + self.misses = 0 + + def clone(self): + return CacheStats(self.hits, self.misses) + + +class CacheBase: + def __init__(self): + self.lock = _threading.Lock() + self.statistics = CacheStats() + + def reset_statistics(self): + """Reset all statistics to zero.""" + with self.lock: + self.statistics.reset() + + def hits(self): + """How many hits has the cache had?""" + with self.lock: + return self.statistics.hits + + def misses(self): + """How many misses has the cache had?""" + with self.lock: + return self.statistics.misses + + def get_statistics_snapshot(self): + """Return a consistent snapshot of all the statistics. + + If running with multiple threads, it's better to take a + snapshot than to call statistics methods such as hits() and + misses() individually. + """ + with self.lock: + return self.statistics.clone() + + +class Cache(CacheBase): """Simple thread-safe DNS answer cache.""" def __init__(self, cleaning_interval=300.0): @@ -291,10 +338,10 @@ class Cache: periodic cleanings. """ + super().__init__() self.data = {} self.cleaning_interval = cleaning_interval self.next_cleaning = time.time() + self.cleaning_interval - self.lock = _threading.Lock() def _maybe_clean(self): """Clean the cache if it's time to do so.""" @@ -325,7 +372,9 @@ class Cache: self._maybe_clean() v = self.data.get(key) if v is None or v.expiration <= time.time(): + self.statistics.misses += 1 return None + self.statistics.hits += 1 return v def put(self, key, value): @@ -366,6 +415,7 @@ class LRUCacheNode: def __init__(self, key, value): self.key = key self.value = value + self.hits = 0 self.prev = self self.next = self @@ -380,7 +430,7 @@ class LRUCacheNode: self.prev.next = self.next -class LRUCache: +class LRUCache(CacheBase): """Thread-safe, bounded, least-recently-used DNS answer cache. This cache is better than the simple cache (above) if you're @@ -395,12 +445,12 @@ class LRUCache: it must be greater than 0. """ + super().__init__() self.data = {} self.set_max_size(max_size) self.sentinel = LRUCacheNode(None, None) self.sentinel.prev = self.sentinel self.sentinel.next = self.sentinel - self.lock = _threading.Lock() def set_max_size(self, max_size): if max_size < 1: @@ -421,16 +471,29 @@ class LRUCache: with self.lock: node = self.data.get(key) if node is None: + self.statistics.misses += 1 return None # Unlink because we're either going to move the node to the front # of the LRU list or we're going to free it. node.unlink() if node.value.expiration <= time.time(): del self.data[node.key] + self.statistics.misses += 1 return None node.link_after(self.sentinel) + self.statistics.hits += 1 + node.hits += 1 return node.value + def get_hits_for_key(self, key): + """Return the number of cache hits associated with the specified key.""" + with self.lock: + node = self.data.get(key) + if node is None or node.value.expiration <= time.time(): + return 0 + else: + return node.hits + def put(self, key, value): """Associate key and value in the cache. diff --git a/tests/test_resolver.py b/tests/test_resolver.py index cadf2245..b63ec196 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -347,6 +347,41 @@ class BaseResolverTests(unittest.TestCase): self.assertFalse(on_lru_list(cache, key, answer1)) self.assertTrue(on_lru_list(cache, key, answer2)) + def test_cache_stats(self): + caches = [dns.resolver.Cache(), dns.resolver.LRUCache(4)] + key1 = (dns.name.from_text('key1.'), dns.rdatatype.A, dns.rdataclass.IN) + key2 = (dns.name.from_text('key2.'), dns.rdatatype.A, dns.rdataclass.IN) + for cache in caches: + answer1 = FakeAnswer(time.time() + 10) + answer2 = FakeAnswer(10) # expired! + a = cache.get(key1) + self.assertIsNone(a) + self.assertEqual(cache.hits(), 0) + self.assertEqual(cache.misses(), 1) + if isinstance(cache, dns.resolver.LRUCache): + self.assertEqual(cache.get_hits_for_key(key1), 0) + cache.put(key1, answer1) + a = cache.get(key1) + self.assertIs(a, answer1) + self.assertEqual(cache.hits(), 1) + self.assertEqual(cache.misses(), 1) + if isinstance(cache, dns.resolver.LRUCache): + self.assertEqual(cache.get_hits_for_key(key1), 1) + cache.put(key2, answer2) + a = cache.get(key2) + self.assertIsNone(a) + self.assertEqual(cache.hits(), 1) + self.assertEqual(cache.misses(), 2) + if isinstance(cache, dns.resolver.LRUCache): + self.assertEqual(cache.get_hits_for_key(key2), 0) + stats = cache.get_statistics_snapshot() + self.assertEqual(stats.hits, 1) + self.assertEqual(stats.misses, 2) + cache.reset_statistics() + stats = cache.get_statistics_snapshot() + self.assertEqual(stats.hits, 0) + self.assertEqual(stats.misses, 0) + def testEmptyAnswerSection(self): # TODO: dangling_cname_0_message_text was the only sample message # with an empty answer section. Other than that it doesn't -- 2.47.3