]> git.ipfire.org Git - thirdparty/kernel/stable.git/commitdiff
bpf: set 'changed' status if propagate_precision() did any updates
authorEduard Zingerman <eddyz87@gmail.com>
Wed, 11 Jun 2025 20:08:30 +0000 (13:08 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 12 Jun 2025 23:52:43 +0000 (16:52 -0700)
Add an out parameter to `propagate_precision()` to record whether any
new precision bits were set during its execution.

Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
Link: https://lore.kernel.org/r/20250611200836.4135542-5-eddyz87@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/verifier.c

index cca8858c3caafd26f8ea303c4a4de50bb0104a84..b00769ceef8cc11643debb28868e9f362629fac6 100644 (file)
@@ -4678,7 +4678,9 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
  * finalized states which help in short circuiting more future states.
  */
 static int __mark_chain_precision(struct bpf_verifier_env *env,
-                                 struct bpf_verifier_state *starting_state, int regno)
+                                 struct bpf_verifier_state *starting_state,
+                                 int regno,
+                                 bool *changed)
 {
        struct bpf_verifier_state *st = starting_state;
        struct backtrack_state *bt = &env->bt;
@@ -4686,13 +4688,14 @@ static int __mark_chain_precision(struct bpf_verifier_env *env,
        int last_idx = starting_state->insn_idx;
        int subseq_idx = -1;
        struct bpf_func_state *func;
+       bool tmp, skip_first = true;
        struct bpf_reg_state *reg;
-       bool skip_first = true;
        int i, fr, err;
 
        if (!env->bpf_capable)
                return 0;
 
+       changed = changed ?: &tmp;
        /* set frame number from which we are starting to backtrack */
        bt_init(bt, starting_state->curframe);
 
@@ -4738,8 +4741,10 @@ static int __mark_chain_precision(struct bpf_verifier_env *env,
                                for_each_set_bit(i, mask, 32) {
                                        reg = &st->frame[0]->regs[i];
                                        bt_clear_reg(bt, i);
-                                       if (reg->type == SCALAR_VALUE)
+                                       if (reg->type == SCALAR_VALUE) {
                                                reg->precise = true;
+                                               *changed = true;
+                                       }
                                }
                                return 0;
                        }
@@ -4798,10 +4803,12 @@ static int __mark_chain_precision(struct bpf_verifier_env *env,
                                        bt_clear_frame_reg(bt, fr, i);
                                        continue;
                                }
-                               if (reg->precise)
+                               if (reg->precise) {
                                        bt_clear_frame_reg(bt, fr, i);
-                               else
+                               } else {
                                        reg->precise = true;
+                                       *changed = true;
+                               }
                        }
 
                        bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
@@ -4816,10 +4823,12 @@ static int __mark_chain_precision(struct bpf_verifier_env *env,
                                        continue;
                                }
                                reg = &func->stack[i].spilled_ptr;
-                               if (reg->precise)
+                               if (reg->precise) {
                                        bt_clear_frame_slot(bt, fr, i);
-                               else
+                               } else {
                                        reg->precise = true;
+                                       *changed = true;
+                               }
                        }
                        if (env->log.level & BPF_LOG_LEVEL2) {
                                fmt_reg_mask(env->tmp_str_buf, TMP_STR_BUF_LEN,
@@ -4855,7 +4864,7 @@ static int __mark_chain_precision(struct bpf_verifier_env *env,
 
 int mark_chain_precision(struct bpf_verifier_env *env, int regno)
 {
-       return __mark_chain_precision(env, env->cur_state, regno);
+       return __mark_chain_precision(env, env->cur_state, regno, NULL);
 }
 
 /* mark_chain_precision_batch() assumes that env->bt is set in the caller to
@@ -4864,7 +4873,7 @@ int mark_chain_precision(struct bpf_verifier_env *env, int regno)
 static int mark_chain_precision_batch(struct bpf_verifier_env *env,
                                      struct bpf_verifier_state *starting_state)
 {
-       return __mark_chain_precision(env, starting_state, -1);
+       return __mark_chain_precision(env, starting_state, -1, NULL);
 }
 
 static bool is_spillable_regtype(enum bpf_reg_type type)
@@ -18893,7 +18902,9 @@ static int propagate_liveness(struct bpf_verifier_env *env,
  * propagate them into the current state
  */
 static int propagate_precision(struct bpf_verifier_env *env,
-                              const struct bpf_verifier_state *old)
+                              const struct bpf_verifier_state *old,
+                              struct bpf_verifier_state *cur,
+                              bool *changed)
 {
        struct bpf_reg_state *state_reg;
        struct bpf_func_state *state;
@@ -18941,7 +18952,7 @@ static int propagate_precision(struct bpf_verifier_env *env,
                        verbose(env, "\n");
        }
 
-       err = mark_chain_precision_batch(env, env->cur_state);
+       err = __mark_chain_precision(env, cur, -1, changed);
        if (err < 0)
                return err;
 
@@ -19264,7 +19275,7 @@ hit:
                         */
                        if (is_jmp_point(env, env->insn_idx))
                                err = err ? : push_jmp_history(env, cur, 0, 0);
-                       err = err ? : propagate_precision(env, &sl->state);
+                       err = err ? : propagate_precision(env, &sl->state, cur, NULL);
                        if (err)
                                return err;
                        return 1;