]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Increase branch coverage of the resolver.
authorBob Halley <halley@dnspython.org>
Wed, 22 Jul 2020 14:25:07 +0000 (07:25 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 22 Jul 2020 14:25:07 +0000 (07:25 -0700)
Speed up test resolver by around 6 seconds by mocking the clock.

Improve the cleaning test to verify that it really was cleaning that
removed the entry, and not get detecting the expiration.

tests/test_resolver.py

index 1eb8b1eb39ce716e3b4a35de193f8a1073795e01..3399e4a29b948de9b6ebce699760746d5b41c6b0 100644 (file)
@@ -160,6 +160,38 @@ class FakeAnswer(object):
         self.expiration = expiration
 
 
+class FakeTime:
+    # Mock the clock!
+    def __init__(self, now=None, want_fake=True):
+        if now is None:
+            now = time.time()
+        self.now = now
+        self.saved_time = time.time
+        self.want_fake = want_fake
+
+    def __enter__(self):
+        if self.want_fake:
+            time.time = self.time
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        if self.want_fake:
+            time.time = self.saved_time
+        return False
+
+    def time(self):
+        if self.want_fake:
+            return self.now
+        else:
+            return time.time()
+
+    def sleep(self, offset):
+        if self.want_fake:
+            self.now += offset
+        else:
+            time.sleep(offset)
+
+
 class BaseResolverTests(unittest.TestCase):
 
     def testRead(self):
@@ -211,26 +243,45 @@ class BaseResolverTests(unittest.TestCase):
             r.read_resolv_conf(f)
 
     def testCacheExpiration(self):
-        message = dns.message.from_text(message_text)
-        name = dns.name.from_text('example.')
-        answer = dns.resolver.Answer(name, dns.rdatatype.A, dns.rdataclass.IN,
-                                     message)
-        cache = dns.resolver.Cache()
-        cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
-        time.sleep(2)
-        self.assertTrue(cache.get((name, dns.rdatatype.A, dns.rdataclass.IN))
-                        is None)
+        with FakeTime() as fake_time:
+            message = dns.message.from_text(message_text)
+            name = dns.name.from_text('example.')
+            answer = dns.resolver.Answer(name, dns.rdatatype.A,
+                                         dns.rdataclass.IN, message)
+            cache = dns.resolver.Cache()
+            cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
+            fake_time.sleep(2)
+            self.assertTrue(cache.get((name, dns.rdatatype.A,
+                                       dns.rdataclass.IN))
+                            is None)
 
     def testCacheCleaning(self):
-        message = dns.message.from_text(message_text)
-        name = dns.name.from_text('example.')
-        answer = dns.resolver.Answer(name, dns.rdatatype.A, dns.rdataclass.IN,
-                                     message)
-        cache = dns.resolver.Cache(cleaning_interval=1.0)
-        cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
-        time.sleep(2)
-        self.assertTrue(cache.get((name, dns.rdatatype.A, dns.rdataclass.IN))
-                        is None)
+        with FakeTime() as fake_time:
+            message = dns.message.from_text(message_text)
+            name = dns.name.from_text('example.')
+            answer = dns.resolver.Answer(name, dns.rdatatype.A,
+                                         dns.rdataclass.IN, message)
+            cache = dns.resolver.Cache(cleaning_interval=1.0)
+            cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
+            fake_time.sleep(2)
+            cache._maybe_clean()
+            self.assertTrue(cache.data.get((name, dns.rdatatype.A,
+                                            dns.rdataclass.IN))
+                            is None)
+
+    def testCacheNonCleaning(self):
+        with FakeTime() as fake_time:
+            message = dns.message.from_text(message_text)
+            name = dns.name.from_text('example.')
+            answer = dns.resolver.Answer(name, dns.rdatatype.A,
+                                         dns.rdataclass.IN, message)
+            # override TTL as we're testing non-cleaning
+            answer.expiration = fake_time.time() + 100
+            cache = dns.resolver.Cache(cleaning_interval=1.0)
+            cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
+            fake_time.sleep(1.1)
+            self.assertEqual(cache.get((name, dns.rdatatype.A,
+                                        dns.rdataclass.IN)), answer)
 
     def testIndexErrorOnEmptyRRsetAccess(self):
         def bad():
@@ -252,6 +303,14 @@ class BaseResolverTests(unittest.TestCase):
             del answer[0]
         self.assertRaises(IndexError, bad)
 
+    def testRRsetDelete(self):
+        message = dns.message.from_text(message_text)
+        name = dns.name.from_text('example.')
+        answer = dns.resolver.Answer(name, dns.rdatatype.A,
+                                     dns.rdataclass.IN, message)
+        del answer[0]
+        self.assertEqual(len(answer), 0)
+
     def testLRUReplace(self):
         cache = dns.resolver.LRUCache(4)
         for i in range(0, 5):
@@ -293,21 +352,23 @@ class BaseResolverTests(unittest.TestCase):
                                 is None)
 
     def testLRUExpiration(self):
-        cache = dns.resolver.LRUCache(4)
-        for i in range(0, 4):
-            name = dns.name.from_text('example%d.' % i)
-            answer = FakeAnswer(time.time() + 1)
-            cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
-        time.sleep(2)
-        for i in range(0, 4):
-            name = dns.name.from_text('example%d.' % i)
-            self.assertTrue(cache.get((name, dns.rdatatype.A,
-                                       dns.rdataclass.IN))
-                            is None)
+        with FakeTime() as fake_time:
+            cache = dns.resolver.LRUCache(4)
+            for i in range(0, 4):
+                name = dns.name.from_text('example%d.' % i)
+                answer = FakeAnswer(time.time() + 1)
+                cache.put((name, dns.rdatatype.A, dns.rdataclass.IN), answer)
+            fake_time.sleep(2)
+            for i in range(0, 4):
+                name = dns.name.from_text('example%d.' % i)
+                self.assertTrue(cache.get((name, dns.rdatatype.A,
+                                           dns.rdataclass.IN))
+                                is None)
 
     def test_cache_flush(self):
         name1 = dns.name.from_text('name1')
         name2 = dns.name.from_text('name2')
+        name3 = dns.name.from_text('name3')
         basic_cache = dns.resolver.Cache()
         lru_cache = dns.resolver.LRUCache(100)
         for cache in [basic_cache, lru_cache]:
@@ -319,6 +380,12 @@ class BaseResolverTests(unittest.TestCase):
             self.assertTrue(canswer is answer1)
             canswer = cache.get((name2, dns.rdatatype.A, dns.rdataclass.IN))
             self.assertTrue(canswer is answer2)
+            # explicit flush of nonexistent key, just to exercise the branch
+            cache.flush((name3, dns.rdatatype.A, dns.rdataclass.IN))
+            canswer = cache.get((name1, dns.rdatatype.A, dns.rdataclass.IN))
+            self.assertTrue(canswer is answer1)
+            canswer = cache.get((name2, dns.rdatatype.A, dns.rdataclass.IN))
+            self.assertTrue(canswer is answer2)
             # explicit flush
             cache.flush((name1, dns.rdatatype.A, dns.rdataclass.IN))
             canswer = cache.get((name1, dns.rdatatype.A, dns.rdataclass.IN))