]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_LMDB.py
Merge pull request #14001 from rgacogne/ddist-ffi-policy-no-server
[thirdparty/pdns.git] / regression-tests.dnsdist / test_LMDB.py
index 7529ddd45666306a7513e77d58073ea3ca235b4a..5e308d7d32bceb737829a713e38647520f602604 100644 (file)
@@ -2,9 +2,13 @@
 import unittest
 import dns
 import lmdb
+import os
 import socket
+import struct
+
 from dnsdisttests import DNSDistTest
 
+@unittest.skipIf('SKIP_LMDB_TESTS' in os.environ, 'LMDB tests are disabled')
 class TestLMDB(DNSDistTest):
 
     _lmdbFileName = '/tmp/test-lmdb-db'
@@ -97,7 +101,7 @@ class TestLMDB(DNSDistTest):
             (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
-            self.assertEquals(expectedResponse, receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
 
     def testLMDBQNamePlusTagLookup(self):
         """
@@ -120,7 +124,7 @@ class TestLMDB(DNSDistTest):
             (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
-            self.assertEquals(expectedResponse, receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
 
     def testLMDBSuffixLookup(self):
         """
@@ -143,7 +147,7 @@ class TestLMDB(DNSDistTest):
             (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
-            self.assertEquals(expectedResponse, receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
 
     def testLMDBQNamePlainText(self):
         """
@@ -166,7 +170,7 @@ class TestLMDB(DNSDistTest):
             (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
-            self.assertEquals(expectedResponse, receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
 
     def testLMDBKeyValueStoreLookupRule(self):
         """
@@ -189,4 +193,120 @@ class TestLMDB(DNSDistTest):
             (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
-            self.assertEquals(expectedResponse, receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-1-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.255.255.255') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB range: Match on source address
+        """
+        name = 'source-ip.lmdb-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '5.6.7.8')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPNotInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-2-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.0.0.0') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB not in range: Match on source address
+        """
+        name = 'source-ip.lmdb-not-in-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '9.9.9.9')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)