From: Brian Wellington Date: Wed, 20 Dec 2023 22:59:32 +0000 (-0800) Subject: Add Message.section_count(). (#1024) X-Git-Tag: v2.5.0rc1~8 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=63aa46c1a69912afb3cf814b6ca5c7c04f5ddf98;p=thirdparty%2Fdnspython.git Add Message.section_count(). (#1024) Adds a method to return a count of the number of records in each section. --- diff --git a/dns/message.py b/dns/message.py index 4e226160..c186457b 100644 --- a/dns/message.py +++ b/dns/message.py @@ -489,6 +489,34 @@ class Message: rrset = None return rrset + def section_count(self, section: SectionType) -> int: + """Returns the number of records in the specified section. + + *section*, an ``int`` section number, a ``str`` section name, or one of + the section attributes of this message. This specifies the + the section of the message to count. For example:: + + my_message.section_count(my_message.answer) + my_message.section_count(dns.message.ANSWER) + my_message.section_count("ANSWER") + """ + + if isinstance(section, int): + section_number = section + section = self.section_from_number(section_number) + elif isinstance(section, str): + section_number = MessageSection.from_text(section) + section = self.section_from_number(section_number) + else: + section_number = self.section_number(section) + count = sum(max(1, len(rrs)) for rrs in section) + if section_number == MessageSection.ADDITIONAL: + if self.opt is not None: + count += 1 + if self.tsig is not None: + count += 1 + return count + def _compute_opt_reserve(self) -> int: """Compute the size required for the OPT RR, padding excluded""" if not self.opt: diff --git a/tests/test_message.py b/tests/test_message.py index 356b8a0c..0ee53983 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -942,6 +942,38 @@ www.dnspython.org. 300 IN A 1.2.3.4 self.assertEqual(r2.flags & dns.flags.TC, 0) self.assertEqual(len(r2.additional), 30) + def test_section_count(self): + a = dns.message.from_text(answer_text) + self.assertEqual(a.section_count(a.question), 1) + self.assertEqual(a.section_count(a.answer), 1) + self.assertEqual(a.section_count("authority"), 3) + self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 1) + + a.use_edns() + a.use_tsig(dns.tsig.Key("foo.", b"abcd")) + self.assertEqual(a.section_count(dns.message.MessageSection.ADDITIONAL), 3) + + def test_section_count_update(self): + update = dns.update.Update("example") + update.id = 1 + # These each add 1 record to the prereq section + update.present("foo") + update.present("foo", "a") + update.present("bar", "a", "10.0.0.5") + update.absent("blaz2") + update.absent("blaz2", "a") + # This adds 3 records to the update section + update.replace("foo", 300, "a", "10.0.0.1", "10.0.0.2") + # These each add 1 record to the update section + update.add("bar", dns.rdataset.from_text(1, 1, 300, "10.0.0.3")) + update.delete("bar", "a", "10.0.0.4") + update.delete("blaz", "a") + update.delete("blaz2") + + self.assertEqual(update.section_count(dns.update.UpdateSection.ZONE), 1) + self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5) + self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7) + if __name__ == "__main__": unittest.main()