import unittest
import dns
import lmdb
+import os
import socket
from dnsdisttests import DNSDistTest
+@unittest.skipIf('SKIP_LMDB_TESTS' in os.environ, 'LMDB tests are disabled')
class TestLMDB(DNSDistTest):
_lmdbFileName = '/tmp/test-lmdb-db'
kvs = newLMDBKVStore('%s', '%s')
-- KVS lookups follow
+ -- if the qname is 'kvs-rule.lmdb.tests.powerdns.com.', does a lookup in the LMDB database using the qname as key, and spoof an answer if it matches
+ addAction(AndRule{QNameRule('kvs-rule.lmdb.tests.powerdns.com.'), KeyValueStoreLookupRule(kvs, KeyValueLookupKeyQName(false))}, SpoofAction('13.14.15.16'))
+
-- does a lookup in the LMDB database using the source IP as key, and store the result into the 'kvs-sourceip-result' tag
addAction(AllRule(), KeyValueStoreLookupAction(kvs, KeyValueLookupKeySourceIP(), 'kvs-sourceip-result'))
txn.put(socket.inet_aton('127.0.0.1'), b'this is the value of the source address tag')
txn.put(b'this is the value of the qname tag', b'this is the value of the second tag')
txn.put(b'\x06suffix\x04lmdb\x05tests\x08powerdns\x03com\x00', b'this is the value of the suffix tag')
- txn.put(b'qname-plaintext.lmdb.tests.powerdns.com.', b'this is the value of the plaintext tag')
+ txn.put(b'qname-plaintext.lmdb.tests.powerdns.com', b'this is the value of the plaintext tag')
+ txn.put(b'kvs-rule.lmdb.tests.powerdns.com', b'the value does not matter')
@classmethod
def setUpClass(cls):
self.assertFalse(receivedQuery)
self.assertTrue(receivedResponse)
self.assertEquals(expectedResponse, receivedResponse)
+
+ def testLMDBKeyValueStoreLookupRule(self):
+ """
+ LMDB: KeyValueStoreLookupRule
+ """
+ name = 'kvs-rule.lmdb.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ # dnsdist set RA = RD for spoofed responses
+ query.flags &= ~dns.flags.RD
+ expectedResponse = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '13.14.15.16')
+ expectedResponse.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response=None, useQueue=False)
+ self.assertFalse(receivedQuery)
+ self.assertTrue(receivedResponse)
+ self.assertEquals(expectedResponse, receivedResponse)