]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-135904: Optimize the JIT's assembly control flow (GH-135905)
authorBrandt Bucher <brandtbucher@microsoft.com>
Fri, 27 Jun 2025 15:20:51 +0000 (08:20 -0700)
committerGitHub <noreply@github.com>
Fri, 27 Jun 2025 15:20:51 +0000 (08:20 -0700)
Misc/NEWS.d/next/Core_and_Builtins/2025-06-24-16-46-34.gh-issue-135904.78xfon.rst [new file with mode: 0644]
Tools/jit/_optimizers.py [new file with mode: 0644]
Tools/jit/_stencils.py
Tools/jit/_targets.py

diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-06-24-16-46-34.gh-issue-135904.78xfon.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-06-24-16-46-34.gh-issue-135904.78xfon.rst
new file mode 100644 (file)
index 0000000..ecbd8fd
--- /dev/null
@@ -0,0 +1,2 @@
+Perform more aggressive control-flow optimizations on the machine code
+templates emitted by the experimental JIT compiler.
diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py
new file mode 100644 (file)
index 0000000..1077e41
--- /dev/null
@@ -0,0 +1,319 @@
+"""Low-level optimization of textual assembly."""
+
+import dataclasses
+import pathlib
+import re
+import typing
+
+# Same as saying "not string.startswith('')":
+_RE_NEVER_MATCH = re.compile(r"(?!)")
+# Dictionary mapping branch instructions to their inverted branch instructions.
+# If a branch cannot be inverted, the value is None:
+_X86_BRANCHES = {
+    # https://www.felixcloutier.com/x86/jcc
+    "ja": "jna",
+    "jae": "jnae",
+    "jb": "jnb",
+    "jbe": "jnbe",
+    "jc": "jnc",
+    "jcxz": None,
+    "je": "jne",
+    "jecxz": None,
+    "jg": "jng",
+    "jge": "jnge",
+    "jl": "jnl",
+    "jle": "jnle",
+    "jo": "jno",
+    "jp": "jnp",
+    "jpe": "jpo",
+    "jrcxz": None,
+    "js": "jns",
+    "jz": "jnz",
+    # https://www.felixcloutier.com/x86/loop:loopcc
+    "loop": None,
+    "loope": None,
+    "loopne": None,
+    "loopnz": None,
+    "loopz": None,
+}
+# Update with all of the inverted branches, too:
+_X86_BRANCHES |= {v: k for k, v in _X86_BRANCHES.items() if v}
+
+
+@dataclasses.dataclass
+class _Block:
+    label: str | None = None
+    # Non-instruction lines like labels, directives, and comments:
+    noninstructions: list[str] = dataclasses.field(default_factory=list)
+    # Instruction lines:
+    instructions: list[str] = dataclasses.field(default_factory=list)
+    # If this block ends in a jump, where to?
+    target: typing.Self | None = None
+    # The next block in the linked list:
+    link: typing.Self | None = None
+    # Whether control flow can fall through to the linked block above:
+    fallthrough: bool = True
+    # Whether this block can eventually reach the next uop (_JIT_CONTINUE):
+    hot: bool = False
+
+    def resolve(self) -> typing.Self:
+        """Find the first non-empty block reachable from this one."""
+        block = self
+        while block.link and not block.instructions:
+            block = block.link
+        return block
+
+
+@dataclasses.dataclass
+class Optimizer:
+    """Several passes of analysis and optimization for textual assembly."""
+
+    path: pathlib.Path
+    _: dataclasses.KW_ONLY
+    # prefix used to mangle symbols on some platforms:
+    prefix: str = ""
+    # The first block in the linked list:
+    _root: _Block = dataclasses.field(init=False, default_factory=_Block)
+    _labels: dict[str, _Block] = dataclasses.field(init=False, default_factory=dict)
+    # No groups:
+    _re_noninstructions: typing.ClassVar[re.Pattern[str]] = re.compile(
+        r"\s*(?:\.|#|//|$)"
+    )
+    # One group (label):
+    _re_label: typing.ClassVar[re.Pattern[str]] = re.compile(
+        r'\s*(?P<label>[\w."$?@]+):'
+    )
+    # Override everything that follows in subclasses:
+    _alignment: typing.ClassVar[int] = 1
+    _branches: typing.ClassVar[dict[str, str | None]] = {}
+    # Two groups (instruction and target):
+    _re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
+    # One group (target):
+    _re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
+    # No groups:
+    _re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
+
+    def __post_init__(self) -> None:
+        # Split the code into a linked list of basic blocks. A basic block is an
+        # optional label, followed by zero or more non-instruction lines,
+        # followed by zero or more instruction lines (only the last of which may
+        # be a branch, jump, or return):
+        text = self._preprocess(self.path.read_text())
+        block = self._root
+        for line in text.splitlines():
+            # See if we need to start a new block:
+            if match := self._re_label.match(line):
+                # Label. New block:
+                block.link = block = self._lookup_label(match["label"])
+                block.noninstructions.append(line)
+                continue
+            if self._re_noninstructions.match(line):
+                if block.instructions:
+                    # Non-instruction lines. New block:
+                    block.link = block = _Block()
+                block.noninstructions.append(line)
+                continue
+            if block.target or not block.fallthrough:
+                # Current block ends with a branch, jump, or return. New block:
+                block.link = block = _Block()
+            block.instructions.append(line)
+            if match := self._re_branch.match(line):
+                # A block ending in a branch has a target and fallthrough:
+                block.target = self._lookup_label(match["target"])
+                assert block.fallthrough
+            elif match := self._re_jump.match(line):
+                # A block ending in a jump has a target and no fallthrough:
+                block.target = self._lookup_label(match["target"])
+                block.fallthrough = False
+            elif self._re_return.match(line):
+                # A block ending in a return has no target and fallthrough:
+                assert not block.target
+                block.fallthrough = False
+
+    def _preprocess(self, text: str) -> str:
+        # Override this method to do preprocessing of the textual assembly:
+        return text
+
+    @classmethod
+    def _invert_branch(cls, line: str, target: str) -> str | None:
+        match = cls._re_branch.match(line)
+        assert match
+        inverted = cls._branches.get(match["instruction"])
+        if not inverted:
+            return None
+        (a, b), (c, d) = match.span("instruction"), match.span("target")
+        # Before:
+        #     je FOO
+        # After:
+        #     jne BAR
+        return "".join([line[:a], inverted, line[b:c], target, line[d:]])
+
+    @classmethod
+    def _update_jump(cls, line: str, target: str) -> str:
+        match = cls._re_jump.match(line)
+        assert match
+        a, b = match.span("target")
+        # Before:
+        #     jmp FOO
+        # After:
+        #     jmp BAR
+        return "".join([line[:a], target, line[b:]])
+
+    def _lookup_label(self, label: str) -> _Block:
+        if label not in self._labels:
+            self._labels[label] = _Block(label)
+        return self._labels[label]
+
+    def _blocks(self) -> typing.Generator[_Block, None, None]:
+        block: _Block | None = self._root
+        while block:
+            yield block
+            block = block.link
+
+    def _body(self) -> str:
+        lines = []
+        hot = True
+        for block in self._blocks():
+            if hot != block.hot:
+                hot = block.hot
+                # Make it easy to tell at a glance where cold code is:
+                lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#"))
+            lines.extend(block.noninstructions)
+            lines.extend(block.instructions)
+        return "\n".join(lines)
+
+    def _predecessors(self, block: _Block) -> typing.Generator[_Block, None, None]:
+        # This is inefficient, but it's never wrong:
+        for pre in self._blocks():
+            if pre.target is block or pre.fallthrough and pre.link is block:
+                yield pre
+
+    def _insert_continue_label(self) -> None:
+        # Find the block with the last instruction:
+        for end in reversed(list(self._blocks())):
+            if end.instructions:
+                break
+        # Before:
+        #    jmp FOO
+        # After:
+        #    jmp FOO
+        #    .balign 8
+        #    _JIT_CONTINUE:
+        # This lets the assembler encode _JIT_CONTINUE jumps at build time!
+        align = _Block()
+        align.noninstructions.append(f"\t.balign\t{self._alignment}")
+        continuation = self._lookup_label(f"{self.prefix}_JIT_CONTINUE")
+        assert continuation.label
+        continuation.noninstructions.append(f"{continuation.label}:")
+        end.link, align.link, continuation.link = align, continuation, end.link
+
+    def _mark_hot_blocks(self) -> None:
+        # Start with the last block, and perform a DFS to find all blocks that
+        # can eventually reach it:
+        todo = list(self._blocks())[-1:]
+        while todo:
+            block = todo.pop()
+            block.hot = True
+            todo.extend(pre for pre in self._predecessors(block) if not pre.hot)
+
+    def _invert_hot_branches(self) -> None:
+        for branch in self._blocks():
+            link = branch.link
+            if link is None:
+                continue
+            jump = link.resolve()
+            # Before:
+            #    je HOT
+            #    jmp COLD
+            # After:
+            #    jne COLD
+            #    jmp HOT
+            if (
+                # block ends with a branch to hot code...
+                branch.target
+                and branch.fallthrough
+                and branch.target.hot
+                # ...followed by a jump to cold code with no other predecessors:
+                and jump.target
+                and not jump.fallthrough
+                and not jump.target.hot
+                and len(jump.instructions) == 1
+                and list(self._predecessors(jump)) == [branch]
+            ):
+                assert jump.target.label
+                assert branch.target.label
+                inverted = self._invert_branch(
+                    branch.instructions[-1], jump.target.label
+                )
+                # Check to see if the branch can even be inverted:
+                if inverted is None:
+                    continue
+                branch.instructions[-1] = inverted
+                jump.instructions[-1] = self._update_jump(
+                    jump.instructions[-1], branch.target.label
+                )
+                branch.target, jump.target = jump.target, branch.target
+                jump.hot = True
+
+    def _remove_redundant_jumps(self) -> None:
+        # Zero-length jumps can be introduced by _insert_continue_label and
+        # _invert_hot_branches:
+        for block in self._blocks():
+            # Before:
+            #    jmp FOO
+            #    FOO:
+            # After:
+            #    FOO:
+            if (
+                block.target
+                and block.link
+                and block.target.resolve() is block.link.resolve()
+            ):
+                block.target = None
+                block.fallthrough = True
+                block.instructions.pop()
+
+    def run(self) -> None:
+        """Run this optimizer."""
+        self._insert_continue_label()
+        self._mark_hot_blocks()
+        self._invert_hot_branches()
+        self._remove_redundant_jumps()
+        self.path.write_text(self._body())
+
+
+class OptimizerAArch64(Optimizer):  # pylint: disable = too-few-public-methods
+    """aarch64-apple-darwin/aarch64-pc-windows-msvc/aarch64-unknown-linux-gnu"""
+
+    # TODO: @diegorusso
+    _alignment = 8
+    # https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch-
+    _re_jump = re.compile(r"\s*b\s+(?P<target>[\w.]+)")
+
+
+class OptimizerX86(Optimizer):  # pylint: disable = too-few-public-methods
+    """i686-pc-windows-msvc/x86_64-apple-darwin/x86_64-unknown-linux-gnu"""
+
+    _branches = _X86_BRANCHES
+    _re_branch = re.compile(
+        rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
+    )
+    # https://www.felixcloutier.com/x86/jmp
+    _re_jump = re.compile(r"\s*jmp\s+(?P<target>[\w.]+)")
+    # https://www.felixcloutier.com/x86/ret
+    _re_return = re.compile(r"\s*ret\b")
+
+
+class OptimizerX8664Windows(OptimizerX86):  # pylint: disable = too-few-public-methods
+    """x86_64-pc-windows-msvc"""
+
+    def _preprocess(self, text: str) -> str:
+        text = super()._preprocess(text)
+        # Before:
+        #     rex64 jmpq *__imp__JIT_CONTINUE(%rip)
+        # After:
+        #     jmp _JIT_CONTINUE
+        far_indirect_jump = (
+            rf"rex64\s+jmpq\s+\*__imp_(?P<target>{self.prefix}_JIT_\w+)\(%rip\)"
+        )
+        return re.sub(far_indirect_jump, r"jmp\t\g<target>", text)
index 03b0ba647b0db774e0c7249e7b7c080ad9ebc99b..1d82f5366f6ce0b0ef5acdf22db2e0b7dfdc19ef 100644 (file)
@@ -17,8 +17,6 @@ class HoleValue(enum.Enum):
 
     # The base address of the machine code for the current uop (exposed as _JIT_ENTRY):
     CODE = enum.auto()
