]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-128682: Spill the stack pointer in labels, as well as instructions (GH-129618)
authorMark Shannon <mark@hotpy.org>
Tue, 4 Feb 2025 12:18:31 +0000 (12:18 +0000)
committerGitHub <noreply@github.com>
Tue, 4 Feb 2025 12:18:31 +0000 (12:18 +0000)
15 files changed:
Include/internal/pycore_optimizer.h
Lib/test/test_generated_cases.py
Python/bytecodes.c
Python/executor_cases.c.h
Python/generated_cases.c.h
Python/optimizer.c
Tools/cases_generator/analyzer.py
Tools/cases_generator/generators_common.py
Tools/cases_generator/lexer.py
Tools/cases_generator/optimizer_generator.py
Tools/cases_generator/parser.py
Tools/cases_generator/parsing.py
Tools/cases_generator/stack.py
Tools/cases_generator/tier1_generator.py
Tools/cases_generator/tier2_generator.py

index e806e306d2d57f20e477a4b1bcf513f6dae01413..00fc4338b0a412ade9e47884db30183cdc1701fc 100644 (file)
@@ -282,7 +282,7 @@ extern int _Py_uop_frame_pop(JitOptContext *ctx);
 
 PyAPI_FUNC(PyObject *) _Py_uop_symbols_test(PyObject *self, PyObject *ignored);
 
