"""Low-level optimization of textual assembly."""
import dataclasses
+import enum
import pathlib
import re
import typing
# MyPy doesn't understand that a invariant variable can be initialized by a covariant value
CUSTOM_AARCH64_BRANCH19: str | None = "CUSTOM_AARCH64_BRANCH19"
-# Branches are either b.{cond} or bc.{cond}
-_AARCH64_BRANCHES: dict[str, tuple[str | None, str | None]] = {
- "b." + cond: (("b." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
- for (cond, inverse) in _AARCH64_COND_CODES.items()
-} | {
- "bc." + cond: (("bc." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
- for (cond, inverse) in _AARCH64_COND_CODES.items()
+_AARCH64_SHORT_BRANCHES = {
+ "tbz": "tbnz",
+ "tbnz": "tbz",
}
+# Branches are either b.{cond}, bc.{cond}, cbz, cbnz, tbz or tbnz
+_AARCH64_BRANCHES: dict[str, tuple[str | None, str | None]] = (
+ {
+ "b." + cond: (("b." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
+ for (cond, inverse) in _AARCH64_COND_CODES.items()
+ }
+ | {
+ "bc." + cond: (("bc." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19)
+ for (cond, inverse) in _AARCH64_COND_CODES.items()
+ }
+ | {
+ "cbz": ("cbnz", CUSTOM_AARCH64_BRANCH19),
+ "cbnz": ("cbz", CUSTOM_AARCH64_BRANCH19),
+ }
+ | {cond: (inverse, None) for (cond, inverse) in _AARCH64_SHORT_BRANCHES.items()}
+)
+
+
+@enum.unique
+class InstructionKind(enum.Enum):
+
+ JUMP = enum.auto()
+ LONG_BRANCH = enum.auto()
+ SHORT_BRANCH = enum.auto()
+ RETURN = enum.auto()
+ OTHER = enum.auto()
+
@dataclasses.dataclass
+class Instruction:
+ kind: InstructionKind
+ name: str
+ text: str
+ target: str | None
+
+ def is_branch(self) -> bool:
+ return self.kind in (InstructionKind.LONG_BRANCH, InstructionKind.SHORT_BRANCH)
+
+ def update_target(self, target: str) -> "Instruction":
+ assert self.target is not None
+ return Instruction(
+ self.kind, self.name, self.text.replace(self.target, target), target
+ )
+
+ def update_name_and_target(self, name: str, target: str) -> "Instruction":
+ assert self.target is not None
+ return Instruction(
+ self.kind,
+ name,
+ self.text.replace(self.name, name).replace(self.target, target),
+ target,
+ )
+
+
+@dataclasses.dataclass(eq=False)
class _Block:
label: str | None = None
# Non-instruction lines like labels, directives, and comments:
noninstructions: list[str] = dataclasses.field(default_factory=list)
# Instruction lines:
- instructions: list[str] = dataclasses.field(default_factory=list)
+ instructions: list[Instruction] = dataclasses.field(default_factory=list)
# If this block ends in a jump, where to?
target: typing.Self | None = None
# The next block in the linked list:
# Prefixes used to mangle local labels and symbols:
label_prefix: str
symbol_prefix: str
+ re_global: re.Pattern[str]
# The first block in the linked list:
_root: _Block = dataclasses.field(init=False, default_factory=_Block)
_labels: dict[str, _Block] = dataclasses.field(init=False, default_factory=dict)
# Override everything that follows in subclasses:
_supports_external_relocations = True
_branches: typing.ClassVar[dict[str, tuple[str | None, str | None]]] = {}
+ # Short branches are instructions that can branch within a micro-op,
+ # but might not have the reach to branch anywhere within a trace.
+ _short_branches: typing.ClassVar[dict[str, str]] = {}
# Two groups (instruction and target):
_re_branch: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
# One group (target):
_re_jump: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
# No groups:
_re_return: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
+ text: str = ""
+ globals: set[str] = dataclasses.field(default_factory=set)
def __post_init__(self) -> None:
# Split the code into a linked list of basic blocks. A basic block is an
# optional label, followed by zero or more non-instruction lines,
# followed by zero or more instruction lines (only the last of which may
# be a branch, jump, or return):
- text = self._preprocess(self.path.read_text())
+ self.text = self._preprocess(self.path.read_text())
block = self._root
- for line in text.splitlines():
+ for line in self.text.splitlines():
# See if we need to start a new block:
if match := self._re_label.match(line):
# Label. New block:
block.link = block = self._lookup_label(match["label"])
block.noninstructions.append(line)
continue
+ if match := self.re_global.match(line):
+ self.globals.add(match["label"])
+ block.noninstructions.append(line)
+ continue
if self._re_noninstructions.match(line):
if block.instructions:
# Non-instruction lines. New block:
if block.target or not block.fallthrough:
# Current block ends with a branch, jump, or return. New block:
block.link = block = _Block()
- block.instructions.append(line)
- if match := self._re_branch.match(line):
+ inst = self._parse_instruction(line)
+ block.instructions.append(inst)
+ if inst.is_branch():
# A block ending in a branch has a target and fallthrough:
- block.target = self._lookup_label(match["target"])
+ assert inst.target is not None
+ block.target = self._lookup_label(inst.target)
assert block.fallthrough
- elif match := self._re_jump.match(line):
+ elif inst.kind == InstructionKind.JUMP:
# A block ending in a jump has a target and no fallthrough:
- block.target = self._lookup_label(match["target"])
+ assert inst.target is not None
+ block.target = self._lookup_label(inst.target)
block.fallthrough = False
- elif self._re_return.match(line):
+ elif inst.kind == InstructionKind.RETURN:
# A block ending in a return has no target and fallthrough:
assert not block.target
block.fallthrough = False
continue_label = f"{self.label_prefix}_JIT_CONTINUE"
return re.sub(continue_symbol, continue_label, text)
- @classmethod
- def _invert_branch(cls, line: str, target: str) -> str | None:
- match = cls._re_branch.match(line)
- assert match
- inverted_reloc = cls._branches.get(match["instruction"])
+ def _parse_instruction(self, line: str) -> Instruction:
+ target = None
+ if match := self._re_branch.match(line):
+ target = match["target"]
+ name = match["instruction"]
+ if name in self._short_branches:
+ kind = InstructionKind.SHORT_BRANCH
+ else:
+ kind = InstructionKind.LONG_BRANCH
+ elif match := self._re_jump.match(line):
+ target = match["target"]
+ name = line[: -len(target)].strip()
+ kind = InstructionKind.JUMP
+ elif match := self._re_return.match(line):
+ name = line
+ kind = InstructionKind.RETURN
+ else:
+ name, *_ = line.split(" ")
+ kind = InstructionKind.OTHER
+ return Instruction(kind, name, line, target)
+
+ def _invert_branch(self, inst: Instruction, target: str) -> Instruction | None:
+ assert inst.is_branch()
+ if inst.kind == InstructionKind.SHORT_BRANCH and self._is_far_target(target):
+ return None
+ inverted_reloc = self._branches.get(inst.name)
if inverted_reloc is None:
return None
inverted = inverted_reloc[0]
if not inverted:
return None
- (a, b), (c, d) = match.span("instruction"), match.span("target")
- # Before:
- # je FOO
- # After:
- # jne BAR
- return "".join([line[:a], inverted, line[b:c], target, line[d:]])
-
- @classmethod
- def _update_jump(cls, line: str, target: str) -> str:
- match = cls._re_jump.match(line)
- assert match
- a, b = match.span("target")
- # Before:
- # jmp FOO
- # After:
- # jmp BAR
- return "".join([line[:a], target, line[b:]])
+ return inst.update_name_and_target(inverted, target)
def _lookup_label(self, label: str) -> _Block:
if label not in self._labels:
self._labels[label] = _Block(label)
return self._labels[label]
+ def _is_far_target(self, label: str) -> bool:
+ return not label.startswith(self.label_prefix)
+
def _blocks(self) -> typing.Generator[_Block, None, None]:
block: _Block | None = self._root
while block:
block = block.link
def _body(self) -> str:
- lines = []
+ lines = ["#" + line for line in self.text.splitlines()]
hot = True
for block in self._blocks():
if hot != block.hot:
# Make it easy to tell at a glance where cold code is:
lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#"))
lines.extend(block.noninstructions)
- lines.extend(block.instructions)
+ for inst in block.instructions:
+ lines.append(inst.text)
return "\n".join(lines)
def _predecessors(self, block: _Block) -> typing.Generator[_Block, None, None]:
if inverted is None:
continue
branch.instructions[-1] = inverted
- jump.instructions[-1] = self._update_jump(
- jump.instructions[-1], branch.target.label
+ jump.instructions[-1] = jump.instructions[-1].update_target(
+ branch.target.label
)
branch.target, jump.target = jump.target, branch.target
jump.hot = True
# Zero-length jumps can be introduced by _insert_continue_label and
# _invert_hot_branches:
for block in self._blocks():
+ target = block.target
+ if target is None:
+ continue
+ target = target.resolve()
# Before:
# jmp FOO
# FOO:
# After:
# FOO:
- if (
- block.target
- and block.link
- and block.target.resolve() is block.link.resolve()
- ):
+ if block.link and target is block.link.resolve():
block.target = None
block.fallthrough = True
block.instructions.pop()
+ # Before:
+ # br ? FOO:
+ # ...
+ # FOO:
+ # jump BAR
+ # After:
+ # br cond BAR
+ # ...
+ elif (
+ len(target.instructions) == 1
+ and target.instructions[0].kind == InstructionKind.JUMP
+ ):
+ assert target.target is not None
+ assert target.target.label is not None
+ if block.instructions[
+ -1
+ ].kind == InstructionKind.SHORT_BRANCH and self._is_far_target(
+ target.target.label
+ ):
+ continue
+ block.target = target.target
+ block.instructions[-1] = block.instructions[-1].update_target(
+ target.target.label
+ )
+
+ def _find_live_blocks(self) -> set[_Block]:
+ live: set[_Block] = set()
+ # Externally reachable blocks are live
+ todo: set[_Block] = {b for b in self._blocks() if b.label in self.globals}
+ while todo:
+ block = todo.pop()
+ live.add(block)
+ if block.fallthrough:
+ next = block.link
+ if next is not None and next not in live:
+ todo.add(next)
+ next = block.target
+ if next is not None and next not in live:
+ todo.add(next)
+ return live
+
+ def _remove_unreachable(self) -> None:
+ live = self._find_live_blocks()
+ continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")
+ # Keep blocks after continuation as they may contain data and
+ # metadata that the assembler needs
+ prev: _Block | None = None
+ block = self._root
+ while block is not continuation:
+ next = block.link
+ assert next is not None
+ if not block in live and prev:
+ prev.link = next
+ else:
+ prev = block
+ block = next
+ assert prev.link is block
def _fixup_external_labels(self) -> None:
if self._supports_external_relocations:
# Nothing to fix up
return
- for block in self._blocks():
+ for index, block in enumerate(self._blocks()):
if block.target and block.fallthrough:
branch = block.instructions[-1]
- match = self._re_branch.match(branch)
- assert match is not None
- target = match["target"]
- reloc = self._branches[match["instruction"]][1]
- if reloc is not None and not target.startswith(self.label_prefix):
+ assert branch.is_branch()
+ target = branch.target
+ assert target is not None
+ reloc = self._branches[branch.name][1]
+ if reloc is not None and self._is_far_target(target):
name = target[len(self.symbol_prefix) :]
- block.instructions[-1] = (
- f"// target='{target}' prefix='{self.label_prefix}'"
- )
- block.instructions.append(
- f"{self.symbol_prefix}{reloc}_JIT_RELOCATION_{name}:"
+ label = f"{self.symbol_prefix}{reloc}_JIT_RELOCATION_{name}_JIT_RELOCATION_{index}:"
+ block.instructions[-1] = Instruction(
+ InstructionKind.OTHER, "", label, None
)
- a, b = match.span("target")
- branch = "".join([branch[:a], "0", branch[b:]])
- block.instructions.append(branch)
+ block.instructions.append(branch.update_target("0"))
def run(self) -> None:
"""Run this optimizer."""
self._insert_continue_label()
self._mark_hot_blocks()
- self._invert_hot_branches()
- self._remove_redundant_jumps()
+ # Removing branches can expose opportunities for more branch removal.
+ # Repeat a few times. 2 would probably do, but it's fast enough with 4.
+ for _ in range(4):
+ self._invert_hot_branches()
+ self._remove_redundant_jumps()
+ self._remove_unreachable()
self._fixup_external_labels()
self.path.write_text(self._body())
"""aarch64-pc-windows-msvc/aarch64-apple-darwin/aarch64-unknown-linux-gnu"""
_branches = _AARCH64_BRANCHES
+ _short_branches = _AARCH64_SHORT_BRANCHES
# Mach-O does not support the 19 bit branch locations needed for branch reordering
_supports_external_relocations = False
+ _branch_patterns = [name.replace(".", r"\.") for name in _AARCH64_BRANCHES]
_re_branch = re.compile(
- rf"\s*(?P<instruction>{'|'.join(_AARCH64_BRANCHES)})\s+(.+,\s+)*(?P<target>[\w.]+)"
+ rf"\s*(?P<instruction>{'|'.join(_branch_patterns)})\s+(.+,\s+)*(?P<target>[\w.]+)"
)
# https://developer.arm.com/documentation/ddi0602/2025-03/Base-Instructions/B--Branch-
"""i686-pc-windows-msvc/x86_64-apple-darwin/x86_64-unknown-linux-gnu"""
_branches = _X86_BRANCHES
+ _short_branches = {}
_re_branch = re.compile(
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
)