From: Bob Halley Date: Tue, 1 Sep 2020 13:08:43 +0000 (-0700) Subject: check for TTL type errors in rdataset/rrset from_text; allow text-form TTLs there. X-Git-Tag: v2.1.0rc1~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d33bc6bcb6572e23ff3bfb7091a7a1046bac7e60;p=thirdparty%2Fdnspython.git check for TTL type errors in rdataset/rrset from_text; allow text-form TTLs there. --- diff --git a/dns/rdataset.py b/dns/rdataset.py index 2a42e424..2e3b4d4a 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -80,9 +80,9 @@ class Rdataset(dns.set.Set): TTL or the specified TTL. If the set contains no rdatas, set the TTL to the specified TTL. - *ttl*, an ``int``. + *ttl*, an ``int`` or ``str``. """ - + ttl = dns.ttl.make(ttl) if len(self) == 0: self.ttl = ttl elif ttl < self.ttl: diff --git a/dns/ttl.py b/dns/ttl.py index 1a2aaeb7..8ea52135 100644 --- a/dns/ttl.py +++ b/dns/ttl.py @@ -73,3 +73,12 @@ def from_text(text): if total < 0 or total > MAX_TTL: raise BadTTL("TTL should be between 0 and 2^31 - 1 (inclusive)") return total + + +def make(value): + if isinstance(value, int): + return value + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError('cannot convert value to TTL') diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 40714037..66ed67cc 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -35,6 +35,7 @@ from dns.rdtypes.ANY.GPOS import GPOS import dns.rdtypes.ANY.RRSIG import dns.rdtypes.util import dns.tokenizer +import dns.ttl import dns.wire import tests.stxt_module @@ -743,5 +744,16 @@ class UtilTestCase(unittest.TestCase): dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, r'\# 4 000aC000') + def test_rdataset_ttl_conversion(self): + rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + self.assertEqual(rds1.ttl, 300) + rds2 = dns.rdataset.from_text('in', 'a', '5m', '10.0.0.1') + self.assertEqual(rds2.ttl, 300) + with self.assertRaises(ValueError): + dns.rdataset.from_text('in', 'a', 1.6, '10.0.0.1') + with self.assertRaises(dns.ttl.BadTTL): + dns.rdataset.from_text('in', 'a', '10.0.0.1', '10.0.0.2') + + if __name__ == '__main__': unittest.main()