-PyAPI_FUNC(int) _PyOptimizer_Optimize(struct _PyInterpreterFrame *frame, _Py_CODEUNIT *start, _PyStackRef *stack_pointer, _PyExecutorObject **exec_ptr, int chain_depth);
+PyAPI_FUNC(int) _PyOptimizer_Optimize(struct _PyInterpreterFrame *frame, _Py_CODEUNIT *start, _PyExecutorObject **exec_ptr, int chain_depth);
 
 static inline int is_terminator(const _PyUOpInstruction *uop)
 {
index 35600ce54866422980037c0363fb60118592c51d..d2b33706ea6b75a79266e63d33c1548439a8aef7 100644 (file)
@@ -286,7 +286,7 @@ class TestGeneratedCases(unittest.TestCase):
             instructions, labels_with_prelude_and_postlude = rest.split(tier1_generator.INSTRUCTION_END_MARKER)
             _, labels_with_postlude = labels_with_prelude_and_postlude.split(tier1_generator.LABEL_START_MARKER)
             labels, _ = labels_with_postlude.split(tier1_generator.LABEL_END_MARKER)
-            actual = instructions + labels
+            actual = instructions.strip() + "\n\n        " + labels.strip()
         # if actual.strip() != expected.strip():
         #     print("Actual:")
         #     print(actual)
@@ -652,6 +652,9 @@ class TestGeneratedCases(unittest.TestCase):
 
     def test_suppress_dispatch(self):
         input = """
+        label(somewhere) {
+        }
+
         inst(OP, (--)) {
             goto somewhere;
         }
@@ -663,6 +666,11 @@ class TestGeneratedCases(unittest.TestCase):
             INSTRUCTION_STATS(OP);
             goto somewhere;
         }
+
+        somewhere:
+        {
+
+        }
     """
         self.run_cases_test(input, output)
 
@@ -1768,9 +1776,15 @@ class TestGeneratedCases(unittest.TestCase):
 
     def test_complex_label(self):
         input = """
+        label(other_label) {
+        }
+
+        label(other_label2) {
+        }
+
         label(my_label) {
             // Comment
-            do_thing()
+            do_thing();
             if (complex) {
                 goto other_label;
             }
@@ -1779,10 +1793,22 @@ class TestGeneratedCases(unittest.TestCase):
         """
 
         output = """
+        other_label:
+        {
+
+        }
+
+        other_label2:
+        {
+
+        }
+
         my_label:
         {
             // Comment
-            do_thing()
+            _PyFrame_SetStackPointer(frame, stack_pointer);
+            do_thing();
+            stack_pointer = _PyFrame_GetStackPointer(frame);
             if (complex) {
                 goto other_label;
             }
@@ -1791,6 +1817,60 @@ class TestGeneratedCases(unittest.TestCase):
         """
         self.run_cases_test(input, output)
 
+    def test_spilled_label(self):
+        input = """
+        spilled label(one) {
+            RELOAD_STACK();
+            goto two;
+        }
+
+        label(two) {
+            SAVE_STACK();
+            goto one;
+        }
+        """
+
+        output = """
+        one:
+        {
+            /* STACK SPILLED */
+            stack_pointer = _PyFrame_GetStackPointer(frame);
+            goto two;
+        }
+
+        two:
+        {
+            _PyFrame_SetStackPointer(frame, stack_pointer);
+            goto one;
+        }
+        """
+        self.run_cases_test(input, output)
+
+
+    def test_incorrect_spills(self):
+        input1 = """
+        spilled label(one) {
+            goto two;
+        }
+
+        label(two) {
+        }
+        """
+
+        input2 = """
+        spilled label(one) {
+        }
+
+        label(two) {
+            goto one;
+        }
+        """
+        with self.assertRaisesRegex(SyntaxError, ".*reload.*"):
+            self.run_cases_test(input1, "")
+        with self.assertRaisesRegex(SyntaxError, ".*spill.*"):
+            self.run_cases_test(input2, "")
+
+
     def test_multiple_labels(self):
         input = """
         label(my_label_1) {
@@ -1802,7 +1882,7 @@ class TestGeneratedCases(unittest.TestCase):
         label(my_label_2) {
             // Comment
             do_thing2();
-            goto my_label_3;
+            goto my_label_1;
         }
         """
 
@@ -1818,7 +1898,7 @@ class TestGeneratedCases(unittest.TestCase):
         {
             // Comment
             do_thing2();
-            goto my_label_3;
+            goto my_label_1;
         }
         """
 
index e679d90620ea9e27b7b5e4e0a9ca4cb54dc913bb..cb88ba74f9a5fe7b5d5b68d96b10eb31fd7fece2 100644 (file)
@@ -2808,7 +2808,7 @@ dummy_func(
                     start--;
                 }
                 _PyExecutorObject *executor;
-                int optimized = _PyOptimizer_Optimize(frame, start, stack_pointer, &executor, 0);
+                int optimized = _PyOptimizer_Optimize(frame, start, &executor, 0);
                 if (optimized <= 0) {
                     this_instr[1].counter = restart_backoff_counter(counter);
                     ERROR_IF(optimized < 0, error);
@@ -5033,7 +5033,7 @@ dummy_func(
                 }
                 else {
                     int chain_depth = current_executor->vm_data.chain_depth + 1;
-                    int optimized = _PyOptimizer_Optimize(frame, target, stack_pointer, &executor, chain_depth);
+                    int optimized = _PyOptimizer_Optimize(frame, target, &executor, chain_depth);
                     if (optimized <= 0) {
                         exit->temperature = restart_backoff_counter(temperature);
                         if (optimized < 0) {
@@ -5134,7 +5134,7 @@ dummy_func(
                     exit->temperature = advance_backoff_counter(exit->temperature);
                     GOTO_TIER_ONE(target);
                 }
-                int optimized = _PyOptimizer_Optimize(frame, target, stack_pointer, &executor, 0);
+                int optimized = _PyOptimizer_Optimize(frame, target, &executor, 0);
                 if (optimized <= 0) {
                     exit->temperature = restart_backoff_counter(exit->temperature);
                     if (optimized < 0) {
@@ -5242,29 +5242,29 @@ dummy_func(
             goto exception_unwind;
         }
 
-        label(exception_unwind) {
+        spilled label(exception_unwind) {
             /* We can't use frame->instr_ptr here, as RERAISE may have set it */
             int offset = INSTR_OFFSET()-1;
             int level, handler, lasti;
-            if (get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti) == 0) {
+            int handled = get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti);
+            if (handled == 0) {
                 // No handlers, so exit.
                 assert(_PyErr_Occurred(tstate));
-
                 /* Pop remaining stack entries. */
                 _PyStackRef *stackbase = _PyFrame_Stackbase(frame);
-                while (stack_pointer > stackbase) {
-                    PyStackRef_XCLOSE(POP());
+                while (frame->stackpointer > stackbase) {
+                    _PyStackRef ref = _PyFrame_StackPop(frame);
+                    PyStackRef_XCLOSE(ref);
                 }
-                assert(STACK_LEVEL() == 0);
-                _PyFrame_SetStackPointer(frame, stack_pointer);
                 monitor_unwind(tstate, frame, next_instr-1);
                 goto exit_unwind;
             }
-
             assert(STACK_LEVEL() >= level);
             _PyStackRef *new_top = _PyFrame_Stackbase(frame) + level;
-            while (stack_pointer > new_top) {
-                PyStackRef_XCLOSE(POP());
+            assert(frame->stackpointer >= new_top);
+            while (frame->stackpointer > new_top) {
+                _PyStackRef ref = _PyFrame_StackPop(frame);
+                PyStackRef_XCLOSE(ref);
             }
             if (lasti) {
                 int frame_lasti = _PyInterpreterFrame_LASTI(frame);
@@ -5272,7 +5272,7 @@ dummy_func(
                 if (lasti == NULL) {
                     goto exception_unwind;
                 }
-                PUSH(PyStackRef_FromPyObjectSteal(lasti));
+                _PyFrame_StackPush(frame, PyStackRef_FromPyObjectSteal(lasti));
             }
 
             /* Make the raw exception data
@@ -5280,10 +5280,11 @@ dummy_func(
                 so a program can emulate the
                 Python main loop. */
             PyObject *exc = _PyErr_GetRaisedException(tstate);
-            PUSH(PyStackRef_FromPyObjectSteal(exc));
+            _PyFrame_StackPush(frame, PyStackRef_FromPyObjectSteal(exc));
             next_instr = _PyFrame_GetBytecode(frame) + handler;
 
-            if (monitor_handled(tstate, frame, next_instr, exc) < 0) {
+            int err = monitor_handled(tstate, frame, next_instr, exc);
+            if (err < 0) {
                 goto exception_unwind;
             }
             /* Resume normal execution */
@@ -5292,10 +5293,11 @@ dummy_func(
                 lltrace_resume_frame(frame);
             }
 #endif
+            RELOAD_STACK();
             DISPATCH();
         }
 
-        label(exit_unwind) {
+        spilled label(exit_unwind) {
             assert(_PyErr_Occurred(tstate));
             _Py_LeaveRecursiveCallPy(tstate);
             assert(frame->owner != FRAME_OWNED_BY_INTERPRETER);
@@ -5311,16 +5313,16 @@ dummy_func(
                 return NULL;
             }
             next_instr = frame->instr_ptr;
-            stack_pointer = _PyFrame_GetStackPointer(frame);
+            RELOAD_STACK();
             goto error;
         }
 
-        label(start_frame) {
-            if (_Py_EnterRecursivePy(tstate)) {
+        spilled label(start_frame) {
+            int too_deep = _Py_EnterRecursivePy(tstate);
+            if (too_deep) {
                 goto exit_unwind;
             }
             next_instr = frame->instr_ptr;
-            stack_pointer = _PyFrame_GetStackPointer(frame);
 
         #ifdef LLTRACE
             {
@@ -5339,6 +5341,7 @@ dummy_func(
             assert(!_PyErr_Occurred(tstate));
         #endif
 
+            RELOAD_STACK();
             DISPATCH();
         }
 
index 59f1a1ba4dc92a57f7bfd25c8e43541602630f10..5b19ec182b5805ec848ac25e5f3a31ece768d7c6 100644 (file)
                 else {
                     int chain_depth = current_executor->vm_data.chain_depth + 1;
                     _PyFrame_SetStackPointer(frame, stack_pointer);
-                    int optimized = _PyOptimizer_Optimize(frame, target, stack_pointer, &executor, chain_depth);
+                    int optimized = _PyOptimizer_Optimize(frame, target, &executor, chain_depth);
                     stack_pointer = _PyFrame_GetStackPointer(frame);
                     if (optimized <= 0) {
                         exit->temperature = restart_backoff_counter(temperature);
                     GOTO_TIER_ONE(target);
                 }
                 _PyFrame_SetStackPointer(frame, stack_pointer);
-                int optimized = _PyOptimizer_Optimize(frame, target, stack_pointer, &executor, 0);
+                int optimized = _PyOptimizer_Optimize(frame, target, &executor, 0);
                 stack_pointer = _PyFrame_GetStackPointer(frame);
                 if (optimized <= 0) {
                     exit->temperature = restart_backoff_counter(exit->temperature);
index 7dd9d6528bb49dcef80d94354fb5c85b336d6483..0bc92f30bfded2c529d14cc5c7d315e0fad84125 100644 (file)
                 _PyErr_SetRaisedException(tstate, Py_NewRef(exc_value));
                 monitor_reraise(tstate, frame, this_instr);
                 stack_pointer = _PyFrame_GetStackPointer(frame);
+                _PyFrame_SetStackPointer(frame, stack_pointer);
                 goto exception_unwind;
             }
             stack_pointer[-3] = none;
                 _PyErr_SetRaisedException(tstate, exc);
                 monitor_reraise(tstate, frame, this_instr);
                 stack_pointer = _PyFrame_GetStackPointer(frame);
+                _PyFrame_SetStackPointer(frame, stack_pointer);
                 goto exception_unwind;
             }
             stack_pointer += -2;
                     }
                     _PyExecutorObject *executor;
                     _PyFrame_SetStackPointer(frame, stack_pointer);
-                    int optimized = _PyOptimizer_Optimize(frame, start, stack_pointer, &executor, 0);
+                    int optimized = _PyOptimizer_Optimize(frame, start, &executor, 0);
                     stack_pointer = _PyFrame_GetStackPointer(frame);
                     if (optimized <= 0) {
                         this_instr[1].counter = restart_backoff_counter(counter);
                 _PyFrame_SetStackPointer(frame, stack_pointer);
                 monitor_reraise(tstate, frame, this_instr);
                 stack_pointer = _PyFrame_GetStackPointer(frame);
+                _PyFrame_SetStackPointer(frame, stack_pointer);
                 goto exception_unwind;
             }
             goto error;
             _PyErr_SetRaisedException(tstate, exc);
             monitor_reraise(tstate, frame, this_instr);
             stack_pointer = _PyFrame_GetStackPointer(frame);
+            _PyFrame_SetStackPointer(frame, stack_pointer);
             goto exception_unwind;
         }
 
             /* Double-check exception status. */
             #ifdef NDEBUG
             if (!_PyErr_Occurred(tstate)) {
+                _PyFrame_SetStackPointer(frame, stack_pointer);
                 _PyErr_SetString(tstate, PyExc_SystemError,
                              "error return without exception set");
+                stack_pointer = _PyFrame_GetStackPointer(frame);
             }
             #else
             assert(_PyErr_Occurred(tstate));
             /* Log traceback info. */
             assert(frame->owner != FRAME_OWNED_BY_INTERPRETER);
             if (!_PyFrame_IsIncomplete(frame)) {
+                _PyFrame_SetStackPointer(frame, stack_pointer);
                 PyFrameObject *f = _PyFrame_GetFrameObject(frame);
+                stack_pointer = _PyFrame_GetStackPointer(frame);
                 if (f != NULL) {
+                    _PyFrame_SetStackPointer(frame, stack_pointer);
                     PyTraceBack_Here(f);
+                    stack_pointer = _PyFrame_GetStackPointer(frame);
                 }
             }
+            _PyFrame_SetStackPointer(frame, stack_pointer);
             _PyEval_MonitorRaise(tstate, frame, next_instr-1);
+            stack_pointer = _PyFrame_GetStackPointer(frame);
+            _PyFrame_SetStackPointer(frame, stack_pointer);
             goto exception_unwind;
         }
 
         exception_unwind:
         {
+            /* STACK SPILLED */
             /* We can't use frame->instr_ptr here, as RERAISE may have set it */
             int offset = INSTR_OFFSET()-1;
             int level, handler, lasti;
-            if (get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti) == 0) {
+            int handled = get_exception_handler(_PyFrame_GetCode(frame), offset, &level, &handler, &lasti);
+            if (handled == 0) {
                 // No handlers, so exit.
                 assert(_PyErr_Occurred(tstate));
                 /* Pop remaining stack entries. */
                 _PyStackRef *stackbase = _PyFrame_Stackbase(frame);
-                while (stack_pointer > stackbase) {
-                    PyStackRef_XCLOSE(POP());
+                while (frame->stackpointer > stackbase) {
+                    _PyStackRef ref = _PyFrame_StackPop(frame);
+                    PyStackRef_XCLOSE(ref);
                 }
-                assert(STACK_LEVEL() == 0);
-                _PyFrame_SetStackPointer(frame, stack_pointer);
                 monitor_unwind(tstate, frame, next_instr-1);
                 goto exit_unwind;
             }
             assert(STACK_LEVEL() >= level);
             _PyStackRef *new_top = _PyFrame_Stackbase(frame) + level;
-            while (stack_pointer > new_top) {
-                PyStackRef_XCLOSE(POP());
+            assert(frame->stackpointer >= new_top);
+            while (frame->stackpointer > new_top) {
+                _PyStackRef ref = _PyFrame_StackPop(frame);
+                PyStackRef_XCLOSE(ref);
             }
             if (lasti) {
                 int frame_lasti = _PyInterpreterFrame_LASTI(frame);
                 if (lasti == NULL) {
                     goto exception_unwind;
                 }
-                PUSH(PyStackRef_FromPyObjectSteal(lasti));
+                _PyFrame_StackPush(frame, PyStackRef_FromPyObjectSteal(lasti));
             }
             /* Make the raw exception data
                available to the handler,
                so a program can emulate the
                Python main loop. */
             PyObject *exc = _PyErr_GetRaisedException(tstate);
-            PUSH(PyStackRef_FromPyObjectSteal(exc));
+            _PyFrame_StackPush(frame, PyStackRef_FromPyObjectSteal(exc));
             next_instr = _PyFrame_GetBytecode(frame) + handler;
-            if (monitor_handled(tstate, frame, next_instr, exc) < 0) {
+            int err = monitor_handled(tstate, frame, next_instr, exc);
+            if (err < 0) {
                 goto exception_unwind;
             }
             /* Resume normal execution */
                 lltrace_resume_frame(frame);
             }
             #endif
+            stack_pointer = _PyFrame_GetStackPointer(frame);
             DISPATCH();
         }
 
         exit_unwind:
         {
+            /* STACK SPILLED */
             assert(_PyErr_Occurred(tstate));
             _Py_LeaveRecursiveCallPy(tstate);
             assert(frame->owner != FRAME_OWNED_BY_INTERPRETER);
 
         start_frame:
         {
-            if (_Py_EnterRecursivePy(tstate)) {
+            /* STACK SPILLED */
+            int too_deep = _Py_EnterRecursivePy(tstate);
+            if (too_deep) {
                 goto exit_unwind;
             }
             next_instr = frame->instr_ptr;
-            stack_pointer = _PyFrame_GetStackPointer(frame);
             #ifdef LLTRACE
             {
                 int lltrace = maybe_lltrace_resume_frame(frame, GLOBALS());
                caller loses its exception */
             assert(!_PyErr_Occurred(tstate));
             #endif
-
+            stack_pointer = _PyFrame_GetStackPointer(frame);
             DISPATCH();
         }
 
index d71abd3224240b05c01e55721ee337547c12f187..97831f58098c95c80ee56ef9047878d8ce9cc47f 100644 (file)
@@ -105,8 +105,9 @@ uop_optimize(_PyInterpreterFrame *frame, _Py_CODEUNIT *instr,
 int
 _PyOptimizer_Optimize(
     _PyInterpreterFrame *frame, _Py_CODEUNIT *start,
-    _PyStackRef *stack_pointer, _PyExecutorObject **executor_ptr, int chain_depth)
+    _PyExecutorObject **executor_ptr, int chain_depth)
 {
+    _PyStackRef *stack_pointer = frame->stackpointer;
     assert(_PyInterpreterState_GET()->jit);
     // The first executor in a chain and the MAX_CHAIN_DEPTH'th executor *must*
     // make progress in order to avoid infinite loops or excessively-long
index eda8d687a70ccd6d3f5c59fda0c50aa85bfa59ad..724fba5f953a4e1427bcfbbef934dc33b183b9e8 100644 (file)
@@ -130,6 +130,8 @@ class Flush:
         return 0
 
 
+
+
 @dataclass
 class StackItem:
     name: str
@@ -228,7 +230,24 @@ class Uop:
         return False
 
 
+class Label:
+
+    def __init__(self, name: str, spilled: bool, body: list[lexer.Token], properties: Properties):
+        self.name = name
+        self.spilled = spilled
+        self.body = body
+        self.properties = properties
+
+    size:int = 0
+    output_stores: list[lexer.Token] = []
+    instruction_size = None
+
+    def __str__(self) -> str:
+        return f"label({self.name})"
+
+
 Part = Uop | Skip | Flush
+CodeSection = Uop | Label
 
 
 @dataclass
@@ -268,12 +287,6 @@ class Instruction:
             return False
 
 
-@dataclass
-class Label:
-    name: str
-    body: list[lexer.Token]
-
-
 @dataclass
 class PseudoInstruction:
     name: str
@@ -503,22 +516,24 @@ def analyze_deferred_refs(node: parser.InstDef) -> dict[lexer.Token, str | None]
     return refs
 
 
-def variable_used(node: parser.InstDef, name: str) -> bool:
+def variable_used(node: parser.CodeDef, name: str) -> bool:
     """Determine whether a variable with a given name is used in a node."""
     return any(
         token.kind == "IDENTIFIER" and token.text == name for token in node.block.tokens
     )
 
 
-def oparg_used(node: parser.InstDef) -> bool:
+def oparg_used(node: parser.CodeDef) -> bool:
     """Determine whether `oparg` is used in a node."""
     return any(
         token.kind == "IDENTIFIER" and token.text == "oparg" for token in node.tokens
     )
 
 
-def tier_variable(node: parser.InstDef) -> int | None:
+def tier_variable(node: parser.CodeDef) -> int | None:
     """Determine whether a tier variable is used in a node."""
+    if isinstance(node, parser.LabelDef):
+        return None
     for token in node.tokens:
         if token.kind == "ANNOTATION":
             if token.text == "specializing":
@@ -528,7 +543,7 @@ def tier_variable(node: parser.InstDef) -> int | None:
     return None
 
 
-def has_error_with_pop(op: parser.InstDef) -> bool:
+def has_error_with_pop(op: parser.CodeDef) -> bool:
     return (
         variable_used(op, "ERROR_IF")
         or variable_used(op, "pop_1_error")
@@ -536,7 +551,7 @@ def has_error_with_pop(op: parser.InstDef) -> bool:
     )
 
 
-def has_error_without_pop(op: parser.InstDef) -> bool:
+def has_error_without_pop(op: parser.CodeDef) -> bool:
     return (
         variable_used(op, "ERROR_NO_POP")
         or variable_used(op, "pop_1_error")
@@ -665,7 +680,7 @@ NON_ESCAPING_FUNCTIONS = (
     "restart_backoff_counter",
 )
 
-def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
+def find_stmt_start(node: parser.CodeDef, idx: int) -> lexer.Token:
     assert idx < len(node.block.tokens)
     while True:
         tkn = node.block.tokens[idx-1]
@@ -678,7 +693,7 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
     return node.block.tokens[idx]
 
 
-def find_stmt_end(node: parser.InstDef, idx: int) -> lexer.Token:
+def find_stmt_end(node: parser.CodeDef, idx: int) -> lexer.Token:
     assert idx < len(node.block.tokens)
     while True:
         idx += 1
@@ -686,7 +701,7 @@ def find_stmt_end(node: parser.InstDef, idx: int) -> lexer.Token:
         if tkn.kind == "SEMI":
             return node.block.tokens[idx+1]
 
-def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, EscapingCall]) -> None:
+def check_escaping_calls(instr: parser.CodeDef, escapes: dict[lexer.Token, EscapingCall]) -> None:
     calls = {e.call for e in escapes.values()}
     in_if = 0
     tkn_iter = iter(instr.block.tokens)
@@ -705,7 +720,7 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, Escap
         elif tkn in calls and in_if:
             raise analysis_error(f"Escaping call '{tkn.text} in condition", tkn)
 
-def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, EscapingCall]:
+def find_escaping_api_calls(instr: parser.CodeDef) -> dict[lexer.Token, EscapingCall]:
     result: dict[lexer.Token, EscapingCall] = {}
     tokens = instr.block.tokens
     for idx, tkn in enumerate(tokens):
@@ -764,7 +779,7 @@ EXITS = {
 }
 
 
-def always_exits(op: parser.InstDef) -> bool:
+def always_exits(op: parser.CodeDef) -> bool:
     depth = 0
     tkn_iter = iter(op.tokens)
     for tkn in tkn_iter:
@@ -823,7 +838,7 @@ def effect_depends_on_oparg_1(op: parser.InstDef) -> bool:
     return False
 
 
-def compute_properties(op: parser.InstDef) -> Properties:
+def compute_properties(op: parser.CodeDef) -> Properties:
     escaping_calls = find_escaping_api_calls(op)
     has_free = (
         variable_used(op, "PyCell_New")
@@ -851,6 +866,8 @@ def compute_properties(op: parser.InstDef) -> Properties:
         variable_used(op, "Py_CLEAR") or
         variable_used(op, "SETLOCAL")
     )
+    pure = False if isinstance(op, parser.LabelDef) else "pure" in op.annotations
+    no_save_ip = False if isinstance(op, parser.LabelDef) else "no_save_ip" in op.annotations
     return Properties(
         escaping_calls=escaping_calls,
         escapes=escapes,
@@ -870,8 +887,8 @@ def compute_properties(op: parser.InstDef) -> Properties:
             and not has_free,
         uses_opcode=variable_used(op, "opcode"),
         has_free=has_free,
-        pure="pure" in op.annotations,
-        no_save_ip="no_save_ip" in op.annotations,
+        pure=pure,
+        no_save_ip=no_save_ip,
         tier=tier_variable(op),
         needs_prev=variable_used(op, "prev_instr"),
     )
@@ -1050,7 +1067,8 @@ def add_label(
     label: parser.LabelDef,
     labels: dict[str, Label],
 ) -> None:
-    labels[label.name] = Label(label.name, label.block.tokens)
+    properties = compute_properties(label)
+    labels[label.name] = Label(label.name, label.spilled, label.block.tokens, properties)
 
 
 def assign_opcodes(
index 6f2af5fc01c47b4bd1da4e657f5346cfadaf1715..1c572ec0512b37049ace3041e888645d5c42748d 100644 (file)
@@ -7,6 +7,8 @@ from analyzer import (
     Properties,
     StackItem,
     analysis_error,
+    Label,
+    CodeSection,
 )
 from cwriter import CWriter
 from typing import Callable, TextIO, Iterator, Iterable
@@ -90,7 +92,7 @@ def emit_to(out: CWriter, tkn_iter: TokenIterator, end: str) -> Token:
 
 
 ReplacementFunctionType = Callable[
-    [Token, TokenIterator, Uop, Storage, Instruction | None], bool
+    [Token, TokenIterator, CodeSection, Storage, Instruction | None], bool
 ]
 
 def always_true(tkn: Token | None) -> bool:
@@ -106,9 +108,10 @@ NON_ESCAPING_DEALLOCS = {
 
 class Emitter:
     out: CWriter
+    labels: dict[str, Label]
     _replacers: dict[str, ReplacementFunctionType]
 
-    def __init__(self, out: CWriter):
+    def __init__(self, out: CWriter, labels: dict[str, Label]):
         self._replacers = {
             "EXIT_IF": self.exit_if,
             "DEOPT_IF": self.deopt_if,
@@ -124,18 +127,22 @@ class Emitter:
             "PyStackRef_AsPyObjectSteal": self.stackref_steal,
             "DISPATCH": self.dispatch,
             "INSTRUCTION_SIZE": self.instruction_size,
-            "POP_INPUT": self.pop_input
+            "POP_INPUT": self.pop_input,
+            "stack_pointer": self.stack_pointer,
         }
         self.out = out
+        self.labels = labels
 
     def dispatch(
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
+        if storage.spilled:
+            raise analysis_error("stack_pointer needs reloading before dispatch", tkn)
         self.emit(tkn)
         return False
 
@@ -143,7 +150,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -174,7 +181,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -213,7 +220,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -227,7 +234,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -263,7 +270,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -278,7 +285,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -318,7 +325,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -348,7 +355,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -368,7 +375,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -380,6 +387,32 @@ class Emitter:
         self._print_storage(storage)
         return True
 
+    def stack_pointer(
+        self,
+        tkn: Token,
+        tkn_iter: TokenIterator,
+        uop: CodeSection,
+        storage: Storage,
+        inst: Instruction | None,
+    ) -> bool:
+        if storage.spilled:
+            raise analysis_error("stack_pointer is invalid when stack is spilled to memory", tkn)
+        self.emit(tkn)
+        return True
+
+    def goto_label(self, goto: Token, label: Token, storage: Storage) -> None:
+        if label.text not in self.labels:
+            print(self.labels.keys())
+            raise analysis_error(f"Label '{label.text}' does not exist", label)
+        label_node = self.labels[label.text]
+        if label_node.spilled:
+            if not storage.spilled:
+                self.emit_save(storage)
+        elif storage.spilled:
+            raise analysis_error("Cannot jump from spilled label without reloading the stack pointer", goto)
+        self.out.emit(goto)
+        self.out.emit(label)
+
     def emit_save(self, storage: Storage) -> None:
         storage.save(self.out)
         self._print_storage(storage)
@@ -388,7 +421,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -402,7 +435,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -429,7 +462,7 @@ class Emitter:
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -442,7 +475,7 @@ class Emitter:
     def instruction_size(self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -461,7 +494,7 @@ class Emitter:
     def _emit_if(
         self,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> tuple[bool, Token, Storage]:
@@ -521,7 +554,7 @@ class Emitter:
     def _emit_block(
         self,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
         emit_first_brace: bool
@@ -568,8 +601,9 @@ class Emitter:
                         return reachable, tkn, storage
                     self.out.emit(tkn)
                 elif tkn.kind == "GOTO":
+                    label_tkn = next(tkn_iter)
+                    self.goto_label(tkn, label_tkn, storage)
                     reachable = False;
-                    self.out.emit(tkn)
                 elif tkn.kind == "IDENTIFIER":
                     if tkn.text in self._replacers:
                         if not self._replacers[tkn.text](tkn, tkn_iter, uop, storage, inst):
@@ -599,17 +633,18 @@ class Emitter:
 
     def emit_tokens(
         self,
-        uop: Uop,
+        code: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> Storage:
-        tkn_iter = TokenIterator(uop.body)
+        tkn_iter = TokenIterator(code.body)
         self.out.start_line()
-        _, rbrace, storage = self._emit_block(tkn_iter, uop, storage, inst, False)
+        reachable, rbrace, storage = self._emit_block(tkn_iter, code, storage, inst, False)
         try:
-            self._print_storage(storage)
-            storage.push_outputs()
-            self._print_storage(storage)
+            if reachable:
+                self._print_storage(storage)
+                storage.push_outputs()
+                self._print_storage(storage)
         except StackError as ex:
             raise analysis_error(ex.args[0], rbrace) from None
         return storage
index cf3c39762f29cb7298add5d986759b8cbb67a555..6afca750be9b1950794da1b332b439c3e87cb936 100644 (file)
@@ -216,6 +216,8 @@ kwds.append(MACRO)
 # A label in the DSL
 LABEL = "LABEL"
 kwds.append(LABEL)
+SPILLED = "SPILLED"
+kwds.append(SPILLED)
 keywords = {name.lower(): name for name in kwds}
 
 ANNOTATION = "ANNOTATION"
index 5cfec4bfecbf074bb4899a5992e13752762df0b5..6c33debd58e1feacb3e2bd1923136da9d15e2b77 100644 (file)
@@ -112,6 +112,9 @@ class OptimizerEmitter(Emitter):
     def emit_reload(self, storage: Storage) -> None:
         pass
 
+    def goto_label(self, goto: Token, label: Token, storage: Storage) -> None:
+        self.out.emit(goto)
+        self.out.emit(label)
 
 def write_uop(
     override: Uop | None,
@@ -145,7 +148,7 @@ def write_uop(
                         cast = f"uint{cache.size*16}_t"
                     out.emit(f"{type}{cache.name} = ({cast})this_instr->operand0;\n")
         if override:
-            emitter = OptimizerEmitter(out)
+            emitter = OptimizerEmitter(out, {})
             # No reference management of inputs needed.
             for var in storage.inputs:  # type: ignore[possibly-undefined]
                 var.defined = False
index 68bbb88719e682dc8ed6ebfbed81ce251e56f97a..696c5c16432990e10bd81c64c4c4bb67d72ccca4 100644 (file)
@@ -13,6 +13,7 @@ from parsing import (  # noqa: F401
     AstNode,
 )
 
+CodeDef = InstDef | LabelDef
 
 def prettify_filename(filename: str) -> str:
     # Make filename more user-friendly and less platform-specific,
index eb8c8a7ecd32e85c8e323a7d8e23730bd0f20331..011f34de288871021385b68c1bdf579efa880a6c 100644 (file)
@@ -153,6 +153,7 @@ class Pseudo(Node):
 @dataclass
 class LabelDef(Node):
     name: str
+    spilled: bool
     block: Block
 
 
@@ -176,12 +177,15 @@ class Parser(PLexer):
 
     @contextual
     def label_def(self) -> LabelDef | None:
+        spilled = False
+        if self.expect(lx.SPILLED):
+            spilled = True
         if self.expect(lx.LABEL):
             if self.expect(lx.LPAREN):
                 if tkn := self.expect(lx.IDENTIFIER):
                     if self.expect(lx.RPAREN):
                         if block := self.block():
-                            return LabelDef(tkn.text, block)
+                            return LabelDef(tkn.text, spilled, block)
         return None
 
     @contextual
index 5121837ed8334baddb5199162b194ff0ff3d912f..729973f1e32758ac9d4cce94439d275142f20fb4 100644 (file)
@@ -570,7 +570,7 @@ class Storage:
         assert [v.name for v in inputs] == [v.name for v in self.inputs], (inputs, self.inputs)
         return Storage(
             new_stack, inputs,
-            self.copy_list(self.outputs), self.copy_list(self.peeks)
+            self.copy_list(self.outputs), self.copy_list(self.peeks), self.spilled
         )
 
     def sanity_check(self) -> None:
index eed3086c32792665a575022609961763c4190d05..c7cf09e2ec4edef8ceac46fb43eb665ab12ad204 100644 (file)
@@ -184,19 +184,25 @@ def generate_tier1_labels(
     analysis: Analysis, outfile: TextIO, lines: bool
 ) -> None:
     out = CWriter(outfile, 2, lines)
+    emitter = Emitter(out, analysis.labels)
     out.emit("\n")
     for name, label in analysis.labels.items():
         out.emit(f"{name}:\n")
-        for tkn in label.body:
-            out.emit(tkn)
+        out.emit("{\n")
+        storage = Storage(Stack(), [], [], [])
+        if label.spilled:
+            storage.spilled = 1
+            out.emit("/* STACK SPILLED */\n")
+        emitter.emit_tokens(label, storage, None)
         out.emit("\n")
+        out.emit("}\n")
         out.emit("\n")
 
 def generate_tier1_cases(
     analysis: Analysis, outfile: TextIO, lines: bool
 ) -> None:
     out = CWriter(outfile, 2, lines)
-    emitter = Emitter(out)
+    emitter = Emitter(out, analysis.labels)
     out.emit("\n")
     for name, inst in sorted(analysis.instructions.items()):
         needs_this = uses_this(inst)
index 4540eb252634ba2962918310c07291114ef520ed..5e23360cdc0aaf11c149f90ec034d9004fa9ccae 100644 (file)
@@ -9,6 +9,8 @@ from analyzer import (
     Analysis,
     Instruction,
     Uop,
+    Label,
+    CodeSection,
     analyze_files,
     StackItem,
     analysis_error,
@@ -65,8 +67,8 @@ def declare_variables(uop: Uop, out: CWriter) -> None:
 
 class Tier2Emitter(Emitter):
 
-    def __init__(self, out: CWriter):
-        super().__init__(out)
+    def __init__(self, out: CWriter, labels: dict[str, Label]):
+        super().__init__(out, labels)
         self._replacers["oparg"] = self.oparg
 
     def goto_error(self, offset: int, label: str, storage: Storage) -> str:
@@ -79,7 +81,7 @@ class Tier2Emitter(Emitter):
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -100,7 +102,7 @@ class Tier2Emitter(Emitter):
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -120,7 +122,7 @@ class Tier2Emitter(Emitter):
         self,
         tkn: Token,
         tkn_iter: TokenIterator,
-        uop: Uop,
+        uop: CodeSection,
         storage: Storage,
         inst: Instruction | None,
     ) -> bool:
@@ -180,7 +182,7 @@ def generate_tier2(
 """
     )
     out = CWriter(outfile, 2, lines)
-    emitter = Tier2Emitter(out)
+    emitter = Tier2Emitter(out, analysis.labels)
     out.emit("\n")
     for name, uop in analysis.uops.items():
         if uop.properties.tier == 1: