self.assertRcodeEqual(res, dns.rcode.NOERROR)
self.assertEqual(res.answer, response.answer)
- def _getCounter(self):
+ def _getCounter(self, tcp=False):
"""
Helper function for shared/non-shared testing
"""
query = dns.message.make_query(name, 'TXT')
responses = []
+ sender = self.sendTCPQuery if tcp else self.sendUDPQuery
+
for i in range(50):
- res = self.sendUDPQuery(query)
+ res = sender(query)
responses.append(res.answer[0][0])
return(responses)
Test non-shared behaviour
"""
- res = set(self._getCounter())
+ resUDP = set(self._getCounter(tcp=False))
+ resTCP = set(self._getCounter(tcp=True))
- self.assertEqual(len(res), 1)
+ self.assertEqual(len(resUDP), 1)
+ self.assertEqual(len(resTCP), 1)
class TestLuaRecordsShared(TestLuaRecords):
_config_template = """
Test shared behaviour
"""
- res = set(self._getCounter())
+ resUDP = set(self._getCounter(tcp=False))
+ resTCP = set(self._getCounter(tcp=True))
- self.assertEqual(len(res), 50)
+ self.assertEqual(len(resUDP), 50)
+ self.assertEqual(len(resTCP), 50)
if __name__ == '__main__':
unittest.main()