]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
pytest: adjust dns_aging to handle some non-TXT records
authorDouglas Bagnall <douglas.bagnall@catalyst.net.nz>
Thu, 10 Jun 2021 23:29:15 +0000 (23:29 +0000)
committerAndrew Bartlett <abartlet@samba.org>
Sun, 20 Jun 2021 23:26:32 +0000 (23:26 +0000)
Signed-off-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/dns_aging.py

index deb43687c45f4b8696fc22f76abca593d36b266d..47d3ba7351f98633384888cc587727e9c1777f2f 100644 (file)
@@ -130,6 +130,12 @@ def txt_s_list(txt):
     return s_list
 
 
+def make_txt_record(txt):
+    r = dns.txt_record()
+    r.txt = txt_s_list(txt)
+    return r
+
+
 def copy_rec(rec):
     copy = dnsserver.DNS_RPC_RECORD()
     copy.wType = rec.wType
@@ -141,6 +147,15 @@ def copy_rec(rec):
     return copy
 
 
+def guess_wtype(data):
+    if isinstance(data, list):
+        data = make_txt_record(data)
+        return (data, dnsp.DNS_TYPE_TXT)
+    if ":" in data:
+        return (data, dnsp.DNS_TYPE_AAAA)
+    return (data, dnsp.DNS_TYPE_A)
+
+
 class TestDNSAging(DNSTest):
     """Probe DNS aging and scavenging, using LDAP and RPC to set and test
     the timestamps behind DNS's back."""
@@ -247,6 +262,70 @@ class TestDNSAging(DNSTest):
                 match = r
         return match
 
+    def get_unique_ip_record(self, name, addr, wtype=None):
+        """Get an A or AAAA record on name with the matching data."""
+        if wtype is None:
+            wtype = guess_wtype(addr)
+
+        recs = self.ldap_get_records(name)
+
+        # We need to use the internal dns_record_match because not all
+        # forms always match on strings (e.g. IPv6)
+        rec = dnsp.DnssrvRpcRecord()
+        rec.wType = wtype
+        rec.data = addr
+
+        match = None
+        for r in recs:
+            if dsdb_dns.records_match(r, rec):
+                self.assertIsNone(match)
+                match = r
+        return match
+
+    def dns_query(self, name, qtype=dns.DNS_QTYPE_ALL):
+        """make a query, which might help Windows notice LDAP changes"""
+        p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
+        fullname = "%s.%s" % (name, self.zone)
+        q = self.make_name_question(fullname, qtype, dns.DNS_QCLASS_IN)
+        self.finish_name_packet(p, [q])
+        r, rp = self.dns_transaction_udp(p, host=SERVER_IP)
+
+        return r
+
+    def dns_update_non_text(self, name,
+                            data,
+                            wtype=None,
+                            qclass=dns.DNS_QCLASS_IN):
+        if wtype is None:
+            data, wtype = guess_wtype(data)
+
+        if qclass == dns.DNS_QCLASS_IN:
+            ttl = 123
+        else:
+            ttl = 0
+
+        fullname = "%s.%s" % (name, self.zone)
+        p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
+        u = self.make_name_question(self.zone,
+                                    dns.DNS_QTYPE_SOA,
+                                    dns.DNS_QCLASS_IN)
+        self.finish_name_packet(p, [u])
+
+        r = dns.res_rec()
+        r.name = fullname
+        r.rr_type = wtype
+        r.rr_class = qclass
+        r.ttl = ttl
+        r.length = 0xffff
+        r.rdata = data
+
+        p.nscount = 1
+        p.nsrecs = [r]
+
+        (code, response) = self.dns_transaction_udp(p, host=SERVER_IP)
+        self.assert_dns_rcode_equals(code, dns.DNS_RCODE_OK)
+        return response
+
     def dns_update_record(self, name, txt, ttl=900):
         if isinstance(txt, str):
             txt = [txt]
@@ -362,21 +441,18 @@ class TestDNSAging(DNSTest):
             msg["dnsRecord"].set_flags(ldb.FLAG_MOD_ADD)
             self.samdb.add(msg)
 
-    def ldap_update_record(self, name, txt, **kwargs):
-        """Add the record that self.dns_update_record() would add, via ldap,
-        thus allowing us to set additional dnsRecord features like
-        dwTimestamp.
-        """
+    def ldap_update_core(self, name, wtype, data, **kwargs):
+        """This one is not TXT specific."""
         records = self.ldap_get_records(name)
 
         # default values
         rec = dnsp.DnssrvRpcRecord()
-        rec.wType = dnsp.DNS_TYPE_TXT
+        rec.wType = wtype
         rec.rank = dnsp.DNS_RANK_ZONE
         rec.dwTtlSeconds = 900
         rec.dwSerial = 110
         rec.dwTimeStamp = 0
-        rec.data = txt_s_list(txt)
+        rec.data = data
 
         # override defaults, as required
         for k, v in kwargs.items():
@@ -390,6 +466,17 @@ class TestDNSAging(DNSTest):
             records.append(rec)
 
         self.ldap_replace_records(name, records)
+        return rec
+
+    def ldap_update_record(self, name, txt, **kwargs):
+        """Add the record that self.dns_update_record() would add, via ldap,
+        thus allowing us to set additional dnsRecord features like
+        dwTimestamp.
+        """
+        rec = self.ldap_update_core(name,
+                                    dnsp.DNS_TYPE_TXT,
+                                    txt_s_list(txt),
+                                    **kwargs)
 
         recs = self.ldap_get_records(name)
         match = None
@@ -404,20 +491,42 @@ class TestDNSAging(DNSTest):
         self.assertEqual(match.dwTimeStamp, rec.dwTimeStamp)
         return match
 
-    def ldap_delete_record(self, name, txt):
+    def ldap_delete_record(self, name, data, wtype=dnsp.DNS_TYPE_TXT):
         rec = dnsp.DnssrvRpcRecord()
-        rec.wType = dnsp.DNS_TYPE_TXT
-        rec.data = txt_s_list(txt)
+        if wtype == dnsp.DNS_TYPE_TXT:
+            data = txt_s_list(data)
+
+        rec.wType = wtype
+        rec.data = data
         records = self.ldap_get_records(name)
         for i, r in enumerate(records[:]):
             if dsdb_dns.records_match(r, rec):
                 del records[i]
                 break
         else:
-            self.fail(f"record {txt} not found")
+            self.fail(f"record {data} not found")
 
         self.ldap_replace_records(name, records)
 
+    def add_ip_record(self, name, addr, wtype=None, **kwargs):
+        if wtype is None:
+            addr, wtype = guess_wtype(addr)
+        rec = self.ldap_update_core(name,
+                                    wtype,
+                                    addr,
+                                    **kwargs)
+
+        recs = self.ldap_get_records(name)
+        match = None
+        for r in recs:
+            if dsdb_dns.records_match(r, rec):
+                self.assertIsNone(match, f"duplicate records for {name}")
+                match = r
+        self.assertEqual(match.rank, rec.rank & 255)
+        self.assertEqual(match.dwTtlSeconds, rec.dwTtlSeconds)
+        self.assertEqual(match.dwTimeStamp, rec.dwTimeStamp)
+        return match
+
     def ldap_modify_timestamps(self, name, delta):
         records = self.ldap_get_records(name)
         for rec in records: