]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-126175: Add attributes to TOMLDecodeError. Deprecate free-form `__init__` args...
authorTaneli Hukkinen <3275109+hukkin@users.noreply.github.com>
Wed, 13 Nov 2024 12:52:16 +0000 (14:52 +0200)
committerGitHub <noreply@github.com>
Wed, 13 Nov 2024 12:52:16 +0000 (13:52 +0100)
Doc/library/tomllib.rst
Lib/test/test_tomllib/test_error.py
Lib/tomllib/_parser.py
Misc/NEWS.d/next/Library/2024-11-05-09-54-49.gh-issue-126175.spnjJr.rst [new file with mode: 0644]

index 521a7a17fb3e8bc9984e0d875ba7a3c89de231d8..4b88b2e29e782200f85d148e6680f23fde07b290 100644 (file)
@@ -60,9 +60,36 @@ This module defines the following functions:
 
 The following exceptions are available:
 
-.. exception:: TOMLDecodeError
+.. exception:: TOMLDecodeError(msg, doc, pos)
 
-   Subclass of :exc:`ValueError`.
+   Subclass of :exc:`ValueError` with the following additional attributes:
+
+   .. attribute:: msg
+
+      The unformatted error message.
+
+   .. attribute:: doc
+
+      The TOML document being parsed.
+
+   .. attribute:: pos
+
+      The index of *doc* where parsing failed.
+
+   .. attribute:: lineno
+
+      The line corresponding to *pos*.
+
+   .. attribute:: colno
+
+      The column corresponding to *pos*.
+
+   .. versionchanged:: next
+      Added the *msg*, *doc* and *pos* parameters.
+      Added the :attr:`msg`, :attr:`doc`, :attr:`pos`, :attr:`lineno` and :attr:`colno` attributes.
+
+   .. deprecated:: next
+      Passing free-form positional arguments is deprecated.
 
 
 Examples
index d2ef59a29ca350cb4c9654d7fce9aa4416a56bcf..3a8587492859ca65f60c51cd354f1da2e576ebe5 100644 (file)
@@ -49,7 +49,9 @@ class TestError(unittest.TestCase):
         self.assertEqual(str(exc_info.exception), "Expected str object, not 'bool'")
 
     def test_module_name(self):
-        self.assertEqual(tomllib.TOMLDecodeError().__module__, tomllib.__name__)
+        self.assertEqual(
+            tomllib.TOMLDecodeError("", "", 0).__module__, tomllib.__name__
+        )
 
     def test_invalid_parse_float(self):
         def dict_returner(s: str) -> dict:
@@ -64,3 +66,33 @@ class TestError(unittest.TestCase):
             self.assertEqual(
                 str(exc_info.exception), "parse_float must not return dicts or lists"
             )
+
+    def test_deprecated_tomldecodeerror(self):
+        for args in [
+            (),
+            ("err msg",),
+            (None,),
+            (None, "doc"),
+            ("err msg", None),
+            (None, "doc", None),
+            ("err msg", "doc", None),
+            ("one", "two", "three", "four"),
+            ("one", "two", 3, "four", "five"),
+        ]:
+            with self.assertWarns(DeprecationWarning):
+                e = tomllib.TOMLDecodeError(*args)  # type: ignore[arg-type]
+            self.assertEqual(e.args, args)
+
+    def test_tomldecodeerror(self):
+        msg = "error parsing"
+        doc = "v=1\n[table]\nv='val'"
+        pos = 13
+        formatted_msg = "error parsing (at line 3, column 2)"
+        e = tomllib.TOMLDecodeError(msg, doc, pos)
+        self.assertEqual(e.args, (formatted_msg,))
+        self.assertEqual(str(e), formatted_msg)
+        self.assertEqual(e.msg, msg)
+        self.assertEqual(e.doc, doc)
+        self.assertEqual(e.pos, pos)
+        self.assertEqual(e.lineno, 3)
+        self.assertEqual(e.colno, 2)
index 5671326646ca5a6a186b1ef5b9dbb012a1a5f82d..4d208bcfb4a9a679f094410bdc4a8a1faa8c6bf2 100644 (file)
@@ -8,6 +8,7 @@ from collections.abc import Iterable
 import string
 from types import MappingProxyType
 from typing import Any, BinaryIO, NamedTuple
