]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-142305: JIT: Deduplicating GOT symbols in the trace (#142316)
authorDiego Russo <diego.russo@arm.com>
Wed, 10 Dec 2025 16:04:04 +0000 (16:04 +0000)
committerGitHub <noreply@github.com>
Wed, 10 Dec 2025 16:04:04 +0000 (16:04 +0000)
Misc/NEWS.d/next/Core_and_Builtins/2025-12-05-17-24-34.gh-issue-142305.ybXvtr.rst [new file with mode: 0644]
Python/jit.c
Tools/jit/_stencils.py
Tools/jit/_writer.py

diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-12-05-17-24-34.gh-issue-142305.ybXvtr.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-12-05-17-24-34.gh-issue-142305.ybXvtr.rst
new file mode 100644 (file)
index 0000000..9e6d25a
--- /dev/null
@@ -0,0 +1 @@
+Decrease the size of the generated stencils and the runtime JIT code. Patch by Diego Russo.
index b0d53d156fa440210de678b729539f2749ef3f74..1e066a58974e1d237378d7edfd4e9f471862f175 100644 (file)
@@ -134,7 +134,8 @@ mark_executable(unsigned char *memory, size_t size)
 
 // JIT compiler stuff: /////////////////////////////////////////////////////////
 
-#define SYMBOL_MASK_WORDS 4
+#define GOT_SLOT_SIZE sizeof(uintptr_t)
+#define SYMBOL_MASK_WORDS 8
 
 typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];
 
@@ -142,10 +143,11 @@ typedef struct {
     unsigned char *mem;
     symbol_mask mask;
     size_t size;
-} trampoline_state;
+} symbol_state;
 
 typedef struct {
-    trampoline_state trampolines;
+    symbol_state trampolines;
+    symbol_state got_symbols;
     uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
 } jit_state;
 
@@ -210,6 +212,33 @@ set_bits(uint32_t *loc, uint8_t loc_start, uint64_t value, uint8_t value_start,
 // - x86_64-unknown-linux-gnu:
 //   - https://github.com/llvm/llvm-project/blob/main/lld/ELF/Arch/X86_64.cpp
 
+
+// Get the symbol slot memory location for a given symbol ordinal.
+static unsigned char *
+get_symbol_slot(int ordinal, symbol_state *state, int size)
+{
+    const uint32_t symbol_mask = 1U << (ordinal % 32);
+    const uint32_t state_mask = state->mask[ordinal / 32];
+    assert(symbol_mask & state_mask);
+
+     // Count the number of set bits in the symbol mask lower than ordinal
+    size_t index = _Py_popcount32(state_mask & (symbol_mask - 1));
+    for (int i = 0; i < ordinal / 32; i++) {
+        index += _Py_popcount32(state->mask[i]);
+    }
+
+    unsigned char *slot = state->mem + index * size;
+    assert((size_t)(index + 1) * size <= state->size);
+    return slot;
+}
+
+// Return the address of the GOT slot for the requested symbol ordinal.
+static uintptr_t
+got_symbol_address(int ordinal, jit_state *state)
+{
+    return (uintptr_t)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
+}
+
 // Many of these patches are "relaxing", meaning that they can rewrite the
 // code they're patching to be more efficient (like turning a 64-bit memory
 // load into a 32-bit immediate load). These patches have an "x" in their name.
@@ -452,6 +481,7 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
     patch_32r(location, value);
 }
 
+void patch_got_symbol(jit_state *state, int ordinal);
 void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
 void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state);
 
@@ -470,23 +500,13 @@ void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *st
     #define DATA_ALIGN 1
 #endif
 
-// Get the trampoline memory location for a given symbol ordinal.
-static unsigned char *
-get_trampoline_slot(int ordinal, jit_state *state)
+// Populate the GOT entry for the given symbol ordinal with its resolved address.
+void
+patch_got_symbol(jit_state *state, int ordinal)
 {
-    const uint32_t symbol_mask = 1 << (ordinal % 32);
-    const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
-    assert(symbol_mask & trampoline_mask);
-
-     // Count the number of set bits in the trampoline mask lower than ordinal
-    int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
-    for (int i = 0; i < ordinal / 32; i++) {
-        index += _Py_popcount32(state->trampolines.mask[i]);
-    }
-
-    unsigned char *trampoline = state->trampolines.mem + index * TRAMPOLINE_SIZE;
-    assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
-    return trampoline;
+    uint64_t value = (uintptr_t)symbols_map[ordinal];
+    unsigned char *location = (unsigned char *)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE);
+    patch_64(location, value);
 }
 
 // Generate and patch AArch64 trampolines. The symbols to jump to are stored
