import lmdb
import os
import socket
+import struct
+
from dnsdisttests import DNSDistTest
@unittest.skipIf('SKIP_LMDB_TESTS' in os.environ, 'LMDB tests are disabled')
(receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
- self.assertEquals(expectedResponse, receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
def testLMDBQNamePlusTagLookup(self):
"""
(receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
- self.assertEquals(expectedResponse, receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
def testLMDBSuffixLookup(self):
"""
(receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
- self.assertEquals(expectedResponse, receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
def testLMDBQNamePlainText(self):
"""
(receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
- self.assertEquals(expectedResponse, receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
def testLMDBKeyValueStoreLookupRule(self):
"""
(receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
- self.assertEquals(expectedResponse, receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPInRange(DNSDistTest):
+
+ _lmdbFileName = '/tmp/test-lmdb-range-1-db'
+ _lmdbDBName = 'db-name'
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ kvs = newLMDBKVStore('%s', '%s')
+
+ -- KVS range lookups follow
+ -- does a range lookup in the LMDB database using the source IP as key
+ addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+ -- otherwise, spoof a different response
+ addAction(AllRule(), SpoofAction('9.9.9.9'))
+ """
+ _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+ @classmethod
+ def setUpLMDB(cls):
+ env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+ db = env.open_db(key=cls._lmdbDBName.encode())
+ with env.begin(db=db, write=True) as txn:
+ txn.put(socket.inet_aton('127.255.255.255') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+ @classmethod
+ def setUpClass(cls):
+
+ cls.setUpLMDB()
+ cls.startResponders()
+ cls.startDNSDist()
+ cls.setUpSockets()
+
+ print("Launching tests..")
+
+ def testLMDBSource(self):
+ """
+ LMDB range: Match on source address
+ """
+ name = 'source-ip.lmdb-range.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '5.6.7.8')
+ expectedResponse.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertFalse(receivedQuery)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
+
+class TestLMDBIPNotInRange(DNSDistTest):
+
+ _lmdbFileName = '/tmp/test-lmdb-range-2-db'
+ _lmdbDBName = 'db-name'
+ _config_template = """
+ newServer{address="127.0.0.1:%d"}
+
+ kvs = newLMDBKVStore('%s', '%s')
+
+ -- KVS range lookups follow
+ -- does a range lookup in the LMDB database using the source IP as key
+ addAction(KeyValueStoreRangeLookupRule(kvs, KeyValueLookupKeySourceIP(32, 128, true)), SpoofAction('5.6.7.8'))
+
+ -- otherwise, spoof a different response
+ addAction(AllRule(), SpoofAction('9.9.9.9'))
+ """
+ _config_params = ['_testServerPort', '_lmdbFileName', '_lmdbDBName']
+
+ @classmethod
+ def setUpLMDB(cls):
+ env = lmdb.open(cls._lmdbFileName, map_size=1014*1024, max_dbs=1024, subdir=False)
+ db = env.open_db(key=cls._lmdbDBName.encode())
+ with env.begin(db=db, write=True) as txn:
+ txn.put(socket.inet_aton('127.0.0.0') + struct.pack("!H", 255), socket.inet_aton('127.0.0.0') + struct.pack("!H", 0) + b'this is the value of the source address tag')
+
+ @classmethod
+ def setUpClass(cls):
+
+ cls.setUpLMDB()
+ cls.startResponders()
+ cls.startDNSDist()
+ cls.setUpSockets()
+
+ print("Launching tests..")
+
+ def testLMDBSource(self):
+ """
+ LMDB not in range: Match on source address
+ """
+ name = 'source-ip.lmdb-not-in-range.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '9.9.9.9')
+ expectedResponse.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertFalse(receivedQuery)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)