]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
selftest: merge DNSTest boilerplate
authorAndrew Bartlett <abartlet@samba.org>
Thu, 8 Jun 2017 22:00:09 +0000 (10:00 +1200)
committerAndrew Bartlett <abartlet@samba.org>
Sat, 10 Jun 2017 19:48:21 +0000 (21:48 +0200)
This will help unifying dns.py and dns_tkey.py to use common subclasses

The code was originally copied, but has since divereged.  This handles
that divergence.

Signed-off-by: Andrew Bartlett <abartlet@samba.org>
Reviewed-by: Garming Sam <garming@catalyst.net.nz>
python/samba/tests/dns.py
python/samba/tests/dns_tkey.py

index 50d7fb60c2f18a273e2769a19ff638ee928b9593..ac82d5051aa50ff4ae11afe5a70d7d77d6d6571a 100644 (file)
@@ -63,12 +63,8 @@ creds.set_krb_forwardable(credentials.NO_KRB_FORWARDABLE)
 class DNSTest(TestCase):
 
     def setUp(self):
-        global server, server_ip, lp, creds
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.lp = lp
-        self.creds = creds
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -84,30 +80,42 @@ class DNSTest(TestCase):
             "NXRRSET",
             "NOTAUTH",
             "NOTZONE",
+            "0x0B",
+            "0x0C",
+            "0x0D",
+            "0x0E",
+            "0x0F",
+            "BADSIG",
+            "BADKEY"
         ]
 
         return string_codes[errcode]
 
+    def assert_rcode_equals(self, rcode, expected):
+        "Helper function to check return code"
+        self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
+                          (self.errstr(expected), self.errstr(rcode)))
 
     def assert_dns_rcode_equals(self, packet, rcode):
         "Helper function to check return code"
         p_errcode = packet.operation & 0x000F
         self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
-                            (self.errstr(rcode), self.errstr(p_errcode)))
+                          (self.errstr(rcode), self.errstr(p_errcode)))
 
     def assert_dns_opcode_equals(self, packet, opcode):
         "Helper function to check opcode"
         p_opcode = packet.operation & 0x7800
         self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
-                            (opcode, p_opcode))
+                          (opcode, p_opcode))
 
     def make_name_packet(self, opcode, qid=None):
         "Helper creating a dns.name_packet"
         p = dns.name_packet()
         if qid is None:
-            p.id = random.randint(0x0, 0xffff)
+            p.id = random.randint(0x0, 0xff00)
         p.operation = opcode
         p.questions = []
+        p.additional = []
         return p
 
     def finish_name_packet(self, packet, questions):
@@ -135,10 +143,12 @@ class DNSTest(TestCase):
         "Helper to get dns domain"
         return self.creds.get_realm().lower()
 
-    def dns_transaction_udp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
+    def dns_transaction_udp(self, packet, host,
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -146,19 +156,22 @@ class DNSTest(TestCase):
             s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
             s.settimeout(timeout)
             s.connect((host, 53))
-            s.send(send_packet, 0)
+            s.sendall(send_packet, 0)
             recv_packet = s.recv(2048, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet)
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet)
+            return (response, recv_packet)
         finally:
             if s is not None:
                 s.close()
 
-    def dns_transaction_tcp(self, packet, host=server_ip,
-                            dump=False, timeout=timeout):
-        "send a DNS query and read the reply"
+    def dns_transaction_tcp(self, packet, host,
+                            dump=False, timeout=None):
+        "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -168,14 +181,22 @@ class DNSTest(TestCase):
             s.connect((host, 53))
             tcp_packet = struct.pack('!H', len(send_packet))
             tcp_packet += send_packet
-            s.send(tcp_packet, 0)
+            s.sendall(tcp_packet)
+
             recv_packet = s.recv(0xffff + 2, 0)
             if dump:
                 print self.hexdump(recv_packet)
-            return ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+            response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
+
         finally:
