]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.recursor-dnssec/test_RPZ.py
Merge pull request #14200 from rgacogne/auth-enable-leak-detection-unit-tests
[thirdparty/pdns.git] / regression-tests.recursor-dnssec / test_RPZ.py
index 5d3121df37759613d10950e87d70a924b70de113..ca7292d39c750292665df1216081247b490ce307 100644 (file)
@@ -17,7 +17,7 @@ class RPZServer(object):
         self._targetSerial = 1
         self._serverPort = port
         listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
-        listener.setDaemon(True)
+        listener.daemon = True
         listener.start()
 
     def getCurrentSerial(self):
@@ -185,7 +185,12 @@ class RPZServer(object):
                 break
 
             wire = answer.to_wire()
-            conn.send(struct.pack("!H", len(wire)))
+            lenprefix = struct.pack("!H", len(wire))
+
+            for b in lenprefix:
+                conn.send(bytes([b]))
+                time.sleep(0.5)
+
             conn.send(wire)
             self._currentSerial = serial
             break
@@ -208,7 +213,7 @@ class RPZServer(object):
                 thread = threading.Thread(name='RPZ Connection Handler',
                                       target=self._connectionHandler,
                                       args=[conn])
-                thread.setDaemon(True)
+                thread.daemon = True
                 thread.start()
 
             except socket.error as e:
@@ -248,7 +253,28 @@ api-key=%s
 log-rpz-changes=yes
 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
 
-    def checkBlocked(self, name, shouldBeBlocked=True, adQuery=False, singleCheck=False):
+    def sendNotify(self):
+        notify = dns.message.make_query('zone.rpz', 'SOA', want_dnssec=False)
+        notify.set_opcode(4) # notify
+        res = self.sendUDPQuery(notify)
+        self.assertRcodeEqual(res, dns.rcode.NOERROR)
+        self.assertEqual(res.opcode(), 4)
+        self.assertEqual(res.question[0].to_text(), 'zone.rpz. IN SOA')
+
+    def assertAdditionalHasSOA(self, msg):
+        if not isinstance(msg, dns.message.Message):
+            raise TypeError("msg is not a dns.message.Message but a %s" % type(msg))
+
+        found = False
+        for rrset in msg.additional:
+            if rrset.rdtype == dns.rdatatype.SOA:
+                found = True
+                break
+
+        if not found:
+            raise AssertionError("No SOA record found in the authority section:\n%s" % msg.to_text())
+
+    def checkBlocked(self, name, shouldBeBlocked=True, adQuery=False, singleCheck=False, soa=False):
         query = dns.message.make_query(name, 'A', want_dnssec=True)
         query.flags |= dns.flags.CD
         if adQuery:
@@ -264,13 +290,15 @@ log-rpz-changes=yes
                 expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.42')
 
             self.assertRRsetInAnswer(res, expected)
+            if soa:
+                self.assertAdditionalHasSOA(res)
             if singleCheck:
                 break
 
     def checkNotBlocked(self, name, adQuery=False, singleCheck=False):
         self.checkBlocked(name, False, adQuery, singleCheck)
 
-    def checkCustom(self, qname, qtype, expected):
+    def checkCustom(self, qname, qtype, expected, soa=False):
         query = dns.message.make_query(qname, qtype, want_dnssec=True)
         query.flags |= dns.flags.CD
         for method in ("sendUDPQuery", "sendTCPQuery"):
@@ -278,8 +306,10 @@ log-rpz-changes=yes
             res = sender(query)
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertRRsetInAnswer(res, expected)
+            if soa:
+                self.assertAdditionalHasSOA(res)
 
-    def checkNoData(self, qname, qtype):
+    def checkNoData(self, qname, qtype, soa=False):
         query = dns.message.make_query(qname, qtype, want_dnssec=True)
         query.flags |= dns.flags.CD
         for method in ("sendUDPQuery", "sendTCPQuery"):
@@ -287,6 +317,8 @@ log-rpz-changes=yes
             res = sender(query)
             self.assertRcodeEqual(res, dns.rcode.NOERROR)
             self.assertEqual(len(res.answer), 0)
+            if soa:
+                self.assertAdditionalHasSOA(res)
 
     def checkNXD(self, qname, qtype='A'):
         query = dns.message.make_query(qname, qtype, want_dnssec=True)
@@ -298,7 +330,7 @@ log-rpz-changes=yes
             self.assertEqual(len(res.answer), 0)
             self.assertEqual(len(res.authority), 1)
 
-    def checkTruncated(self, qname, qtype='A'):
+    def checkTruncated(self, qname, qtype='A', soa=False):
         query = dns.message.make_query(qname, qtype, want_dnssec=True)
         query.flags |= dns.flags.CD
         res = self.sendUDPQuery(query)
