]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-135256: Simplify parsing parameters in Argument Clinic (GH-135257)
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 13 Jul 2025 20:27:48 +0000 (23:27 +0300)
committerGitHub <noreply@github.com>
Sun, 13 Jul 2025 20:27:48 +0000 (23:27 +0300)
Lib/test/test_clinic.py
Tools/clinic/libclinic/dsl_parser.py

index a6ed887999242f2b79a0e2652034c7509e70e8a7..e83114794519d5c6e1d071f17013f4434af47988 100644 (file)
@@ -1277,12 +1277,8 @@ class ClinicParserTest(TestCase):
             os.stat
                 invalid syntax: int = 42
         """
-        err = dedent(r"""
-            Function 'stat' has an invalid parameter declaration:
-            \s+'invalid syntax: int = 42'
-        """).strip()
-        with self.assertRaisesRegex(ClinicError, err):
-            self.parse_function(block)
+        err = "Function 'stat' has an invalid parameter declaration: 'invalid syntax: int = 42'"
+        self.expect_failure(block, err, lineno=2)
 
     def test_param_default_invalid_syntax(self):
         block = """
@@ -1290,7 +1286,7 @@ class ClinicParserTest(TestCase):
             os.stat
                 x: int = invalid syntax
         """
-        err = r"Syntax error: 'x = invalid syntax\n'"
+        err = "Function 'stat' has an invalid parameter declaration:"
         self.expect_failure(block, err, lineno=2)
 
     def test_cloning_nonexistent_function_correctly_fails(self):
@@ -2510,7 +2506,7 @@ class ClinicParserTest(TestCase):
         self.expect_failure(block, err, lineno=1)
 
     def test_vararg_cannot_take_default_value(self):
-        err = "Vararg can't take a default value!"
+        err = "Function 'fn' has an invalid parameter declaration:"
         block = """
             fn
                 *args: tuple = None
index 282ff64cd33089c568afef64e7ad42e35429c6f2..eca41531f7c8e9b00cf57573d64cfabfdc3320f7 100644 (file)
@@ -877,43 +877,16 @@ class DSLParser:
 
         # handle "as" for  parameters too
         c_name = None
-        name, have_as_token, trailing = line.partition(' as ')
-        if have_as_token:
-            name = name.strip()
-            if ' ' not in name:
-                fields = trailing.strip().split(' ')
-                if not fields:
-                    fail("Invalid 'as' clause!")
-                c_name = fields[0]
-                if c_name.endswith(':'):
-                    name += ':'
-                    c_name = c_name[:-1]
-                fields[0] = name
-                line = ' '.join(fields)
-
-        default: str | None
-        base, equals, default = line.rpartition('=')
-        if not equals:
-            base = default
-            default = None
-
-        module = None
+        m = re.match(r'(?:\* *)?\w+( +as +(\w+))', line)
+        if m:
+            c_name = m[2]
+            line = line[:m.start(1)] + line[m.end(1):]
+
         try:
-            ast_input = f"def x({base}): pass"
+            ast_input = f"def x({line}\n): pass"
             module = ast.parse(ast_input)
         except SyntaxError:
-            try:
-                # the last = was probably inside a function call, like
-                #   c: int(accept={str})
-                # so assume there was no actual default value.
-                default = None
-                ast_input = f"def x({line}): pass"
-                module = ast.parse(ast_input)
-            except SyntaxError:
-                pass
-        if not module:
-            fail(f"Function {self.function.name!r} has an invalid parameter declaration:\n\t",
-                 repr(line))
+            fail(f"Function {self.function.name!r} has an invalid parameter declaration: {line!r}")
 
         function = module.body[0]
         assert isinstance(function, ast.FunctionDef)
@@ -922,9 +895,6 @@ class DSLParser:
         if len(function_args.args) > 1:
             fail(f"Function {self.function.name!r} has an "
                  f"invalid parameter declaration (comma?): {line!r}")
-        if function_args.defaults or function_args.kw_defaults:
-            fail(f"Function {self.function.name!r} has an "
-                 f"invalid parameter declaration (default value?): {line!r}")
         if function_args.kwarg:
             fail(f"Function {self.function.name!r} has an "
                  f"invalid parameter declaration (**kwargs?): {line!r}")
@@ -944,7 +914,7 @@ class DSLParser:
             name = 'varpos_' + name
 
         value: object
-        if not default:
+        if not function_args.defaults:
             if is_vararg:
                 value = NULL
             else:
@@ -955,17 +925,13 @@ class DSLParser:
             if 'py_default' in kwargs:
                 fail("You can't specify py_default without specifying a default value!")
         else:
-            if is_vararg:
-                fail("Vararg can't take a default value!")
+            expr = function_args.defaults[0]
+            default = ast_input[expr.col_offset: expr.end_col_offset].strip()
 
             if self.parameter_state is ParamState.REQUIRED:
                 self.parameter_state = ParamState.OPTIONAL
-            default = default.strip()
             bad = False
-            ast_input = f"x = {default}"
             try:
-                module = ast.parse(ast_input)
-
                 if 'c_default' not in kwargs:
                     # we can only represent very simple data values in C.
                     # detect whether default is okay, via a denylist
@@ -992,13 +958,14 @@ class DSLParser:
                         visit_Starred = bad_node
 
                     denylist = DetectBadNodes()
-                    denylist.visit(module)
+                    denylist.visit(expr)
                     bad = denylist.bad
                 else:
                     # if they specify a c_default, we can be more lenient about the default value.
                     # but at least make an attempt at ensuring it's a valid expression.
+                    code = compile(ast.Expression(expr), '<expr>', 'eval')
                     try:
-                        value = eval(default)
+                        value = eval(code)
                     except NameError:
                         pass # probably a named constant
                     except Exception as e:
@@ -1010,9 +977,6 @@ class DSLParser:
                 if bad:
                     fail(f"Unsupported expression as default value: {default!r}")
 
-                assignment = module.body[0]
-                assert isinstance(assignment, ast.Assign)
-                expr = assignment.value
                 # mild hack: explicitly support NULL as a default value
                 c_default: str | None
                 if isinstance(expr, ast.Name) and expr.id == 'NULL':
@@ -1064,8 +1028,6 @@ class DSLParser:
                     else:
                         c_default = py_default
 
-            except SyntaxError as e:
-                fail(f"Syntax error: {e.text!r}")
             except (ValueError, AttributeError):
                 value = unknown
                 c_default = kwargs.get("c_default")