]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-104683: Rework Argument Clinic error handling (#107551)
authorErlend E. Aasland <erlend@python.org>
Thu, 3 Aug 2023 00:00:06 +0000 (02:00 +0200)
committerGitHub <noreply@github.com>
Thu, 3 Aug 2023 00:00:06 +0000 (00:00 +0000)
Introduce ClinicError, and use it in fail(). The CLI runs main(),
catches ClinicError, formats the error message, prints to stderr
and exits with an error.

As a side effect, this refactor greatly improves the accuracy of
reported line numbers in case of error.

Also, adapt the test suite to work with ClinicError.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Lib/test/test_clinic.py
Tools/clinic/clinic.py

index cdabcbaa6f03ca3659c546c4f10bf90770ab5d1a..127008d443e4c63b06d80250abfba44d0bba14ec 100644 (file)
@@ -11,6 +11,7 @@ import collections
 import contextlib
 import inspect
 import os.path
+import re
 import sys
 import unittest
 
@@ -20,17 +21,24 @@ with test_tools.imports_under_tool('clinic'):
     from clinic import DSLParser
 
 
-class _ParserBase(TestCase):
-    maxDiff = None
-
-    def expect_parser_failure(self, parser, _input):
-        with support.captured_stdout() as stdout:
-            with self.assertRaises(SystemExit):
-                parser(_input)
-        return stdout.getvalue()
+def _expect_failure(tc, parser, code, errmsg, *, filename=None, lineno=None):
+    """Helper for the parser tests.
 
-    def parse_function_should_fail(self, _input):
-        return self.expect_parser_failure(self.parse_function, _input)
+    tc: unittest.TestCase; passed self in the wrapper
+    parser: the clinic parser used for this test case
+    code: a str with input text (clinic code)
+    errmsg: the expected error message
+    filename: str, optional filename
+    lineno: int, optional line number
+    """
+    code = dedent(code).strip()
+    errmsg = re.escape(errmsg)
+    with tc.assertRaisesRegex(clinic.ClinicError, errmsg) as cm:
+        parser(code)
+    if filename is not None:
+        tc.assertEqual(cm.exception.filename, filename)
+    if lineno is not None:
+        tc.assertEqual(cm.exception.lineno, lineno)
 
 
 class FakeConverter:
@@ -108,14 +116,15 @@ class FakeClinic:
         return "<FakeClinic object>"
 
 
-class ClinicWholeFileTest(_ParserBase):
+class ClinicWholeFileTest(TestCase):
+
+    def expect_failure(self, raw, errmsg, *, filename=None, lineno=None):
+        _expect_failure(self, self.clinic.parse, raw, errmsg,
+                        filename=filename, lineno=lineno)
+
     def setUp(self):
         self.clinic = clinic.Clinic(clinic.CLanguage(None), filename="test.c")
 
-    def expect_failure(self, raw):
-        _input = dedent(raw).strip()
-        return self.expect_parser_failure(self.clinic.parse, _input)
-
     def test_eol(self):
         # regression test:
         # clinic's block parser didn't recognize
@@ -139,12 +148,11 @@ class ClinicWholeFileTest(_ParserBase):
             [clinic start generated code]*/
             /*[clinic end generated code: foo]*/
         """
-        msg = (
-            'Error in file "test.c" on line 3:\n'
-            "Mangled Argument Clinic marker line: '/*[clinic end generated code: foo]*/'\n"
+        err = (
+            "Mangled Argument Clinic marker line: "
+            "'/*[clinic end generated code: foo]*/'"
         )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        self.expect_failure(raw, err, filename="test.c", lineno=3)
 
     def test_checksum_mismatch(self):
         raw = """
@@ -152,38 +160,31 @@ class ClinicWholeFileTest(_ParserBase):
             [clinic start generated code]*/
             /*[clinic end generated code: output=0123456789abcdef input=fedcba9876543210]*/
         """
-        msg = (
-            'Error in file "test.c" on line 3:\n'
+        err = (
             'Checksum mismatch!\n'
             'Expected: 0123456789abcdef\n'
             'Computed: da39a3ee5e6b4b0d\n'
         )
-        out = self.expect_failure(raw)
-        self.assertIn(msg, out)
+        self.expect_failure(raw, err, filename="test.c", lineno=3)
 
     def test_garbage_after_stop_line(self):
         raw = """
             /*[clinic input]
             [clinic start generated code]*/foobarfoobar!
         """
-        msg = (
-            'Error in file "test.c" on line 2:\n'
-            "Garbage after stop line: 'foobarfoobar!'\n"
-        )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        err = "Garbage after stop line: 'foobarfoobar!'"
+        self.expect_failure(raw, err, filename="test.c", lineno=2)
 
     def test_whitespace_before_stop_line(self):
         raw = """
             /*[clinic input]
              [clinic start generated code]*/
         """
-        msg = (
-            'Error in file "test.c" on line 2:\n'
-            "Whitespace is not allowed before the stop line: ' [clinic start generated code]*/'\n"
+        err = (
+            "Whitespace is not allowed before the stop line: "
+            "' [clinic start generated code]*/'"
         )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        self.expect_failure(raw, err, filename="test.c", lineno=2)
 
     def test_parse_with_body_prefix(self):
         clang = clinic.CLanguage(None)
@@ -213,12 +214,8 @@ class ClinicWholeFileTest(_ParserBase):
             */
             */
         """
-        msg = (
-            'Error in file "test.c" on line 2:\n'
-            'Nested block comment!\n'
-        )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        err = 'Nested block comment!'
+        self.expect_failure(raw, err, filename="test.c", lineno=2)
 
     def test_cpp_monitor_fail_invalid_format_noarg(self):
         raw = """
@@ -226,12 +223,8 @@ class ClinicWholeFileTest(_ParserBase):
             a()
             #endif
         """
-        msg = (
-            'Error in file "test.c" on line 1:\n'
-            'Invalid format for #if line: no argument!\n'
-        )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        err = 'Invalid format for #if line: no argument!'
+        self.expect_failure(raw, err, filename="test.c", lineno=1)
 
     def test_cpp_monitor_fail_invalid_format_toomanyargs(self):
         raw = """
@@ -239,39 +232,31 @@ class ClinicWholeFileTest(_ParserBase):
             a()
             #endif
         """
-        msg = (
-            'Error in file "test.c" on line 1:\n'
-            'Invalid format for #ifdef line: should be exactly one argument!\n'
-        )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        err = 'Invalid format for #ifdef line: should be exactly one argument!'
+        self.expect_failure(raw, err, filename="test.c", lineno=1)
 
     def test_cpp_monitor_fail_no_matching_if(self):
         raw = '#else'
-        msg = (
-            'Error in file "test.c" on line 1:\n'
-            '#else without matching #if / #ifdef / #ifndef!\n'
-        )
-        out = self.expect_failure(raw)
-        self.assertEqual(out, msg)
+        err = '#else without matching #if / #ifdef / #ifndef!'
+        self.expect_failure(raw, err, filename="test.c", lineno=1)
 
     def test_directive_output_unknown_preset(self):
-        out = self.expect_failure("""
+        raw = """
             /*[clinic input]
             output preset nosuchpreset
             [clinic start generated code]*/
-        """)
-        msg = "Unknown preset 'nosuchpreset'"
-        self.assertIn(msg, out)
+        """
+        err = "Unknown preset 'nosuchpreset'"
+        self.expect_failure(raw, err)
 
     def test_directive_output_cant_pop(self):
-        out = self.expect_failure("""
+        raw = """
             /*[clinic input]
             output pop
             [clinic start generated code]*/
-        """)
-        msg = "Can't 'output pop', stack is empty"
-        self.assertIn(msg, out)
+        """
+        err = "Can't 'output pop', stack is empty"
+        self.expect_failure(raw, err)
 
     def test_directive_output_print(self):
         raw = dedent("""
@@ -309,16 +294,16 @@ class ClinicWholeFileTest(_ParserBase):
         )
 
     def test_unknown_destination_command(self):
-        out = self.expect_failure("""
+        raw = """
             /*[clinic input]
             destination buffer nosuchcommand
             [clinic start generated code]*/
-        """)
-        msg = "unknown destination command 'nosuchcommand'"
-        self.assertIn(msg, out)
+        """
+        err = "unknown destination command 'nosuchcommand'"
+        self.expect_failure(raw, err)
 
     def test_no_access_to_members_in_converter_init(self):
-        out = self.expect_failure("""
+        raw = """
             /*[python input]
             class Custom_converter(CConverter):
                 converter = "some_c_function"
@@ -330,11 +315,11 @@ class ClinicWholeFileTest(_ParserBase):
             test.fn
                 a: Custom
             [clinic start generated code]*/
-        """)
-        msg = (
+        """
+        err = (
             "accessing self.function inside converter_init is disallowed!"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(raw, err)
 
     @staticmethod
     @contextlib.contextmanager
@@ -375,30 +360,30 @@ class ClinicWholeFileTest(_ParserBase):
                 "  Version: 4\n"
                 "  Required: 5"
             )
-            out = self.expect_failure("""
+            block = """
                 /*[clinic input]
                 version 5
                 [clinic start generated code]*/
-            """)
-            self.assertIn(err, out)
+            """
+            self.expect_failure(block, err)
 
     def test_version_directive_illegal_char(self):
         err = "Illegal character 'v' in version string 'v5'"
-        out = self.expect_failure("""
+        block = """
             /*[clinic input]
             version v5
             [clinic start generated code]*/
-        """)
-        self.assertIn(err, out)
+        """
+        self.expect_failure(block, err)
 
     def test_version_directive_unsupported_string(self):
         err = "Unsupported version string: '.-'"
-        out = self.expect_failure("""
+        block = """
             /*[clinic input]
             version .-
             [clinic start generated code]*/
-        """)
-        self.assertIn(err, out)
+        """
+        self.expect_failure(block, err)
 
 
 class ClinicGroupPermuterTest(TestCase):
@@ -577,7 +562,7 @@ xyz
 """)
 
 