@@ -306,7 +338,8 @@ log-rpz-changes=yes
         self.assertMessageHasFlags(res, ['QR', 'RA', 'RD', 'CD', 'TC'])
         self.assertEqual(len(res.answer), 0)
         self.assertEqual(len(res.authority), 0)
-        self.assertEqual(len(res.additional), 0)
+        if soa:
+            self.assertAdditionalHasSOA(res)
 
         res = self.sendTCPQuery(query)
         self.assertRcodeEqual(res, dns.rcode.NXDOMAIN)
@@ -328,7 +361,7 @@ log-rpz-changes=yes
         url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
         r = requests.get(url, headers=headers, timeout=self._wsTimeout)
         self.assertTrue(r)
-        self.assertEquals(r.status_code, 200)
+        self.assertEqual(r.status_code, 200)
         self.assertTrue(r.json())
         content = r.json()
         self.assertIn('zone.rpz.', content)
@@ -336,10 +369,10 @@ log-rpz-changes=yes
         for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
             self.assertIn(key, zone)
 
-        self.assertEquals(zone['serial'], serial)
-        self.assertEquals(zone['records'], recordsCount)
-        self.assertEquals(zone['transfers_full'], fullXFRCount)
-        self.assertEquals(zone['transfers_success'], totalXFRCount)
+        self.assertEqual(zone['serial'], serial)
+        self.assertEqual(zone['records'], recordsCount)
+        self.assertEqual(zone['transfers_full'], fullXFRCount)
+        self.assertEqual(zone['transfers_success'], totalXFRCount)
 
 rpzServerPort = 4250
 rpzServer = RPZServer(rpzServerPort)
@@ -352,7 +385,7 @@ class RPZXFRRecursorTest(RPZRecursorTest):
     global rpzServerPort
     _lua_config_file = """
     -- The first server is a bogus one, to test that we correctly fail over to the second one
-    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
+    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1, includeSOA=true})
     """ % (rpzServerPort)
     _confdir = 'RPZXFR'
     _wsPort = 8042
@@ -366,6 +399,9 @@ webserver-port=%d
 webserver-address=127.0.0.1
 webserver-password=%s
 api-key=%s
