]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-118926: Spill deferred references to stack in cases generator (#122748)
authorSam Gross <colesbury@gmail.com>
Wed, 7 Aug 2024 17:23:53 +0000 (13:23 -0400)
committerGitHub <noreply@github.com>
Wed, 7 Aug 2024 17:23:53 +0000 (13:23 -0400)
This automatically spills the results from `_PyStackRef_FromPyObjectNew`
to the in-memory stack so that the deferred references are visible to
the GC before we make any possibly escaping call.

Co-authored-by: Ken Jin <kenjin@python.org>
Python/bytecodes.c
Python/executor_cases.c.h
Python/generated_cases.c.h
Tools/cases_generator/analyzer.py
Tools/cases_generator/generators_common.py
Tools/cases_generator/lexer.py
Tools/cases_generator/stack.py
Tools/cases_generator/tier1_generator.py
Tools/cases_generator/tier2_generator.py

index d28cbd767e787813731b745e3ad3bb5dfca2a323..b68f9327d898c2c3cbd020cdad1761f4f8fafb0b 100644 (file)
@@ -1424,7 +1424,7 @@ dummy_func(
                                  "no locals found");
                 ERROR_IF(true, error);
             }
-            locals = PyStackRef_FromPyObjectNew(l);;
+            locals = PyStackRef_FromPyObjectNew(l);
         }
 
         inst(LOAD_FROM_DICT_OR_GLOBALS, (mod_or_class_dict -- v)) {
index 4def11c515fd3cc1ab1a9013006d270581e2f8ea..f2741286d6019749d0279d2ffbf5786871600007 100644 (file)
             }
             STAT_INC(UNPACK_SEQUENCE, hit);
             val0 = PyStackRef_FromPyObjectNew(PyTuple_GET_ITEM(seq_o, 0));
+            stack_pointer[0] = val0;
             val1 = PyStackRef_FromPyObjectNew(PyTuple_GET_ITEM(seq_o, 1));
-            PyStackRef_CLOSE(seq);
             stack_pointer[-1] = val1;
-            stack_pointer[0] = val0;
+            PyStackRef_CLOSE(seq);
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             break;
                                  "no locals found");
                 if (true) JUMP_TO_ERROR();
             }
-            locals = PyStackRef_FromPyObjectNew(l);;
+            locals = PyStackRef_FromPyObjectNew(l);
             stack_pointer[0] = locals;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             STAT_INC(LOAD_ATTR, hit);
             null = PyStackRef_NULL;
             attr = PyStackRef_FromPyObjectNew(attr_o);
-            PyStackRef_CLOSE(owner);
             stack_pointer[-1] = attr;
+            PyStackRef_CLOSE(owner);
             break;
         }
 
             STAT_INC(LOAD_ATTR, hit);
             null = PyStackRef_NULL;
             attr = PyStackRef_FromPyObjectNew(attr_o);
-            PyStackRef_CLOSE(owner);
             stack_pointer[-1] = attr;
+            PyStackRef_CLOSE(owner);
             stack_pointer[0] = null;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             STAT_INC(LOAD_ATTR, hit);
             assert(descr != NULL);
             attr = PyStackRef_FromPyObjectNew(descr);
+            stack_pointer[-1] = attr;
             null = PyStackRef_NULL;
             PyStackRef_CLOSE(owner);
-            stack_pointer[-1] = attr;
             break;
         }
 
             STAT_INC(LOAD_ATTR, hit);
             assert(descr != NULL);
             attr = PyStackRef_FromPyObjectNew(descr);
+            stack_pointer[-1] = attr;
             null = PyStackRef_NULL;
             PyStackRef_CLOSE(owner);
-            stack_pointer[-1] = attr;
             stack_pointer[0] = null;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             assert(descr != NULL);
             assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
             attr = PyStackRef_FromPyObjectNew(descr);
-            self = owner;
             stack_pointer[-1] = attr;
+            self = owner;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             assert(descr != NULL);
             assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
             attr = PyStackRef_FromPyObjectNew(descr);
-            self = owner;
             stack_pointer[-1] = attr;
+            self = owner;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             assert(descr != NULL);
             assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
             attr = PyStackRef_FromPyObjectNew(descr);
-            self = owner;
             stack_pointer[-1] = attr;