+import warnings
 
 from ._re import (
     RE_DATETIME,
@@ -50,8 +51,68 @@ BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType(
 )
 
 
+class DEPRECATED_DEFAULT:
+    """Sentinel to be used as default arg during deprecation
+    period of TOMLDecodeError's free-form arguments."""
+
+
 class TOMLDecodeError(ValueError):
-    """An error raised if a document is not valid TOML."""
+    """An error raised if a document is not valid TOML.
+
+    Adds the following attributes to ValueError:
+    msg: The unformatted error message
+    doc: The TOML document being parsed
+    pos: The index of doc where parsing failed
+    lineno: The line corresponding to pos
+    colno: The column corresponding to pos
+    """
+
+    def __init__(
+        self,
+        msg: str = DEPRECATED_DEFAULT,  # type: ignore[assignment]
+        doc: str = DEPRECATED_DEFAULT,  # type: ignore[assignment]
+        pos: Pos = DEPRECATED_DEFAULT,  # type: ignore[assignment]
+        *args: Any,
+    ):
+        if (
+            args
+            or not isinstance(msg, str)
+            or not isinstance(doc, str)
+            or not isinstance(pos, int)
+        ):
+            warnings.warn(
+                "Free-form arguments for TOMLDecodeError are deprecated. "
+                "Please set 'msg' (str), 'doc' (str) and 'pos' (int) arguments only.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            if pos is not DEPRECATED_DEFAULT:  # type: ignore[comparison-overlap]
+                args = pos, *args
+            if doc is not DEPRECATED_DEFAULT:  # type: ignore[comparison-overlap]
+                args = doc, *args
+            if msg is not DEPRECATED_DEFAULT:  # type: ignore[comparison-overlap]
+                args = msg, *args
+            ValueError.__init__(self, *args)
+            return
+
+        lineno = doc.count("\n", 0, pos) + 1
+        if lineno == 1:
+            colno = pos + 1
+        else:
+            colno = pos - doc.rindex("\n", 0, pos)
+
+        if pos >= len(doc):
+            coord_repr = "end of document"
+        else:
+            coord_repr = f"line {lineno}, column {colno}"
+        errmsg = f"{msg} (at {coord_repr})"
+        ValueError.__init__(self, errmsg)
+
+        self.msg = msg
+        self.doc = doc
+        self.pos = pos
+        self.lineno = lineno
+        self.colno = colno
 
 
 def load(fp: BinaryIO, /, *, parse_float: ParseFloat = float) -> dict[str, Any]:
@@ -118,7 +179,7 @@ def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]:  # n
                 pos, header = create_dict_rule(src, pos, out)
             pos = skip_chars(src, pos, TOML_WS)
         elif char != "#":
-            raise suffixed_err(src, pos, "Invalid statement")
+            raise TOMLDecodeError("Invalid statement", src, pos)
 
         # 3. Skip comment
         pos = skip_comment(src, pos)
@@ -129,8 +190,8 @@ def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]:  # n
         except IndexError:
             break
         if char != "\n":
-            raise suffixed_err(
-                src, pos, "Expected newline or end of document after a statement"
+            raise TOMLDecodeError(
+                "Expected newline or end of document after a statement", src, pos
             )
         pos += 1
 
@@ -256,12 +317,12 @@ def skip_until(
     except ValueError:
         new_pos = len(src)
         if error_on_eof:
-            raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None
+            raise TOMLDecodeError(f"Expected {expect!r}", src, new_pos) from None
 
     if not error_on.isdisjoint(src[pos:new_pos]):
         while src[pos] not in error_on:
             pos += 1
-        raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}")
+        raise TOMLDecodeError(f"Found invalid character {src[pos]!r}", src, pos)
     return new_pos
 
 
@@ -292,15 +353,17 @@ def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
     pos, key = parse_key(src, pos)
 
     if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN):
-        raise suffixed_err(src, pos, f"Cannot declare {key} twice")
+        raise TOMLDecodeError(f"Cannot declare {key} twice", src, pos)
     out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
     try:
         out.data.get_or_create_nest(key)
     except KeyError:
-        raise suffixed_err(src, pos, "Cannot overwrite a value") from None
+        raise TOMLDecodeError("Cannot overwrite a value", src, pos) from None
 
     if not src.startswith("]", pos):
