From 9bb03a890e42d08fcb6cb2768cdbb8c251bcabc6 Mon Sep 17 00:00:00 2001 From: Ken Jin <28750310+Fidget-Spinner@users.noreply.github.com> Date: Sun, 19 Oct 2025 20:26:01 +0100 Subject: [PATCH] Handle EXTENDED_ARG --- Python/bytecodes.c | 4 +- Python/ceval_macros.h | 2 +- Python/executor_cases.c.h | 3 +- Python/generated_tracer_cases.c.h | 1 + Python/optimizer.c | 67 +++++++++++++++++++------------ 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/Python/bytecodes.c b/Python/bytecodes.c index e75f1d05a484..68a8d380643f 100644 --- a/Python/bytecodes.c +++ b/Python/bytecodes.c @@ -1378,6 +1378,7 @@ dummy_func( if (err == 0) { assert(retval_o != NULL); JUMPBY(oparg); + RECORD_JUMP_TAKEN(); } else { PyStackRef_CLOSE(v); @@ -5464,9 +5465,8 @@ dummy_func( // from a single exit! tier2 op(_DYNAMIC_EXIT, (exit_p/4 --)) { _Py_CODEUNIT *target = frame->instr_ptr; - _PyExitData *exit = (_PyExitData *)exit_p; - _Py_BackoffCounter temperature = exit->temperature; #if defined(Py_DEBUG) && !defined(_Py_JIT) + _PyExitData *exit = (_PyExitData *)exit_p; OPT_HIST(trace_uop_execution_counter, trace_run_length_hist); if (frame->lltrace >= 2) { printf("DYNAMIC EXIT: [UOp "); diff --git a/Python/ceval_macros.h b/Python/ceval_macros.h index dab0ff33c9ae..3adbadab4c53 100644 --- a/Python/ceval_macros.h +++ b/Python/ceval_macros.h @@ -426,7 +426,7 @@ do { \ frame = tstate->current_frame; \ stack_pointer = _PyFrame_GetStackPointer(frame); \ int keep_tracing_bit = (uintptr_t)next_instr & 1; \ - next_instr = (_Py_CODEUNIT *)(((uintptr_t)next_instr) >> 1 << 1); \ + next_instr = (_Py_CODEUNIT *)(((uintptr_t)next_instr) & (~1)); \ if (next_instr == NULL) { \ next_instr = frame->instr_ptr; \ JUMP_TO_LABEL(error); \ diff --git a/Python/executor_cases.c.h b/Python/executor_cases.c.h index 8b2b6b899274..e32c144b9e91 100644 --- a/Python/executor_cases.c.h +++ b/Python/executor_cases.c.h @@ -7524,9 +7524,8 @@ case _DYNAMIC_EXIT: { PyObject *exit_p = (PyObject *)CURRENT_OPERAND0(); _Py_CODEUNIT *target = frame->instr_ptr; - _PyExitData *exit = (_PyExitData *)exit_p; - _Py_BackoffCounter temperature = exit->temperature; #if defined(Py_DEBUG) && !defined(_Py_JIT) + _PyExitData *exit = (_PyExitData *)exit_p; OPT_HIST(trace_uop_execution_counter, trace_run_length_hist); if (frame->lltrace >= 2) { _PyFrame_SetStackPointer(frame, stack_pointer); diff --git a/Python/generated_tracer_cases.c.h b/Python/generated_tracer_cases.c.h index 3feeb0d05362..4c5a9c33fb96 100644 --- a/Python/generated_tracer_cases.c.h +++ b/Python/generated_tracer_cases.c.h @@ -12162,6 +12162,7 @@ if (err == 0) { assert(retval_o != NULL); JUMPBY(oparg); + RECORD_JUMP_TAKEN(); } else { stack_pointer += -1; diff --git a/Python/optimizer.c b/Python/optimizer.c index 18b7d1d5aa29..a11244d4b54f 100644 --- a/Python/optimizer.c +++ b/Python/optimizer.c @@ -570,10 +570,9 @@ _PyJIT_translate_single_bytecode_to_trace( { int is_first_instr = tstate->interp->jit_state.jit_tracer_initial_instr == this_instr; - bool progress_needed = (tstate->interp->jit_state.jit_tracer_initial_chain_depth % MAX_CHAIN_DEPTH) == 0 && is_first_instr;; + bool progress_needed = (tstate->interp->jit_state.jit_tracer_initial_chain_depth % MAX_CHAIN_DEPTH) == 0;; _PyBloomFilter *dependencies = &tstate->interp->jit_state.jit_tracer_dependencies; _Py_BloomFilter_Add(dependencies, old_code); - _Py_CODEUNIT *target_instr = this_instr; int trace_length = tstate->interp->jit_state.jit_tracer_code_curr_size; _PyUOpInstruction *trace = tstate->interp->jit_state.jit_tracer_code_buffer; int max_length = tstate->interp->jit_state.jit_tracer_code_max_size; @@ -585,11 +584,26 @@ _PyJIT_translate_single_bytecode_to_trace( lltrace = *python_lltrace - '0'; // TODO: Parse an int and all that } #endif - + _Py_CODEUNIT *target_instr = this_instr; uint32_t target = 0; target = INSTR_IP(target_instr, old_code); + // Rewind EXTENDED_ARG so that we see the whole thing. + // We must point to the first EXTENDED_ARG when deopting. + int rewind_oparg = oparg; + while (rewind_oparg > 255) { + rewind_oparg >>= 8; + target--; + } +#ifdef Py_DEBUG + if (oparg > 255) { + assert(_Py_GetBaseCodeUnit(old_code, target).op.code == EXTENDED_ARG); + } +#endif + + DPRINTF(2, "%p %d: %s(%d) %d\n", old_code, target, _PyOpcode_OpName[opcode], oparg, progress_needed); + bool needs_guard_ip = _PyOpcode_NeedsGuardIp[opcode] && !(opcode == FOR_ITER_RANGE || opcode == FOR_ITER_LIST || opcode == FOR_ITER_TUPLE) && !(opcode == JUMP_BACKWARD_NO_INTERRUPT || opcode == JUMP_BACKWARD || opcode == JUMP_BACKWARD_JIT) && @@ -600,8 +614,7 @@ _PyJIT_translate_single_bytecode_to_trace( // This happens when a recursive call happens that we can't trace. Such as Python -> C -> Python calls // If we haven't guarded the IP, then it's untraceable. (frame != tstate->interp->jit_state.jit_tracer_current_frame && !needs_guard_ip) || - // TODO handle extended args. - oparg > 255 || opcode == EXTENDED_ARG || + (oparg > 0xFFFF) || // TODO handle BINARY_OP_INPLACE_ADD_UNICODE opcode == BINARY_OP_INPLACE_ADD_UNICODE || // TODO (gh-140277): The constituent uops are invalid. @@ -633,8 +646,6 @@ _PyJIT_translate_single_bytecode_to_trace( tstate->interp->jit_state.jit_tracer_current_frame = frame; - DPRINTF(2, "%p %d: %s(%d)\n", old_code, target, _PyOpcode_OpName[opcode], oparg); - if (opcode == NOP) { return 1; } @@ -643,6 +654,10 @@ _PyJIT_translate_single_bytecode_to_trace( return 1; } + if (opcode == EXTENDED_ARG) { + return 1; + } + // One for possible _DEOPT, one because _CHECK_VALIDITY itself might _DEOPT max_length -= 2; @@ -663,7 +678,7 @@ _PyJIT_translate_single_bytecode_to_trace( /* Special case the first instruction, * so that we can guarantee forward progress */ - if (progress_needed && is_first_instr) { + if (progress_needed && tstate->interp->jit_state.jit_tracer_code_curr_size <= 2) { if (OPCODE_HAS_EXIT(opcode) || OPCODE_HAS_DEOPT(opcode)) { opcode = _PyOpcode_Deopt[opcode]; } @@ -695,12 +710,13 @@ _PyJIT_translate_single_bytecode_to_trace( case POP_JUMP_IF_FALSE: case POP_JUMP_IF_TRUE: { - RESERVE(1); - _Py_CODEUNIT *computed_next_instr = target_instr + 1 + _PyOpcode_Caches[_PyOpcode_Deopt[opcode]]; - _Py_CODEUNIT *computed_jump_instr = computed_next_instr + oparg; - int jump_likely = computed_jump_instr == next_instr; - uint32_t uopcode = BRANCH_TO_GUARD[opcode - POP_JUMP_IF_FALSE][jump_likely]; - ADD_TO_TRACE(uopcode, 0, 0, INSTR_IP(jump_likely ? computed_next_instr : computed_jump_instr, old_code)); + _Py_CODEUNIT *computed_next_instr_without_modifiers = target_instr + 1 + _PyOpcode_Caches[_PyOpcode_Deopt[opcode]]; + _Py_CODEUNIT *computed_next_instr = computed_next_instr_without_modifiers + (computed_next_instr_without_modifiers->op.code == NOT_TAKEN); + _Py_CODEUNIT *computed_jump_instr = computed_next_instr_without_modifiers + oparg; + assert(next_instr == computed_next_instr || next_instr == computed_jump_instr); + int jump_happened = computed_jump_instr == next_instr; + uint32_t uopcode = BRANCH_TO_GUARD[opcode - POP_JUMP_IF_FALSE][jump_happened]; + ADD_TO_TRACE(uopcode, 0, 0, INSTR_IP(jump_happened ? computed_next_instr : computed_jump_instr, old_code)); break; } case JUMP_BACKWARD_JIT: @@ -731,8 +747,10 @@ _PyJIT_translate_single_bytecode_to_trace( assert(nuops > 0); RESERVE(nuops + 1); /* One extra for exit */ uint32_t orig_oparg = oparg; // For OPARG_TOP/BOTTOM + uint32_t orig_target = target; for (int i = 0; i < nuops; i++) { oparg = orig_oparg; + target = orig_target; uint32_t uop = expansion->uops[i].uop; uint64_t operand = 0; // Add one to account for the actual opcode/oparg pair: @@ -751,9 +769,11 @@ _PyJIT_translate_single_bytecode_to_trace( operand = read_u64(&this_instr[offset].cache); break; case OPARG_TOP: // First half of super-instr + assert(orig_oparg <= 255); oparg = orig_oparg >> 4; break; case OPARG_BOTTOM: // Second half of super-instr + assert(orig_oparg <= 255); oparg = orig_oparg & 0xF; break; case OPARG_SAVE_RETURN_OFFSET: // op=_SAVE_RETURN_OFFSET; oparg=return_offset @@ -768,13 +788,15 @@ _PyJIT_translate_single_bytecode_to_trace( if (uop == _TIER2_RESUME_CHECK) { target = next_inst; } -#ifdef Py_DEBUG else if (uop != _FOR_ITER_TIER_TWO) { - uint32_t jump_target = next_inst + oparg; + int extended_arg = orig_oparg > 255; + uint32_t jump_target = next_inst + orig_oparg + extended_arg; assert(_Py_GetBaseCodeUnit(old_code, jump_target).op.code == END_FOR); assert(_Py_GetBaseCodeUnit(old_code, jump_target+1).op.code == POP_ITER); + if (is_for_iter_test[uop]) { + target = jump_target + 1; + } } -#endif break; case OPERAND1_1: assert(trace[trace_length-1].opcode == uop); @@ -859,11 +881,12 @@ _PyJIT_InitializeTracing(PyThreadState *tstate, _PyInterpreterFrame *frame, _Py_ lltrace = *python_lltrace - '0'; // TODO: Parse an int and all that } DPRINTF(2, - "Tracing %s (%s:%d) at byte offset %d\n", + "Tracing %s (%s:%d) at byte offset %d at chain depth %d\n", PyUnicode_AsUTF8(code->co_qualname), PyUnicode_AsUTF8(code->co_filename), code->co_firstlineno, - 2 * INSTR_IP(next_instr, code)); + 2 * INSTR_IP(next_instr, code), + chain_depth); #endif add_to_trace(tstate->interp->jit_state.jit_tracer_code_buffer, 0, _START_EXECUTOR, 0, (uintptr_t)next_instr, INSTR_IP(next_instr, code)); add_to_trace(tstate->interp->jit_state.jit_tracer_code_buffer, 1, _MAKE_WARM, 0, 0, 0); @@ -976,12 +999,6 @@ prepare_for_execution(_PyUOpInstruction *buffer, int length) exit_op = _DYNAMIC_EXIT; unique_target = true; } - if (is_for_iter_test[opcode]) { - /* Target the POP_TOP immediately after the END_FOR, - * leaving only the iterator on the stack. */ - int32_t next_inst = target + 1 + INLINE_CACHE_ENTRIES_FOR_ITER; - jump_target = next_inst + inst->oparg + 1; - } if (unique_target || jump_target != current_jump_target || current_exit_op != exit_op) { make_exit(&buffer[next_spare], exit_op, jump_target); current_exit_op = exit_op; -- 2.47.3