]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-119726: Deduplicate AArch64 trampolines within a trace (GH-123872)
authorDiego Russo <diego.russo@arm.com>
Wed, 2 Oct 2024 19:07:20 +0000 (20:07 +0100)
committerGitHub <noreply@github.com>
Wed, 2 Oct 2024 19:07:20 +0000 (12:07 -0700)
Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst [new file with mode: 0644]
Python/jit.c
Tools/jit/_stencils.py
Tools/jit/_targets.py
Tools/jit/_writer.py

diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst
new file mode 100644 (file)
index 0000000..c01eeff
--- /dev/null
@@ -0,0 +1,2 @@
+The JIT now generates more efficient code for calls to C functions resulting
+in up to 0.8% memory savings and 1.5% speed improvement on AArch64. Patch by Diego Russo.
index 33320761621c4c5eadd2849910bb598f36faf374..234fc7dda832311889a5756042fef8f7b5146d70 100644 (file)
@@ -3,6 +3,7 @@
 #include "Python.h"
 
 #include "pycore_abstract.h"
+#include "pycore_bitutils.h"
 #include "pycore_call.h"
 #include "pycore_ceval.h"
 #include "pycore_critical_section.h"
@@ -113,6 +114,21 @@ mark_executable(unsigned char *memory, size_t size)
 
 // JIT compiler stuff: /////////////////////////////////////////////////////////
 
+#define SYMBOL_MASK_WORDS 4
+
+typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];
+
+typedef struct {
+    unsigned char *mem;
+    symbol_mask mask;
+    size_t size;
+} trampoline_state;
+
+typedef struct {
+    trampoline_state trampolines;
+    uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
+} jit_state;
+
 // Warning! AArch64 requires you to get your hands dirty. These are your gloves:
 
 // value[value_start : value_start + len]
@@ -390,66 +406,126 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value)
     patch_32r(location, value);
 }
 
+void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
+
 #include "jit_stencils.h"
 
+#if defined(__aarch64__) || defined(_M_ARM64)
+    #define TRAMPOLINE_SIZE 16
+#else
+    #define TRAMPOLINE_SIZE 0
+#endif
+
+// Generate and patch AArch64 trampolines. The symbols to jump to are stored
+// in the jit_stencils.h in the symbols_map.
+void
+patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
+{
+    // Masking is done modulo 32 as the mask is stored as an array of uint32_t
+    const uint32_t symbol_mask = 1 << (ordinal % 32);
+    const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
+    assert(symbol_mask & trampoline_mask);
+
+    // Count the number of set bits in the trampoline mask lower than ordinal,
+    // this gives the index into the array of trampolines.
+    int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
+    for (int i = 0; i < ordinal / 32; i++) {
+        index += _Py_popcount32(state->trampolines.mask[i]);
+    }
+
+    uint32_t *p = (uint32_t*)(state->trampolines.mem + index * TRAMPOLINE_SIZE);
+    assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
+
+    uint64_t value = (uintptr_t)symbols_map[ordinal];
+
+    /* Generate the trampoline
+       0: 58000048      ldr     x8, 8
+       4: d61f0100      br      x8
+       8: 00000000      // The next two words contain the 64-bit address to jump to.
+       c: 00000000
+    */
+    p[0] = 0x58000048;
+    p[1] = 0xD61F0100;
+    p[2] = value & 0xffffffff;
+    p[3] = value >> 32;
+
+    patch_aarch64_26r(location, (uintptr_t)p);
+}
+
+static void
+combine_symbol_mask(const symbol_mask src, symbol_mask dest)
+{
+    // Calculate the union of the trampolines required by each StencilGroup
+    for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) {
+        dest[i] |= src[i];
+    }
+}
+
 // Compiles executor in-place. Don't forget to call _PyJIT_Free later!
 int
 _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length)
 {
     const StencilGroup *group;
     // Loop once to find the total compiled size:
-    uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
     size_t code_size = 0;
     size_t data_size = 0;
+    jit_state state = {};
     group = &trampoline;
     code_size += group->code_size;
     data_size += group->data_size;
     for (size_t i = 0; i < length; i++) {
         const _PyUOpInstruction *instruction = &trace[i];
         group = &stencil_groups[instruction->opcode];
-        instruction_starts[i] = code_size;
+        state.instruction_starts[i] = code_size;
         code_size += group->code_size;
         data_size += group->data_size;
+        combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
     }
     group = &stencil_groups[_FATAL_ERROR];
     code_size += group->code_size;
     data_size += group->data_size;
+    combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
+    // Calculate the size of the trampolines required by the whole trace
+    for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
+        state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
+    }
     // Round up to the nearest page:
     size_t page_size = get_page_size();
     assert((page_size & (page_size - 1)) == 0);
