]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
manager: datamodel: types: EscapedStr type
authorAleš Mrázek <ales.mrazek@nic.cz>
Tue, 20 Jun 2023 12:30:18 +0000 (14:30 +0200)
committerAleš Mrázek <ales.mrazek@nic.cz>
Thu, 13 Jul 2023 07:50:09 +0000 (09:50 +0200)
manager/knot_resolver_manager/datamodel/types/__init__.py
manager/knot_resolver_manager/datamodel/types/types.py
manager/tests/unit/datamodel/templates/test_types_render.py [new file with mode: 0644]
manager/tests/unit/datamodel/types/test_custom_types.py

index 0b708c4a967ef9ec9213b37b8f668eb2a51b93d7..d70d33326e8e03eeb4920f64aaac0a223a007dde 100644 (file)
@@ -3,6 +3,7 @@ from .files import AbsoluteDir, Dir, File, FilePath
 from .generic_types import ListOrItem
 from .types import (
     DomainName,
+    EscapedStr,
     IDPattern,
     Int0_512,
     Int0_65535,
@@ -31,6 +32,7 @@ __all__ = [
     "PolicyFlagEnum",
     "DNSRecordTypeEnum",
     "DomainName",
+    "EscapedStr",
     "IDPattern",
     "Int0_512",
     "Int0_65535",
index 2b409e7e563ec2f5cc1fff7ff8f935bf851ea865..a2ad074af27dc2208fd74b172cc0e042a7fbf3a1 100644 (file)
@@ -64,6 +64,36 @@ class TimeUnit(UnitBase):
         return self._base_value
 
 
+class EscapedStr(StrBase):
+    """
+    A string where escape sequences are ignored and quotes are escaped.
+    """
+
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        super().__init__(source_value, object_path)
+
+        escape = {
+            "'": r"\'",
+            '"': r"\"",
+            "\a": r"\a",
+            "\n": r"\n",
+            "\r": r"\r",
+            "\t": r"\t",
+            "\b": r"\b",
+            "\f": r"\f",
+            "\v": r"\v",
+            "\0": r"\0",
+        }
+
+        s = list(self._value)
+        for i, c in enumerate(self._value):
+            if c in escape:
+                s[i] = escape[c]
+            elif not c.isalnum():
+                s[i] = repr(c)[1:-1]
+        self._value = "".join(s)
+
+
 class DomainName(StrBase):
     """
     Fully or partially qualified domain name.
diff --git a/manager/tests/unit/datamodel/templates/test_types_render.py b/manager/tests/unit/datamodel/templates/test_types_render.py
new file mode 100644 (file)
index 0000000..15a0611
--- /dev/null
@@ -0,0 +1,32 @@
+from typing import Any
+
+import pytest
+from jinja2 import Template
+
+from knot_resolver_manager.datamodel.types import EscapedStr
+from knot_resolver_manager.utils.modeling import ConfigSchema
+
+str_template = Template("'{{ string }}'")
+
+
+@pytest.mark.parametrize(
+    "val,exp",
+    [
+        ("", ""),
+        ("string", "string"),
+        (2000, "2000"),
+        ('"\a\b\f\n\r\t\v\\"', r"\"\a\b\f\n\r\t\v\\\""),
+        ('""', r"\"\""),
+        ("''", r"\'\'"),
+        # fmt: off
+        ('\"\"', r'\"\"'),
+        ("\'\'", r'\'\''),
+        # fmt: on
+    ],
+)
+def test_escaped_str(val: Any, exp: str):
+    class TestSchema(ConfigSchema):
+        pattern: EscapedStr
+
+    d = TestSchema({"pattern": val})
+    assert str_template.render(string=d.pattern) == f"'{exp}'"
index 3e1f5c61d2e010bc7d31eb137219aec4eec9d758..ac95b79cf7bca20c97c48ce0b9f919fc0ea3ae67 100644 (file)
@@ -9,6 +9,7 @@ from pytest import raises
 from knot_resolver_manager.datamodel.types import (
     Dir,
     DomainName,
+    EscapedStr,
     InterfaceName,
     InterfaceOptionalPort,
     InterfacePort,
@@ -117,6 +118,34 @@ def test_pin_sha256_invalid(val: str):
         PinSha256(val)
 
 
+@pytest.mark.parametrize(
+    "val,exp",
+    [
+        ("", r""),
+        (2000, "2000"),
+        ("string", r"string"),
+        ("\t\n\v", r"\t\n\v"),
+        ("\a\b\f\n\r\t\v\\", r"\a\b\f\n\r\t\v\\"),
+        # fmt: off
+        ("''", r"\'\'"),
+        ('""', r'\"\"'),
+        ("\'\'", r"\'\'"),
+        ('\"\"', r'\"\"'),
+        ('\\"\\"', r'\\\"\\\"'),
+        ("\\'\\'", r"\\\'\\\'"),
+        # fmt: on
+    ],
+)
+def test_escaped_str_valid(val: Any, exp: str):
+    assert str(EscapedStr(val)) == exp
+
+
+@pytest.mark.parametrize("val", [1.1, False])
+def test_escaped_str_invalid(val: Any):
+    with raises(ValueError):
+        EscapedStr(val)
+
+
 @pytest.mark.parametrize(
     "val",
     [