--- /dev/null
+The JIT now generates more efficient code for calls to C functions resulting
+in up to 0.8% memory savings and 1.5% speed improvement on AArch64. Patch by Diego Russo.
#include "Python.h"
#include "pycore_abstract.h"
+#include "pycore_bitutils.h"
#include "pycore_call.h"
#include "pycore_ceval.h"
#include "pycore_critical_section.h"
// JIT compiler stuff: /////////////////////////////////////////////////////////
+#define SYMBOL_MASK_WORDS 4
+
+typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS];
+
+typedef struct {
+ unsigned char *mem;
+ symbol_mask mask;
+ size_t size;
+} trampoline_state;
+
+typedef struct {
+ trampoline_state trampolines;
+ uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
+} jit_state;
+
// Warning! AArch64 requires you to get your hands dirty. These are your gloves:
// value[value_start : value_start + len]
patch_32r(location, value);
}
+void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state);
+
#include "jit_stencils.h"
+#if defined(__aarch64__) || defined(_M_ARM64)
+ #define TRAMPOLINE_SIZE 16
+#else
+ #define TRAMPOLINE_SIZE 0
+#endif
+
+// Generate and patch AArch64 trampolines. The symbols to jump to are stored
+// in the jit_stencils.h in the symbols_map.
+void
+patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state)
+{
+ // Masking is done modulo 32 as the mask is stored as an array of uint32_t
+ const uint32_t symbol_mask = 1 << (ordinal % 32);
+ const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32];
+ assert(symbol_mask & trampoline_mask);
+
+ // Count the number of set bits in the trampoline mask lower than ordinal,
+ // this gives the index into the array of trampolines.
+ int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1));
+ for (int i = 0; i < ordinal / 32; i++) {
+ index += _Py_popcount32(state->trampolines.mask[i]);
+ }
+
+ uint32_t *p = (uint32_t*)(state->trampolines.mem + index * TRAMPOLINE_SIZE);
+ assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size);
+
+ uint64_t value = (uintptr_t)symbols_map[ordinal];
+
+ /* Generate the trampoline
+ 0: 58000048 ldr x8, 8
+ 4: d61f0100 br x8
+ 8: 00000000 // The next two words contain the 64-bit address to jump to.
+ c: 00000000
+ */
+ p[0] = 0x58000048;
+ p[1] = 0xD61F0100;
+ p[2] = value & 0xffffffff;
+ p[3] = value >> 32;
+
+ patch_aarch64_26r(location, (uintptr_t)p);
+}
+
+static void
+combine_symbol_mask(const symbol_mask src, symbol_mask dest)
+{
+ // Calculate the union of the trampolines required by each StencilGroup
+ for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) {
+ dest[i] |= src[i];
+ }
+}
+
// Compiles executor in-place. Don't forget to call _PyJIT_Free later!
int
_PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length)
{
const StencilGroup *group;
// Loop once to find the total compiled size:
- uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH];
size_t code_size = 0;
size_t data_size = 0;
+ jit_state state = {};
group = &trampoline;
code_size += group->code_size;
data_size += group->data_size;
for (size_t i = 0; i < length; i++) {
const _PyUOpInstruction *instruction = &trace[i];
group = &stencil_groups[instruction->opcode];
- instruction_starts[i] = code_size;
+ state.instruction_starts[i] = code_size;
code_size += group->code_size;
data_size += group->data_size;
+ combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
}
group = &stencil_groups[_FATAL_ERROR];
code_size += group->code_size;
data_size += group->data_size;
+ combine_symbol_mask(group->trampoline_mask, state.trampolines.mask);
+ // Calculate the size of the trampolines required by the whole trace
+ for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) {
+ state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE;
+ }
// Round up to the nearest page:
size_t page_size = get_page_size();
assert((page_size & (page_size - 1)) == 0);
- size_t padding = page_size - ((code_size + data_size) & (page_size - 1));
- size_t total_size = code_size + data_size + padding;
+ size_t padding = page_size - ((code_size + data_size + state.trampolines.size) & (page_size - 1));
+ size_t total_size = code_size + data_size + state.trampolines.size + padding;
unsigned char *memory = jit_alloc(total_size);
if (memory == NULL) {
return -1;
}
// Update the offsets of each instruction:
for (size_t i = 0; i < length; i++) {
- instruction_starts[i] += (uintptr_t)memory;
+ state.instruction_starts[i] += (uintptr_t)memory;
}
// Loop again to emit the code:
unsigned char *code = memory;
unsigned char *data = memory + code_size;
+ state.trampolines.mem = memory + code_size + data_size;
// Compile the trampoline, which handles converting between the native
// calling convention and the calling convention used by jitted code
// (which may be different for efficiency reasons). On platforms where
// we don't change calling conventions, the trampoline is empty and
// nothing is emitted here:
group = &trampoline;
- group->emit(code, data, executor, NULL, instruction_starts);
+ group->emit(code, data, executor, NULL, &state);
code += group->code_size;
data += group->data_size;
assert(trace[0].opcode == _START_EXECUTOR);
for (size_t i = 0; i < length; i++) {
const _PyUOpInstruction *instruction = &trace[i];
group = &stencil_groups[instruction->opcode];
- group->emit(code, data, executor, instruction, instruction_starts);
+ group->emit(code, data, executor, instruction, &state);
code += group->code_size;
data += group->data_size;
}
// Protect against accidental buffer overrun into data:
group = &stencil_groups[_FATAL_ERROR];
- group->emit(code, data, executor, NULL, instruction_starts);
+ group->emit(code, data, executor, NULL, &state);
code += group->code_size;
data += group->data_size;
assert(code == memory + code_size);
import dataclasses
import enum
-import sys
import typing
import _schema
HoleValue.OPERAND_HI: "(instruction->operand >> 32)",
HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)",
HoleValue.TARGET: "instruction->target",
- HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]",
- HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]",
+ HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]",
+ HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]",
HoleValue.ZERO: "",
}
symbol: str | None
# ...plus this addend:
addend: int
+ need_state: bool = False
func: str = dataclasses.field(init=False)
# Convenience method:
replace = dataclasses.replace
if value:
value += " + "
value += f"(uintptr_t)&{self.symbol}"
- if _signed(self.addend):
+ if _signed(self.addend) or not value:
if value:
value += " + "
value += f"{_signed(self.addend):#x}"
+ if self.need_state:
+ return f"{self.func}({location}, {value}, state);"
return f"{self.func}({location}, {value});"
body: bytearray = dataclasses.field(default_factory=bytearray, init=False)
holes: list[Hole] = dataclasses.field(default_factory=list, init=False)
disassembly: list[str] = dataclasses.field(default_factory=list, init=False)
- trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
def pad(self, alignment: int) -> None:
"""Pad the stencil to the given alignment."""
self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
self.body.extend([0] * padding)
- def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole:
- """Even with the large code model, AArch64 Linux insists on 28-bit jumps."""
- assert hole.symbol is not None
- reuse_trampoline = hole.symbol in self.trampolines
- if reuse_trampoline:
- # Re-use the base address of the previously created trampoline
- base = self.trampolines[hole.symbol]
- else:
- self.pad(alignment)
- base = len(self.body)
- new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA)
-
- if reuse_trampoline:
- return new_hole
-
- self.disassembly += [
- f"{base + 4 * 0:x}: 58000048 ldr x8, 8",
- f"{base + 4 * 1:x}: d61f0100 br x8",
- f"{base + 4 * 2:x}: 00000000",
- f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}",
- f"{base + 4 * 3:x}: 00000000",
- ]
- for code in [
- 0x58000048.to_bytes(4, sys.byteorder),
- 0xD61F0100.to_bytes(4, sys.byteorder),
- 0x00000000.to_bytes(4, sys.byteorder),
- 0x00000000.to_bytes(4, sys.byteorder),
- ]:
- self.body.extend(code)
- self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64"))
- self.trampolines[hole.symbol] = base
- return new_hole
-
def remove_jump(self, *, alignment: int = 1) -> None:
"""Remove a zero-length continuation jump, if it exists."""
hole = max(self.holes, key=lambda hole: hole.offset)
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
-
- def process_relocations(self, *, alignment: int = 1) -> None:
+ _trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
+
+ def process_relocations(
+ self,
+ known_symbols: dict[str, int],
+ *,
+ alignment: int = 1,
+ ) -> None:
"""Fix up all GOT and internal relocations for this stencil group."""
for hole in self.code.holes.copy():
if (
in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"}
and hole.value is HoleValue.ZERO
):
- new_hole = self.data.emit_aarch64_trampoline(hole, alignment)
- self.code.holes.remove(hole)
- self.code.holes.append(new_hole)
+ hole.func = "patch_aarch64_trampoline"
+ hole.need_state = True
+ assert hole.symbol is not None
+ if hole.symbol in known_symbols:
+ ordinal = known_symbols[hole.symbol]
+ else:
+ ordinal = len(known_symbols)
+ known_symbols[hole.symbol] = ordinal
+ self._trampolines.add(ordinal)
+ hole.addend = ordinal
+ hole.symbol = None
self.code.remove_jump(alignment=alignment)
self.code.pad(alignment)
self.data.pad(8)
)
self.data.body.extend([0] * 8)
+ def _get_trampoline_mask(self) -> str:
+ bitmask: int = 0
+ trampoline_mask: list[str] = []
+ for ordinal in self._trampolines:
+ bitmask |= 1 << ordinal
+ while bitmask:
+ word = bitmask & ((1 << 32) - 1)
+ trampoline_mask.append(f"{word:#04x}")
+ bitmask >>= 32
+ return "{" + ", ".join(trampoline_mask) + "}"
+
def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
- return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}"
+ return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"
def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
stable: bool = False
debug: bool = False
verbose: bool = False
+ known_symbols: dict[str, int] = dataclasses.field(default_factory=dict)
def _compute_digest(self, out: pathlib.Path) -> str:
hasher = hashlib.sha256()
if group.data.body:
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
group.data.disassembly.append(line)
- group.process_relocations(alignment=self.alignment)
+ group.process_relocations(
+ known_symbols=self.known_symbols, alignment=self.alignment
+ )
return group
def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
if comment:
file.write(f"// {comment}\n")
file.write("\n")
- for line in _writer.dump(stencil_groups):
+ for line in _writer.dump(stencil_groups, self.known_symbols):
file.write(f"{line}\n")
try:
jit_stencils_new.replace(jit_stencils)
import itertools
import typing
+import math
import _stencils
-def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
+def _dump_footer(
+ groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
+) -> typing.Iterator[str]:
+ symbol_mask_size = max(math.ceil(len(symbols) / 32), 1)
+ yield f'static_assert(SYMBOL_MASK_WORDS >= {symbol_mask_size}, "SYMBOL_MASK_WORDS too small");'
+ yield ""
yield "typedef struct {"
yield " void (*emit)("
yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
- yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);"
+ yield " const _PyUOpInstruction *instruction, jit_state *state);"
yield " size_t code_size;"
yield " size_t data_size;"
+ yield " symbol_mask trampoline_mask;"
yield "} StencilGroup;"
yield ""
yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};"
continue
yield f" [{opname}] = {group.as_c(opname)},"
yield "};"
+ yield ""
+ yield f"static const void * const symbols_map[{max(len(symbols), 1)}] = {{"
+ for symbol, ordinal in symbols.items():
+ yield f" [{ordinal}] = &{symbol},"
+ yield "};"
def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]:
yield "void"
yield f"emit_{opname}("
yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor,"
- yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[])"
+ yield " const _PyUOpInstruction *instruction, jit_state *state)"
yield "{"
for part, stencil in [("code", group.code), ("data", group.data)]:
for line in stencil.disassembly:
yield ""
-def dump(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
+def dump(
+ groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int]
+) -> typing.Iterator[str]:
"""Yield a JIT compiler line-by-line as a C header file."""
for opname, group in sorted(groups.items()):
yield from _dump_stencil(opname, group)
- yield from _dump_footer(groups)
+ yield from _dump_footer(groups, symbols)