]> git.ipfire.org Git - thirdparty/linux.git/blobdiff - kernel/bpf/verifier.c
Merge tag 'for-netdev' of https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf...
[thirdparty/linux.git] / kernel / bpf / verifier.c
index 21f80383f8b215b72df159fcdce0e6073d58c25f..63749ad5ac6b8d63f108b92690897f032c7eacb6 100644 (file)
@@ -533,6 +533,16 @@ static bool is_async_callback_calling_insn(struct bpf_insn *insn)
        return bpf_helper_call(insn) && is_async_callback_calling_function(insn->imm);
 }
 
+static bool is_may_goto_insn(struct bpf_insn *insn)
+{
+       return insn->code == (BPF_JMP | BPF_JCOND) && insn->src_reg == BPF_MAY_GOTO;
+}
+
+static bool is_may_goto_insn_at(struct bpf_verifier_env *env, int insn_idx)
+{
+       return is_may_goto_insn(&env->prog->insnsi[insn_idx]);
+}
+
 static bool is_storage_get_function(enum bpf_func_id func_id)
 {
        return func_id == BPF_FUNC_sk_storage_get ||
@@ -1429,6 +1439,7 @@ static int copy_verifier_state(struct bpf_verifier_state *dst_state,
        dst_state->dfs_depth = src->dfs_depth;
        dst_state->callback_unroll_depth = src->callback_unroll_depth;
        dst_state->used_as_loop_entry = src->used_as_loop_entry;
+       dst_state->may_goto_depth = src->may_goto_depth;
        for (i = 0; i <= src->curframe; i++) {
                dst = dst_state->frame[i];
                if (!dst) {
@@ -4375,6 +4386,7 @@ static bool is_spillable_regtype(enum bpf_reg_type type)
        case PTR_TO_MEM:
        case PTR_TO_FUNC:
        case PTR_TO_MAP_KEY:
+       case PTR_TO_ARENA:
                return true;
        default:
                return false;
@@ -5262,7 +5274,7 @@ bad_type:
 
 static bool in_sleepable(struct bpf_verifier_env *env)
 {
-       return env->prog->aux->sleepable;
+       return env->prog->sleepable;
 }
 
 /* The non-sleepable programs and sleepable programs with explicit bpf_rcu_read_lock()
@@ -5817,6 +5829,8 @@ static int check_ptr_alignment(struct bpf_verifier_env *env,
        case PTR_TO_XDP_SOCK:
                pointer_desc = "xdp_sock ";
                break;
+       case PTR_TO_ARENA:
+               return 0;
        default:
                break;
        }
@@ -6926,6 +6940,9 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 
                if (!err && value_regno >= 0 && (rdonly_mem || t == BPF_READ))
                        mark_reg_unknown(env, regs, value_regno);
+       } else if (reg->type == PTR_TO_ARENA) {
+               if (t == BPF_READ && value_regno >= 0)
+                       mark_reg_unknown(env, regs, value_regno);
        } else {
                verbose(env, "R%d invalid mem access '%s'\n", regno,
                        reg_type_str(env, reg->type));
@@ -8397,6 +8414,7 @@ static int check_func_arg_reg_off(struct bpf_verifier_env *env,
        case PTR_TO_MEM | MEM_RINGBUF:
        case PTR_TO_BUF:
        case PTR_TO_BUF | MEM_RDONLY:
+       case PTR_TO_ARENA:
        case SCALAR_VALUE:
                return 0;
        /* All the rest must be rejected, except PTR_TO_BTF_ID which allows
@@ -9361,6 +9379,18 @@ static int btf_check_func_arg_match(struct bpf_verifier_env *env, int subprog,
                                bpf_log(log, "arg#%d is expected to be non-NULL\n", i);
                                return -EINVAL;
                        }
+               } else if (base_type(arg->arg_type) == ARG_PTR_TO_ARENA) {
+                       /*
+                        * Can pass any value and the kernel won't crash, but
+                        * only PTR_TO_ARENA or SCALAR make sense. Everything
+                        * else is a bug in the bpf program. Point it out to
+                        * the user at the verification time instead of
+                        * run-time debug nightmare.
+                        */
+                       if (reg->type != PTR_TO_ARENA && reg->type != SCALAR_VALUE) {
+                               bpf_log(log, "R%d is not a pointer to arena or scalar.\n", regno);
+                               return -EINVAL;
+                       }
                } else if (arg->arg_type == (ARG_PTR_TO_DYNPTR | MEM_RDONLY)) {
                        ret = process_dynptr_func(env, regno, -1, arg->arg_type, 0);
                        if (ret)
@@ -10741,6 +10771,11 @@ static bool is_kfunc_arg_ignore(const struct btf *btf, const struct btf_param *a
        return btf_param_match_suffix(btf, arg, "__ign");
 }
 
+static bool is_kfunc_arg_map(const struct btf *btf, const struct btf_param *arg)
+{
+       return btf_param_match_suffix(btf, arg, "__map");
+}
+
 static bool is_kfunc_arg_alloc_obj(const struct btf *btf, const struct btf_param *arg)
 {
        return btf_param_match_suffix(btf, arg, "__alloc");
@@ -10910,6 +10945,7 @@ enum kfunc_ptr_arg_type {
        KF_ARG_PTR_TO_RB_NODE,
        KF_ARG_PTR_TO_NULL,
        KF_ARG_PTR_TO_CONST_STR,
+       KF_ARG_PTR_TO_MAP,
 };
 
 enum special_kfunc_type {
@@ -11063,6 +11099,9 @@ get_kfunc_ptr_arg_type(struct bpf_verifier_env *env,
        if (is_kfunc_arg_const_str(meta->btf, &args[argno]))
                return KF_ARG_PTR_TO_CONST_STR;
 
+       if (is_kfunc_arg_map(meta->btf, &args[argno]))
+               return KF_ARG_PTR_TO_MAP;
+
        if ((base_type(reg->type) == PTR_TO_BTF_ID || reg2btf_ids[base_type(reg->type)])) {
                if (!btf_type_is_struct(ref_t)) {
                        verbose(env, "kernel function %s args#%d pointer type %s %s is not supported\n",
@@ -11663,6 +11702,7 @@ static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_
                switch (kf_arg_type) {
                case KF_ARG_PTR_TO_NULL:
                        continue;
+               case KF_ARG_PTR_TO_MAP:
                case KF_ARG_PTR_TO_ALLOC_BTF_ID:
                case KF_ARG_PTR_TO_BTF_ID:
                        if (!is_kfunc_trusted_args(meta) && !is_kfunc_rcu(meta))
@@ -11879,6 +11919,12 @@ static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_
                        if (ret < 0)
                                return ret;
                        break;
+               case KF_ARG_PTR_TO_MAP:
+                       /* If argument has '__map' suffix expect 'struct bpf_map *' */
+                       ref_id = *reg2btf_ids[CONST_PTR_TO_MAP];
+                       ref_t = btf_type_by_id(btf_vmlinux, ref_id);
+                       ref_tname = btf_name_by_offset(btf, ref_t->name_off);
+                       fallthrough;
                case KF_ARG_PTR_TO_BTF_ID:
                        /* Only base_type is checked, further checks are done here */
                        if ((base_type(reg->type) != PTR_TO_BTF_ID ||
@@ -12353,6 +12399,9 @@ static int check_kfunc_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
                                        meta.func_name);
                                return -EFAULT;
                        }
+               } else if (btf_type_is_void(ptr_type)) {
+                       /* kfunc returning 'void *' is equivalent to returning scalar */
+                       mark_reg_unknown(env, regs, BPF_REG_0);
                } else if (!__btf_type_is_struct(ptr_type)) {
                        if (!meta.r0_size) {
                                __u32 sz;
@@ -13822,6 +13871,21 @@ static int adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 
        dst_reg = &regs[insn->dst_reg];
        src_reg = NULL;
+
+       if (dst_reg->type == PTR_TO_ARENA) {
+               struct bpf_insn_aux_data *aux = cur_aux(env);
+
+               if (BPF_CLASS(insn->code) == BPF_ALU64)
+                       /*
+                        * 32-bit operations zero upper bits automatically.
+                        * 64-bit operations need to be converted to 32.
+                        */
+                       aux->needs_zext = true;
+
+               /* Any arithmetic operations are allowed on arena pointers */
+               return 0;
+       }
+
        if (dst_reg->type != SCALAR_VALUE)
                ptr_reg = dst_reg;
        else
@@ -13939,19 +14003,20 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
        } else if (opcode == BPF_MOV) {
 
                if (BPF_SRC(insn->code) == BPF_X) {
-                       if (insn->imm != 0) {
-                               verbose(env, "BPF_MOV uses reserved fields\n");
-                               return -EINVAL;
-                       }
-
                        if (BPF_CLASS(insn->code) == BPF_ALU) {
-                               if (insn->off != 0 && insn->off != 8 && insn->off != 16) {
+                               if ((insn->off != 0 && insn->off != 8 && insn->off != 16) ||
+                                   insn->imm) {
                                        verbose(env, "BPF_MOV uses reserved fields\n");
                                        return -EINVAL;
                                }
+                       } else if (insn->off == BPF_ADDR_SPACE_CAST) {
+                               if (insn->imm != 1 && insn->imm != 1u << 16) {
+                                       verbose(env, "addr_space_cast insn can only convert between address space 1 and 0\n");
+                                       return -EINVAL;
+                               }
                        } else {
-                               if (insn->off != 0 && insn->off != 8 && insn->off != 16 &&
-                                   insn->off != 32) {
+                               if ((insn->off != 0 && insn->off != 8 && insn->off != 16 &&
+                                    insn->off != 32) || insn->imm) {
                                        verbose(env, "BPF_MOV uses reserved fields\n");
                                        return -EINVAL;
                                }
@@ -13978,7 +14043,12 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
                        struct bpf_reg_state *dst_reg = regs + insn->dst_reg;
 
                        if (BPF_CLASS(insn->code) == BPF_ALU64) {
-                               if (insn->off == 0) {
+                               if (insn->imm) {
+                                       /* off == BPF_ADDR_SPACE_CAST */
+                                       mark_reg_unknown(env, regs, insn->dst_reg);
+                                       if (insn->imm == 1) /* cast from as(1) to as(0) */
+                                               dst_reg->type = PTR_TO_ARENA;
+                               } else if (insn->off == 0) {
                                        /* case: R1 = R2
                                         * copy register state to dest reg
                                         */
@@ -14871,11 +14941,36 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
        int err;
 
        /* Only conditional jumps are expected to reach here. */
-       if (opcode == BPF_JA || opcode > BPF_JSLE) {
+       if (opcode == BPF_JA || opcode > BPF_JCOND) {
                verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
                return -EINVAL;
        }
 
+       if (opcode == BPF_JCOND) {
+               struct bpf_verifier_state *cur_st = env->cur_state, *queued_st, *prev_st;
+               int idx = *insn_idx;
+
+               if (insn->code != (BPF_JMP | BPF_JCOND) ||
+                   insn->src_reg != BPF_MAY_GOTO ||
+                   insn->dst_reg || insn->imm || insn->off == 0) {
+                       verbose(env, "invalid may_goto off %d imm %d\n",
+                               insn->off, insn->imm);
+                       return -EINVAL;
+               }
+               prev_st = find_prev_entry(env, cur_st->parent, idx);
+
+               /* branch out 'fallthrough' insn as a new state to explore */
+               queued_st = push_stack(env, idx + 1, idx, false);
+               if (!queued_st)
+                       return -ENOMEM;
+
+               queued_st->may_goto_depth++;
+               if (prev_st)
+                       widen_imprecise_scalars(env, prev_st, queued_st);
+               *insn_idx += insn->off;
+               return 0;
+       }
+
        /* check src2 operand */
        err = check_reg_arg(env, insn->dst_reg, SRC_OP);
        if (err)
@@ -15127,6 +15222,10 @@ static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn)
 
        if (insn->src_reg == BPF_PSEUDO_MAP_VALUE ||
            insn->src_reg == BPF_PSEUDO_MAP_IDX_VALUE) {
+               if (map->map_type == BPF_MAP_TYPE_ARENA) {
+                       __mark_reg_unknown(env, dst_reg);
+                       return 0;
+               }
                dst_reg->type = PTR_TO_MAP_VALUE;
                dst_reg->off = aux->map_off;
                WARN_ON_ONCE(map->max_entries != 1);
@@ -15659,6 +15758,8 @@ static int visit_insn(int t, struct bpf_verifier_env *env)
        default:
                /* conditional jump with two edges */
                mark_prune_point(env, t);
+               if (is_may_goto_insn(insn))
+                       mark_force_checkpoint(env, t);
 
                ret = push_insn(t, t + 1, FALLTHROUGH, env);
                if (ret)
@@ -16222,8 +16323,8 @@ static int check_btf_info(struct bpf_verifier_env *env,
 }
 
 /* check %cur's range satisfies %old's */
-static bool range_within(struct bpf_reg_state *old,
-                        struct bpf_reg_state *cur)
+static bool range_within(const struct bpf_reg_state *old,
+                        const struct bpf_reg_state *cur)
 {
        return old->umin_value <= cur->umin_value &&
               old->umax_value >= cur->umax_value &&
@@ -16387,21 +16488,28 @@ static bool regs_exact(const struct bpf_reg_state *rold,
               check_ids(rold->ref_obj_id, rcur->ref_obj_id, idmap);
 }
 
+enum exact_level {
+       NOT_EXACT,
+       EXACT,
+       RANGE_WITHIN
+};
+
 /* Returns true if (rold safe implies rcur safe) */
 static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
-                   struct bpf_reg_state *rcur, struct bpf_idmap *idmap, bool exact)
+                   struct bpf_reg_state *rcur, struct bpf_idmap *idmap,
+                   enum exact_level exact)
 {
-       if (exact)
+       if (exact == EXACT)
                return regs_exact(rold, rcur, idmap);
 
-       if (!(rold->live & REG_LIVE_READ))
+       if (!(rold->live & REG_LIVE_READ) && exact == NOT_EXACT)
                /* explored state didn't use this */
                return true;
-       if (rold->type == NOT_INIT)
-               /* explored state can't have used this */
-               return true;
-       if (rcur->type == NOT_INIT)
-               return false;
+       if (rold->type == NOT_INIT) {
+               if (exact == NOT_EXACT || rcur->type == NOT_INIT)
+                       /* explored state can't have used this */
+                       return true;
+       }
 
        /* Enforce that register types have to match exactly, including their
         * modifiers (like PTR_MAYBE_NULL, MEM_RDONLY, etc), as a general
@@ -16436,7 +16544,7 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
                        return memcmp(rold, rcur, offsetof(struct bpf_reg_state, id)) == 0 &&
                               check_scalar_ids(rold->id, rcur->id, idmap);
                }
-               if (!rold->precise)
+               if (!rold->precise && exact == NOT_EXACT)
                        return true;
                /* Why check_ids() for scalar registers?
                 *
@@ -16504,6 +16612,8 @@ static bool regsafe(struct bpf_verifier_env *env, struct bpf_reg_state *rold,
                 * the same stack frame, since fp-8 in foo != fp-8 in bar
                 */
                return regs_exact(rold, rcur, idmap) && rold->frameno == rcur->frameno;
+       case PTR_TO_ARENA:
+               return true;
        default:
                return regs_exact(rold, rcur, idmap);
        }
@@ -16547,7 +16657,8 @@ static struct bpf_reg_state *scalar_reg_for_stack(struct bpf_verifier_env *env,
 }
 
 static bool stacksafe(struct bpf_verifier_env *env, struct bpf_func_state *old,
-                     struct bpf_func_state *cur, struct bpf_idmap *idmap, bool exact)
+                     struct bpf_func_state *cur, struct bpf_idmap *idmap,
+                     enum exact_level exact)
 {
        int i, spi;
 
@@ -16560,12 +16671,13 @@ static bool stacksafe(struct bpf_verifier_env *env, struct bpf_func_state *old,
 
                spi = i / BPF_REG_SIZE;
 
-               if (exact &&
+               if (exact != NOT_EXACT &&
                    old->stack[spi].slot_type[i % BPF_REG_SIZE] !=
                    cur->stack[spi].slot_type[i % BPF_REG_SIZE])
                        return false;
 
-               if (!(old->stack[spi].spilled_ptr.live & REG_LIVE_READ) && !exact) {
+               if (!(old->stack[spi].spilled_ptr.live & REG_LIVE_READ)
+                   && exact == NOT_EXACT) {
                        i += BPF_REG_SIZE - 1;
                        /* explored state didn't use this */
                        continue;
@@ -16711,7 +16823,7 @@ static bool refsafe(struct bpf_func_state *old, struct bpf_func_state *cur,
  * the current state will reach 'bpf_exit' instruction safely
  */
 static bool func_states_equal(struct bpf_verifier_env *env, struct bpf_func_state *old,
-                             struct bpf_func_state *cur, bool exact)
+                             struct bpf_func_state *cur, enum exact_level exact)
 {
        int i;
 
@@ -16741,7 +16853,7 @@ static void reset_idmap_scratch(struct bpf_verifier_env *env)
 static bool states_equal(struct bpf_verifier_env *env,
                         struct bpf_verifier_state *old,
                         struct bpf_verifier_state *cur,
-                        bool exact)
+                        enum exact_level exact)
 {
        int i;
 
@@ -17115,7 +17227,7 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
                         * => unsafe memory access at 11 would not be caught.
                         */
                        if (is_iter_next_insn(env, insn_idx)) {
-                               if (states_equal(env, &sl->state, cur, true)) {
+                               if (states_equal(env, &sl->state, cur, RANGE_WITHIN)) {
                                        struct bpf_func_state *cur_frame;
                                        struct bpf_reg_state *iter_state, *iter_reg;
                                        int spi;
@@ -17138,15 +17250,23 @@ static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
                                }
                                goto skip_inf_loop_check;
                        }
+                       if (is_may_goto_insn_at(env, insn_idx)) {
+                               if (states_equal(env, &sl->state, cur, RANGE_WITHIN)) {
+                                       update_loop_entry(cur, &sl->state);
+                                       goto hit;
+                               }
+                               goto skip_inf_loop_check;
+                       }
                        if (calls_callback(env, insn_idx)) {
-                               if (states_equal(env, &sl->state, cur, true))
+                               if (states_equal(env, &sl->state, cur, RANGE_WITHIN))
                                        goto hit;
                                goto skip_inf_loop_check;
                        }
                        /* attempt to detect infinite loop to avoid unnecessary doomed work */
                        if (states_maybe_looping(&sl->state, cur) &&
-                           states_equal(env, &sl->state, cur, true) &&
+                           states_equal(env, &sl->state, cur, EXACT) &&
                            !iter_active_depths_differ(&sl->state, cur) &&
+                           sl->state.may_goto_depth == cur->may_goto_depth &&
                            sl->state.callback_unroll_depth == cur->callback_unroll_depth) {
                                verbose_linfo(env, insn_idx, "; ");
                                verbose(env, "infinite loop detected at insn %d\n", insn_idx);
@@ -17202,7 +17322,7 @@ skip_inf_loop_check:
                 */
                loop_entry = get_loop_entry(&sl->state);
                force_exact = loop_entry && loop_entry->branches > 0;
-               if (states_equal(env, &sl->state, cur, force_exact)) {
+               if (states_equal(env, &sl->state, cur, force_exact ? RANGE_WITHIN : NOT_EXACT)) {
                        if (force_exact)
                                update_loop_entry(cur, loop_entry);
 hit:
@@ -17372,6 +17492,7 @@ static bool reg_type_mismatch_ok(enum bpf_reg_type type)
        case PTR_TO_TCP_SOCK:
        case PTR_TO_XDP_SOCK:
        case PTR_TO_BTF_ID:
+       case PTR_TO_ARENA:
                return false;
        default:
                return true;
@@ -18019,7 +18140,7 @@ static int check_map_prog_compatibility(struct bpf_verifier_env *env,
                return -EINVAL;
        }
 
-       if (prog->aux->sleepable)
+       if (prog->sleepable)
                switch (map->map_type) {
                case BPF_MAP_TYPE_HASH:
                case BPF_MAP_TYPE_LRU_HASH:
@@ -18037,6 +18158,7 @@ static int check_map_prog_compatibility(struct bpf_verifier_env *env,
                case BPF_MAP_TYPE_CGRP_STORAGE:
                case BPF_MAP_TYPE_QUEUE:
                case BPF_MAP_TYPE_STACK:
+               case BPF_MAP_TYPE_ARENA:
                        break;
                default:
                        verbose(env,
@@ -18206,7 +18328,7 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
                                return -E2BIG;
                        }
 
-                       if (env->prog->aux->sleepable)
+                       if (env->prog->sleepable)
                                atomic64_inc(&map->sleepable_refcnt);
                        /* hold the map. If the program is rejected by verifier,
                         * the map will be released by release_maps() or it
@@ -18224,6 +18346,31 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
                                fdput(f);
                                return -EBUSY;
                        }
+                       if (map->map_type == BPF_MAP_TYPE_ARENA) {
+                               if (env->prog->aux->arena) {
+                                       verbose(env, "Only one arena per program\n");
+                                       fdput(f);
+                                       return -EBUSY;
+                               }
+                               if (!env->allow_ptr_leaks || !env->bpf_capable) {
+                                       verbose(env, "CAP_BPF and CAP_PERFMON are required to use arena\n");
+                                       fdput(f);
+                                       return -EPERM;
+                               }
+                               if (!env->prog->jit_requested) {
+                                       verbose(env, "JIT is required to use arena\n");
+                                       return -EOPNOTSUPP;
+                               }
+                               if (!bpf_jit_supports_arena()) {
+                                       verbose(env, "JIT doesn't support arena\n");
+                                       return -EOPNOTSUPP;
+                               }
+                               env->prog->aux->arena = (void *)map;
+                               if (!bpf_arena_get_user_vm_start(env->prog->aux->arena)) {
+                                       verbose(env, "arena's user address must be set via map_extra or mmap()\n");
+                                       return -EINVAL;
+                               }
+                       }
 
                        fdput(f);
 next_insn:
@@ -18845,6 +18992,14 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                                env->prog->aux->num_exentries++;
                        }
                        continue;
+               case PTR_TO_ARENA:
+                       if (BPF_MODE(insn->code) == BPF_MEMSX) {
+                               verbose(env, "sign extending loads from arena are not supported yet\n");
+                               return -EOPNOTSUPP;
+                       }
+                       insn->code = BPF_CLASS(insn->code) | BPF_PROBE_MEM32 | BPF_SIZE(insn->code);
+                       env->prog->aux->num_exentries++;
+                       continue;
                default:
                        continue;
                }
@@ -19030,13 +19185,19 @@ static int jit_subprogs(struct bpf_verifier_env *env)
                func[i]->aux->nr_linfo = prog->aux->nr_linfo;
                func[i]->aux->jited_linfo = prog->aux->jited_linfo;
                func[i]->aux->linfo_idx = env->subprog_info[i].linfo_idx;
+               func[i]->aux->arena = prog->aux->arena;
                num_exentries = 0;
                insn = func[i]->insnsi;
                for (j = 0; j < func[i]->len; j++, insn++) {
                        if (BPF_CLASS(insn->code) == BPF_LDX &&
                            (BPF_MODE(insn->code) == BPF_PROBE_MEM ||
+                            BPF_MODE(insn->code) == BPF_PROBE_MEM32 ||
                             BPF_MODE(insn->code) == BPF_PROBE_MEMSX))
                                num_exentries++;
+                       if ((BPF_CLASS(insn->code) == BPF_STX ||
+                            BPF_CLASS(insn->code) == BPF_ST) &&
+                            BPF_MODE(insn->code) == BPF_PROBE_MEM32)
+                               num_exentries++;
                }
                func[i]->aux->num_exentries = num_exentries;
                func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable;
@@ -19411,7 +19572,10 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
        struct bpf_insn insn_buf[16];
        struct bpf_prog *new_prog;
        struct bpf_map *map_ptr;
-       int i, ret, cnt, delta = 0;
+       int i, ret, cnt, delta = 0, cur_subprog = 0;
+       struct bpf_subprog_info *subprogs = env->subprog_info;
+       u16 stack_depth = subprogs[cur_subprog].stack_depth;
+       u16 stack_depth_extra = 0;
 
        if (env->seen_exception && !env->exception_callback_subprog) {
                struct bpf_insn patch[] = {
@@ -19431,7 +19595,22 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                mark_subprog_exc_cb(env, env->exception_callback_subprog);
        }
 
-       for (i = 0; i < insn_cnt; i++, insn++) {
+       for (i = 0; i < insn_cnt;) {
+               if (insn->code == (BPF_ALU64 | BPF_MOV | BPF_X) && insn->imm) {
+                       if ((insn->off == BPF_ADDR_SPACE_CAST && insn->imm == 1) ||
+                           (((struct bpf_map *)env->prog->aux->arena)->map_flags & BPF_F_NO_USER_CONV)) {
+                               /* convert to 32-bit mov that clears upper 32-bit */
+                               insn->code = BPF_ALU | BPF_MOV | BPF_X;
+                               /* clear off, so it's a normal 'wX = wY' from JIT pov */
+                               insn->off = 0;
+                       } /* cast from as(0) to as(1) should be handled by JIT */
+                       goto next_insn;
+               }
+
+               if (env->insn_aux_data[i + delta].needs_zext)
+                       /* Convert BPF_CLASS(insn->code) == BPF_ALU64 to 32-bit ALU */
+                       insn->code = BPF_ALU | BPF_OP(insn->code) | BPF_SRC(insn->code);
+
                /* Make divide-by-zero exceptions impossible. */
                if (insn->code == (BPF_ALU64 | BPF_MOD | BPF_X) ||
                    insn->code == (BPF_ALU64 | BPF_DIV | BPF_X) ||
@@ -19470,7 +19649,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement LD_ABS and LD_IND with a rewrite, if supported by the program type. */
@@ -19490,7 +19669,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Rewrite pointer arithmetic to mitigate speculation attacks. */
@@ -19505,7 +19684,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        aux = &env->insn_aux_data[i + delta];
                        if (!aux->alu_state ||
                            aux->alu_state == BPF_ALU_NON_POINTER)
-                               continue;
+                               goto next_insn;
 
                        isneg = aux->alu_state & BPF_ALU_NEG_VALUE;
                        issrc = (aux->alu_state & BPF_ALU_SANITIZE) ==
@@ -19543,19 +19722,39 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
+               }
+
+               if (is_may_goto_insn(insn)) {
+                       int stack_off = -stack_depth - 8;
+
+                       stack_depth_extra = 8;
+                       insn_buf[0] = BPF_LDX_MEM(BPF_DW, BPF_REG_AX, BPF_REG_10, stack_off);
+                       insn_buf[1] = BPF_JMP_IMM(BPF_JEQ, BPF_REG_AX, 0, insn->off + 2);
+                       insn_buf[2] = BPF_ALU64_IMM(BPF_SUB, BPF_REG_AX, 1);
+                       insn_buf[3] = BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_AX, stack_off);
+                       cnt = 4;
+
+                       new_prog = bpf_patch_insn_data(env, i + delta, insn_buf, cnt);
+                       if (!new_prog)
+                               return -ENOMEM;
+
+                       delta += cnt - 1;
+                       env->prog = prog = new_prog;
+                       insn = new_prog->insnsi + i + delta;
+                       goto next_insn;
                }
 
                if (insn->code != (BPF_JMP | BPF_CALL))
-                       continue;
+                       goto next_insn;
                if (insn->src_reg == BPF_PSEUDO_CALL)
-                       continue;
+                       goto next_insn;
                if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
                        ret = fixup_kfunc_call(env, insn, insn_buf, i + delta, &cnt);
                        if (ret)
                                return ret;
                        if (cnt == 0)
-                               continue;
+                               goto next_insn;
 
                        new_prog = bpf_patch_insn_data(env, i + delta, insn_buf, cnt);
                        if (!new_prog)
@@ -19564,7 +19763,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                if (insn->imm == BPF_FUNC_get_route_realm)
@@ -19612,11 +19811,11 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                                }
 
                                insn->imm = ret + 1;
-                               continue;
+                               goto next_insn;
                        }
 
                        if (!bpf_map_ptr_unpriv(aux))
-                               continue;
+                               goto next_insn;
 
                        /* instead of changing every JIT dealing with tail_call
                         * emit two extra insns:
@@ -19645,7 +19844,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                if (insn->imm == BPF_FUNC_timer_set_callback) {
@@ -19757,7 +19956,7 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
                                delta    += cnt - 1;
                                env->prog = prog = new_prog;
                                insn      = new_prog->insnsi + i + delta;
-                               continue;
+                               goto next_insn;
                        }
 
                        BUILD_BUG_ON(!__same_type(ops->map_lookup_elem,
@@ -19788,31 +19987,31 @@ patch_map_ops_generic:
                        switch (insn->imm) {
                        case BPF_FUNC_map_lookup_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_lookup_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_update_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_update_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_delete_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_delete_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_push_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_push_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_pop_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_pop_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_peek_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_peek_elem);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_redirect_map:
                                insn->imm = BPF_CALL_IMM(ops->map_redirect);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_for_each_map_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_for_each_callback);
-                               continue;
+                               goto next_insn;
                        case BPF_FUNC_map_lookup_percpu_elem:
                                insn->imm = BPF_CALL_IMM(ops->map_lookup_percpu_elem);
-                               continue;
+                               goto next_insn;
                        }
 
                        goto patch_call_imm;
@@ -19840,7 +20039,7 @@ patch_map_ops_generic:
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement bpf_get_func_arg inline. */
@@ -19865,7 +20064,7 @@ patch_map_ops_generic:
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement bpf_get_func_ret inline. */
@@ -19893,7 +20092,7 @@ patch_map_ops_generic:
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement get_func_arg_cnt inline. */
@@ -19908,7 +20107,7 @@ patch_map_ops_generic:
 
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement bpf_get_func_ip inline. */
@@ -19923,7 +20122,7 @@ patch_map_ops_generic:
 
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 
                /* Implement bpf_kptr_xchg inline */
@@ -19941,7 +20140,7 @@ patch_map_ops_generic:
                        delta    += cnt - 1;
                        env->prog = prog = new_prog;
                        insn      = new_prog->insnsi + i + delta;
-                       continue;
+                       goto next_insn;
                }
 patch_call_imm:
                fn = env->ops->get_func_proto(insn->imm, env->prog);
@@ -19955,6 +20154,40 @@ patch_call_imm:
                        return -EFAULT;
                }
                insn->imm = fn->func - __bpf_call_base;
+next_insn:
+               if (subprogs[cur_subprog + 1].start == i + delta + 1) {
+                       subprogs[cur_subprog].stack_depth += stack_depth_extra;
+                       subprogs[cur_subprog].stack_extra = stack_depth_extra;
+                       cur_subprog++;
+                       stack_depth = subprogs[cur_subprog].stack_depth;
+                       stack_depth_extra = 0;
+               }
+               i++;
+               insn++;
+       }
+
+       env->prog->aux->stack_depth = subprogs[0].stack_depth;
+       for (i = 0; i < env->subprog_cnt; i++) {
+               int subprog_start = subprogs[i].start;
+               int stack_slots = subprogs[i].stack_extra / 8;
+
+               if (!stack_slots)
+                       continue;
+               if (stack_slots > 1) {
+                       verbose(env, "verifier bug: stack_slots supports may_goto only\n");
+                       return -EFAULT;
+               }
+
+               /* Add ST insn to subprog prologue to init extra stack */
+               insn_buf[0] = BPF_ST_MEM(BPF_DW, BPF_REG_FP,
+                                        -subprogs[i].stack_depth, BPF_MAX_LOOPS);
+               /* Copy first actual insn to preserve it */
+               insn_buf[1] = env->prog->insnsi[subprog_start];
+
+               new_prog = bpf_patch_insn_data(env, subprog_start, insn_buf, 2);
+               if (!new_prog)
+                       return -ENOMEM;
+               env->prog = prog = new_prog;
        }
 
        /* Since poke tab is now finalized, publish aux to tracker. */
@@ -20230,6 +20463,9 @@ static int do_check_common(struct bpf_verifier_env *env, int subprog)
                                reg->btf = bpf_get_btf_vmlinux(); /* can't fail at this point */
                                reg->btf_id = arg->btf_id;
                                reg->id = ++env->id_gen;
+                       } else if (base_type(arg->arg_type) == ARG_PTR_TO_ARENA) {
+                               /* caller can pass either PTR_TO_ARENA or SCALAR */
+                               mark_reg_unknown(env, regs, i);
                        } else {
                                WARN_ONCE(1, "BUG: unhandled arg#%d type %d\n",
                                          i - BPF_REG_1, arg->arg_type);
@@ -20705,7 +20941,7 @@ int bpf_check_attach_target(struct bpf_verifier_log *log,
                        }
                }
 
-               if (prog->aux->sleepable) {
+               if (prog->sleepable) {
                        ret = -EINVAL;
                        switch (prog->type) {
                        case BPF_PROG_TYPE_TRACING:
@@ -20816,14 +21052,14 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
        u64 key;
 
        if (prog->type == BPF_PROG_TYPE_SYSCALL) {
-               if (prog->aux->sleepable)
+               if (prog->sleepable)
                        /* attach_btf_id checked to be zero already */
                        return 0;
                verbose(env, "Syscall programs can only be sleepable\n");
                return -EINVAL;
        }
 
-       if (prog->aux->sleepable && !can_be_sleepable(prog)) {
+       if (prog->sleepable && !can_be_sleepable(prog)) {
                verbose(env, "Only fentry/fexit/fmod_ret, lsm, iter, uprobe, and struct_ops programs can be sleepable\n");
                return -EINVAL;
        }