+            self = owner;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
                 PyObject *callable_o = PyStackRef_AsPyObjectBorrow(callable);
                 PyObject *self = ((PyMethodObject *)callable_o)->im_self;
                 maybe_self = PyStackRef_FromPyObjectNew(self);
+                stack_pointer[-1 - oparg] = maybe_self;
                 PyObject *method = ((PyMethodObject *)callable_o)->im_func;
                 func = PyStackRef_FromPyObjectNew(method);
+                stack_pointer[-2 - oparg] = func;
                 /* Make sure that callable and all args are in memory */
                 args[-2] = func;
                 args[-1] = maybe_self;
                 func = callable;
                 maybe_self = self_or_null;
             }
-            stack_pointer[-2 - oparg] = func;
-            stack_pointer[-1 - oparg] = maybe_self;
             break;
         }
 
             assert(PyStackRef_IsNull(null));
             assert(Py_TYPE(callable_o) == &PyMethod_Type);
             self = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_self);
+            stack_pointer[-1 - oparg] = self;
             method = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_func);
+            stack_pointer[-2 - oparg] = method;
             assert(PyStackRef_FunctionCheck(method));
             PyStackRef_CLOSE(callable);
-            stack_pointer[-2 - oparg] = method;
-            stack_pointer[-1 - oparg] = self;
             break;
         }
 
             PyObject *callable_o = PyStackRef_AsPyObjectBorrow(callable);
             STAT_INC(CALL, hit);
             self = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_self);
+            stack_pointer[-1 - oparg] = self;
             func = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_func);
-            PyStackRef_CLOSE(callable);
             stack_pointer[-2 - oparg] = func;
-            stack_pointer[-1 - oparg] = self;
+            PyStackRef_CLOSE(callable);
             break;
         }
 
             _PyStackRef null;
             PyObject *ptr = (PyObject *)CURRENT_OPERAND();
             value = PyStackRef_FromPyObjectNew(ptr);
-            null = PyStackRef_NULL;
             stack_pointer[0] = value;
+            null = PyStackRef_NULL;
             stack_pointer[1] = null;
             stack_pointer += 2;
             assert(WITHIN_STACK_BOUNDS());
index 5890fcea8e64d5ea5f36adb57565586bb5837064..31f95eb4686eb713b3c14a2970cc5995fd990c02 100644 (file)
                     PyObject *callable_o = PyStackRef_AsPyObjectBorrow(callable);
                     PyObject *self = ((PyMethodObject *)callable_o)->im_self;
                     maybe_self = PyStackRef_FromPyObjectNew(self);
+                    stack_pointer[-1 - oparg] = maybe_self;
                     PyObject *method = ((PyMethodObject *)callable_o)->im_func;
                     func = PyStackRef_FromPyObjectNew(method);
+                    stack_pointer[-2 - oparg] = func;
                     /* Make sure that callable and all args are in memory */
                     args[-2] = func;
                     args[-1] = maybe_self;
                 PyObject *callable_o = PyStackRef_AsPyObjectBorrow(callable);
                 STAT_INC(CALL, hit);
                 self = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_self);
+                stack_pointer[-1 - oparg] = self;
                 func = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_func);
+                stack_pointer[-2 - oparg] = func;
                 PyStackRef_CLOSE(callable);
             }
             // flush
-            stack_pointer[-2 - oparg] = func;
-            stack_pointer[-1 - oparg] = self;
             // _CHECK_FUNCTION_VERSION
             callable = stack_pointer[-2 - oparg];
             {
                 assert(PyStackRef_IsNull(null));
                 assert(Py_TYPE(callable_o) == &PyMethod_Type);
                 self = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_self);
+                stack_pointer[-1 - oparg] = self;
                 method = PyStackRef_FromPyObjectNew(((PyMethodObject *)callable_o)->im_func);
+                stack_pointer[-2 - oparg] = method;
                 assert(PyStackRef_FunctionCheck(method));
                 PyStackRef_CLOSE(callable);
             }
             // flush