-    size_t padding = page_size - ((code_size + data_size) & (page_size - 1));
-    size_t total_size = code_size + data_size + padding;
+    size_t padding = page_size - ((code_size + data_size + state.trampolines.size) & (page_size - 1));
+    size_t total_size = code_size + data_size + state.trampolines.size + padding;
     unsigned char *memory = jit_alloc(total_size);
     if (memory == NULL) {
         return -1;
     }
     // Update the offsets of each instruction:
     for (size_t i = 0; i < length; i++) {
-        instruction_starts[i] += (uintptr_t)memory;
+        state.instruction_starts[i] += (uintptr_t)memory;
     }
     // Loop again to emit the code:
     unsigned char *code = memory;
     unsigned char *data = memory + code_size;
+    state.trampolines.mem = memory + code_size + data_size;
     // Compile the trampoline, which handles converting between the native
     // calling convention and the calling convention used by jitted code
     // (which may be different for efficiency reasons). On platforms where
     // we don't change calling conventions, the trampoline is empty and
     // nothing is emitted here:
     group = &trampoline;
-    group->emit(code, data, executor, NULL, instruction_starts);
+    group->emit(code, data, executor, NULL, &state);
     code += group->code_size;
     data += group->data_size;
     assert(trace[0].opcode == _START_EXECUTOR);
     for (size_t i = 0; i < length; i++) {
         const _PyUOpInstruction *instruction = &trace[i];
         group = &stencil_groups[instruction->opcode];
-        group->emit(code, data, executor, instruction, instruction_starts);
+        group->emit(code, data, executor, instruction, &state);
         code += group->code_size;
         data += group->data_size;
     }
     // Protect against accidental buffer overrun into data:
     group = &stencil_groups[_FATAL_ERROR];
-    group->emit(code, data, executor, NULL, instruction_starts);
+    group->emit(code, data, executor, NULL, &state);
     code += group->code_size;
     data += group->data_size;
     assert(code == memory + code_size);
index 1c6a9edb39840d8b4e4abe296ea6903401e0e108..bbb52f391f4b01abfc9e838ff22f6a1e8b3aa1f1 100644 (file)
@@ -2,7 +2,6 @@
 
 import dataclasses
 import enum
-import sys
 import typing
 
 import _schema
@@ -103,8 +102,8 @@ _HOLE_EXPRS = {
     HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
     HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
     HoleValue.TARGET: "instruction->target",
-    HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
-    HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
+    HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]",
+    HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]",
     HoleValue.ZERO: "",
 }
 
@@ -125,6 +124,7 @@ class Hole:
     symbol: str | None
     # ...plus this addend:
     addend: int
+    need_state: bool = False
     func: str = dataclasses.field(init=False)
     # Convenience method:
     replace = dataclasses.replace
@@ -157,10 +157,12 @@ class Hole:
             if value:
                 value += " + "
             value += f"(uintptr_t)&{self.symbol}"
-        if _signed(self.addend):
+        if _signed(self.addend) or not value:
             if value:
                 value += " + "
             value += f"{_signed(self.addend):#x}"
+        if self.need_state:
+            return f"{self.func}({location}, {value}, state);"
         return f"{self.func}({location}, {value});"
 
 
@@ -175,7 +177,6 @@ class Stencil:
     body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
     holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
     disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
-    trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
 
     def pad(self, alignment: int) -> None:
         """Pad the stencil to the given alignment."""
@@ -184,39 +185,6 @@ class Stencil:
         self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
         self.body.extend([0] * padding)
 
