]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
rrset-reader PR
authorBob Halley <halley@dnspython.org>
Sun, 31 Oct 2021 21:16:36 +0000 (14:16 -0700)
committerBob Halley <halley@dnspython.org>
Mon, 1 Nov 2021 16:12:17 +0000 (09:12 -0700)
dns/zonefile.py
tests/test_rrset_reader.py [new file with mode: 0644]

index 92e2f0cf0005ae1d2e8562b7fd82b6bd9597387b..d3b96566c610c0e363f3d75e32954ea1b9e9ddd5 100644 (file)
@@ -42,21 +42,35 @@ class Reader:
 
     """Read a DNS zone file into a transaction."""
 
-    def __init__(self, tok, rdclass, txn, allow_include=False):
+    def __init__(self, tok, rdclass, txn, allow_include=False,
+                 allow_directives=True, force_name=None,
+                 force_ttl=None, force_rdclass=None, force_rdtype=None,
+                 default_ttl=None):
         self.tok = tok
         (self.zone_origin, self.relativize, _) = \
             txn.manager.origin_information()
         self.current_origin = self.zone_origin
         self.last_ttl = 0
         self.last_ttl_known = False
-        self.default_ttl = 0
-        self.default_ttl_known = False
+        if force_ttl is not None:
+            default_ttl = force_ttl
+        if default_ttl is None:
+            self.default_ttl = 0
+            self.default_ttl_known = False
+        else:
+            self.default_ttl = default_ttl
+            self.default_ttl_known = True
         self.last_name = self.current_origin
         self.zone_rdclass = rdclass
         self.txn = txn
         self.saved_state = []
         self.current_file = None
         self.allow_include = allow_include
+        self.allow_directives = allow_directives
+        self.force_name = force_name
+        self.force_ttl = force_ttl
+        self.force_rdclass = force_rdclass
+        self.force_rdtype = force_rdtype
 
     def _eat_line(self):
         while 1:
@@ -64,63 +78,85 @@ class Reader:
             if token.is_eol_or_eof():
                 break
 
+    def _get_identifier(self):
+        token = self.tok.get()
+        if not token.is_identifier():
+            raise dns.exception.SyntaxError
+        return token
+
     def _rr_line(self):
         """Process one line from a DNS zone file."""
+        token = None
         # Name
-        if self.current_origin is None:
-            raise UnknownOrigin
-        token = self.tok.get(want_leading=True)
-        if not token.is_whitespace():
-            self.last_name = self.tok.as_name(token, self.current_origin)
+        if self.force_name is not None:
+            name = self.force_name
         else:
-            token = self.tok.get()
-            if token.is_eol_or_eof():
-                # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+            if self.current_origin is None:
+                raise UnknownOrigin
+            token = self.tok.get(want_leading=True)
+            if not token.is_whitespace():
+                self.last_name = self.tok.as_name(token, self.current_origin)
+            else:
+                token = self.tok.get()
+                if token.is_eol_or_eof():
+                    # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+                    return
+                self.tok.unget(token)
+            name = self.last_name
+            if not name.is_subdomain(self.zone_origin):
+                self._eat_line()
                 return
-            self.tok.unget(token)
-        name = self.last_name
-        if not name.is_subdomain(self.zone_origin):
-            self._eat_line()
-            return
-        if self.relativize:
-            name = name.relativize(self.zone_origin)
-        token = self.tok.get()
-        if not token.is_identifier():
-            raise dns.exception.SyntaxError
+            if self.relativize:
+                name = name.relativize(self.zone_origin)
 
         # TTL
-        ttl = None
-        try:
-            ttl = dns.ttl.from_text(token.value)
+        if self.force_ttl is not None:
+            ttl = self.force_ttl
             self.last_ttl = ttl
             self.last_ttl_known = True
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.ttl.BadTTL:
-            if self.default_ttl_known:
-                ttl = self.default_ttl
-            elif self.last_ttl_known:
-                ttl = self.last_ttl
+        else:
+            token = self._get_identifier()
+            ttl = None
+            try:
+                ttl = dns.ttl.from_text(token.value)
+                self.last_ttl = ttl
+                self.last_ttl_known = True
+                token = None
+            except dns.ttl.BadTTL:
+                if self.default_ttl_known:
+                    ttl = self.default_ttl
+                elif self.last_ttl_known:
+                    ttl = self.last_ttl
+                self.tok.unget(token)
 
         # Class
-        try:
-            rdclass = dns.rdataclass.from_text(token.value)
-            token = self.tok.get()
-            if not token.is_identifier():
-                raise dns.exception.SyntaxError
-        except dns.exception.SyntaxError:
-            raise
-        except Exception:
-            rdclass = self.zone_rdclass
-        if rdclass != self.zone_rdclass:
-            raise dns.exception.SyntaxError("RR class is not zone's class")
+        if self.force_rdclass is not None:
+            rdclass = self.force_rdclass
+        else:
+            token = self._get_identifier()
+            try:
+                rdclass = dns.rdataclass.from_text(token.value)
+            except dns.exception.SyntaxError:
+                raise
+            except Exception:
+                rdclass = self.zone_rdclass
+                self.tok.unget(token)
+            if rdclass != self.zone_rdclass:
+                raise dns.exception.SyntaxError("RR class is not zone's class")
+
         # Type
