]> git.ipfire.org Git - thirdparty/kernel/linux.git/commitdiff
selftests/bpf: validate jit behaviour for tail calls
authorEduard Zingerman <eddyz87@gmail.com>
Tue, 20 Aug 2024 10:23:56 +0000 (03:23 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Wed, 21 Aug 2024 18:03:01 +0000 (11:03 -0700)
A program calling sub-program which does a tail call.
The idea is to verify instructions generated by jit for tail calls:
- in program and sub-program prologues;
- for subprogram call instruction;
- for tail call itself.

Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
Link: https://lore.kernel.org/r/20240820102357.3372779-9-eddyz87@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
tools/testing/selftests/bpf/prog_tests/verifier.c
tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c [new file with mode: 0644]

index f8f546eba4885473693662f8d9bb69b3e5d3e839..cf3662dbd24fdbed437153fe7f39ef3be8c52a48 100644 (file)
@@ -75,6 +75,7 @@
 #include "verifier_stack_ptr.skel.h"
 #include "verifier_subprog_precision.skel.h"
 #include "verifier_subreg.skel.h"
+#include "verifier_tailcall_jit.skel.h"
 #include "verifier_typedef.skel.h"
 #include "verifier_uninit.skel.h"
 #include "verifier_unpriv.skel.h"
@@ -198,6 +199,7 @@ void test_verifier_spin_lock(void)            { RUN(verifier_spin_lock); }
 void test_verifier_stack_ptr(void)            { RUN(verifier_stack_ptr); }
 void test_verifier_subprog_precision(void)    { RUN(verifier_subprog_precision); }
 void test_verifier_subreg(void)               { RUN(verifier_subreg); }
+void test_verifier_tailcall_jit(void)         { RUN(verifier_tailcall_jit); }
 void test_verifier_typedef(void)              { RUN(verifier_typedef); }
 void test_verifier_uninit(void)               { RUN(verifier_uninit); }
 void test_verifier_unpriv(void)               { RUN(verifier_unpriv); }
diff --git a/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c b/tools/testing/selftests/bpf/progs/verifier_tailcall_jit.c
new file mode 100644 (file)
index 0000000..06d327c
--- /dev/null
@@ -0,0 +1,105 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <linux/bpf.h>
+#include <bpf/bpf_helpers.h>
+#include "bpf_misc.h"
+
+int main(void);
+
+struct {
+       __uint(type, BPF_MAP_TYPE_PROG_ARRAY);
+       __uint(max_entries, 1);
+       __uint(key_size, sizeof(__u32));
+       __array(values, void (void));
+} jmp_table SEC(".maps") = {
+       .values = {
+               [0] = (void *) &main,
+       },
+};
+
+__noinline __auxiliary
+static __naked int sub(void)
+{
+       asm volatile (
+       "r2 = %[jmp_table] ll;"
+       "r3 = 0;"
+       "call 12;"
+       "exit;"
+       :
+       : __imm_addr(jmp_table)
+       : __clobber_all);
+}
+
+__success
+__arch_x86_64
+/* program entry for main(), regular function prologue */
+__jited("      endbr64")
+__jited("      nopl    (%rax,%rax)")
+__jited("      xorq    %rax, %rax")
+__jited("      pushq   %rbp")
+__jited("      movq    %rsp, %rbp")
+/* tail call prologue for program:
+ * - establish memory location for tail call counter at &rbp[-8];
+ * - spill tail_call_cnt_ptr at &rbp[-16];
+ * - expect tail call counter to be passed in rax;
+ * - for entry program rax is a raw counter, value < 33;
+ * - for tail called program rax is tail_call_cnt_ptr (value > 33).
+ */
+__jited("      endbr64")
+__jited("      cmpq    $0x21, %rax")
+__jited("      ja      L0")
+__jited("      pushq   %rax")
+__jited("      movq    %rsp, %rax")
+__jited("      jmp     L1")
+__jited("L0:   pushq   %rax")                  /* rbp[-8]  = rax         */
+__jited("L1:   pushq   %rax")                  /* rbp[-16] = rax         */
+/* on subprogram call restore rax to be tail_call_cnt_ptr from rbp[-16]
+ * (cause original rax might be clobbered by this point)
+ */
+__jited("      movq    -0x10(%rbp), %rax")
+__jited("      callq   0x{{.*}}")              /* call to sub()          */
+__jited("      xorl    %eax, %eax")
+__jited("      leave")
+__jited("      retq")
+__jited("...")
+/* subprogram entry for sub(), regular function prologue */
+__jited("      endbr64")
+__jited("      nopl    (%rax,%rax)")
+__jited("      nopl    (%rax)")
+__jited("      pushq   %rbp")
+__jited("      movq    %rsp, %rbp")
+/* tail call prologue for subprogram address of tail call counter
+ * stored at rbp[-16].
+ */
+__jited("      endbr64")
+__jited("      pushq   %rax")                  /* rbp[-8]  = rax          */
+__jited("      pushq   %rax")                  /* rbp[-16] = rax          */
+__jited("      movabsq ${{.*}}, %rsi")         /* r2 = &jmp_table         */
+__jited("      xorl    %edx, %edx")            /* r3 = 0                  */
+/* bpf_tail_call implementation:
+ * - load tail_call_cnt_ptr from rbp[-16];
+ * - if *tail_call_cnt_ptr < 33, increment it and jump to target;
+ * - otherwise do nothing.
+ */
+__jited("      movq    -0x10(%rbp), %rax")
+__jited("      cmpq    $0x21, (%rax)")
+__jited("      jae     L0")
+__jited("      nopl    (%rax,%rax)")
+__jited("      addq    $0x1, (%rax)")          /* *tail_call_cnt_ptr += 1 */
+__jited("      popq    %rax")
+__jited("      popq    %rax")
+__jited("      jmp     {{.*}}")                /* jump to tail call tgt   */
+__jited("L0:   leave")
+__jited("      retq")
+SEC("tc")
+__naked int main(void)
+{
+       asm volatile (
+       "call %[sub];"
+       "r0 = 0;"
+       "exit;"
+       :
+       : __imm(sub)
+       : __clobber_all);
+}
+
+char __license[] SEC("license") = "GPL";