]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_LMDB.py
Merge pull request #14001 from rgacogne/ddist-ffi-policy-no-server
[thirdparty/pdns.git] / regression-tests.dnsdist / test_LMDB.py
index 0fb1143e826d57136d0e033a495a9ec9b03d8b05..5e308d7d32bceb737829a713e38647520f602604 100644 (file)
@@ -4,6 +4,8 @@ import dns
 import lmdb
 import os
 import socket
+import struct
+
 from dnsdisttests import DNSDistTest
 
 @unittest.skipIf('SKIP_LMDB_TESTS' in os.environ, 'LMDB tests are disabled')
@@ -192,3 +194,119 @@ class TestLMDB(DNSDistTest):
             self.assertFalse(receivedQuery)
             self.assertTrue(receivedResponse)
             self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-1-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.255.255.255') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB range: Match on source address
+        """
+        name = 'source-ip.lmdb-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '5.6.7.8')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPNotInRange(DNSDistTest):
+
+    _lmdbFileName = '/tmp/test-lmdb-range-2-db'
+    _lmdbDBName = 'db-name'
+    _config_template = """
+    newServer{address="127.0.0.1:%d"}
+
+    kvs = newLMDBKVStore('%s', '%s')
+
+    -- KVS range lookups follow
+    -- does a range lookup in the LMDB database using the source IP as key
+    addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+    -- otherwise, spoof a different response
+    addAction(AllRule(), SpoofAction('9.9.9.9'))
+    """
+    _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+    @classmethod
+    def setUpLMDB(cls):
+        env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+        db = env.open_db(key=cls._lmdbDBName.encode())
+        with env.begin(db=db, write=True) as txn:
+            txn.put(socket.inet_aton('127.0.0.0') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+    @classmethod
+    def setUpClass(cls):
+
+        cls.setUpLMDB()
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
+
+    def testLMDBSource(self):
+        """
+        LMDB not in range: Match on source address
+        """
+        name = 'source-ip.lmdb-not-in-range.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '9.9.9.9')
+        expectedResponse.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+            self.assertFalse(receivedQuery)
+            self.assertTrue(receivedResponse)
+            self.assertEqual(expectedResponse, receivedResponse)