-        raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration")
+        raise TOMLDecodeError(
+            "Expected ']' at the end of a table declaration", src, pos
+        )
     return pos + 1, key
 
 
@@ -310,7 +373,7 @@ def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
     pos, key = parse_key(src, pos)
 
     if out.flags.is_(key, Flags.FROZEN):
-        raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
+        raise TOMLDecodeError(f"Cannot mutate immutable namespace {key}", src, pos)
     # Free the namespace now that it points to another empty list item...
     out.flags.unset_all(key)
     # ...but this key precisely is still prohibited from table declaration
@@ -318,10 +381,12 @@ def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
     try:
         out.data.append_nest_to_list(key)
     except KeyError:
-        raise suffixed_err(src, pos, "Cannot overwrite a value") from None
+        raise TOMLDecodeError("Cannot overwrite a value", src, pos) from None
 
     if not src.startswith("]]", pos):
-        raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration")
+        raise TOMLDecodeError(
+            "Expected ']]' at the end of an array declaration", src, pos
+        )
     return pos + 2, key
 
 
@@ -336,22 +401,22 @@ def key_value_rule(
     for cont_key in relative_path_cont_keys:
         # Check that dotted key syntax does not redefine an existing table
         if out.flags.is_(cont_key, Flags.EXPLICIT_NEST):
-            raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}")
+            raise TOMLDecodeError(f"Cannot redefine namespace {cont_key}", src, pos)
         # Containers in the relative path can't be opened with the table syntax or
         # dotted key/value syntax in following table sections.
         out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST)
 
     if out.flags.is_(abs_key_parent, Flags.FROZEN):
-        raise suffixed_err(
-            src, pos, f"Cannot mutate immutable namespace {abs_key_parent}"
+        raise TOMLDecodeError(
+            f"Cannot mutate immutable namespace {abs_key_parent}", src, pos
         )
 
     try:
         nest = out.data.get_or_create_nest(abs_key_parent)
     except KeyError:
-        raise suffixed_err(src, pos, "Cannot overwrite a value") from None
+        raise TOMLDecodeError("Cannot overwrite a value", src, pos) from None
     if key_stem in nest:
-        raise suffixed_err(src, pos, "Cannot overwrite a value")
+        raise TOMLDecodeError("Cannot overwrite a value", src, pos)
     # Mark inline table and array namespaces recursively immutable
     if isinstance(value, (dict, list)):
         out.flags.set(header + key, Flags.FROZEN, recursive=True)
@@ -368,7 +433,7 @@ def parse_key_value_pair(
     except IndexError:
         char = None
     if char != "=":
-        raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair")
+        raise TOMLDecodeError("Expected '=' after a key in a key/value pair", src, pos)
     pos += 1
     pos = skip_chars(src, pos, TOML_WS)
     pos, value = parse_value(src, pos, parse_float)
@@ -406,7 +471,7 @@ def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]:
         return parse_literal_str(src, pos)
     if char == '"':
         return parse_one_line_basic_str(src, pos)
-    raise suffixed_err(src, pos, "Invalid initial character for a key part")
+    raise TOMLDecodeError("Invalid initial character for a key part", src, pos)
 
 
 def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
@@ -430,7 +495,7 @@ def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]
         if c == "]":
             return pos + 1, array
         if c != ",":
-            raise suffixed_err(src, pos, "Unclosed array")
+            raise TOMLDecodeError("Unclosed array", src, pos)
         pos += 1
 
         pos = skip_comments_and_array_ws(src, pos)
@@ -450,20 +515,20 @@ def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos
         pos, key, value = parse_key_value_pair(src, pos, parse_float)
         key_parent, key_stem = key[:-1], key[-1]
         if flags.is_(key, Flags.FROZEN):
-            raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
+            raise TOMLDecodeError(f"Cannot mutate immutable namespace {key}", src, pos)
         try:
             nest = nested_dict.get_or_create_nest(key_parent, access_lists=False)
         except KeyError:
-            raise suffixed_err(src, pos, "Cannot overwrite a value") from None
+            raise TOMLDecodeError("Cannot overwrite a value", src, pos) from None
         if key_stem in nest:
-            raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}")
+            raise TOMLDecodeError(f"Duplicate inline table key {key_stem!r}", src, pos)
         nest[key_stem] = value
         pos = skip_chars(src, pos, TOML_WS)
         c = src[pos : pos + 1]
         if c == "}":
             return pos + 1, nested_dict.dict
         if c != ",":
-            raise suffixed_err(src, pos, "Unclosed inline table")
+            raise TOMLDecodeError("Unclosed inline table", src, pos)
         if isinstance(value, (dict, list)):
             flags.set(key, Flags.FROZEN, recursive=True)
         pos += 1
@@ -485,7 +550,7 @@ def parse_basic_str_escape(
             except IndexError:
                 return pos, ""
             if char != "\n":
-                raise suffixed_err(src, pos, "Unescaped '\\' in a string")
+                raise TOMLDecodeError("Unescaped '\\' in a string", src, pos)
             pos += 1
         pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
         return pos, ""
@@ -496,7 +561,7 @@ def parse_basic_str_escape(
     try:
         return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
     except KeyError:
-        raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None
+        raise TOMLDecodeError("Unescaped '\\' in a string", src, pos) from None
 
 
 def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
@@ -506,11 +571,13 @@ def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
 def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]:
     hex_str = src[pos : pos + hex_len]
     if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str):
-        raise suffixed_err(src, pos, "Invalid hex value")
+        raise TOMLDecodeError("Invalid hex value", src, pos)
     pos += hex_len
     hex_int = int(hex_str, 16)
     if not is_unicode_scalar_value(hex_int):
-        raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value")
+        raise TOMLDecodeError(
+            "Escaped character is not a Unicode scalar value", src, pos
+        )
     return pos, chr(hex_int)
 
 
@@ -567,7 +634,7 @@ def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
         try:
             char = src[pos]
         except IndexError:
-            raise suffixed_err(src, pos, "Unterminated string") from None
+            raise TOMLDecodeError("Unterminated string", src, pos) from None
         if char == '"':
             if not multiline:
                 return pos + 1, result + src[start_pos:pos]
@@ -582,7 +649,7 @@ def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
             start_pos = pos
             continue
         if char in error_on:
-            raise suffixed_err(src, pos, f"Illegal character {char!r}")
+            raise TOMLDecodeError(f"Illegal character {char!r}", src, pos)
         pos += 1
 
 
@@ -630,7 +697,7 @@ def parse_value(  # noqa: C901
         try:
             datetime_obj = match_to_datetime(datetime_match)
         except ValueError as e:
-            raise suffixed_err(src, pos, "Invalid date or datetime") from e
+            raise TOMLDecodeError("Invalid date or datetime", src, pos) from e
         return datetime_match.end(), datetime_obj
     localtime_match = RE_LOCALTIME.match(src, pos)
     if localtime_match:
@@ -651,24 +718,7 @@ def parse_value(  # noqa: C901
     if first_four in {"-inf", "+inf", "-nan", "+nan"}:
         return pos + 4, parse_float(first_four)
 
-    raise suffixed_err(src, pos, "Invalid value")
-
-
-def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError:
-    """Return a `TOMLDecodeError` where error message is suffixed with
-    coordinates in source."""
-
-    def coord_repr(src: str, pos: Pos) -> str:
-        if pos >= len(src):
-            return "end of document"
-        line = src.count("\n", 0, pos) + 1
-        if line == 1:
-            column = pos + 1
-        else:
-            column = pos - src.rindex("\n", 0, pos)
-        return f"line {line}, column {column}"
-
-    return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})")
+    raise TOMLDecodeError("Invalid value", src, pos)
 
 
 def is_unicode_scalar_value(codepoint: int) -> bool:
diff --git a/Misc/NEWS.d/next/Library/2024-11-05-09-54-49.gh-issue-126175.spnjJr.rst b/Misc/NEWS.d/next/Library/2024-11-05-09-54-49.gh-issue-126175.spnjJr.rst
new file mode 100644 (file)
index 0000000..de7ce88
--- /dev/null
@@ -0,0 +1,2 @@
+Add ``msg``, ``doc``, ``pos``, ``lineno`` and ``colno`` attributes to :exc:`tomllib.TOMLDecodeError`.
+Deprecate instantiating with free-form arguments.