-        try:
-            rdtype = dns.rdatatype.from_text(token.value)
-        except Exception:
-            raise dns.exception.SyntaxError(
-                "unknown rdatatype '%s'" % token.value)
+        if self.force_rdtype is not None:
+            rdtype = self.force_rdtype
+            # we need to unget the token we got, as there is always one
+            # outstanding at this point
+        else:
+            token = self._get_identifier()
+            try:
+                rdtype = dns.rdatatype.from_text(token.value)
+            except Exception:
+                raise dns.exception.SyntaxError(
+                    "unknown rdatatype '%s'" % token.value)
+
         try:
             rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
                                      self.current_origin, self.relativize,
@@ -341,7 +377,7 @@ class Reader:
                 elif token.is_comment():
                     self.tok.get_eol()
                     continue
-                elif token.value[0] == '$':
+                elif token.value[0] == '$' and self.allow_directives:
                     c = token.value.upper()
                     if c == '$TTL':
                         token = self.tok.get()
@@ -399,3 +435,109 @@ class Reader:
                 "%s:%d: %s" % (filename, line_number, detail))
             tb = sys.exc_info()[2]
             raise ex.with_traceback(tb) from None
+
+
+class RRsetsReaderTransaction(dns.transaction.Transaction):
+
+    def __init__(self, manager, replacement, read_only):
+        assert not read_only
+        super().__init__(manager, replacement, read_only)
+        self.rdatasets = {}
+
+    def _get_rdataset(self, name, rdtype, covers):
+        return self.rdatasets.get((name, rdtype, covers))
+
+    def _put_rdataset(self, name, rdataset):
+        self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+
+    def _delete_name(self, name):
+        # First remove any changes involving the name
+        remove = []
+        for key in self.rdatasets:
+            if key[0] == name:
+                remove.append(key)
+        if len(remove) > 0:
+            for key in remove:
+                del self.rdatasets[key]
+
+    def _delete_rdataset(self, name, rdtype, covers):
+        try:
+            del self.rdatasets[(name, rdtype, covers)]
+        except KeyError:
+            pass
+
+    def _name_exists(self, name):
+        for (n, _, _) in self.rdatasets:
+            if n == name:
+                return True
+        return False
+
+    def _changed(self):
+        return len(self.rdatasets) > 0
+
+    def _end_transaction(self, commit):
+        if commit and self._changed():
+            rrsets = []
+            for (name, _, _), rdataset in self.rdatasets.items():
+                rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype,
+                                        rdataset.covers)
+                rrset.update(rdataset)
+                rrsets.append(rrset)
+            self.manager.set_rrsets(rrsets)
+
+    def _set_origin(self, origin):
+        pass
+
+
+class RRSetsReaderManager(dns.transaction.TransactionManager):
+    def __init__(self, origin=dns.name.root, relativize=False,
+                 rdclass=dns.rdataclass.IN):
+        self.origin = origin
+        self.relativize = relativize
+        self.rdclass = rdclass
+        self.rrsets = []
+
+    def writer(self, replacement=False):
+        assert replacement == True
+        return RRsetsReaderTransaction(self, True, False)
+
+    def get_class(self):
+        return self.rdclass
+
+    def origin_information(self):
+        if self.relativize:
+            effective = dns.name.empty
+        else:
+            effective = self.origin
+        return (self.origin, self.relativize, effective)
+
+    def set_rrsets(self, rrsets):
+        self.rrsets = rrsets
+
+
+def read_rrsets(text, name=None, ttl=None, rdclass=dns.rdataclass.IN,
+                default_rdclass=dns.rdataclass.IN,
+                rdtype=None, default_ttl=None, idna_codec=None,
+                origin=dns.name.root, relativize=False):
+    if isinstance(origin, str):
+        origin = dns.name.from_text(origin, dns.name.root, idna_codec)
+    if isinstance(name, str):
+        name = dns.name.from_text(name, origin, idna_codec)
+    if isinstance(ttl, str):
+        ttl = dns.ttl.from_text(ttl)
+    if isinstance(default_ttl, str):
+        default_ttl = dns.ttl.from_text(default_ttl)
+    if rdclass is not None:
+        rdclass = dns.rdataclass.RdataClass.make(rdclass)
+    default_rdclass = dns.rdataclass.RdataClass.make(default_rdclass)
+    if rdtype is not None:
+        rdtype = dns.rdatatype.RdataType.make(rdtype)
+    manager = RRSetsReaderManager(origin, relativize, default_rdclass)
+    with manager.writer(True) as txn:
+        tok = dns.tokenizer.Tokenizer(text, '<input>', idna_codec=idna_codec)
+        reader = Reader(tok, default_rdclass, txn, allow_directives=False,
+                        force_name=name, force_ttl=ttl, force_rdclass=rdclass,
+                        force_rdtype=rdtype, default_ttl=default_ttl)
+        reader.read()
+    return manager.rrsets
+
diff --git a/tests/test_rrset_reader.py b/tests/test_rrset_reader.py
new file mode 100644 (file)
index 0000000..8d4255e
--- /dev/null
@@ -0,0 +1,131 @@
+import pytest
+
+import dns.rrset
+from dns.zonefile import read_rrsets
+
+expected_mx_1= dns.rrset.from_text('name.', 300, 'in', 'mx', '10 a.', '20 b.')
+expected_mx_2 = dns.rrset.from_text('name.', 10, 'in', 'mx', '10 a.', '20 b.')
+expected_mx_3 = dns.rrset.from_text('foo.', 10, 'in', 'mx', '10 a.')
+expected_mx_4 = dns.rrset.from_text('bar.', 10, 'in', 'mx', '20 b.')
+expected_mx_5 = dns.rrset.from_text('foo.example.', 10, 'in', 'mx',
+                                    '10 a.example.')
+expected_mx_6 = dns.rrset.from_text('bar.example.', 10, 'in', 'mx', '20 b.')
+expected_mx_7 = dns.rrset.from_text('foo', 10, 'in', 'mx', '10 a')
+expected_mx_8 = dns.rrset.from_text('bar', 10, 'in', 'mx', '20 b.')
+expected_ns_1 = dns.rrset.from_text('name.', 300, 'in', 'ns', 'hi.')
+expected_ns_2 = dns.rrset.from_text('name.', 300, 'ch', 'ns', 'hi.')
+
+def equal_rrsets(a, b):
+    # return True iff. a and b have the same rrsets regardless of order
+    if len(a) != len(b):
+        return False
+    for rrset in a:
+        if not rrset in b:
+            return False
+    return True
+
+def test_name_ttl_rdclass_forced():
+    input=''';
+mx 10 a
+mx 20 b.
+ns hi'''
+    rrsets = read_rrsets(input, name='name', ttl=300)
+    assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1])
+    assert rrsets[0].ttl == 300
+    assert rrsets[1].ttl == 300
+
+def test_name_ttl_rdclass_forced_rdata_split():
+    input=''';
+mx 10 a
+ns hi
+mx 20 b.'''
+    rrsets = read_rrsets(input, name='name', ttl=300)
+    assert equal_rrsets(rrsets, [expected_mx_1, expected_ns_1])
+
+def test_name_ttl_rdclass_rdtype_forced():
+    input=''';
+10 a
+20 b.'''
+    rrsets = read_rrsets(input, name='name', ttl=300, rdtype='mx')
+    assert equal_rrsets(rrsets, [expected_mx_1])
+
+def test_name_rdclass_forced():
+    input = '''30 mx 10 a
+10 mx 20 b.
+'''
+    rrsets = read_rrsets(input, name='name')
+    assert equal_rrsets(rrsets, [expected_mx_2])
+    assert rrsets[0].ttl == 10
+
+def test_rdclass_forced():
+    input = ''';
+foo 20 mx 10 a
+bar 30 mx 20 b.
+'''
+    rrsets = read_rrsets(input)
+    assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4])
+
+def test_rdclass_forced_with_origin():
+    input = ''';
+foo 20 mx 10 a
+bar.example. 30 mx 20 b.
+'''
+    rrsets = read_rrsets(input, origin='example')
+    assert equal_rrsets(rrsets, [expected_mx_5, expected_mx_6])
+
+
+def test_rdclass_forced_with_origin_relativized():
+    input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 mx 20 b.
+'''
+    rrsets = read_rrsets(input, origin='example', relativize=True)
+    assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8])
+
+def test_rdclass_matching_default_tolerated():
+    input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 in mx 20 b.
+'''
+    rrsets = read_rrsets(input, origin='example', relativize=True,
+                         rdclass=None)
+    assert equal_rrsets(rrsets, [expected_mx_7, expected_mx_8])
+
+def test_rdclass_not_matching_default_rejected():
+    input = ''';
+foo 20 mx 10 a.example.
+bar.example. 30 ch mx 20 b.
+'''
+    with pytest.raises(dns.exception.SyntaxError):
+        rrsets = read_rrsets(input, origin='example', relativize=True,
+                             rdclass=None)
+
+def test_default_rdclass_is_none():
+    input = ''
+    with pytest.raises(TypeError):
+        rrsets = read_rrsets(input, default_rdclass=None, origin='example',
+                             relativize=True)
+
+def test_name_rdclass_rdtype_force():
+    # No real-world usage should do this, but it can be specified so we test it.
+    input = ''';
+30 10 a
+10 20 b.
+'''
+    rrsets = read_rrsets(input, name='name', rdtype='mx')
+    assert equal_rrsets(rrsets, [expected_mx_1])
+    assert rrsets[0].ttl == 10
+
+def test_rdclass_rdtype_force():
+    # No real-world usage should do this, but it can be specified so we test it.
+    input = ''';
+foo 30 10 a
+bar 30 20 b.
+'''
+    rrsets = read_rrsets(input, rdtype='mx')
+    assert equal_rrsets(rrsets, [expected_mx_3, expected_mx_4])
+
+# also weird but legal
+#input5 = '''foo 30 10 a
+#bar 10 20 foo.
+#'''