-                if s is not None:
-                    s.close()
+            if s is not None:
+                s.close()
+
+        # unpacking and packing again should produce same bytestream
+        my_packet = ndr.ndr_pack(response)
+        self.assertEquals(my_packet, recv_packet[2:])
+
+        return (response, recv_packet[2:])
 
     def make_txt_update(self, prefix, txt_array):
         p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
@@ -210,12 +231,21 @@ class DNSTest(TestCase):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
 
+
 class TestSimpleQueries(DNSTest):
+    def setUp(self):
+        super(TestSimpleQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "create a query packet containing one query record"
@@ -228,7 +258,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -246,7 +276,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -264,7 +294,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_tcp(p)
+        (response, response_packet) = self.dns_transaction_tcp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -282,7 +312,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -296,7 +326,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -316,7 +346,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -336,7 +366,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
 
         num_answers = 1
         dc_ipv6 = os.getenv('SERVER_IPV6')
@@ -362,7 +392,7 @@ class TestSimpleQueries(DNSTest):
 
         self.finish_name_packet(p, questions)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -381,7 +411,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         # We don't get SOA records for single hosts
@@ -399,7 +429,7 @@ class TestSimpleQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -407,6 +437,14 @@ class TestSimpleQueries(DNSTest):
 
 
 class TestDNSUpdates(DNSTest):
+    def setUp(self):
+        super(TestDNSUpdates, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_two_updates(self):
         "create two update requests"
@@ -423,7 +461,7 @@ class TestDNSUpdates(DNSTest):
 
         self.finish_name_packet(p, updates)
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -442,7 +480,7 @@ class TestDNSUpdates(DNSTest):
         updates.append(u)
 
         self.finish_name_packet(p, updates)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NOTIMP)
 
     def test_update_prereq_with_non_null_ttl(self):
@@ -469,7 +507,7 @@ class TestDNSUpdates(DNSTest):
         p.answers = prereqs
 
         try:
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_FORMERR)
         except socket.timeout:
             # Windows chooses not to respond to incorrectly formatted queries.
