]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-124503: Optimize ast.literal_eval() for small input (GH-137010)
authorKrzysztof Magusiak <chris.magusiak@gmail.com>
Thu, 31 Jul 2025 09:55:00 +0000 (11:55 +0200)
committerGitHub <noreply@github.com>
Thu, 31 Jul 2025 09:55:00 +0000 (12:55 +0300)
The implementation does not create anymore local functions which reduces
the overhead for small inputs. Some other calls are inlined into a
single `_convert_literal` function.
We have a gain of 10-20% for small inputs and only 1-2% for bigger
inputs.

Lib/ast.py
Misc/NEWS.d/next/Library/2025-07-30-11-12-22.gh-issue-124503.d4hc7b.rst [new file with mode: 0644]

index 6d3daf64f5c6d758f801334d39a85a903ad08a10..983ac1710d0205b3bb96f288f30d1d1d7396ada3 100644 (file)
@@ -57,53 +57,60 @@ def literal_eval(node_or_string):
     Caution: A complex expression can overflow the C stack and cause a crash.
     """
     if isinstance(node_or_string, str):
-        node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
-    if isinstance(node_or_string, Expression):
+        node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval').body
+    elif isinstance(node_or_string, Expression):
         node_or_string = node_or_string.body
-    def _raise_malformed_node(node):
-        msg = "malformed node or string"
-        if lno := getattr(node, 'lineno', None):
-            msg += f' on line {lno}'
-        raise ValueError(msg + f': {node!r}')
-    def _convert_num(node):
-        if not isinstance(node, Constant) or type(node.value) not in (int, float, complex):
-            _raise_malformed_node(node)
+    return _convert_literal(node_or_string)
+
+
+def _convert_literal(node):
+    """
+    Used by `literal_eval` to convert an AST node into a value.
+    """
+    if isinstance(node, Constant):
         return node.value
-    def _convert_signed_num(node):
-        if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
-            operand = _convert_num(node.operand)
-            if isinstance(node.op, UAdd):
-                return + operand
-            else:
-                return - operand
-        return _convert_num(node)
-    def _convert(node):
-        if isinstance(node, Constant):
-            return node.value
-        elif isinstance(node, Tuple):
-            return tuple(map(_convert, node.elts))
-        elif isinstance(node, List):
-            return list(map(_convert, node.elts))
-        elif isinstance(node, Set):
-            return set(map(_convert, node.elts))
-        elif (isinstance(node, Call) and isinstance(node.func, Name) and
-              node.func.id == 'set' and node.args == node.keywords == []):
-            return set()
-        elif isinstance(node, Dict):
-            if len(node.keys) != len(node.values):
-                _raise_malformed_node(node)
-            return dict(zip(map(_convert, node.keys),
-                            map(_convert, node.values)))
-        elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
-            left = _convert_signed_num(node.left)
-            right = _convert_num(node.right)
-            if isinstance(left, (int, float)) and isinstance(right, complex):
-                if isinstance(node.op, Add):
-                    return left + right
-                else:
-                    return left - right
-        return _convert_signed_num(node)
-    return _convert(node_or_string)
+    if isinstance(node, Dict) and len(node.keys) == len(node.values):
+        return dict(zip(
+            map(_convert_literal, node.keys),
+            map(_convert_literal, node.values),
+        ))
+    if isinstance(node, Tuple):
+        return tuple(map(_convert_literal, node.elts))
+    if isinstance(node, List):
+        return list(map(_convert_literal, node.elts))
+    if isinstance(node, Set):
+        return set(map(_convert_literal, node.elts))
+    if (
+        isinstance(node, Call) and isinstance(node.func, Name)
+        and node.func.id == 'set' and node.args == node.keywords == []
+    ):
+        return set()
+    if (
+        isinstance(node, UnaryOp)
+        and isinstance(node.op, (UAdd, USub))
+        and isinstance(node.operand, Constant)
+        and type(operand := node.operand.value) in (int, float, complex)
+    ):
+        if isinstance(node.op, UAdd):
+            return + operand
+        else:
+            return - operand
+    if (
+        isinstance(node, BinOp)
+        and isinstance(node.op, (Add, Sub))
+        and isinstance(node.left, (Constant, UnaryOp))
+        and isinstance(node.right, Constant)
+        and type(left := _convert_literal(node.left)) in (int, float)
+        and type(right := _convert_literal(node.right)) is complex
+    ):
+        if isinstance(node.op, Add):
+            return left + right
+        else:
+            return left - right
+    msg = "malformed node or string"
+    if lno := getattr(node, 'lineno', None):
+        msg += f' on line {lno}'
+    raise ValueError(msg + f': {node!r}')
 
 
 def dump(
diff --git a/Misc/NEWS.d/next/Library/2025-07-30-11-12-22.gh-issue-124503.d4hc7b.rst b/Misc/NEWS.d/next/Library/2025-07-30-11-12-22.gh-issue-124503.d4hc7b.rst
new file mode 100644 (file)
index 0000000..c04eba9
--- /dev/null
@@ -0,0 +1 @@
+:func:`ast.literal_eval` is 10-20% faster for small inputs.