+disable-packetcache
+allow-notify-from=127.0.0.0/8
+allow-notify-for=zone.rpz
 """ % (_confdir, _wsPort, _wsPassword, _apiKey)
     _xfrDone = 0
 
@@ -403,72 +439,81 @@ e 3600 IN A 192.0.2.42
         raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
 
     def testRPZ(self):
+        # Fresh RPZ does not need a notify
+        self.waitForTCPSocket("127.0.0.1", self._wsPort)
         # first zone, only a should be blocked
         self.waitUntilCorrectSerialIsLoaded(1)
         self.checkRPZStats(1, 1, 1, self._xfrDone)
-        self.checkBlocked('a.example.')
+        self.checkBlocked('a.example.', soa=True)
         self.checkNotBlocked('b.example.')
         self.checkNotBlocked('c.example.')
 
         # second zone, a and b should be blocked
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(2)
         self.checkRPZStats(2, 2, 1, self._xfrDone)
-        self.checkBlocked('a.example.')
-        self.checkBlocked('b.example.')
+        self.checkBlocked('a.example.', soa=True)
+        self.checkBlocked('b.example.', soa=True)
         self.checkNotBlocked('c.example.')
 
         # third zone, only b should be blocked
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(3)
         self.checkRPZStats(3, 1, 1, self._xfrDone)
         self.checkNotBlocked('a.example.')
-        self.checkBlocked('b.example.')
+        self.checkBlocked('b.example.', soa=True)
         self.checkNotBlocked('c.example.')
 
         # fourth zone, only c should be blocked
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(4)
         self.checkRPZStats(4, 1, 1, self._xfrDone)
         self.checkNotBlocked('a.example.')
         self.checkNotBlocked('b.example.')
-        self.checkBlocked('c.example.')
+        self.checkBlocked('c.example.', soa=True)
 
         # fifth zone, we should get a full AXFR this time, and only d should be blocked
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(5)
         self.checkRPZStats(5, 3, 2, self._xfrDone)
         self.checkNotBlocked('a.example.')
         self.checkNotBlocked('b.example.')
         self.checkNotBlocked('c.example.')
-        self.checkBlocked('d.example.')
+        self.checkBlocked('d.example.', soa=True)
 
         # sixth zone, only e should be blocked, f is a local data record
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(6)
         self.checkRPZStats(6, 2, 2, self._xfrDone)
         self.checkNotBlocked('a.example.')
         self.checkNotBlocked('b.example.')
         self.checkNotBlocked('c.example.')
         self.checkNotBlocked('d.example.')
-        self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.1', '192.0.2.2'))
+        self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.1', '192.0.2.2'), soa=True)
         self.checkCustom('e.example.', 'MX', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'MX', '10 mx.example.'))
-        self.checkNoData('e.example.', 'AAAA')
-        self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'))
+        self.checkNoData('e.example.', 'AAAA', soa=True)
+        self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'), soa=True)
 
         # seventh zone, e should only have one A
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(7)
         self.checkRPZStats(7, 4, 2, self._xfrDone)
         self.checkNotBlocked('a.example.')
         self.checkNotBlocked('b.example.')
         self.checkNotBlocked('c.example.')
         self.checkNotBlocked('d.example.')
-        self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.2'))
-        self.checkCustom('e.example.', 'MX', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'MX', '10 mx.example.'))
-        self.checkNoData('e.example.', 'AAAA')
-        self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'))
+        self.checkCustom('e.example.', 'A', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.2'), soa=True)
+        self.checkCustom('e.example.', 'MX', dns.rrset.from_text('e.example.', 0, dns.rdataclass.IN, 'MX', '10 mx.example.'), soa=True)
+        self.checkNoData('e.example.', 'AAAA', soa=True)
+        self.checkCustom('f.example.', 'A', dns.rrset.from_text('f.example.', 0, dns.rdataclass.IN, 'CNAME', 'e.example.'), soa=True)
         # check that the policy is disabled for AD=1 queries
         self.checkNotBlocked('e.example.', True)
         # check non-custom policies
-        self.checkTruncated('tc.example.')
+        self.checkTruncated('tc.example.', soa=True)
         self.checkDropped('drop.example.')
 
         # eighth zone, all entries should be gone
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(8)
         self.checkRPZStats(8, 0, 3, self._xfrDone)
         self.checkNotBlocked('a.example.')
@@ -483,7 +528,9 @@ e 3600 IN A 192.0.2.42
         # 9th zone is a duplicate, it might get skipped
         global rpzServer
         rpzServer.moveToSerial(9)
+        self.sendNotify()
         time.sleep(3)
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(10)
         self.checkRPZStats(10, 1, 4, self._xfrDone)
         self.checkNotBlocked('a.example.')
@@ -491,13 +538,15 @@ e 3600 IN A 192.0.2.42
         self.checkNotBlocked('c.example.')
         self.checkNotBlocked('d.example.')
         self.checkNotBlocked('e.example.')
-        self.checkBlocked('f.example.')
+        self.checkBlocked('f.example.', soa=True)
         self.checkNXD('tc.example.')
         self.checkNXD('drop.example.')
 
         # the next update will update the zone twice
         rpzServer.moveToSerial(11)
+        self.sendNotify()
         time.sleep(3)
+        self.sendNotify()
         self.waitUntilCorrectSerialIsLoaded(12)
         self.checkRPZStats(12, 1, 4, self._xfrDone)
         self.checkNotBlocked('a.example.')
@@ -506,7 +555,7 @@ e 3600 IN A 192.0.2.42
         self.checkNotBlocked('d.example.')
         self.checkNotBlocked('e.example.')
         self.checkNXD('f.example.')
-        self.checkBlocked('g.example.')
+        self.checkBlocked('g.example.', soa=True)
         self.checkNXD('tc.example.')
         self.checkNXD('drop.example.')
 
@@ -517,7 +566,7 @@ class RPZFileRecursorTest(RPZRecursorTest):
 
     _confdir = 'RPZFile'
     _lua_config_file = """
-    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz." })
+    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz.", includeSOA=true })
     """ % (_confdir)
     _config_template = """
 auth-zones=example=configs/%s/example.zone
@@ -553,7 +602,7 @@ tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
     def testRPZ(self):
         self.checkCustom('a.example.', 'A', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42', '192.0.2.43'))
         self.checkCustom('a.example.', 'TXT', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'TXT', '"some text"'))
-        self.checkBlocked('z.example.')
+        self.checkBlocked('z.example.', soa=True)
         self.checkNotBlocked('b.example.')
         self.checkNotBlocked('c.example.')
         self.checkNotBlocked('d.example.')
@@ -561,7 +610,7 @@ tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
         # check that the policy is disabled for AD=1 queries
         self.checkNotBlocked('z.example.', True)
         # check non-custom policies
-        self.checkTruncated('tc.example.')
+        self.checkTruncated('tc.example.', soa=True)
         self.checkDropped('drop.example.')
 
 class RPZFileDefaultPolRecursorTest(RPZRecursorTest):
