]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-112720: Move dis's cache output code to the Formatter, labels lookup to the arg_re...
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Fri, 15 Dec 2023 12:28:22 +0000 (12:28 +0000)
committerGitHub <noreply@github.com>
Fri, 15 Dec 2023 12:28:22 +0000 (12:28 +0000)
Lib/dis.py
Lib/test/test_dis.py

index 183091cb0d609882865a0c3e55929afe926fe5d1..1a2f1032d500afe94e3361f728f1c0d9fa0a3cc1 100644 (file)
@@ -113,7 +113,14 @@ def dis(x=None, *, file=None, depth=None, show_caches=False, adaptive=False,
     elif hasattr(x, 'co_code'): # Code object
         _disassemble_recursive(x, file=file, depth=depth, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
     elif isinstance(x, (bytes, bytearray)): # Raw bytecode
-        _disassemble_bytes(x, file=file, show_caches=show_caches, show_offsets=show_offsets)
+        labels_map = _make_labels_map(x)
+        label_width = 4 + len(str(len(labels_map)))
+        formatter = Formatter(file=file,
+                              offset_width=len(str(max(len(x) - 2, 9999))) if show_offsets else 0,
+                              label_width=label_width,
+                              show_caches=show_caches)
+        arg_resolver = ArgResolver(labels_map=labels_map)
+        _disassemble_bytes(x, arg_resolver=arg_resolver, formatter=formatter)
     elif isinstance(x, str):    # Source code
         _disassemble_str(x, file=file, depth=depth, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
     else:
@@ -394,23 +401,41 @@ class Instruction(_Instruction):
 class Formatter:
 
     def __init__(self, file=None, lineno_width=0, offset_width=0, label_width=0,
-                       line_offset=0):
+                       line_offset=0, show_caches=False):
         """Create a Formatter
 
         *file* where to write the output
         *lineno_width* sets the width of the line number field (0 omits it)
         *offset_width* sets the width of the instruction offset field
         *label_width* sets the width of the label field
+        *show_caches* is a boolean indicating whether to display cache lines
 
-        *line_offset* the line number (within the code unit)
         """
         self.file = file
         self.lineno_width = lineno_width
         self.offset_width = offset_width
         self.label_width = label_width
-
+        self.show_caches = show_caches
 
     def print_instruction(self, instr, mark_as_current=False):
+        self.print_instruction_line(instr, mark_as_current)
+        if self.show_caches and instr.cache_info:
+            offset = instr.offset
+            for name, size, data in instr.cache_info:
+                for i in range(size):
+                    offset += 2
+                    # Only show the fancy argrepr for a CACHE instruction when it's
+                    # the first entry for a particular cache value:
+                    if i == 0:
+                        argrepr = f"{name}: {int.from_bytes(data, sys.byteorder)}"
+                    else:
+                        argrepr = ""
+                    self.print_instruction_line(
+                        Instruction("CACHE", CACHE, 0, None, argrepr, offset, offset,
+                                    False, None, None, instr.positions),
+                        False)
+
+    def print_instruction_line(self, instr, mark_as_current):
         """Format instruction details for inclusion in disassembly output."""
         lineno_width = self.lineno_width
         offset_width = self.offset_width
@@ -474,11 +499,14 @@ class Formatter:
 
 
 class ArgResolver:
-    def __init__(self, co_consts, names, varname_from_oparg, labels_map):
+    def __init__(self, co_consts=None, names=None, varname_from_oparg=None, labels_map=None):
         self.co_consts = co_consts
         self.names = names
         self.varname_from_oparg = varname_from_oparg
-        self.labels_map = labels_map
+        self.labels_map = labels_map or {}
+
+    def get_label_for_offset(self, offset):
+        return self.labels_map.get(offset, None)
 
     def get_argval_argrepr(self, op, arg, offset):
         get_name = None if self.names is None else self.names.__getitem__
@@ -547,8 +575,7 @@ class ArgResolver:
                 argrepr = _intrinsic_2_descs[arg]
         return argval, argrepr
 
-
-def get_instructions(x, *, first_line=None, show_caches=False, adaptive=False):
+def get_instructions(x, *, first_line=None, show_caches=None, adaptive=False):
     """Iterator for the opcodes in methods, functions or code
 
     Generates a series of Instruction named tuples giving the details of
@@ -567,9 +594,10 @@ def get_instructions(x, *, first_line=None, show_caches=False, adaptive=False):
         line_offset = 0
 
     original_code = co.co_code
-    labels_map = _make_labels_map(original_code)
-    arg_resolver = ArgResolver(co.co_consts, co.co_names, co._varname_from_oparg,
-                               labels_map)
+    arg_resolver = ArgResolver(co_consts=co.co_consts,
+                               names=co.co_names,
+                               varname_from_oparg=co._varname_from_oparg,
+                               labels_map=_make_labels_map(original_code))
     return _get_instructions_bytes(_get_code_array(co, adaptive),
                                    linestarts=linestarts,
                                    line_offset=line_offset,
@@ -648,7 +676,7 @@ def _is_backward_jump(op):
                           'ENTER_EXECUTOR')
 
 def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=None,
-                            original_code=None, labels_map=None, arg_resolver=None):
+                            original_code=None, arg_resolver=None):
     """Iterate over the instructions in a bytecode string.
 
     Generates a sequence of Instruction namedtuples giving the details of each
@@ -661,8 +689,6 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
     original_code = original_code or code
     co_positions = co_positions or iter(())
 
-    labels_map = labels_map or _make_labels_map(original_code)
-
     starts_line = False
     local_line_number = None
     line_number = None
@@ -684,10 +710,6 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
         else:
             argval, argrepr = arg, repr(arg)
 
-        instr = Instruction(_all_opname[op], op, arg, argval, argrepr,
-                            offset, start_offset, starts_line, line_number,
-                            labels_map.get(offset, None), positions)
-
         caches = _get_cache_size(_all_opname[deop])
         # Advance the co_positions iterator:
         for _ in range(caches):
@@ -701,10 +723,10 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
         else:
             cache_info = None
 
+        label = arg_resolver.get_label_for_offset(offset) if arg_resolver else None
         yield Instruction(_all_opname[op], op, arg, argval, argrepr,
                           offset, start_offset, starts_line, line_number,
-                          labels_map.get(offset, None), positions, cache_info)
-
+                          label, positions, cache_info)
 
 
 def disassemble(co, lasti=-1, *, file=None, show_caches=False, adaptive=False,
@@ -712,12 +734,20 @@ def disassemble(co, lasti=-1, *, file=None, show_caches=False, adaptive=False,
     """Disassemble a code object."""
     linestarts = dict(findlinestarts(co))
     exception_entries = _parse_exception_table(co)
-    _disassemble_bytes(_get_code_array(co, adaptive),
-                       lasti, co._varname_from_oparg,
-                       co.co_names, co.co_consts, linestarts, file=file,
-                       exception_entries=exception_entries,
-                       co_positions=co.co_positions(), show_caches=show_caches,
-                       original_code=co.co_code, show_offsets=show_offsets)
+    labels_map = _make_labels_map(co.co_code, exception_entries=exception_entries)
+    label_width = 4 + len(str(len(labels_map)))
+    formatter = Formatter(file=file,
+                          lineno_width=_get_lineno_width(linestarts),
+                          offset_width=len(str(max(len(co.co_code) - 2, 9999))) if show_offsets else 0,
+                          label_width=label_width,
+                          show_caches=show_caches)
+    arg_resolver = ArgResolver(co_consts=co.co_consts,
+                               names=co.co_names,
+                               varname_from_oparg=co._varname_from_oparg,
+                               labels_map=labels_map)
+    _disassemble_bytes(_get_code_array(co, adaptive), lasti, linestarts,
+                       exception_entries=exception_entries, co_positions=co.co_positions(),
+                       original_code=co.co_code, arg_resolver=arg_resolver, formatter=formatter)
 
 def _disassemble_recursive(co, *, file=None, depth=None, show_caches=False, adaptive=False, show_offsets=False):
     disassemble(co, file=file, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
@@ -764,60 +794,29 @@ def _get_lineno_width(linestarts):
     return lineno_width
 
 
-def _disassemble_bytes(code, lasti=-1, varname_from_oparg=None,
-                       names=None, co_consts=None, linestarts=None,
-                       *, file=None, line_offset=0, exception_entries=(),
-                       co_positions=None, show_caches=False, original_code=None,
-                       show_offsets=False):
-
-    offset_width = len(str(max(len(code) - 2, 9999))) if show_offsets else 0
-
-    labels_map = _make_labels_map(original_code or code, exception_entries)
-    label_width = 4 + len(str(len(labels_map)))
+def _disassemble_bytes(code, lasti=-1, linestarts=None,
+                       *, line_offset=0, exception_entries=(),
+                       co_positions=None, original_code=None,
+                       arg_resolver=None, formatter=None):
 
-    formatter = Formatter(file=file,
-                          lineno_width=_get_lineno_width(linestarts),
-                          offset_width=offset_width,
-                          label_width=label_width,
-                          line_offset=line_offset)
+    assert formatter is not None
+    assert arg_resolver is not None
 
-    arg_resolver = ArgResolver(co_consts, names, varname_from_oparg, labels_map)
     instrs = _get_instructions_bytes(code, linestarts=linestarts,
                                            line_offset=line_offset,
                                            co_positions=co_positions,
                                            original_code=original_code,
-                                           labels_map=labels_map,
                                            arg_resolver=arg_resolver)
 
-    print_instructions(instrs, exception_entries, formatter,
-                       show_caches=show_caches, lasti=lasti)
+    print_instructions(instrs, exception_entries, formatter, lasti=lasti)
 
 
-def print_instructions(instrs, exception_entries, formatter, show_caches=False, lasti=-1):
+def print_instructions(instrs, exception_entries, formatter, lasti=-1):
     for instr in instrs:
-        if show_caches:
-            is_current_instr = instr.offset == lasti
-        else:
-            # Each CACHE takes 2 bytes
-            is_current_instr = instr.offset <= lasti \
-                <= instr.offset + 2 * _get_cache_size(_all_opname[_deoptop(instr.opcode)])
+        # Each CACHE takes 2 bytes
+        is_current_instr = instr.offset <= lasti \
+            <= instr.offset + 2 * _get_cache_size(_all_opname[_deoptop(instr.opcode)])
         formatter.print_instruction(instr, is_current_instr)
-        deop = _deoptop(instr.opcode)
-        if show_caches and instr.cache_info:
-            offset = instr.offset
-            for name, size, data in instr.cache_info:
-                for i in range(size):
-                    offset += 2
-                    # Only show the fancy argrepr for a CACHE instruction when it's
-                    # the first entry for a particular cache value:
-                    if i == 0:
-                        argrepr = f"{name}: {int.from_bytes(data, sys.byteorder)}"
-                    else:
-                        argrepr = ""
-                    formatter.print_instruction(
-                        Instruction("CACHE", CACHE, 0, None, argrepr, offset, offset,
-                                    False, None, None, instr.positions),
-                        is_current_instr)
 
     formatter.print_exception_table(exception_entries)
 
@@ -960,14 +959,15 @@ class Bytecode:
         co = self.codeobj
         original_code = co.co_code
         labels_map = _make_labels_map(original_code, self.exception_entries)
-        arg_resolver = ArgResolver(co.co_consts, co.co_names, co._varname_from_oparg,
-                                   labels_map)
+        arg_resolver = ArgResolver(co_consts=co.co_consts,
+                                   names=co.co_names,
+                                   varname_from_oparg=co._varname_from_oparg,
+                                   labels_map=labels_map)
         return _get_instructions_bytes(_get_code_array(co, self.adaptive),
                                        linestarts=self._linestarts,
                                        line_offset=self._line_offset,
                                        co_positions=co.co_positions(),
                                        original_code=original_code,
-                                       labels_map=labels_map,
                                        arg_resolver=arg_resolver)
 
     def __repr__(self):
@@ -995,18 +995,32 @@ class Bytecode:
         else:
             offset = -1
         with io.StringIO() as output:
-            _disassemble_bytes(_get_code_array(co, self.adaptive),
-                               varname_from_oparg=co._varname_from_oparg,
-                               names=co.co_names, co_consts=co.co_consts,
+            code = _get_code_array(co, self.adaptive)
+            offset_width = len(str(max(len(code) - 2, 9999))) if self.show_offsets else 0
+
+
+            labels_map = _make_labels_map(co.co_code, self.exception_entries)
+            label_width = 4 + len(str(len(labels_map)))
+            formatter = Formatter(file=output,
+                                  lineno_width=_get_lineno_width(self._linestarts),
+                                  offset_width=offset_width,
+                                  label_width=label_width,
+                                  line_offset=self._line_offset,
+                                  show_caches=self.show_caches)
+
+            arg_resolver = ArgResolver(co_consts=co.co_consts,
+                                       names=co.co_names,
+                                       varname_from_oparg=co._varname_from_oparg,
+                                       labels_map=labels_map)
+            _disassemble_bytes(code,
                                linestarts=self._linestarts,
                                line_offset=self._line_offset,
-                               file=output,
                                lasti=offset,
                                exception_entries=self.exception_entries,
                                co_positions=co.co_positions(),
-                               show_caches=self.show_caches,
                                original_code=co.co_code,
-                               show_offsets=self.show_offsets)
+                               arg_resolver=arg_resolver,
+                               formatter=formatter)
             return output.getvalue()
 
 
index 12e2c57e50b0ba5efe3935a4b72ab9b737ef614b..0c7fd60f640854d395741c1a747fa381ab33c56d 100644 (file)
@@ -2,6 +2,7 @@
 
 import contextlib
 import dis
+import functools
 import io
 import re
 import sys
@@ -1982,19 +1983,27 @@ class InstructionTests(InstructionTestCase):
         self.assertEqual(f(opcode.opmap["BINARY_OP"], 3, *args), (3, '<<'))
         self.assertEqual(f(opcode.opmap["CALL_INTRINSIC_1"], 2, *args), (2, 'INTRINSIC_IMPORT_STAR'))
 
+    def get_instructions(self, code):
+        return dis._get_instructions_bytes(code)
+
     def test_start_offset(self):
         # When no extended args are present,
         # start_offset should be equal to offset
+
         instructions = list(dis.Bytecode(_f))
         for instruction in instructions:
             self.assertEqual(instruction.offset, instruction.start_offset)
 
+        def last_item(iterable):
+            return functools.reduce(lambda a, b : b, iterable)
+
         code = bytes([
             opcode.opmap["LOAD_FAST"], 0x00,
             opcode.opmap["EXTENDED_ARG"], 0x01,
             opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
         ])
-        jump = list(dis._get_instructions_bytes(code))[-1]
+        labels_map = dis._make_labels_map(code)
+        jump = last_item(self.get_instructions(code))
         self.assertEqual(4, jump.offset)
         self.assertEqual(2, jump.start_offset)
 
@@ -2006,7 +2015,7 @@ class InstructionTests(InstructionTestCase):
             opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
             opcode.opmap["CACHE"], 0x00,
         ])
-        jump = list(dis._get_instructions_bytes(code))[-1]
+        jump = last_item(self.get_instructions(code))
         self.assertEqual(8, jump.offset)
         self.assertEqual(2, jump.start_offset)
 
@@ -2021,7 +2030,7 @@ class InstructionTests(InstructionTestCase):
             opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
             opcode.opmap["CACHE"], 0x00,
         ])
-        instructions = list(dis._get_instructions_bytes(code))
+        instructions = list(self.get_instructions(code))
         # 1st jump
         self.assertEqual(4, instructions[2].offset)
         self.assertEqual(2, instructions[2].start_offset)
@@ -2042,7 +2051,7 @@ class InstructionTests(InstructionTestCase):
             opcode.opmap["CACHE"], 0x00,
             opcode.opmap["CACHE"], 0x00
         ])
-        instructions = list(dis._get_instructions_bytes(code))
+        instructions = list(self.get_instructions(code))
         self.assertEqual(2, instructions[0].cache_offset)
         self.assertEqual(10, instructions[0].end_offset)
         self.assertEqual(12, instructions[1].cache_offset)