-class ClinicParserTest(_ParserBase):
+class ClinicParserTest(TestCase):
 
     def parse(self, text):
         c = FakeClinic()
@@ -594,6 +579,10 @@ class ClinicParserTest(_ParserBase):
         assert isinstance(s[function_index], clinic.Function)
         return s[function_index]
 
+    def expect_failure(self, block, err, *, filename=None, lineno=None):
+        _expect_failure(self, self.parse_function, block, err,
+                        filename=filename, lineno=lineno)
+
     def checkDocstring(self, fn, expected):
         self.assertTrue(hasattr(fn, "docstring"))
         self.assertEqual(fn.docstring.strip(),
@@ -663,17 +652,16 @@ class ClinicParserTest(_ParserBase):
         self.assertEqual(sys.maxsize, p.default)
         self.assertEqual("MAXSIZE", p.converter.c_default)
 
-        expected_msg = (
-            "Error on line 0:\n"
+        err = (
             "When you specify a named constant ('sys.maxsize') as your default value,\n"
-            "you MUST specify a valid c_default.\n"
+            "you MUST specify a valid c_default."
         )
-        out = self.parse_function_should_fail("""
+        block = """
             module os
             os.access
                 follow_symlinks: int = sys.maxsize
-        """)
-        self.assertEqual(out, expected_msg)
+        """
+        self.expect_failure(block, err, lineno=2)
 
     def test_param_no_docstring(self):
         function = self.parse_function("""
@@ -688,17 +676,17 @@ class ClinicParserTest(_ParserBase):
         self.assertIsInstance(conv, clinic.str_converter)
 
     def test_param_default_parameters_out_of_order(self):
-        expected_msg = (
-            "Error on line 0:\n"
+        err = (
             "Can't have a parameter without a default ('something_else')\n"
-            "after a parameter with a default!\n"
+            "after a parameter with a default!"
         )
-        out = self.parse_function_should_fail("""
+        block = """
             module os
             os.access
                 follow_symlinks: bool = True
-                something_else: str""")
-        self.assertEqual(out, expected_msg)
+                something_else: str
+        """
+        self.expect_failure(block, err, lineno=3)
 
     def disabled_test_converter_arguments(self):
         function = self.parse_function("""
@@ -797,17 +785,19 @@ class ClinicParserTest(_ParserBase):
         self.assertEqual("os_stat_fn", function.c_basename)
 
     def test_cloning_nonexistent_function_correctly_fails(self):
-        stdout = self.parse_function_should_fail("""
-            cloned = fooooooooooooooooooooooo
+        block = """
+            cloned = fooooooooooooooooo
             This is trying to clone a nonexistent function!!
+        """
+        err = "Couldn't find existing function 'fooooooooooooooooo'!"
+        with support.captured_stderr() as stderr:
+            self.expect_failure(block, err, lineno=0)
+        expected_debug_print = dedent("""\
+            cls=None, module=<FakeClinic object>, existing='fooooooooooooooooo'
+            (cls or module).functions=[]
         """)
-        expected_error = """\
-cls=None, module=<FakeClinic object>, existing='fooooooooooooooooooooooo'
-(cls or module).functions=[]
-Error on line 0:
-Couldn't find existing function 'fooooooooooooooooooooooo'!
-"""
-        self.assertEqual(expected_error, stdout)
+        stderr = stderr.getvalue()
+        self.assertIn(expected_debug_print, stderr)
 
     def test_return_converter(self):
         function = self.parse_function("""
@@ -817,30 +807,28 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         self.assertIsInstance(function.return_converter, clinic.int_return_converter)
 
     def test_return_converter_invalid_syntax(self):
-        stdout = self.parse_function_should_fail("""
+        block = """
             module os
             os.stat -> invalid syntax
-        """)
-        expected_error = "Badly formed annotation for os.stat: 'invalid syntax'"
-        self.assertIn(expected_error, stdout)
+        """
+        err = "Badly formed annotation for os.stat: 'invalid syntax'"
+        self.expect_failure(block, err)
 
     def test_legacy_converter_disallowed_in_return_annotation(self):
-        stdout = self.parse_function_should_fail("""
+        block = """
             module os
             os.stat -> "s"
-        """)
-        expected_error = "Legacy converter 's' not allowed as a return converter"
-        self.assertIn(expected_error, stdout)
+        """
+        err = "Legacy converter 's' not allowed as a return converter"
+        self.expect_failure(block, err)
 
     def test_unknown_return_converter(self):
-        stdout = self.parse_function_should_fail("""
+        block = """
             module os
-            os.stat -> foooooooooooooooooooooooo
-        """)
-        expected_error = (
-            "No available return converter called 'foooooooooooooooooooooooo'"
-        )
-        self.assertIn(expected_error, stdout)
+            os.stat -> fooooooooooooooooooooooo
+        """
+        err = "No available return converter called 'fooooooooooooooooooooooo'"
+        self.expect_failure(block, err)
 
     def test_star(self):
         function = self.parse_function("""
@@ -985,19 +973,12 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 Attributes for the character.
         """)
 
-    def parse_function_should_fail(self, s):
-        with support.captured_stdout() as stdout:
-            with self.assertRaises(SystemExit):
-                self.parse_function(s)
-        return stdout.getvalue()
-
     def test_disallowed_grouping__two_top_groups_on_left(self):
-        expected_msg = (
-            'Error on line 0:\n'
+        err = (
             'Function two_top_groups_on_left has an unsupported group '
-            'configuration. (Unexpected state 2.b)\n'
+            'configuration. (Unexpected state 2.b)'
         )
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.two_top_groups_on_left
                 [
@@ -1007,11 +988,11 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 group2 : int
                 ]
                 param: int
-        """)
-        self.assertEqual(out, expected_msg)
+        """
+        self.expect_failure(block, err, lineno=5)
 
     def test_disallowed_grouping__two_top_groups_on_right(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.two_top_groups_on_right
                 param: int
@@ -1021,15 +1002,15 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 [
                 group2 : int
                 ]
-        """)
-        msg = (
+        """
+        err = (
             "Function two_top_groups_on_right has an unsupported group "
             "configuration. (Unexpected state 6.b)"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_disallowed_grouping__parameter_after_group_on_right(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.parameter_after_group_on_right
                 param: int
@@ -1039,15 +1020,15 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 ]
                 group2 : int
                 ]
-        """)
-        msg = (
+        """
+        err = (
             "Function parameter_after_group_on_right has an unsupported group "
             "configuration. (Unexpected state 6.a)"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_disallowed_grouping__group_after_parameter_on_left(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.group_after_parameter_on_left
                 [
@@ -1057,15 +1038,15 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 ]
                 ]
                 param: int
-        """)
-        msg = (
+        """
+        err = (
             "Function group_after_parameter_on_left has an unsupported group "
             "configuration. (Unexpected state 2.b)"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_disallowed_grouping__empty_group_on_left(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.empty_group
                 [
@@ -1074,15 +1055,15 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 group2 : int
                 ]
                 param: int
-        """)
-        msg = (
+        """
+        err = (
             "Function empty_group has an empty group.\n"
             "All groups must contain at least one parameter."
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_disallowed_grouping__empty_group_on_right(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.empty_group
                 param: int
@@ -1091,24 +1072,24 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 ]
                 group2 : int
                 ]
-        """)
-        msg = (
+        """
+        err = (
             "Function empty_group has an empty group.\n"
             "All groups must contain at least one parameter."
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_disallowed_grouping__no_matching_bracket(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.empty_group
                 param: int
                 ]
                 group2: int
                 ]
-        """)
-        msg = "Function empty_group has a ] without a matching [."
-        self.assertIn(msg, out)
+        """
+        err = "Function empty_group has a ] without a matching [."
+        self.expect_failure(block, err)
 
     def test_no_parameters(self):
         function = self.parse_function("""
@@ -1137,31 +1118,32 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         self.assertEqual(1, len(function.parameters))
 
     def test_illegal_module_line(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar => int
                 /
-        """)
-        msg = "Illegal function name: foo.bar => int"
-        self.assertIn(msg, out)
+        """
+        err = "Illegal function name: foo.bar => int"
+        self.expect_failure(block, err)
 
     def test_illegal_c_basename(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar as 935
                 /
-        """)
-        msg = "Illegal C basename: 935"
-        self.assertIn(msg, out)
+        """
+        err = "Illegal C basename: 935"
+        self.expect_failure(block, err)
 
     def test_single_star(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                 *
                 *
-        """)
-        self.assertIn("Function bar uses '*' more than once.", out)
+        """
+        err = "Function bar uses '*' more than once."
+        self.expect_failure(block, err)
 
     def test_parameters_required_after_star(self):
         dataset = (
@@ -1170,39 +1152,38 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
             "module foo\nfoo.bar\n  this: int\n  *",
             "module foo\nfoo.bar\n  this: int\n  *\nDocstring.",
         )
-        msg = "Function bar specifies '*' without any parameters afterwards."
+        err = "Function bar specifies '*' without any parameters afterwards."
         for block in dataset:
             with self.subTest(block=block):
-                out = self.parse_function_should_fail(block)
-                self.assertIn(msg, out)
+                self.expect_failure(block, err)
 
     def test_single_slash(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                 /
                 /
-        """)
-        msg = (
+        """
+        err = (
             "Function bar has an unsupported group configuration. "
             "(Unexpected state 0.d)"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_double_slash(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                 a: int
                 /
                 b: int
                 /
-        """)
-        msg = "Function bar uses '/' more than once."
-        self.assertIn(msg, out)
+        """
+        err = "Function bar uses '/' more than once."
+        self.expect_failure(block, err)
 
     def test_mix_star_and_slash(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                x: int
@@ -1210,38 +1191,35 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                *
                z: int
                /
-        """)
-        msg = (
+        """
+        err = (
             "Function bar mixes keyword-only and positional-only parameters, "
             "which is unsupported."
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_parameters_not_permitted_after_slash_for_now(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                 /
                 x: int
-        """)
-        msg = (
+        """
+        err = (
             "Function bar has an unsupported group configuration. "
             "(Unexpected state 0.d)"
         )
-        self.assertIn(msg, out)
+        self.expect_failure(block, err)
 
     def test_parameters_no_more_than_one_vararg(self):
-        expected_msg = (
-            "Error on line 0:\n"
-            "Too many var args\n"
-        )
-        out = self.parse_function_should_fail("""
+        err = "Too many var args"
+        block = """
             module foo
             foo.bar
                *vararg1: object
                *vararg2: object
-        """)
-        self.assertEqual(out, expected_msg)
+        """
+        self.expect_failure(block, err, lineno=0)
 
     def test_function_not_at_column_0(self):
         function = self.parse_function("""
@@ -1264,23 +1242,24 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         """)
 
     def test_indent_stack_no_tabs(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
                *vararg1: object
             \t*vararg2: object
-        """)
-        msg = "Tab characters are illegal in the Clinic DSL."
-        self.assertIn(msg, out)
+        """
+        err = "Tab characters are illegal in the Clinic DSL."
+        self.expect_failure(block, err)
 
     def test_indent_stack_illegal_outdent(self):
-        out = self.parse_function_should_fail("""
+        block = """
             module foo
             foo.bar
               a: object
              b: object
-        """)
-        self.assertIn("Illegal outdent", out)
+        """
+        err = "Illegal outdent"
+        self.expect_failure(block, err)
 
     def test_directive(self):
         c = FakeClinic()
@@ -1298,10 +1277,7 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         self.assertIsInstance(conv, clinic.str_converter)
 
     def test_legacy_converters_non_string_constant_annotation(self):
-        expected_failure_message = (
-            "Error on line 0:\n"
-            "Annotations must be either a name, a function call, or a string.\n"
-        )
+        err = "Annotations must be either a name, a function call, or a string"
         dataset = (
             'module os\nos.access\n   path: 42',
             'module os\nos.access\n   path: 42.42',
@@ -1310,14 +1286,10 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         )
         for block in dataset:
             with self.subTest(block=block):
-                out = self.parse_function_should_fail(block)
-                self.assertEqual(out, expected_failure_message)
+                self.expect_failure(block, err, lineno=2)
 
     def test_other_bizarre_things_in_annotations_fail(self):
-        expected_failure_message = (
-            "Error on line 0:\n"
-            "Annotations must be either a name, a function call, or a string.\n"
-        )
+        err = "Annotations must be either a name, a function call, or a string"
         dataset = (
             'module os\nos.access\n   path: {"some": "dictionary"}',
             'module os\nos.access\n   path: ["list", "of", "strings"]',
@@ -1325,30 +1297,24 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         )
         for block in dataset:
             with self.subTest(block=block):
-                out = self.parse_function_should_fail(block)
-                self.assertEqual(out, expected_failure_message)
+                self.expect_failure(block, err, lineno=2)
 
     def test_kwarg_splats_disallowed_in_function_call_annotations(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
-            "Cannot use a kwarg splat in a function-call annotation\n"
-        )
+        err = "Cannot use a kwarg splat in a function-call annotation"
         dataset = (
             'module fo\nfo.barbaz\n   o: bool(**{None: "bang!"})',
             'module fo\nfo.barbaz -> bool(**{None: "bang!"})',
             'module fo\nfo.barbaz -> bool(**{"bang": 42})',
             'module fo\nfo.barbaz\n   o: bool(**{"bang": None})',
         )
-        for fn in dataset:
-            with self.subTest(fn=fn):
-                out = self.parse_function_should_fail(fn)
-                self.assertEqual(out, expected_error_msg)
+        for block in dataset:
+            with self.subTest(block=block):
+                self.expect_failure(block, err)
 
     def test_self_param_placement(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
+        err = (
             "A 'self' parameter, if specified, must be the very first thing "
-            "in the parameter block.\n"
+            "in the parameter block."
         )
         block = """
             module foo
@@ -1356,27 +1322,21 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 a: int
                 self: self(type="PyObject *")
         """
-        out = self.parse_function_should_fail(block)
-        self.assertEqual(out, expected_error_msg)
+        self.expect_failure(block, err, lineno=3)
 
     def test_self_param_cannot_be_optional(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
-            "A 'self' parameter cannot be marked optional.\n"
-        )
+        err = "A 'self' parameter cannot be marked optional."
         block = """
             module foo
             foo.func
                 self: self(type="PyObject *") = None
         """
-        out = self.parse_function_should_fail(block)
-        self.assertEqual(out, expected_error_msg)
+        self.expect_failure(block, err, lineno=2)
 
     def test_defining_class_param_placement(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
+        err = (
             "A 'defining_class' parameter, if specified, must either be the "
-            "first thing in the parameter block, or come just after 'self'.\n"
+            "first thing in the parameter block, or come just after 'self'."
         )
         block = """
             module foo
@@ -1385,21 +1345,16 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 a: int
                 cls: defining_class
         """
-        out = self.parse_function_should_fail(block)
-        self.assertEqual(out, expected_error_msg)
+        self.expect_failure(block, err, lineno=4)
 
     def test_defining_class_param_cannot_be_optional(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
-            "A 'defining_class' parameter cannot be marked optional.\n"
-        )
+        err = "A 'defining_class' parameter cannot be marked optional."
         block = """
             module foo
             foo.func
                 cls: defining_class(type="PyObject *") = None
         """
-        out = self.parse_function_should_fail(block)
-        self.assertEqual(out, expected_error_msg)
+        self.expect_failure(block, err, lineno=2)
 
     def test_slot_methods_cannot_access_defining_class(self):
         block = """
@@ -1409,34 +1364,28 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
                 cls: defining_class
                 a: object
         """
-        msg = "Slot methods cannot access their defining class."
-        with self.assertRaisesRegex(ValueError, msg):
+        err = "Slot methods cannot access their defining class."
+        with self.assertRaisesRegex(ValueError, err):
             self.parse_function(block)
 
     def test_new_must_be_a_class_method(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
-            "__new__ must be a class method!\n"
-        )
-        out = self.parse_function_should_fail("""
+        err = "__new__ must be a class method!"
+        block = """
             module foo
             class Foo "" ""
             Foo.__new__
-        """)
-        self.assertEqual(out, expected_error_msg)
+        """
+        self.expect_failure(block, err, lineno=2)
 
     def test_init_must_be_a_normal_method(self):
-        expected_error_msg = (
-            "Error on line 0:\n"
-            "__init__ must be a normal method, not a class or static method!\n"
-        )
-        out = self.parse_function_should_fail("""
+        err = "__init__ must be a normal method, not a class or static method!"
+        block = """
             module foo
             class Foo "" ""
             @classmethod
             Foo.__init__
-        """)
-        self.assertEqual(out, expected_error_msg)
+        """
+        self.expect_failure(block, err, lineno=3)
 
     def test_unused_param(self):
         block = self.parse("""
@@ -1487,11 +1436,12 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
             'The igloos are melting!\n'
         )
         with support.captured_stdout() as stdout:
-            with self.assertRaises(SystemExit):
-                clinic.fail('The igloos are melting!',
-                            filename='clown.txt', line_number=69)
-        actual = stdout.getvalue()
-        self.assertEqual(actual, expected)
+            errmsg = 'The igloos are melting'
+            with self.assertRaisesRegex(clinic.ClinicError, errmsg) as cm:
+                clinic.fail(errmsg, filename='clown.txt', line_number=69)
+            exc = cm.exception
+            self.assertEqual(exc.filename, 'clown.txt')
+            self.assertEqual(exc.lineno, 69)
 
     def test_non_ascii_character_in_docstring(self):
         block = """
@@ -1507,19 +1457,21 @@ Couldn't find existing function 'fooooooooooooooooooooooo'!
         expected = dedent("""\
             Warning on line 0:
             Non-ascii characters are not allowed in docstrings: 'á'
+
             Warning on line 0:
             Non-ascii characters are not allowed in docstrings: 'ü', 'á', 'ß'
+
         """)
         self.assertEqual(stdout.getvalue(), expected)
 
     def test_illegal_c_identifier(self):
         err = "Illegal C identifier: 17a"
-        out = self.parse_function_should_fail("""
+        block = """
             module test
             test.fn
                 a as 17a: int
-        """)
-        self.assertIn(err, out)
+        """
+        self.expect_failure(block, err)
 
 
 class ClinicExternalTest(TestCase):
@@ -1607,9 +1559,9 @@ class ClinicExternalTest(TestCase):
             # First, run the CLI without -f and expect failure.
             # Note, we cannot check the entire fail msg, because the path to
             # the tmp file will change for every run.
-            out, _ = self.expect_failure(fn)
-            self.assertTrue(out.endswith(fail_msg),
-                            f"{out!r} does not end with {fail_msg!r}")
+            _, err = self.expect_failure(fn)
+            self.assertTrue(err.endswith(fail_msg),
+                            f"{err!r} does not end with {fail_msg!r}")
             # Then, force regeneration; success expected.
             out = self.expect_success("-f", fn)
             self.assertEqual(out, "")
@@ -2231,8 +2183,11 @@ class ClinicFunctionalTest(unittest.TestCase):
         self.assertEqual(arg_refcount_origin, arg_refcount_after)
 
     def test_gh_99240_double_free(self):
-        expected_error = r'gh_99240_double_free\(\) argument 2 must be encoded string without null bytes, not str'
-        with self.assertRaisesRegex(TypeError, expected_error):
+        err = re.escape(
+            "gh_99240_double_free() argument 2 must be encoded string "
+            "without null bytes, not str"
+        )
+        with self.assertRaisesRegex(TypeError, err):
             ac_tester.gh_99240_double_free('a', '\0b')
 
     def test_cloned_func_exception_message(self):
index ce8184753c72922aa7960a95a82b8841a551b5d8..1bcdb6b1c3640a70993cac8ad2f3d682c97e7f26 100755 (executable)
@@ -28,7 +28,6 @@ import shlex
 import string
 import sys
 import textwrap
-import traceback
 
 from collections.abc import (
     Callable,
@@ -137,6 +136,28 @@ def text_accumulator() -> TextAccumulator:
     text, append, output = _text_accumulator()
     return TextAccumulator(append, output)
 
+
+@dc.dataclass
+class ClinicError(Exception):
+    message: str
+    _: dc.KW_ONLY
+    lineno: int | None = None
+    filename: str | None = None
+
+    def __post_init__(self) -> None:
+        super().__init__(self.message)
+
+    def report(self, *, warn_only: bool = False) -> str:
+        msg = "Warning" if warn_only else "Error"
+        if self.filename is not None:
+            msg += f" in file {self.filename!r}"
+        if self.lineno is not None:
+            msg += f" on line {self.lineno}"
+        msg += ":\n"
+        msg += f"{self.message}\n"
+        return msg
+
+
 @overload
 def warn_or_fail(
     *args: object,
@@ -160,25 +181,16 @@ def warn_or_fail(
     line_number: int | None = None,
 ) -> None:
     joined = " ".join([str(a) for a in args])
-    add, output = text_accumulator()
-    if fail:
-        add("Error")
-    else:
-        add("Warning")
     if clinic:
         if filename is None:
             filename = clinic.filename
         if getattr(clinic, 'block_parser', None) and (line_number is None):
             line_number = clinic.block_parser.line_number
-    if filename is not None:
-        add(' in file "' + filename + '"')
-    if line_number is not None:
-        add(" on line " + str(line_number))
-    add(':\n')
-    add(joined)
-    print(output())
+    error = ClinicError(joined, filename=filename, lineno=line_number)
     if fail:
-        sys.exit(-1)
+        raise error
+    else:
+        print(error.report(warn_only=True))
 
 
 def warn(
@@ -347,7 +359,7 @@ def version_splitter(s: str) -> tuple[int, ...]:
     accumulator: list[str] = []
     def flush() -> None:
         if not accumulator:
-            raise ValueError('Unsupported version string: ' + repr(s))
+            fail(f'Unsupported version string: {s!r}')
         version.append(int(''.join(accumulator)))
         accumulator.clear()
 
@@ -360,7 +372,7 @@ def version_splitter(s: str) -> tuple[int, ...]:
             flush()
             version.append('abc'.index(c) - 3)
         else:
-            raise ValueError('Illegal character ' + repr(c) + ' in version string ' + repr(s))
+            fail(f'Illegal character {c!r} in version string {s!r}')
     flush()
     return tuple(version)
 
@@ -2233,11 +2245,7 @@ impl_definition block
                     assert dsl_name in parsers, f"No parser to handle {dsl_name!r} block."
                     self.parsers[dsl_name] = parsers[dsl_name](self)
                 parser = self.parsers[dsl_name]
-                try:
-                    parser.parse(block)
-                except Exception:
-                    fail('Exception raised during parsing:\n' +
-                         traceback.format_exc().rstrip())
+                parser.parse(block)
             printer.print_block(block)
 
         # these are destinations not buffers
@@ -4600,7 +4608,11 @@ class DSLParser:
         for line_number, line in enumerate(lines, self.clinic.block_parser.block_start_line_number):
             if '\t' in line:
                 fail('Tab characters are illegal in the Clinic DSL.\n\t' + repr(line), line_number=block_start)
-            self.state(line)
+            try:
+                self.state(line)
+            except ClinicError as exc:
+                exc.lineno = line_number
+                raise
 
         self.do_post_block_processing_cleanup()
         block.output.extend(self.clinic.language.render(self.clinic, block.signatures))
@@ -4701,8 +4713,8 @@ class DSLParser:
                     if existing_function.name == function_name:
                         break
                 else:
-                    print(f"{cls=}, {module=}, {existing=}")
-                    print(f"{(cls or module).functions=}")
+                    print(f"{cls=}, {module=}, {existing=}", file=sys.stderr)
+                    print(f"{(cls or module).functions=}", file=sys.stderr)
                     fail(f"Couldn't find existing function {existing!r}!")
 
                 fields = [x.strip() for x in full_name.split('.')]
@@ -5719,8 +5731,13 @@ def run_clinic(parser: argparse.ArgumentParser, ns: argparse.Namespace) -> None:
 def main(argv: list[str] | None = None) -> NoReturn:
     parser = create_cli()
     args = parser.parse_args(argv)
-    run_clinic(parser, args)
-    sys.exit(0)
+    try:
+        run_clinic(parser, args)
+    except ClinicError as exc:
+        sys.stderr.write(exc.report())
+        sys.exit(1)
+    else:
+        sys.exit(0)
 
 
 if __name__ == "__main__":