-    def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
-        """Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
-        assert hole.symbol is not None
-        reuse_trampoline = hole.symbol in self.trampolines
-        if reuse_trampoline:
-            # Re-use the base address of the previously created trampoline
-            base = self.trampolines[hole.symbol]
-        else:
-            self.pad(alignment)
-            base = len(self.body)
-        new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)
-
-        if reuse_trampoline:
-            return new_hole
-
-        self.disassembly += [
-            f"{base + 4 * 0:x}: 58000048      ldr     x8, 8",
-            f"{base + 4 * 1:x}: d61f0100      br      x8",
-            f"{base + 4 * 2:x}: 00000000",
-            f"{base + 4 * 2:016x}:  R_AARCH64_ABS64    {hole.symbol}",
-            f"{base + 4 * 3:x}: 00000000",
-        ]
-        for code in [
-            0x58000048.to_bytes(4, sys.byteorder),
-            0xD61F0100.to_bytes(4, sys.byteorder),
-            0x00000000.to_bytes(4, sys.byteorder),
-            0x00000000.to_bytes(4, sys.byteorder),
-        ]:
-            self.body.extend(code)
-        self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
-        self.trampolines[hole.symbol] = base
-        return new_hole
-
     def remove_jump(self, *, alignment: int = 1) -> None:
         """Remove a zero-length continuation jump, if it exists."""
         hole = max(self.holes, key=lambda hole: hole.offset)
@@ -282,8 +250,14 @@ class StencilGroup:
         default_factory=dict, init=False
     )
     _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
-
-    def process_relocations(self, *, alignment: int = 1) -> None:
+    _trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
+
+    def process_relocations(
+        self,
+        known_symbols: dict[str, int],
+        *,
+        alignment: int = 1,
+    ) -> None:
         """Fix up all GOT and internal relocations for this stencil group."""
         for hole in self.code.holes.copy():
             if (
@@ -291,9 +265,17 @@ class StencilGroup:
                 in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
                 and hole.value is HoleValue.ZERO
             ):
-                new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
-                self.code.holes.remove(hole)
-                self.code.holes.append(new_hole)
+                hole.func = "patch_aarch64_trampoline"
+                hole.need_state = True
+                assert hole.symbol is not None
+                if hole.symbol in known_symbols:
+                    ordinal = known_symbols[hole.symbol]
+                else:
+                    ordinal = len(known_symbols)
+                    known_symbols[hole.symbol] = ordinal
+                self._trampolines.add(ordinal)
+                hole.addend = ordinal
+                hole.symbol = None
         self.code.remove_jump(alignment=alignment)
         self.code.pad(alignment)
         self.data.pad(8)
@@ -348,9 +330,20 @@ class StencilGroup:
             )
             self.data.body.extend([0] * 8)
 
+    def _get_trampoline_mask(self) -> str:
+        bitmask: int = 0
+        trampoline_mask: list[str] = []
+        for ordinal in self._trampolines:
+            bitmask |= 1 << ordinal
+        while bitmask:
+            word = bitmask & ((1 << 32) - 1)
+            trampoline_mask.append(f"{word:#04x}")
+            bitmask >>= 32
+        return "{" + ", ".join(trampoline_mask) + "}"
+
     def as_c(self, opname: str) -> str:
         """Dump this hole as a StencilGroup initializer."""
-        return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
+        return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
 
 
 def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
index b6c0e79e72fb3e982af435659d40d2902aaf7a9d..5eb316e782fda8f82f509c4bca3890f3abecd665 100644 (file)
@@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]):
     stable: bool = False
     debug: bool = False
     verbose: bool = False
+    known_symbols: dict[str, int] = dataclasses.field(default_factory=dict)
 
     def _compute_digest(self, out: pathlib.Path) -> str:
         hasher = hashlib.sha256()
@@ -95,7 +96,9 @@ class _Target(typing.Generic[_S, _R]):
         if group.data.body:
             line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
             group.data.disassembly.append(line)
-        group.process_relocations(alignment=self.alignment)
+        group.process_relocations(
+            known_symbols=self.known_symbols, alignment=self.alignment
+        )
         return group
 
     def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
@@ -231,7 +234,7 @@ class _Target(typing.Generic[_S, _R]):
                 if comment:
                     file.write(f"// {comment}\n")
                 file.write("\n")
-                for line in _writer.dump(stencil_groups):
+                for line in _writer.dump(stencil_groups, self.known_symbols):
                     file.write(f"{line}\n")
             try:
                 jit_stencils_new.replace(jit_stencils)
index 9d11094f85c7fffa8745fb92a1f518259cff0ac3..7b99d10310a645e9bf2eb61a0d93327514160359 100644 (file)
@@ -2,17 +2,24 @@
 
 import itertools
 import typing
+import math
 
 import _stencils
 
 
-def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
+def _dump_footer(
+    groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
+) -> typing.Iterator[str]:
+    symbol_mask_size = max(math.ceil(len(symbols) / 32), 1)
+    yield f'static_assert(SYMBOL_MASK_WORDS >= {symbol_mask_size}, "SYMBOL_MASK_WORDS too small");'
+    yield ""
     yield "typedef struct {"
     yield "    void (*emit)("
     yield "        unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
-    yield "        const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);"
+    yield "        const _PyUOpInstruction *instruction, jit_state *state);"
     yield "    size_t code_size;"
     yield "    size_t data_size;"
+    yield "    symbol_mask trampoline_mask;"
     yield "} StencilGroup;"
     yield ""
     yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"
@@ -23,13 +30,18 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s
             continue
         yield f"    [{opname}] = {group.as_c(opname)},"
     yield "};"
+    yield ""
+    yield f"static const void * const symbols_map[{max(len(symbols), 1)}] = {{"
+    for symbol, ordinal in symbols.items():
+        yield f"    [{ordinal}] = &{symbol},"
+    yield "};"
 
 
 def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]:
     yield "void"
     yield f"emit_{opname}("
     yield "    unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
-    yield "    const _PyUOpInstruction *instruction, uintptr_t instruction_starts[])"
+    yield "    const _PyUOpInstruction *instruction, jit_state *state)"
     yield "{"
     for part, stencil in [("code", group.code), ("data", group.data)]:
         for line in stencil.disassembly:
@@ -58,8 +70,10 @@ def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator
     yield ""
 
 
-def dump(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
+def dump(
+    groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
+) -> typing.Iterator[str]:
     """Yield a JIT compiler line-by-line as a C header file."""
     for opname, group in sorted(groups.items()):
         yield from _dump_stencil(opname, group)
-    yield from _dump_footer(groups)
+    yield from _dump_footer(groups, symbols)