@@ -506,8 +526,7 @@ patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
     }
 
     // Out of range - need a trampoline
-    uint32_t *p = (uint32_t *)get_trampoline_slot(ordinal, state);
-
+    uint32_t *p = (uint32_t *)get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);
 
     /* Generate the trampoline
        0: 58000048      ldr     x8, 8
@@ -537,7 +556,7 @@ patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state)
     }
 
     // Out of range - need a trampoline
-    unsigned char *trampoline = get_trampoline_slot(ordinal, state);
+    unsigned char *trampoline = get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE);
 
     /* Generate the trampoline (14 bytes, padded to 16):
        0: ff 25 00 00 00 00    jmp *(%rip)
@@ -579,21 +598,26 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
         code_size += group->code_size;
         data_size += group->data_size;
         combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
+        combine_symbol_mask(group->got_mask, state.got_symbols.mask);
     }
     group = &stencil_groups[_FATAL_ERROR];
     code_size += group->code_size;
     data_size += group->data_size;
     combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
+    combine_symbol_mask(group->got_mask, state.got_symbols.mask);
     // Calculate the size of the trampolines required by the whole trace
     for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
         state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
     }
+    for (size_t i = 0; i < Py_ARRAY_LENGTH(state.got_symbols.mask); i++) {
+        state.got_symbols.size += _Py_popcount32(state.got_symbols.mask[i]) * GOT_SLOT_SIZE;
+    }
     // Round up to the nearest page:
     size_t page_size = get_page_size();
     assert((page_size & (page_size - 1)) == 0);
     size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
-    size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
-    size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
+    size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
+    size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
     unsigned char *memory = jit_alloc(total_size);
     if (memory == NULL) {
         return -1;
@@ -603,6 +627,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
     OPT_STAT_ADD(jit_code_size, code_size);
     OPT_STAT_ADD(jit_trampoline_size, state.trampolines.size);
     OPT_STAT_ADD(jit_data_size, data_size);
+    OPT_STAT_ADD(jit_got_size, state.got_symbols.size);
     OPT_STAT_ADD(jit_padding_size, padding);
     OPT_HIST(total_size, trace_total_memory_hist);
     // Update the offsets of each instruction:
@@ -613,6 +638,7 @@ _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], siz
     unsigned char *code = memory;
     state.trampolines.mem = memory + code_size;
     unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
+    state.got_symbols.mem = data + data_size;
     assert(trace[0].opcode == _START_EXECUTOR || trace[0].opcode == _COLD_EXIT || trace[0].opcode == _COLD_DYNAMIC_EXIT);
     for (size_t i = 0; i < length; i++) {
         const _PyUOpInstruction *instruction = &trace[i];
@@ -654,12 +680,13 @@ compile_trampoline(void)
     code_size += group->code_size;
     data_size += group->data_size;
     combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
+    combine_symbol_mask(group->got_mask, state.got_symbols.mask);
     // Round up to the nearest page:
     size_t page_size = get_page_size();
     assert((page_size & (page_size - 1)) == 0);
     size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1));
-    size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size) & (page_size - 1));
-    size_t total_size = code_size + state.trampolines.size + code_padding + data_size + padding;
+    size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1));
+    size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding;
     unsigned char *memory = jit_alloc(total_size);
     if (memory == NULL) {
         return NULL;
@@ -667,6 +694,7 @@ compile_trampoline(void)
     unsigned char *code = memory;
     state.trampolines.mem = memory + code_size;
     unsigned char *data = memory + code_size + state.trampolines.size + code_padding;
+    state.got_symbols.mem = data + data_size;
     // Compile the shim, which handles converting between the native
     // calling convention and the calling convention used by jitted code
     // (which may be different for efficiency reasons).
index 5c45ab930a4ac4708910bca9c031dbf74d5f9632..cdffd953ef98389b36addab890cd04b88cec771c 100644 (file)
@@ -100,8 +100,8 @@ _HOLE_EXPRS = {
     HoleValue.CODE: "(uintptr_t)code",
     HoleValue.DATA: "(uintptr_t)data",
     HoleValue.EXECUTOR: "(uintptr_t)executor",
+    HoleValue.GOT: "",
     # These should all have been turned into DATA values by process_relocations:
-    # HoleValue.GOT: "",
     HoleValue.OPARG: "instruction->oparg",
     HoleValue.OPERAND0: "instruction->operand0",
     HoleValue.OPERAND0_HI: "(instruction->operand0 >> 32)",
@@ -115,6 +115,24 @@ _HOLE_EXPRS = {
     HoleValue.ZERO: "",
 }
 
+_AARCH64_GOT_RELOCATIONS = {
+    "R_AARCH64_ADR_GOT_PAGE",
+    "R_AARCH64_LD64_GOT_LO12_NC",
+    "ARM64_RELOC_GOT_LOAD_PAGE21",
+    "ARM64_RELOC_GOT_LOAD_PAGEOFF12",
+    "IMAGE_REL_ARM64_PAGEBASE_REL21",
+    "IMAGE_REL_ARM64_PAGEOFFSET_12L",
+    "IMAGE_REL_ARM64_PAGEOFFSET_12A",
+}
+
+_X86_GOT_RELOCATIONS = {
+    "R_X86_64_GOTPCRELX",
+    "R_X86_64_REX_GOTPCRELX",
+    "X86_64_RELOC_GOT",
+    "X86_64_RELOC_GOT_LOAD",
+    "IMAGE_REL_AMD64_REL32",
+}
+
 
 @dataclasses.dataclass
 class Hole:
@@ -133,6 +151,8 @@ class Hole:
     # ...plus this addend:
     addend: int
     need_state: bool = False
+    custom_location: str = ""
+    custom_value: str = ""
     func: str = dataclasses.field(init=False)
     # Convenience method:
     replace = dataclasses.replace
@@ -170,16 +190,22 @@ class Hole:
 
     def as_c(self, where: str) -> str:
         """Dump this hole as a call to a patch_* function."""
-        location = f"{where} + {self.offset:#x}"
-        value = _HOLE_EXPRS[self.value]
-        if self.symbol:
-            if value:
-                value += " + "
-            value += f"(uintptr_t)&{self.symbol}"
-        if _signed(self.addend) or not value:
-            if value:
-                value += " + "
-            value += f"{_signed(self.addend):#x}"
+        if self.custom_location:
+            location = self.custom_location
+        else:
+            location = f"{where} + {self.offset:#x}"
+        if self.custom_value:
+            value = self.custom_value
+        else:
+            value = _HOLE_EXPRS[self.value]
+            if self.symbol:
+                if value:
+                    value += " + "
+                value += f"(uintptr_t)&{self.symbol}"
+            if _signed(self.addend) or not value:
+                if value:
+                    value += " + "
+                value += f"{_signed(self.addend):#x}"
         if self.need_state:
             return f"{self.func}({location}, {value}, state);"
         return f"{self.func}({location}, {value});"
@@ -219,8 +245,11 @@ class StencilGroup:
     symbols: dict[int | str, tuple[HoleValue, int]] = dataclasses.field(
         default_factory=dict, init=False
     )
-    _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
+    _jit_symbol_table: dict[str, int] = dataclasses.field(
+        default_factory=dict, init=False
+    )
     _trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
+    _got_entries: set[int] = dataclasses.field(default_factory=set, init=False)
 
     def convert_labels_to_relocations(self) -> None:
         for name, hole_plus in self.symbols.items():
@@ -270,13 +299,39 @@ class StencilGroup:
                 self._trampolines.add(ordinal)
                 hole.addend = ordinal
                 hole.symbol = None
+            elif (
+                hole.kind in _AARCH64_GOT_RELOCATIONS | _X86_GOT_RELOCATIONS
+                and hole.symbol
+                and "_JIT_" not in hole.symbol
+                and hole.value is HoleValue.GOT
+            ):
+                if hole.symbol in known_symbols:
+                    ordinal = known_symbols[hole.symbol]
+                else:
+                    ordinal = len(known_symbols)
+                    known_symbols[hole.symbol] = ordinal
+                self._got_entries.add(ordinal)
         self.data.pad(8)
         for stencil in [self.code, self.data]:
             for hole in stencil.holes:
                 if hole.value is HoleValue.GOT:
                     assert hole.symbol is not None
-                    hole.value = HoleValue.DATA
-                    hole.addend += self._global_offset_table_lookup(hole.symbol)
+                    if "_JIT_" in hole.symbol:
+                        # Relocations for local symbols
+                        hole.value = HoleValue.DATA
+                        hole.addend += self._jit_symbol_table_lookup(hole.symbol)
+                    else:
+                        _ordinal = known_symbols[hole.symbol]
+                        _custom_value = f"got_symbol_address({_ordinal:#x}, state)"
+                        if hole.kind in _X86_GOT_RELOCATIONS:
+                            # When patching on x86, subtract the addend -4
+                            # that is used to compute the 32 bit RIP relative
+                            # displacement to the GOT entry
+                            _custom_value = (
+                                f"got_symbol_address({_ordinal:#x}, state) - 4"
+                            )
+                        hole.addend = _ordinal
+                        hole.custom_value = _custom_value
                     hole.symbol = None
                 elif hole.symbol in self.symbols:
                     hole.value, addend = self.symbols[hole.symbol]
@@ -289,16 +344,19 @@ class StencilGroup:
                     raise ValueError(
                         f"Add PyAPI_FUNC(...) or PyAPI_DATA(...) to declaration of {hole.symbol}!"
                     )
+        self._emit_jit_symbol_table()
         self._emit_global_offset_table()
         self.code.holes.sort(key=lambda hole: hole.offset)
         self.data.holes.sort(key=lambda hole: hole.offset)
 
-    def _global_offset_table_lookup(self, symbol: str) -> int:
-        return len(self.data.body) + self._got.setdefault(symbol, 8 * len(self._got))
+    def _jit_symbol_table_lookup(self, symbol: str) -> int:
+        return len(self.data.body) + self._jit_symbol_table.setdefault(
+            symbol, 8 * len(self._jit_symbol_table)
+        )
 
-    def _emit_global_offset_table(self) -> None:
+    def _emit_jit_symbol_table(self) -> None:
         got = len(self.data.body)
-        for s, offset in self._got.items():
+        for s, offset in self._jit_symbol_table.items():
             if s in self.symbols:
                 value, addend = self.symbols[s]
                 symbol = None
@@ -322,20 +380,35 @@ class StencilGroup:
             )
             self.data.body.extend([0] * 8)
 
-    def _get_trampoline_mask(self) -> str:
+    def _emit_global_offset_table(self) -> None:
+        for hole in self.code.holes:
+            if hole.value is HoleValue.GOT:
+                _got_hole = Hole(0, "R_X86_64_64", hole.value, None, hole.addend)
+                _got_hole.func = "patch_got_symbol"
+                _got_hole.custom_location = "state"
+                if _got_hole not in self.data.holes:
+                    self.data.holes.append(_got_hole)
+
+    def _get_symbol_mask(self, ordinals: set[int]) -> str:
         bitmask: int = 0
-        trampoline_mask: list[str] = []
-        for ordinal in self._trampolines:
+        symbol_mask: list[str] = []
+        for ordinal in ordinals:
             bitmask |= 1 << ordinal
         while bitmask:
             word = bitmask & ((1 << 32) - 1)
-            trampoline_mask.append(f"{word:#04x}")
+            symbol_mask.append(f"{word:#04x}")
             bitmask >>= 32
-        return "{" + (", ".join(trampoline_mask) or "0") + "}"
+        return "{" + (", ".join(symbol_mask) or "0") + "}"
+
+    def _get_trampoline_mask(self) -> str:
+        return self._get_symbol_mask(self._trampolines)
+
+    def _get_got_mask(self) -> str:
+        return self._get_symbol_mask(self._got_entries)
 
     def as_c(self, opname: str) -> str:
         """Dump this hole as a StencilGroup initializer."""
-        return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
+        return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}, {self._get_got_mask()}}}"
 
 
 def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
index 4f373011ebf079433b804d66135584f46a74bd21..26696bca63aeec503f46891f80abde4674e47491 100644 (file)
@@ -20,6 +20,7 @@ def _dump_footer(
     yield "    size_t code_size;"
     yield "    size_t data_size;"
     yield "    symbol_mask trampoline_mask;"
+    yield "    symbol_mask got_mask;"
     yield "} StencilGroup;"
     yield ""
     yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"