(receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
self.assertTrue(receivedProxyPayload)
self.assertTrue(receivedDNSData)
+
+class QueryCounter:
+
+ def __init__(self, name):
+ self.name = name
+ self.qcnt = 0
+
+ def __call__(self):
+ return self.qcnt
+
+ def create_cb(self):
+ def callback(request):
+ self.qcnt += 1
+ response = dns.message.make_response(request)
+ response.set_rcode(dns.rcode.REFUSED)
+ return response.to_wire()
+ return callback
+
+class TestRestartCount(DNSDistTest):
+
+ _queryCounts = {}
+
+ _testServer1Port = pickAvailablePort()
+ _testServer2Port = pickAvailablePort()
+ _testServer3Port = pickAvailablePort()
+ _testServer4Port = pickAvailablePort()
+ _serverPorts = [_testServer1Port, _testServer2Port, _testServer3Port, _testServer4Port]
+ _config_params = ['_testServer1Port', '_testServer2Port', '_testServer3Port', '_testServer4Port']
+ _config_template = """
+ MaxRestart = 2
+ s0 = newServer{name="s0", address="127.0.0.1:%s"}
+ s0:setUp()
+ s0:addPool("pool0")
+ s1 = newServer{name="s1", address="127.0.0.1:%s"}
+ s1:setUp()
+ s1:addPool("pool1")
+ s2 = newServer{name="s2", address="127.0.0.1:%s"}
+ s2:setUp()
+ s2:addPool("pool2")
+ s3 = newServer{name="s3", address="127.0.0.1:%s"}
+ s3:setUp()
+ s3:addPool("pool3")
+ function makeQueryRestartable(dq) dq:setRestartable() return DNSAction.None end
+ addAction(AllRule(), LuaAction(makeQueryRestartable))
+ function restartQuery(dr)
+ if dr:getRestartCount() < MaxRestart then
+ dr.pool = "pool"..tostring(dr:getRestartCount() + 1)
+ dr:restart()
+ else
+ return DNSResponseAction.ServFail
+ end
+ return DNSResponseAction.None
+ end
+ addResponseAction(RCodeRule(DNSRCode.REFUSED), LuaResponseAction(restartQuery))
+ addAction(AllRule(), PoolAction("pool0"))
+ """
+ @classmethod
+ def startResponders(cls):
+ print("Launching responders..")
+
+ for i, name in enumerate(['s0', 's1', 's2', 's3']):
+ cls._queryCounts[name] = QueryCounter(name)
+ cb = cls._queryCounts[name].create_cb()
+ responder = threading.Thread(name=name, target=cls.UDPResponder, args=[cls._serverPorts[i], cls._toResponderQueue, cls._fromResponderQueue, False, cb])
+ responder.daemon = True
+ responder.start()
+
+ def testDefault(self):
+
+ numberOfQueries = 100
+ name = 'restart.count.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ expectedResponse = dns.message.make_response(query)
+ expectedResponse.set_rcode(dns.rcode.SERVFAIL)
+
+ # send 100 queries
+ for _ in range(numberOfQueries):
+ (_, receivedResponse) = self.sendUDPQuery(query, response=None, useQueue=False)
+ self.assertTrue(receivedResponse)
+ self.assertEqual(expectedResponse, receivedResponse)
+
+ # if restart count is correct, s0/s1/s2 would get all the queies while s3 would get none
+ self.assertEqual(self._queryCounts['s0'](), numberOfQueries)
+ self.assertEqual(self._queryCounts['s1'](), numberOfQueries)
+ self.assertEqual(self._queryCounts['s2'](), numberOfQueries)
+ self.assertEqual(self._queryCounts['s3'](), 0)