-    # The base address of the machine code for the next uop (exposed as _JIT_CONTINUE):
-    CONTINUE = enum.auto()
     # The base address of the read-only data for this uop:
     DATA = enum.auto()
     # The address of the current executor (exposed as _JIT_EXECUTOR):
@@ -97,7 +95,6 @@ _PATCH_FUNCS = {
 # Translate HoleValues to C expressions:
 _HOLE_EXPRS = {
     HoleValue.CODE: "(uintptr_t)code",
-    HoleValue.CONTINUE: "(uintptr_t)code + sizeof(code_body)",
     HoleValue.DATA: "(uintptr_t)data",
     HoleValue.EXECUTOR: "(uintptr_t)executor",
     # These should all have been turned into DATA values by process_relocations:
@@ -209,64 +206,6 @@ class Stencil:
             self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}")
         self.body.extend([0] * padding)
 
-    def add_nops(self, nop: bytes, alignment: int) -> None:
-        """Add NOPs until there is alignment. Fail if it is not possible."""
-        offset = len(self.body)
-        nop_size = len(nop)
-
-        # Calculate the gap to the next multiple of alignment.
-        gap = -offset % alignment
-        if gap:
-            if gap % nop_size == 0:
-                count = gap // nop_size
-                self.body.extend(nop * count)
-            else:
-                raise ValueError(
-                    f"Cannot add nops of size '{nop_size}' to a body with "
-                    f"offset '{offset}' to align with '{alignment}'"
-                )
-
-    def remove_jump(self) -> None:
-        """Remove a zero-length continuation jump, if it exists."""
-        hole = max(self.holes, key=lambda hole: hole.offset)
-        match hole:
-            case Hole(
-                offset=offset,
-                kind="IMAGE_REL_AMD64_REL32",
-                value=HoleValue.GOT,
-                symbol="_JIT_CONTINUE",
-                addend=-4,
-            ) as hole:
-                # jmp qword ptr [rip]
-                jump = b"\x48\xff\x25\x00\x00\x00\x00"
-                offset -= 3
-            case Hole(
-                offset=offset,
-                kind="IMAGE_REL_I386_REL32" | "R_X86_64_PLT32" | "X86_64_RELOC_BRANCH",
-                value=HoleValue.CONTINUE,
-                symbol=None,
-                addend=addend,
-            ) as hole if (
-                _signed(addend) == -4
-            ):
-                # jmp 5
-                jump = b"\xe9\x00\x00\x00\x00"
-                offset -= 1
-            case Hole(
-                offset=offset,
-                kind="R_AARCH64_JUMP26",
-                value=HoleValue.CONTINUE,
-                symbol=None,
-                addend=0,
-            ) as hole:
-                # b #4
-                jump = b"\x00\x00\x00\x14"
-            case _:
-                return
-        if self.body[offset:] == jump:
-            self.body = self.body[:offset]
-            self.holes.remove(hole)
-
 
 @dataclasses.dataclass
 class StencilGroup:
@@ -284,9 +223,7 @@ class StencilGroup:
     _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
     _trampolines: set[int] = dataclasses.field(default_factory=set, init=False)
 
-    def process_relocations(
-        self, known_symbols: dict[str, int], *, alignment: int = 1, nop: bytes = b""
-    ) -> None:
+    def process_relocations(self, known_symbols: dict[str, int]) -> None:
         """Fix up all GOT and internal relocations for this stencil group."""
         for hole in self.code.holes.copy():
             if (
@@ -306,8 +243,6 @@ class StencilGroup:
                 self._trampolines.add(ordinal)
                 hole.addend = ordinal
                 hole.symbol = None
-        self.code.remove_jump()
-        self.code.add_nops(nop=nop, alignment=alignment)
         self.data.pad(8)
         for stencil in [self.code, self.data]:
             for hole in stencil.holes:
index b383e39da194562c72c707cf777aa30d4e69f5df..ed10329d25d2f961de54053932ae6a85cab4c075 100644 (file)
@@ -13,6 +13,7 @@ import typing
 import shlex
 
 import _llvm
+import _optimizers
 import _schema
 import _stencils
 import _writer
@@ -41,8 +42,8 @@ class _Target(typing.Generic[_S, _R]):
     triple: str
     condition: str
     _: dataclasses.KW_ONLY
-    alignment: int = 1
     args: typing.Sequence[str] = ()
+    optimizer: type[_optimizers.Optimizer] = _optimizers.Optimizer
     prefix: str = ""
     stable: bool = False
     debug: bool = False
@@ -121,8 +122,9 @@ class _Target(typing.Generic[_S, _R]):
     async def _compile(
         self, opname: str, c: pathlib.Path, tempdir: pathlib.Path
     ) -> _stencils.StencilGroup:
+        s = tempdir / f"{opname}.s"
         o = tempdir / f"{opname}.o"
-        args = [
+        args_s = [
             f"--target={self.triple}",
             "-DPy_BUILD_CORE_MODULE",
             "-D_DEBUG" if self.debug else "-DNDEBUG",
@@ -136,7 +138,7 @@ class _Target(typing.Generic[_S, _R]):
             f"-I{CPYTHON / 'Python'}",
             f"-I{CPYTHON / 'Tools' / 'jit'}",
             "-O3",
-            "-c",
+            "-S",
             # Shorten full absolute file paths in the generated code (like the
             # __FILE__ macro and assert failure messages) for reproducibility:
             f"-ffile-prefix-map={CPYTHON}=.",
@@ -155,13 +157,16 @@ class _Target(typing.Generic[_S, _R]):
             "-fno-stack-protector",
             "-std=c11",
             "-o",
-            f"{o}",
+            f"{s}",
             f"{c}",
             *self.args,
             # Allow user-provided CFLAGS to override any defaults
             *shlex.split(self.cflags),
         ]
-        await _llvm.run("clang", args, echo=self.verbose)
+        await _llvm.run("clang", args_s, echo=self.verbose)
+        self.optimizer(s, prefix=self.prefix).run()
+        args_o = [f"--target={self.triple}", "-c", "-o", f"{o}", f"{s}"]
+        await _llvm.run("clang", args_o, echo=self.verbose)
         return await self._parse(o)
 
     async def _build_stencils(self) -> dict[str, _stencils.StencilGroup]:
@@ -190,11 +195,7 @@ class _Target(typing.Generic[_S, _R]):
                     tasks.append(group.create_task(coro, name=opname))
         stencil_groups = {task.get_name(): task.result() for task in tasks}
         for stencil_group in stencil_groups.values():
-            stencil_group.process_relocations(
-                known_symbols=self.known_symbols,
-                alignment=self.alignment,
-                nop=self._get_nop(),
-            )
+            stencil_group.process_relocations(self.known_symbols)
         return stencil_groups
 
     def build(
@@ -524,42 +525,43 @@ class _MachO(
 
 def get_target(host: str) -> _COFF | _ELF | _MachO:
     """Build a _Target for the given host "triple" and options."""
+    optimizer: type[_optimizers.Optimizer]
     target: _COFF | _ELF | _MachO
     if re.fullmatch(r"aarch64-apple-darwin.*", host):
         condition = "defined(__aarch64__) && defined(__APPLE__)"
-        target = _MachO(host, condition, alignment=8, prefix="_")
+        optimizer = _optimizers.OptimizerAArch64
+        target = _MachO(host, condition, optimizer=optimizer, prefix="_")
     elif re.fullmatch(r"aarch64-pc-windows-msvc", host):
         args = ["-fms-runtime-lib=dll", "-fplt"]
         condition = "defined(_M_ARM64)"
-        target = _COFF(host, condition, alignment=8, args=args)
+        optimizer = _optimizers.OptimizerAArch64
+        target = _COFF(host, condition, args=args, optimizer=optimizer)
     elif re.fullmatch(r"aarch64-.*-linux-gnu", host):
-        args = [
-            "-fpic",
-            # On aarch64 Linux, intrinsics were being emitted and this flag
-            # was required to disable them.
-            "-mno-outline-atomics",
-        ]
+        # -mno-outline-atomics: Keep intrinsics from being emitted.
+        args = ["-fpic", "-mno-outline-atomics"]
         condition = "defined(__aarch64__) && defined(__linux__)"
-        target = _ELF(host, condition, alignment=8, args=args)
+        optimizer = _optimizers.OptimizerAArch64
+        target = _ELF(host, condition, args=args, optimizer=optimizer)
     elif re.fullmatch(r"i686-pc-windows-msvc", host):
-        args = [
-            "-DPy_NO_ENABLE_SHARED",
-            # __attribute__((preserve_none)) is not supported
-            "-Wno-ignored-attributes",
-        ]
+        # -Wno-ignored-attributes: __attribute__((preserve_none)) is not supported here.
+        args = ["-DPy_NO_ENABLE_SHARED", "-Wno-ignored-attributes"]
+        optimizer = _optimizers.OptimizerX86
         condition = "defined(_M_IX86)"
-        target = _COFF(host, condition, args=args, prefix="_")
+        target = _COFF(host, condition, args=args, optimizer=optimizer, prefix="_")
     elif re.fullmatch(r"x86_64-apple-darwin.*", host):
         condition = "defined(__x86_64__) && defined(__APPLE__)"
-        target = _MachO(host, condition, prefix="_")
+        optimizer = _optimizers.OptimizerX86
+        target = _MachO(host, condition, optimizer=optimizer, prefix="_")
     elif re.fullmatch(r"x86_64-pc-windows-msvc", host):
         args = ["-fms-runtime-lib=dll"]
         condition = "defined(_M_X64)"
-        target = _COFF(host, condition, args=args)
+        optimizer = _optimizers.OptimizerX8664Windows
+        target = _COFF(host, condition, args=args, optimizer=optimizer)
     elif re.fullmatch(r"x86_64-.*-linux-gnu", host):
         args = ["-fno-pic", "-mcmodel=medium", "-mlarge-data-threshold=0"]
         condition = "defined(__x86_64__) && defined(__linux__)"
-        target = _ELF(host, condition, args=args)
+        optimizer = _optimizers.OptimizerX86
+        target = _ELF(host, condition, args=args, optimizer=optimizer)
     else:
         raise ValueError(host)
     return target