@@ -677,7 +726,7 @@ class RPZSimpleAuthServer(object):
     def __init__(self, port):
         self._serverPort = port
         listener = threading.Thread(name='RPZ Simple Auth Listener', target=self._listener, args=[])
-        listener.setDaemon(True)
+        listener.daemon = True
         listener.start()
 
     def _getAnswer(self, message):
@@ -953,7 +1002,7 @@ class RPZCNameChainCustomTest(RPZRecursorTest):
                 sender = getattr(self, method)
                 res = sender(query)
                 self.assertRcodeEqual(res, dns.rcode.NXDOMAIN)
-                self.assertEquals(len(res.answer), 0)
+                self.assertEqual(len(res.answer), 0)
 
     def testRPZChainNODATA(self):
         # we should match the A at the end of the CNAME chain and
@@ -967,7 +1016,7 @@ class RPZCNameChainCustomTest(RPZRecursorTest):
                 sender = getattr(self, method)
                 res = sender(query)
                 self.assertRcodeEqual(res, dns.rcode.NOERROR)
-                self.assertEquals(len(res.answer), 0)
+                self.assertEqual(len(res.answer), 0)
 
     def testRPZChainCustom(self):
         # we should match the A at the end of the CNAME chain and
@@ -982,6 +1031,86 @@ class RPZCNameChainCustomTest(RPZRecursorTest):
                 res = sender(query)
                 self.assertRcodeEqual(res, dns.rcode.NOERROR)
                 # the original CNAME record is signed
-                self.assertEquals(len(res.answer), 3)
+                self.assertEqual(len(res.answer), 3)
                 self.assertRRsetInAnswer(res, dns.rrset.from_text('cname-custom-a.example.', 0, dns.rdataclass.IN, 'CNAME', 'cname-custom-a-target.example.'))
                 self.assertRRsetInAnswer(res, dns.rrset.from_text('cname-custom-a-target.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.103'))
+
+class RPZFileModByLuaRecursorTest(RPZRecursorTest):
+    """
+    This test makes sure that we correctly load RPZ zones from a file while being modified by Lua callbacks
+    """
+
+    _confdir = 'RPZFileModByLua'
+    _lua_dns_script_file = """
+    function preresolve(dq)
+      if dq.qname:equal('zmod.example.') then
+        dq.appliedPolicy.policyKind = pdns.policykinds.Drop
+        return true
+      end
+      return false
+    end
+    function nxdomain(dq)
+      if dq.qname:equal('nxmod.example.') then
+        dq.appliedPolicy.policyKind = pdns.policykinds.Drop
+        return true
+      end
+      return false
+    end
+    function nodata(dq)
+      print("NODATA")
+      if dq.qname:equal('nodatamod.example.') then
+        dq.appliedPolicy.policyKind = pdns.policykinds.Drop
+        return true
+      end
+      return false
+    end
+    """
+    _lua_config_file = """
+    rpzFile('configs/%s/zone.rpz', { policyName="zone.rpz." })
+    """ % (_confdir)
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+""" % (_confdir)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+z 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+
+        rpzFilePath = os.path.join(confdir, 'zone.rpz')
+        with open(rpzFilePath, 'w') as rpzZone:
+            rpzZone.write("""$ORIGIN zone.rpz.
+@ 3600 IN SOA {soa}
+a.example.zone.rpz. 60 IN A 192.0.2.42
+a.example.zone.rpz. 60 IN A 192.0.2.43
+a.example.zone.rpz. 60 IN TXT "some text"
+drop.example.zone.rpz. 60 IN CNAME rpz-drop.
+zmod.example.zone.rpz. 60 IN A 192.0.2.1
+tc.example.zone.rpz. 60 IN CNAME rpz-tcp-only.
+nxmod.exmaple.zone.rpz. 60 in CNAME .
+nodatamod.example.zone.rpz. 60 in CNAME *.
+""".format(soa=cls._SOA))
+        super(RPZFileModByLuaRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def testRPZ(self):
+        self.checkCustom('a.example.', 'A', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'A', '192.0.2.42', '192.0.2.43'))
+        self.checkCustom('a.example.', 'TXT', dns.rrset.from_text('a.example.', 0, dns.rdataclass.IN, 'TXT', '"some text"'))
+        self.checkDropped('zmod.example.')
+        self.checkDropped('nxmod.example.')
+        self.checkDropped('nodatamod.example.')
+        self.checkNotBlocked('b.example.')
+        self.checkNotBlocked('c.example.')
+        self.checkNotBlocked('d.example.')
+        self.checkNotBlocked('e.example.')
+        # check non-custom policies
+        self.checkTruncated('tc.example.')
+        self.checkDropped('drop.example.')