From d33bc6bcb6572e23ff3bfb7091a7a1046bac7e60 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Tue, 1 Sep 2020 06:08:43 -0700 Subject: [PATCH] check for TTL type errors in rdataset/rrset from_text; allow text-form TTLs there. --- dns/rdataset.py | 4 ++-- dns/ttl.py | 9 +++++++++ tests/test_rdata.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/dns/rdataset.py b/dns/rdataset.py index 2a42e424..2e3b4d4a 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -80,9 +80,9 @@ class Rdataset(dns.set.Set): TTL or the specified TTL. If the set contains no rdatas, set the TTL to the specified TTL. - *ttl*, an ``int``. + *ttl*, an ``int`` or ``str``. """ - + ttl = dns.ttl.make(ttl) if len(self) == 0: self.ttl = ttl elif ttl < self.ttl: diff --git a/dns/ttl.py b/dns/ttl.py index 1a2aaeb7..8ea52135 100644 --- a/dns/ttl.py +++ b/dns/ttl.py @@ -73,3 +73,12 @@ def from_text(text): if total < 0 or total > MAX_TTL: raise BadTTL("TTL should be between 0 and 2^31 - 1 (inclusive)") return total + + +def make(value): + if isinstance(value, int): + return value + elif isinstance(value, str): + return dns.ttl.from_text(value) + else: + raise ValueError('cannot convert value to TTL') diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 40714037..66ed67cc 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -35,6 +35,7 @@ from dns.rdtypes.ANY.GPOS import GPOS import dns.rdtypes.ANY.RRSIG import dns.rdtypes.util import dns.tokenizer +import dns.ttl import dns.wire import tests.stxt_module @@ -743,5 +744,16 @@ class UtilTestCase(unittest.TestCase): dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.MX, r'\# 4 000aC000') + def test_rdataset_ttl_conversion(self): + rds1 = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1') + self.assertEqual(rds1.ttl, 300) + rds2 = dns.rdataset.from_text('in', 'a', '5m', '10.0.0.1') + self.assertEqual(rds2.ttl, 300) + with self.assertRaises(ValueError): + dns.rdataset.from_text('in', 'a', 1.6, '10.0.0.1') + with self.assertRaises(dns.ttl.BadTTL): + dns.rdataset.from_text('in', 'a', '10.0.0.1', '10.0.0.2') + + if __name__ == '__main__': unittest.main() -- 2.47.3