]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
allow zonefile directives to be specified explicitly
authorBob Halley <halley@dnspython.org>
Sun, 14 Aug 2022 17:33:50 +0000 (10:33 -0700)
committerBob Halley <halley@dnspython.org>
Sun, 14 Aug 2022 17:33:50 +0000 (10:33 -0700)
dns/zone.py
dns/zonefile.py
tests/test_zone.py

index f5f0933b2900a52514555a50a0d31f46114453ed..248800d71fd0438ff99c7e829a330d484cbfb3ec 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS Zones."""
 
-from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, Iterator, Iterable, List, Optional, Set, Tuple, Union
 
 import contextlib
 import io
@@ -1175,6 +1175,7 @@ def from_text(
     allow_include: bool = False,
     check_origin: bool = True,
     idna_codec: Optional[dns.name.IDNACodec] = None,
+    allow_directives: Union[bool, Iterable[str]] = True,
 ) -> Zone:
     """Build a zone object from a zone file format string.
 
@@ -1209,6 +1210,13 @@ def from_text(
     encoder/decoder.  If ``None``, the default IDNA 2003 encoder/decoder
     is used.
 
+    *allow_directives*, a ``bool`` or an iteratable of `str`.  If ``True``, the default,
+    then directives are permitted, and the *allow_include* parameter controls whether
+    ``$INCLUDE`` is permitted.  If ``False`` or an empty iterable, then no directive
+    processing is done and any directive-like text will be treated as a regular owner
+    name.  If a non-empty iterable, then only the listed directives (including the
+    ``$``) are allowed.
+
     Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
 
     Raises ``dns.zone.NoNS`` if there is no NS RRset.
@@ -1227,7 +1235,13 @@ def from_text(
     zone = zone_factory(origin, rdclass, relativize=relativize)
     with zone.writer(True) as txn:
         tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
-        reader = dns.zonefile.Reader(tok, rdclass, txn, allow_include=allow_include)
+        reader = dns.zonefile.Reader(
+            tok,
+            rdclass,
+            txn,
+            allow_include=allow_include,
+            allow_directives=allow_directives,
+        )
         try:
             reader.read()
         except dns.zonefile.UnknownOrigin:
@@ -1249,6 +1263,7 @@ def from_file(
     allow_include: bool = True,
     check_origin: bool = True,
     idna_codec: Optional[dns.name.IDNACodec] = None,
+    allow_directives: Union[bool, Iterable[str]] = True,
 ) -> Zone:
     """Read a zone file and build a zone object.
 
@@ -1283,6 +1298,13 @@ def from_file(
     encoder/decoder.  If ``None``, the default IDNA 2003 encoder/decoder
     is used.
 
+    *allow_directives*, a ``bool`` or an iteratable of `str`.  If ``True``, the default,
+    then directives are permitted, and the *allow_include* parameter controls whether
+    ``$INCLUDE`` is permitted.  If ``False`` or an empty iterable, then no directive
+    processing is done and any directive-like text will be treated as a regular owner
+    name.  If a non-empty iterable, then only the listed directives (including the
+    ``$``) are allowed.
+
     Raises ``dns.zone.NoSOA`` if there is no SOA RRset.
 
     Raises ``dns.zone.NoNS`` if there is no NS RRset.
@@ -1309,6 +1331,7 @@ def from_file(
             allow_include,
             check_origin,
             idna_codec,
+            allow_directives,
         )
     assert False  # make mypy happy  lgtm[py/unreachable-statement]
 
index 68c6314827db4c8cd673de36ead9128ad358d53b..3d3aa61d9b5df92f0813253c294e0d5a6a0df27e 100644 (file)
@@ -17,7 +17,7 @@
 
 """DNS Zones."""
 
-from typing import Any, List, Optional, Tuple, Union
+from typing import Any, Iterable, List, Optional, Set, Tuple, Union
 
 import re
 import sys
@@ -89,7 +89,7 @@ class Reader:
         rdclass: dns.rdataclass.RdataClass,
         txn: dns.transaction.Transaction,
         allow_include: bool = False,
-        allow_directives: bool = True,
+        allow_directives: Union[bool, Iterable[str]] = True,
         force_name: Optional[dns.name.Name] = None,
         force_ttl: Optional[int] = None,
         force_rdclass: Optional[dns.rdataclass.RdataClass] = None,
@@ -114,8 +114,19 @@ class Reader:
         self.txn = txn
         self.saved_state: List[SavedStateType] = []
         self.current_file: Optional[Any] = None
-        self.allow_include = allow_include
-        self.allow_directives = allow_directives
+        self.allowed_directives: Set[str]
+        if allow_directives is True:
+            self.allowed_directives = {"$GENERATE", "$ORIGIN", "$TTL"}
+            if allow_include:
+                self.allowed_directives.add("$INCLUDE")
+        elif allow_directives is False:
+            # allow_include was ignored in earlier releases if allow_directives was
+            # False, so we continue that.
+            self.allowed_directives = set()
+        else:
+            # Note that if directives are explicitly specified, then allow_include
+            # is ignored.
+            self.allowed_directives = set(d.upper() for d in allow_directives)
         self.force_name = force_name
         self.force_ttl = force_ttl
         self.force_rdclass = force_rdclass
@@ -283,13 +294,9 @@ class Reader:
         width = int(width)
 
         if sign not in ["+", "-"]:
-            raise dns.exception.SyntaxError(
-                "invalid offset sign %s" % sign
-            )
+            raise dns.exception.SyntaxError("invalid offset sign %s" % sign)
         if base not in ["d", "o", "x", "X", "n", "N"]:
-            raise dns.exception.SyntaxError(
-                "invalid type %s" % base
-            )
+            raise dns.exception.SyntaxError("invalid type %s" % base)
 
         return mod, sign, offset, width, base
 
@@ -457,8 +464,14 @@ class Reader:
                 elif token.is_comment():
                     self.tok.get_eol()
                     continue
-                elif token.value[0] == "$" and self.allow_directives:
+                elif token.value[0] == "$" and len(self.allowed_directives) > 0:
+                    # Note that we only run directive processing code if at least
+                    # one directive is allowed in order to be backwards compatible
                     c = token.value.upper()
+                    if not c in self.allowed_directives:
+                        raise dns.exception.SyntaxError(
+                            f"zone file directive '{c}' is not allowed"
+                        )
                     if c == "$TTL":
                         token = self.tok.get()
                         if not token.is_identifier():
@@ -472,7 +485,7 @@ class Reader:
                         if self.zone_origin is None:
                             self.zone_origin = self.current_origin
                         self.txn._set_origin(self.current_origin)
-                    elif c == "$INCLUDE" and self.allow_include:
+                    elif c == "$INCLUDE":
                         token = self.tok.get()
                         filename = token.value
                         token = self.tok.get()
@@ -505,7 +518,7 @@ class Reader:
                         self._generate_line()
                     else:
                         raise dns.exception.SyntaxError(
-                            "Unknown zone file directive '" + c + "'"
+                            f"Unknown zone file directive '{c}'"
                         )
                     continue
                 self.tok.unget(token)
index 2d10274f7c3b6efaba30fe1872b6cefcbbc32c5a..5dfa6ff00079d77a0d172f74786fd074ac8723c7 100644 (file)
@@ -854,12 +854,75 @@ class ZoneTestCase(unittest.TestCase):
         z2 = dns.zone.from_file(here("example"), "example.", relativize=True)
         self.assertEqual(z1, z2)
 
+    def testNoInclude(self):
+        def bad():
+            dns.zone.from_text(
+                include_text, "example.", relativize=True, allow_include=False
+            )
+
+        self.assertRaises(dns.exception.SyntaxError, bad)
+
+    def testExplicitInclude(self):
+        z1 = dns.zone.from_text(
+            include_text,
+            "example.",
+            relativize=True,
+            allow_directives={"$INCLUDE", "$ORIGIN", "$TTL"},
+        )
+        z2 = dns.zone.from_file(here("example"), "example.", relativize=True)
+        self.assertEqual(z1, z2)
+
+    def testExplicitLowerCase(self):
+        z1 = dns.zone.from_text(
+            include_text,
+            "example.",
+            relativize=True,
+            allow_directives={"$include", "$origin", "$ttl"},
+        )
+        z2 = dns.zone.from_file(here("example"), "example.", relativize=True)
+        self.assertEqual(z1, z2)
+
+    def testExplicitWithoutInclude1(self):
+        def bad():
+            dns.zone.from_text(
+                include_text,
+                "example.",
+                relativize=True,
+                allow_include=False,
+                allow_directives={"$ORIGIN", "$TTL"},
+            )
+
+        self.assertRaises(dns.exception.SyntaxError, bad)
+
+    def testExplicitWithoutInclude2(self):
+        def bad():
+            dns.zone.from_text(
+                include_text,
+                "example.",
+                relativize=True,
+                allow_include=True,
+                allow_directives={"$ORIGIN", "$TTL"},
+            )
+
+        self.assertRaises(dns.exception.SyntaxError, bad)
+
     def testBadDirective(self):
         def bad():
             dns.zone.from_text(bad_directive_text, "example.", relativize=True)
 
         self.assertRaises(dns.exception.SyntaxError, bad)
 
+    def testAllowedButNotImplementedDirective(self):
+        def bad():
+            dns.zone.from_text(
+                bad_directive_text,
+                "example.",
+                relativize=True,
+                allow_directives={"$FOO", "$ORIGIN"},
+            )
+
+        self.assertRaises(dns.exception.SyntaxError, bad)
+
     def testFirstRRStartsWithWhitespace(self):
         # no name is specified, so default to the initial origin
         z = dns.zone.from_text(