-            stack_pointer[-2 - oparg] = method;
-            stack_pointer[-1 - oparg] = self;
             // _PY_FRAME_GENERAL
             args = &stack_pointer[-oparg];
             self_or_null = stack_pointer[-1 - oparg];
             int matches = PyErr_GivenExceptionMatches(exc_value, PyExc_StopIteration);
             if (matches) {
                 value = PyStackRef_FromPyObjectNew(((PyStopIterationObject *)exc_value)->value);
+                stack_pointer[-2] = value;
                 PyStackRef_CLOSE(sub_iter_st);
                 PyStackRef_CLOSE(last_sent_val_st);
                 PyStackRef_CLOSE(exc_value_st);
                 goto exception_unwind;
             }
             stack_pointer[-3] = none;
-            stack_pointer[-2] = value;
             stack_pointer += -1;
             assert(WITHIN_STACK_BOUNDS());
             DISPATCH();
                 assert(seq);
                 assert(it->it_index < PyList_GET_SIZE(seq));
                 next = PyStackRef_FromPyObjectNew(PyList_GET_ITEM(seq, it->it_index++));
+                stack_pointer[0] = next;
             }
-            stack_pointer[0] = next;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             DISPATCH();
                 assert(seq);
                 assert(it->it_index < PyTuple_GET_SIZE(seq));
                 next = PyStackRef_FromPyObjectNew(PyTuple_GET_ITEM(seq, it->it_index++));
+                stack_pointer[0] = next;
             }
-            stack_pointer[0] = next;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             DISPATCH();
                     PyObject *callable_o = PyStackRef_AsPyObjectBorrow(callable);
                     PyObject *self = ((PyMethodObject *)callable_o)->im_self;
                     maybe_self = PyStackRef_FromPyObjectNew(self);
+                    stack_pointer[-1 - oparg] = maybe_self;
                     PyObject *method = ((PyMethodObject *)callable_o)->im_func;
                     func = PyStackRef_FromPyObjectNew(method);
+                    stack_pointer[-2 - oparg] = func;
                     /* Make sure that callable and all args are in memory */
                     args[-2] = func;
                     args[-1] = maybe_self;
             // _LOAD_CONST
             {
                 value = PyStackRef_FromPyObjectNew(GETITEM(FRAME_CO_CONSTS, oparg));
+                stack_pointer[0] = value;
             }
             // _RETURN_VALUE_EVENT
             val = value;
                 STAT_INC(LOAD_ATTR, hit);
                 assert(descr != NULL);
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
                 null = PyStackRef_NULL;
                 PyStackRef_CLOSE(owner);
             }
-            stack_pointer[-1] = attr;
             if (oparg & 1) stack_pointer[0] = null;
             stack_pointer += (oparg & 1);
             assert(WITHIN_STACK_BOUNDS());
                 assert(descr != NULL);
                 assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
                 self = owner;
             }
-            stack_pointer[-1] = attr;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
                 assert(descr != NULL);
                 assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
                 self = owner;
             }
-            stack_pointer[-1] = attr;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
                 assert(descr != NULL);
                 assert(_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR));
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
                 self = owner;
             }
-            stack_pointer[-1] = attr;
             stack_pointer[0] = self;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
                 assert(descr != NULL);
                 PyStackRef_CLOSE(owner);
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
             }
-            stack_pointer[-1] = attr;
             DISPATCH();
         }
 
                 assert(descr != NULL);
                 PyStackRef_CLOSE(owner);
                 attr = PyStackRef_FromPyObjectNew(descr);
+                stack_pointer[-1] = attr;
             }
-            stack_pointer[-1] = attr;
             DISPATCH();
         }
 
                 STAT_INC(LOAD_ATTR, hit);
                 null = PyStackRef_NULL;
                 attr = PyStackRef_FromPyObjectNew(attr_o);
+                stack_pointer[-1] = attr;
                 PyStackRef_CLOSE(owner);
             }
             /* Skip 5 cache entries */
-            stack_pointer[-1] = attr;
             if (oparg & 1) stack_pointer[0] = null;
             stack_pointer += (oparg & 1);
             assert(WITHIN_STACK_BOUNDS());
                                  "no locals found");
                 if (true) goto error;
             }
