]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
bpf: Simplify tnum_step()
authorHao Sun <sunhao.th@gmail.com>
Fri, 20 Mar 2026 16:23:36 +0000 (17:23 +0100)
committerAlexei Starovoitov <ast@kernel.org>
Tue, 24 Mar 2026 15:45:29 +0000 (08:45 -0700)
Simplify tnum_step() from a 10-variable algorithm into a straight
line sequence of bitwise operations.

Problem Reduction:

tnum_step(): Given a tnum `(tval, tmask)` where `tval & tmask == 0`,
and a value `z` with `tval ≤ z < (tval | tmask)`, find the smallest
`r > z`, a tnum-satisfying value, i.e., `r & ~tmask == tval`.

Every tnum-satisfying value has the form tval | s where s is a subset
of tmask bits (s & ~tmask == 0).  Since tval and tmask are disjoint:

    tval | s  =  tval + s

Similarly z = tval + d where d = z - tval, so r > z becomes:

    tval + s  >  tval + d
    s > d

The problem reduces to: find the smallest s, a subset of tmask, such
that s > d.

Notice that `s` must be a subset of tmask, the problem now is simplified.

Algorithm:

The mask bits of `d` form a "counter" that we want to increment by one,
but the counter has gaps at the fixed-bit positions.  A normal +1 would
stop at the first 0-bit it meets; we need it to skip over fixed-bit
gaps and land on the next mask bit.

Step 1 -- plug the gaps:

    d | carry_mask | ~tmask

  - ~tmask fills all fixed-bit positions with 1.
  - carry_mask = (1 << fls64(d & ~tmask)) - 1 fills all positions
    (including mask positions) below the highest non-mask bit of d.

After this, the only remaining 0s are mask bits above the highest
non-mask bit of d where d is also 0 -- exactly the positions where
the carry can validly land.

Step 2 -- increment:

    (d | carry_mask | ~tmask) + 1

Adding 1 flips all trailing 1s to 0 and sets the first 0 to 1.  Since
every gap has been plugged, that first 0 is guaranteed to be a mask bit
above all non-mask bits of d.

Step 3 -- mask:

    ((d | carry_mask | ~tmask) + 1) & tmask

Strip the scaffolding, keeping only mask bits.  Call the result inc.

Step 4 -- result:

    tval | inc

Reattach the fixed bits.

A simple 8-bit example:
    tmask:        1  1  0  1  0  1  1  0
    d:            1  0  1  0  0  0  1  0     (d = 162)
                        ^
                        non-mask 1 at bit 5

With carry_mask = 0b00111111 (smeared from bit 5):

    d|carry|~tm   1  0  1  1  1  1  1  1
    + 1           1  1  0  0  0  0  0  0
    & tmask       1  1  0  0  0  0  0  0

The patch passes my local test: test_verifier, test_progs for
`-t verifier` and `-t reg_bounds`.

CBMC shows the new code is equiv to original one[1], and
a lean4 proof of correctness is available[2]:

theorem tnumStep_correct (tval tmask z : BitVec 64)
    -- Precondition: valid tnum and input z
    (h_consistent : (tval &&& tmask) = 0)
    (h_lo : tval ≤ z)
    (h_hi : z < (tval ||| tmask)) :
    -- Postcondition: r must be:
    --    (1) tnum member
    --    (2) z < r
    --    (3) for any other member w > z, r <= w
    let r := tnumStep tval tmask z
    satisfiesTnum64 r tval tmask ∧
    tval ≤ r ∧ r ≤ (tval ||| tmask) ∧
    z < r ∧
    ∀ w, satisfiesTnum64 w tval tmask → z < w → r ≤ w := by
  -- unfold definition
  unfold tnumStep satisfiesTnum64
  simp only []
  refine ⟨?_, ?_, ?_, ?_, ?_⟩
  -- the solver proves each conjunct
  · bv_decide
  · bv_decide
  · bv_decide
  · bv_decide
  · intro w hw1 hw2; bv_decide

[1] https://github.com/eddyz87/tnum-step-verif/blob/master/main.c
[2] https://pastebin.com/raw/czHKiyY0

Signed-off-by: Hao Sun <hao.sun@inf.ethz.ch>
Acked-by: Eduard Zingerman <eddyz87@gmail.com>
Acked-by: Shung-Hsi Yu <shung-hsi.yu@suse.com>
Reviewed-by: Harishankar Vishwanathan <harishankar.vishwanathan@gmail.com>
Link: https://lore.kernel.org/r/20260320162336.166542-1-hao.sun@inf.ethz.ch
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/tnum.c

index 4abc359b3db0123d7b19e951626a4aef26aeb45b..ec9c310cf5d7f22536b2a485c5a006adc3c73586 100644 (file)
@@ -286,8 +286,7 @@ struct tnum tnum_bswap64(struct tnum a)
  */
 u64 tnum_step(struct tnum t, u64 z)
 {
-       u64 tmax, j, p, q, r, s, v, u, w, res;
-       u8 k;
+       u64 tmax, d, carry_mask, filled, inc;
 
        tmax = t.value | t.mask;
 
@@ -299,29 +298,22 @@ u64 tnum_step(struct tnum t, u64 z)
        if (z < t.value)
                return t.value;
 
-       /* keep t's known bits, and match all unknown bits to z */
-       j = t.value | (z & t.mask);
-
-       if (j > z) {
-               p = ~z & t.value & ~t.mask;
-               k = fls64(p); /* k is the most-significant 0-to-1 flip */
-               q = U64_MAX << k;
-               r = q & z; /* positions > k matched to z */
-               s = ~q & t.value; /* positions <= k matched to t.value */
-               v = r | s;
-               res = v;
-       } else {
-               p = z & ~t.value & ~t.mask;
-               k = fls64(p); /* k is the most-significant 1-to-0 flip */
-               q = U64_MAX << k;
-               r = q & t.mask & z; /* unknown positions > k, matched to z */
-               s = q & ~t.mask; /* known positions > k, set to 1 */
-               v = r | s;
-               /* add 1 to unknown positions > k to make value greater than z */
-               u = v + (1ULL << k);
-               /* extract bits in unknown positions > k from u, rest from t.value */
-               w = (u & t.mask) | t.value;
-               res = w;
-       }
-       return res;
+       /*
+        * Let r be the result tnum member, z = t.value + d.
+        * Every tnum member is t.value | s for some submask s of t.mask,
+        * and since t.value & t.mask == 0, t.value | s == t.value + s.
+        * So r > z becomes s > d where d = z - t.value.
+        *
+        * Find the smallest submask s of t.mask greater than d by
+        * "incrementing d within the mask": fill every non-mask
+        * position with 1 (`filled`) so +1 ripples through the gaps,
+        * then keep only mask bits. `carry_mask` additionally fills
+        * positions below the highest non-mask 1 in d, preventing
+        * it from trapping the carry.
+        */
+       d = z - t.value;
+       carry_mask = (1ULL << fls64(d & ~t.mask)) - 1;
+       filled = d | carry_mask | ~t.mask;
+       inc = (filled + 1) & t.mask;
+       return t.value | inc;
 }