]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Add Message.section_count(). (#1024)
authorBrian Wellington <bwelling@xbill.org>
Wed, 20 Dec 2023 22:59:32 +0000 (14:59 -0800)
committerGitHub <noreply@github.com>
Wed, 20 Dec 2023 22:59:32 +0000 (14:59 -0800)
Adds a method to return a count of the number of records in each
section.

dns/message.py
tests/test_message.py

index 4e226160c4831eec64f1618b27540766396371ef..c186457b57826fe0acf9f692aa1a2b010b6d4ff5 100644 (file)
@@ -489,6 +489,34 @@ class Message:
             rrset = None
         return rrset
 
+    def section_count(self, section: SectionType) -> int:
+        """Returns the number of records in the specified section.
+
+        *section*, an ``int`` section number, a ``str`` section name, or one of
+        the section attributes of this message.  This specifies the
+        the section of the message to count.  For example::
+
+            my_message.section_count(my_message.answer)
+            my_message.section_count(dns.message.ANSWER)
+            my_message.section_count("ANSWER")
+        """
+
+        if isinstance(section, int):
+            section_number = section
+            section = self.section_from_number(section_number)
+        elif isinstance(section, str):
+            section_number = MessageSection.from_text(section)
+            section = self.section_from_number(section_number)
+        else:
+            section_number = self.section_number(section)
+        count = sum(max(1, len(rrs)) for rrs in section)
+        if section_number == MessageSection.ADDITIONAL:
+            if self.opt is not None:
+                count += 1
+            if self.tsig is not None:
+                count += 1
+        return count
+
     def _compute_opt_reserve(self) -> int:
         """Compute the size required for the OPT RR, padding excluded"""
         if not self.opt:
index 356b8a0c2cf9b5033b3965a3382cc087db67d66e..0ee53983de849f1f620c3b9f678c5def397dfdfa 100644 (file)
@@ -942,6 +942,38 @@ www.dnspython.org. 300 IN A 1.2.3.4
         self.assertEqual(r2.flags & dns.flags.TC, 0)
         self.assertEqual(len(r2.additional), 30)
 
+    def test_section_count(self):
+        a = dns.message.from_text(answer_text)
+        self.assertEqual(a.section_count(a.question), 1)
+        self.assertEqual(a.section_count(a.answer), 1)
+        self.assertEqual(a.section_count("authority"), 3)
+        self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 1)
+
+        a.use_edns()
+        a.use_tsig(dns.tsig.Key("foo.", b"abcd"))
+        self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 3)
+
+    def test_section_count_update(self):
+        update = dns.update.Update("example")
+        update.id = 1
+        # These each add 1 record to the prereq section
+        update.present("foo")
+        update.present("foo", "a")
+        update.present("bar", "a", "10.0.0.5")
+        update.absent("blaz2")
+        update.absent("blaz2", "a")
+        # This adds 3 records to the update section
+        update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2")
+        # These each add 1 record to the update section
+        update.add("bar", dns.rdataset.from_text(1, 1, 300, "10.0.0.3"))
+        update.delete("bar", "a", "10.0.0.4")
+        update.delete("blaz", "a")
+        update.delete("blaz2")
+
+        self.assertEqual(update.section_count(dns.update.UpdateSection.ZONE), 1)
+        self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5)
+        self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7)
+
 
 if __name__ == "__main__":
     unittest.main()