-            locals = PyStackRef_FromPyObjectNew(l);;
+            locals = PyStackRef_FromPyObjectNew(l);
             stack_pointer[0] = locals;
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             // _LOAD_CONST
             {
                 value = PyStackRef_FromPyObjectNew(GETITEM(FRAME_CO_CONSTS, oparg));
+                stack_pointer[0] = value;
             }
             // _RETURN_VALUE
             retval = value;
             DEOPT_IF(PyTuple_GET_SIZE(seq_o) != 2, UNPACK_SEQUENCE);
             STAT_INC(UNPACK_SEQUENCE, hit);
             val0 = PyStackRef_FromPyObjectNew(PyTuple_GET_ITEM(seq_o, 0));
+            stack_pointer[0] = val0;
             val1 = PyStackRef_FromPyObjectNew(PyTuple_GET_ITEM(seq_o, 1));
-            PyStackRef_CLOSE(seq);
             stack_pointer[-1] = val1;
-            stack_pointer[0] = val0;
+            PyStackRef_CLOSE(seq);
             stack_pointer += 1;
             assert(WITHIN_STACK_BOUNDS());
             DISPATCH();
index f6625a3f7322d5cd6f8eed3ea039ff804d4eae98..3dc9838d75fa0cd7a4413003650d2d9128edfa59 100644 (file)
@@ -157,6 +157,7 @@ class Uop:
     annotations: list[str]
     stack: StackEffect
     caches: list[CacheEntry]
+    deferred_refs: dict[lexer.Token, str | None]
     body: list[lexer.Token]
     properties: Properties
     _size: int = -1
@@ -352,6 +353,47 @@ def analyze_caches(inputs: list[parser.InputEffect]) -> list[CacheEntry]:
     return [CacheEntry(i.name, int(i.size)) for i in caches]
 
 
+def analyze_deferred_refs(node: parser.InstDef) -> dict[lexer.Token, str | None]:
+    """Look for PyStackRef_FromPyObjectNew() calls"""
+
+    def find_assignment_target(idx: int) -> list[lexer.Token]:
+        """Find the tokens that make up the left-hand side of an assignment"""
+        offset = 1
+        for tkn in reversed(node.block.tokens[:idx-1]):
+            if tkn.kind == "SEMI" or tkn.kind == "LBRACE" or tkn.kind == "RBRACE":
+                return node.block.tokens[idx-offset:idx-1]
+            offset += 1
+        return []
+
+    refs: dict[lexer.Token, str | None] = {}
+    for idx, tkn in enumerate(node.block.tokens):
+        if tkn.kind != "IDENTIFIER" or tkn.text != "PyStackRef_FromPyObjectNew":
+            continue
+
+        if idx == 0 or node.block.tokens[idx-1].kind != "EQUALS":
+            raise analysis_error("Expected '=' before PyStackRef_FromPyObjectNew", tkn)
+
+        lhs = find_assignment_target(idx)
+        if len(lhs) == 0:
+            raise analysis_error("PyStackRef_FromPyObjectNew() must be assigned to an output", tkn)
+
+        if lhs[0].kind == "TIMES" or any(t.kind == "ARROW" or t.kind == "LBRACKET" for t in lhs[1:]):
+            # Don't handle: *ptr = ..., ptr->field = ..., or ptr[field] = ...
+            # Assume that they are visible to the GC.
+            refs[tkn] = None
+            continue
+
+        if len(lhs) != 1 or lhs[0].kind != "IDENTIFIER":
+            raise analysis_error("PyStackRef_FromPyObjectNew() must be assigned to an output", tkn)
+
+        name = lhs[0].text
+        if not any(var.name == name for var in node.outputs):
+            raise analysis_error(f"PyStackRef_FromPyObjectNew() must be assigned to an output, not '{name}'", tkn)
+
+        refs[tkn] = name
+
+    return refs
+
 def variable_used(node: parser.InstDef, name: str) -> bool:
     """Determine whether a variable with a given name is used in a node."""
     return any(
@@ -632,6 +674,7 @@ def make_uop(name: str, op: parser.InstDef, inputs: list[parser.InputEffect], uo
         annotations=op.annotations,
         stack=analyze_stack(op),
         caches=analyze_caches(inputs),
+        deferred_refs=analyze_deferred_refs(op),
         body=op.block.tokens,
         properties=compute_properties(op),
     )
@@ -649,6 +692,7 @@ def make_uop(name: str, op: parser.InstDef, inputs: list[parser.InputEffect], uo
                 annotations=op.annotations,
                 stack=analyze_stack(op, bit),
                 caches=analyze_caches(inputs),
+                deferred_refs=analyze_deferred_refs(op),
                 body=op.block.tokens,
                 properties=properties,
             )
@@ -671,6 +715,7 @@ def make_uop(name: str, op: parser.InstDef, inputs: list[parser.InputEffect], uo
             annotations=op.annotations,
             stack=analyze_stack(op),
             caches=analyze_caches(inputs),
+            deferred_refs=analyze_deferred_refs(op),
             body=op.block.tokens,
             properties=properties,
         )
index 2a339f8cd6bb66ee640b1e37791a41f54b3b4ead..37060e2d7e4f502800d4918477ea319dc5c7d2f8 100644 (file)
@@ -6,6 +6,7 @@ from analyzer import (
     Uop,
     Properties,
     StackItem,
+    analysis_error,
 )
 from cwriter import CWriter
 from typing import Callable, Mapping, TextIO, Iterator
@@ -75,6 +76,7 @@ class Emitter:
             "DECREF_INPUTS": self.decref_inputs,
             "CHECK_EVAL_BREAKER": self.check_eval_breaker,
             "SYNC_SP": self.sync_sp,
+            "PyStackRef_FromPyObjectNew": self.py_stack_ref_from_py_object_new,
         }
         self.out = out
 