@@ -501,7 +539,7 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_prereq_nonexisting_name(self):
@@ -527,14 +565,14 @@ class TestDNSUpdates(DNSTest):
         p.ancount = len(prereqs)
         p.answers = prereqs
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXRRSET)
 
     def test_update_add_txt_record(self):
         "test adding records works"
         prefix, txt = 'textrec', ['"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
 
@@ -566,7 +604,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -576,7 +614,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -602,7 +640,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # And finally check it's gone
@@ -613,7 +651,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
     def test_readd_record(self):
@@ -644,7 +682,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -654,7 +692,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now delete the record
@@ -680,7 +718,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # check it's gone
@@ -691,7 +729,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
 
         # recreate the record
@@ -717,7 +755,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         # Now check the record is around
@@ -727,7 +765,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_update_add_mx_record(self):
@@ -756,7 +794,7 @@ class TestDNSUpdates(DNSTest):
         p.nscount = len(updates)
         p.nsrecs = updates
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
         p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
@@ -767,7 +805,7 @@ class TestDNSUpdates(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assertEqual(response.ancount, 1)
         ans = response.answers[0]
@@ -795,22 +833,22 @@ class TestComplexQueries(DNSTest):
         updates = [r]
         p.nscount = 1
         p.nsrecs = updates
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def setUp(self):
         super(TestComplexQueries, self).setUp()
 
-    def tearDown(self):
-        super(TestComplexQueries, self).tearDown()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "create a query packet containing one query record"
 
-        name = "cname_test.%s" % self.get_dns_domain()
-        rdata = "%s.%s" % (self.server, self.get_dns_domain())
-        self.make_dns_update(name, rdata, dns.DNS_QTYPE_CNAME)
-
         try:
 
             # Create the record
@@ -828,7 +866,7 @@ class TestComplexQueries(DNSTest):
             questions.append(q)
 
             self.finish_name_packet(p, questions)
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
             self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
             self.assertEquals(response.ancount, 2)
@@ -862,7 +900,7 @@ class TestComplexQueries(DNSTest):
             p.nscount = len(updates)
             p.nsrecs = updates
 
-            response = self.dns_transaction_udp(p)
+            (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
             self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
 
     def test_cname_two_chain(self):
@@ -880,7 +918,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 3)
@@ -921,7 +959,7 @@ class TestComplexQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
 
@@ -938,6 +976,14 @@ class TestComplexQueries(DNSTest):
         self.assertEquals(response.answers[1].rdata, name0)
 
 class TestInvalidQueries(DNSTest):
+    def setUp(self):
+        super(TestInvalidQueries, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
 
     def test_one_a_query(self):
         "send 0 bytes follows by create a query packet containing one query record"
@@ -960,7 +1006,7 @@ class TestInvalidQueries(DNSTest):
         questions.append(q)
 
         self.finish_name_packet(p, questions)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
@@ -1006,6 +1052,13 @@ class TestInvalidQueries(DNSTest):
 class TestZones(DNSTest):
     def setUp(self):
         super(TestZones, self).setUp()
+        global server, server_ip, lp, creds, timeout
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
+        self.timeout = timeout
+
         self.zone = "test.lan"
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
@@ -1056,21 +1109,21 @@ class TestZones(DNSTest):
         questions.append(q)
         self.finish_name_packet(p, questions)
 
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         # Windows returns OK while BIND logically seems to return NXDOMAIN
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
 
         self.create_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 1)
         self.assertEquals(response.answers[0].rr_type, dns.DNS_QTYPE_SOA)
 
         self.delete_zone(zone)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_NXDOMAIN)
         self.assert_dns_opcode_equals(response, dns.DNS_OPCODE_QUERY)
         self.assertEquals(response.ancount, 0)
@@ -1078,6 +1131,11 @@ class TestZones(DNSTest):
 class TestRPCRoundtrip(DNSTest):
     def setUp(self):
         super(TestRPCRoundtrip, self).setUp()
+        global server, server_ip, lp, creds
+        self.server = server_name
+        self.server_ip = server_ip
+        self.lp = lp
+        self.creds = creds
         self.rpc_conn = dnsserver.dnsserver("ncacn_ip_tcp:%s[sign]" % (self.server_ip),
                                             self.lp, self.creds)
 
@@ -1110,7 +1168,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'pad1textrec', ['"This is a test"', '', '']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1120,7 +1178,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad2textrec', ['"This is a test"', '', '', 'more text']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1130,7 +1188,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'pad3textrec', ['', '', '"This is a test"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1209,7 +1267,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'nulltextrec', ['NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1219,7 +1277,7 @@ class TestRPCRoundtrip(DNSTest):
 
         prefix, txt = 'nulltextrec2', ['NULL\x00BYTE', 'NULL\x00BYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, ['NULL', 'NULL'])
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1254,7 +1312,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'hextextrec', ['HIGH\xFFBYTE']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1289,7 +1347,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding records works"
         prefix, txt = 'slashtextrec', ['Th\\=is=is a test']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1329,7 +1387,7 @@ class TestRPCRoundtrip(DNSTest):
         prefix, txt = 'textrec2', ['"This is a test"',
                                    '"and this is a test, too"']
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
@@ -1368,7 +1426,7 @@ class TestRPCRoundtrip(DNSTest):
         "test adding two txt records works"
         prefix, txt = 'emptytextrec', []
         p = self.make_txt_update(prefix, txt)
-        response = self.dns_transaction_udp(p)
+        (response, response_packet) = self.dns_transaction_udp(p, host=server_ip)
         self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
         self.check_query_txt(prefix, txt)
         self.assertIsNotNone(dns_record_match(self.rpc_conn, self.server_ip,
index f424e079eb8360a1d8506e1884af9ac2f224bf5f..21f5d560dd29fa540080b80c774c223732d34106 100644 (file)
@@ -29,6 +29,7 @@ from samba import credentials
 from samba.dcerpc import dns, dnsp
 from samba.tests.subunitrun import SubunitOptions, TestProgram
 from samba import gensec, tests
+from samba.tests import TestCase
 
 parser = optparse.OptionParser("dns.py <server name> <server ip> [options]")
 sambaopts = options.SambaOptions(parser)
@@ -56,21 +57,11 @@ server_name = args[0]
 server_ip = args[1]
 
 
-class DNSTest(tests.TestCase):
+class DNSTest(TestCase):
+
     def setUp(self):
         super(DNSTest, self).setUp()
-        self.server = server_name
-        self.server_ip = server_ip
-        self.settings = {}
-        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
-        self.settings["target_hostname"] = self.server
-
-        self.creds = credentials.Credentials()
-        self.creds.guess(self.lp_ctx)
-        self.creds.set_username(tests.env_get_var_value('USERNAME'))
-        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
-        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
-        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+        self.timeout = None
 
     def errstr(self, errcode):
         "Return a readable error code"
@@ -150,9 +141,11 @@ class DNSTest(tests.TestCase):
         return self.creds.get_realm().lower()
 
     def dns_transaction_udp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -171,9 +164,11 @@ class DNSTest(tests.TestCase):
                 s.close()
 
     def dns_transaction_tcp(self, packet, host,
-                            dump=False, timeout=timeout):
+                            dump=False, timeout=None):
         "send a DNS query and read the reply, also return the raw packet"
         s = None
+        if timeout is None:
+            timeout = self.timeout
         try:
             send_packet = ndr.ndr_pack(packet)
             if dump:
@@ -200,6 +195,61 @@ class DNSTest(tests.TestCase):
 
         return (response, recv_packet[2:])
 
+    def make_txt_update(self, prefix, txt_array):
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        updates = []
+
+        name = self.get_dns_domain()
+        u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
+        updates.append(u)
+        self.finish_name_packet(p, updates)
+
+        updates = []
+        r = dns.res_rec()
+        r.name = "%s.%s" % (prefix, self.get_dns_domain())
+        r.rr_type = dns.DNS_QTYPE_TXT
+        r.rr_class = dns.DNS_QCLASS_IN
+        r.ttl = 900
+        r.length = 0xffff
+        rdata = self.make_txt_record(txt_array)
+        r.rdata = rdata
+        updates.append(r)
+        p.nscount = len(updates)
+        p.nsrecs = updates
+
+        return p
+
+    def check_query_txt(self, prefix, txt_array):
+        name = "%s.%s" % (prefix, self.get_dns_domain())
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        questions = []
+
+        q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
+        questions.append(q)
+
+        self.finish_name_packet(p, questions)
+        (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
+        self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
+        self.assertEquals(response.ancount, 1)
+        self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
+
+
+class DNSTKeyTest(DNSTest):
+    def setUp(self):
+        super(DNSTKeyTest, self).setUp()
+        self.server = server_name
+        self.server_ip = server_ip
+        self.settings = {}
+        self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
+        self.settings["target_hostname"] = self.server
+
+        self.creds = credentials.Credentials()
+        self.creds.guess(self.lp_ctx)
+        self.creds.set_username(tests.env_get_var_value('USERNAME'))
+        self.creds.set_password(tests.env_get_var_value('PASSWORD'))
+        self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
+        self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
+
     def tkey_trans(self):
         "Do a TKEY transaction and establish a gensec context"
 
@@ -410,7 +460,7 @@ class DNSTest(tests.TestCase):
         return p
 
 
-class TestDNSUpdates(DNSTest):
+class TestDNSUpdates(DNSTKeyTest):
     def test_tkey(self):
         "test DNS TKEY handshake"