@@ -203,6 +205,29 @@ class Emitter:
         if not uop.properties.ends_with_eval_breaker:
             self.out.emit_at("CHECK_EVAL_BREAKER();", tkn)
 
+    def py_stack_ref_from_py_object_new(
+        self,
+        tkn: Token,
+        tkn_iter: Iterator[Token],
+        uop: Uop,
+        stack: Stack,
+        inst: Instruction | None,
+    ) -> None:
+        self.out.emit(tkn)
+        emit_to(self.out, tkn_iter, "SEMI")
+        self.out.emit(";\n")
+
+        target = uop.deferred_refs[tkn]
+        if target is None:
+            # An assignment we don't handle, such as to a pointer or array.
+            return
+
+        # Flush the assignment to the stack.  Note that we don't flush the
+        # stack pointer here, and instead are currently relying on initializing
+        # unused portions of the stack to NULL.
+        stack.flush_single_var(self.out, target, uop.stack.outputs)
+
+
     def emit_tokens(
         self,
         uop: Uop,
index 13aee94f2b957c6b952e89ebbfb293fef7c94d64..d5831593215f76fc5340d182704c9ddd87c6c730 100644 (file)
@@ -242,7 +242,7 @@ def make_syntax_error(
     return SyntaxError(message, (filename, line, column, line_text))
 
 
-@dataclass(slots=True)
+@dataclass(slots=True, frozen=True)
 class Token:
     filename: str
     kind: str
index d2d598a120892d96dbc97a708630c331299e7f13..b44e48af09b3f0fc343ea8df4a0064330c461667 100644 (file)
@@ -256,21 +256,26 @@ class Stack:
                 top_offset.push(var)
         return "\n".join(res)
 
+    @staticmethod
+    def _do_emit(out: CWriter, var: StackItem, base_offset: StackOffset,
+                 cast_type: str = "uintptr_t", extract_bits: bool = False) -> None:
+        cast = f"({cast_type})" if var.type else ""
+        bits = ".bits" if cast and not extract_bits else ""
+        if var.condition == "0":
+            return
+        if var.condition and var.condition != "1":
+            out.emit(f"if ({var.condition}) ")
+        out.emit(
+            f"stack_pointer[{base_offset.to_c()}]{bits} = {cast}{var.name};\n"
+        )
+
     @staticmethod
     def _do_flush(out: CWriter, variables: list[Local], base_offset: StackOffset, top_offset: StackOffset,
                   cast_type: str = "uintptr_t", extract_bits: bool = False) -> None:
         out.start_line()
         for var in variables:
             if var.cached and not var.in_memory and not var.item.peek and not var.name in UNUSED:
-                cast = f"({cast_type})" if var.item.type else ""
-                bits = ".bits" if cast and not extract_bits else ""
-                if var.condition == "0":
-                    continue
-                if var.condition and var.condition != "1":
-                    out.emit(f"if ({var.condition}) ")
-                out.emit(
-                    f"stack_pointer[{base_offset.to_c()}]{bits} = {cast}{var.name};\n"
-                )
+                Stack._do_emit(out, var.item, base_offset, cast_type, extract_bits)
             base_offset.push(var.item)
         if base_offset.to_c() != top_offset.to_c():
             print("base", base_offset, "top", top_offset)
@@ -290,6 +295,26 @@ class Stack:
         self.base_offset.clear()
         self.top_offset.clear()
 
+    def flush_single_var(self, out: CWriter, var_name: str, outputs: list[StackItem],
+                         cast_type: str = "uintptr_t", extract_bits: bool = False) -> None:
+        assert any(var.name == var_name for var in outputs)
+        base_offset = self.base_offset.copy()
+        top_offset = self.top_offset.copy()
+        for var in self.variables:
+            base_offset.push(var.item)
+        for var in outputs:
+            if any(var == v.item for v in self.variables):
+                # The variable is already on the stack, such as a peeked value
+                # in the tier1 generator
+                continue
+            if var.name == var_name:
+                Stack._do_emit(out, var, base_offset, cast_type, extract_bits)
+            base_offset.push(var)
+            top_offset.push(var)
+        if base_offset.to_c() != top_offset.to_c():
+            print("base", base_offset, "top", top_offset)
+            assert False
+
     def peek_offset(self) -> str:
         return self.top_offset.to_c()
 
index 6c13d1f10b39f982a4cb24aa78878d9d2b2c24e6..1ea31a041ce3ae000d5c673daa06f500b45ca96c 100644 (file)
@@ -93,6 +93,16 @@ def write_uop(
         if braces:
             emitter.emit("{\n")
         emitter.out.emit(stack.define_output_arrays(uop.stack.outputs))
+        outputs: list[Local] = []
+        for var in uop.stack.outputs:
+            if not var.peek:
+                if var.name in locals:
+                    local = locals[var.name]
+                elif var.name == "unused":
+                    local = Local.unused(var)
+                else:
+                    local = Local.local(var)
+                outputs.append(local)
 
         for cache in uop.caches:
             if cache.name != "unused":
@@ -109,15 +119,11 @@ def write_uop(
                     emitter.emit(f"(void){cache.name};\n")
             offset += cache.size
         emitter.emit_tokens(uop, stack, inst)
-        for i, var in enumerate(uop.stack.outputs):
-            if not var.peek:
-                if var.name in locals:
-                    local = locals[var.name]
-                elif var.name == "unused":
-                    local = Local.unused(var)
-                else:
-                    local = Local.local(var)
-                emitter.emit(stack.push(local))
+        for output in outputs:
+            if output.name in uop.deferred_refs.values():
+                # We've already spilled this when emitting tokens
+                output.cached = False
+            emitter.emit(stack.push(output))
         if braces:
             emitter.out.start_line()
             emitter.emit("}\n")
index 8c212f75878984297178940571155d13d6e78ad7..461375c71fae83b2228f9abde71b7ed2b94a35e7 100644 (file)
@@ -166,6 +166,13 @@ def write_uop(uop: Uop, emitter: Emitter, stack: Stack) -> None:
             if local.defined:
                 locals[local.name] = local
         emitter.emit(stack.define_output_arrays(uop.stack.outputs))
+        outputs: list[Local] = []
+        for var in uop.stack.outputs:
+            if var.name in locals:
+                local = locals[var.name]
+            else:
+                local = Local.local(var)
+            outputs.append(local)
         for cache in uop.caches:
             if cache.name != "unused":
                 if cache.size == 4:
@@ -175,12 +182,11 @@ def write_uop(uop: Uop, emitter: Emitter, stack: Stack) -> None:
                     cast = f"uint{cache.size*16}_t"
                 emitter.emit(f"{type}{cache.name} = ({cast})CURRENT_OPERAND();\n")
         emitter.emit_tokens(uop, stack, None)
-        for i, var in enumerate(uop.stack.outputs):
-            if var.name in locals:
-                local = locals[var.name]
-            else:
-                local = Local.local(var)
-            emitter.emit(stack.push(local))
+        for output in outputs:
+            if output.name in uop.deferred_refs.values():
+                # We've already spilled this when emitting tokens
+                output.cached = False
+            emitter.emit(stack.push(output))
     except StackError as ex:
         raise analysis_error(ex.args[0], uop.body[0]) from None