Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge branch 'bpf-fix-precision-backtracking-bug-with-linked-registers'

Eduard Zingerman says:

====================
bpf: Fix precision backtracking bug with linked registers

Emil Tsalapatis reported a verifier bug hit by the scx_lavd sched_ext
scheduler. The essential part of the verifier log looks as follows:

436: ...
// checkpoint hit for 438: (1d) if r7 == r8 goto ...
frame 3: propagating r2,r7,r8
frame 2: propagating r6
mark_precise: frame3: last_idx ...
mark_precise: frame3: regs=r2,r7,r8 stack= before 436: ...
mark_precise: frame3: regs=r2,r7 stack= before 435: ...
mark_precise: frame3: regs=r2,r7 stack= before 434: (85) call bpf_trace_vprintk#177
verifier bug: backtracking call unexpected regs 84

The log complains that registers r2 and r7 are tracked as precise
while processing the bpf_trace_vprintk() call in precision backtracking.
This can't be right, as r2 is reset by the call and there is nothing
to backtrack it to. The precision propagation is triggered when
a checkpoint is hit at instruction 438, r2 is dead at that instruction.

This happens because of the following sequence of events:
- Instruction 438 is first reached with registers r2 and r7 having
the same id via a path that does not call bpf_trace_vprintk():
- Checkpoint is created at 438.
- The jump at 438 is predicted, hence r7 and registers linked to it
(r2) are propagated as precise, marking r2 and r7 precise in the
checkpoint.
- Instruction 438 is reached a second time with r2 undefined and via
a path that calls bpf_trace_vprintk():
- Checkpoint is hit.
- propagate_precision() picks registers r2 and r7 and propagates
precision marks for those up to the helper call.

The root cause is the fact that states_equal() and
propagate_precision() assume that the precision flag can't be set for a
dead register (as computed by compute_live_registers()).
However, this is not the case when linked registers are at play.
Fix this by accounting for live register flags in
collect_linked_regs().
---
====================

Link: https://patch.msgid.link/20260306-linked-regs-and-propagate-precision-v1-0-18e859be570d@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>

+137 -38
+10 -3
kernel/bpf/verifier.c
··· 17359 17359 * in verifier state, save R in linked_regs if R->id == id. 17360 17360 * If there are too many Rs sharing same id, reset id for leftover Rs. 17361 17361 */ 17362 - static void collect_linked_regs(struct bpf_verifier_state *vstate, u32 id, 17362 + static void collect_linked_regs(struct bpf_verifier_env *env, 17363 + struct bpf_verifier_state *vstate, 17364 + u32 id, 17363 17365 struct linked_regs *linked_regs) 17364 17366 { 17367 + struct bpf_insn_aux_data *aux = env->insn_aux_data; 17365 17368 struct bpf_func_state *func; 17366 17369 struct bpf_reg_state *reg; 17370 + u16 live_regs; 17367 17371 int i, j; 17368 17372 17369 17373 id = id & ~BPF_ADD_CONST; 17370 17374 for (i = vstate->curframe; i >= 0; i--) { 17375 + live_regs = aux[frame_insn_idx(vstate, i)].live_regs_before; 17371 17376 func = vstate->frame[i]; 17372 17377 for (j = 0; j < BPF_REG_FP; j++) { 17378 + if (!(live_regs & BIT(j))) 17379 + continue; 17373 17380 reg = &func->regs[j]; 17374 17381 __collect_linked_regs(linked_regs, reg, id, i, j, true); 17375 17382 } ··· 17591 17584 * if parent state is created. 17592 17585 */ 17593 17586 if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id) 17594 - collect_linked_regs(this_branch, src_reg->id, &linked_regs); 17587 + collect_linked_regs(env, this_branch, src_reg->id, &linked_regs); 17595 17588 if (dst_reg->type == SCALAR_VALUE && dst_reg->id) 17596 - collect_linked_regs(this_branch, dst_reg->id, &linked_regs); 17589 + collect_linked_regs(env, this_branch, dst_reg->id, &linked_regs); 17597 17590 if (linked_regs.cnt > 1) { 17598 17591 err = push_jmp_history(env, this_branch, 0, linked_regs_pack(&linked_regs)); 17599 17592 if (err)
+17 -17
tools/testing/selftests/bpf/progs/exceptions_assert.c
··· 18 18 return *(u64 *)num; \ 19 19 } 20 20 21 - __msg(": R0=0xffffffff80000000") 21 + __msg("R{{.}}=0xffffffff80000000") 22 22 check_assert(s64, ==, eq_int_min, INT_MIN); 23 - __msg(": R0=0x7fffffff") 23 + __msg("R{{.}}=0x7fffffff") 24 24 check_assert(s64, ==, eq_int_max, INT_MAX); 25 - __msg(": R0=0") 25 + __msg("R{{.}}=0") 26 26 check_assert(s64, ==, eq_zero, 0); 27 - __msg(": R0=0x8000000000000000 R1=0x8000000000000000") 27 + __msg("R{{.}}=0x8000000000000000") 28 28 check_assert(s64, ==, eq_llong_min, LLONG_MIN); 29 - __msg(": R0=0x7fffffffffffffff R1=0x7fffffffffffffff") 29 + __msg("R{{.}}=0x7fffffffffffffff") 30 30 check_assert(s64, ==, eq_llong_max, LLONG_MAX); 31 31 32 - __msg(": R0=scalar(id=1,smax=0x7ffffffe)") 32 + __msg("R{{.}}=scalar(id=1,smax=0x7ffffffe)") 33 33 check_assert(s64, <, lt_pos, INT_MAX); 34 - __msg(": R0=scalar(id=1,smax=-1,umin=0x8000000000000000,var_off=(0x8000000000000000; 0x7fffffffffffffff))") 34 + __msg("R{{.}}=scalar(id=1,smax=-1,umin=0x8000000000000000,var_off=(0x8000000000000000; 0x7fffffffffffffff))") 35 35 check_assert(s64, <, lt_zero, 0); 36 - __msg(": R0=scalar(id=1,smax=0xffffffff7fffffff") 36 + __msg("R{{.}}=scalar(id=1,smax=0xffffffff7fffffff") 37 37 check_assert(s64, <, lt_neg, INT_MIN); 38 38 39 - __msg(": R0=scalar(id=1,smax=0x7fffffff)") 39 + __msg("R{{.}}=scalar(id=1,smax=0x7fffffff)") 40 40 check_assert(s64, <=, le_pos, INT_MAX); 41 - __msg(": R0=scalar(id=1,smax=0)") 41 + __msg("R{{.}}=scalar(id=1,smax=0)") 42 42 check_assert(s64, <=, le_zero, 0); 43 - __msg(": R0=scalar(id=1,smax=0xffffffff80000000") 43 + __msg("R{{.}}=scalar(id=1,smax=0xffffffff80000000") 44 44 check_assert(s64, <=, le_neg, INT_MIN); 45 45 46 - __msg(": R0=scalar(id=1,smin=umin=0x80000000,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 46 + __msg("R{{.}}=scalar(id=1,smin=umin=0x80000000,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 47 47 check_assert(s64, >, gt_pos, INT_MAX); 48 - __msg(": R0=scalar(id=1,smin=umin=1,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 48 + __msg("R{{.}}=scalar(id=1,smin=umin=1,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 49 49 check_assert(s64, >, gt_zero, 0); 50 - __msg(": R0=scalar(id=1,smin=0xffffffff80000001") 50 + __msg("R{{.}}=scalar(id=1,smin=0xffffffff80000001") 51 51 check_assert(s64, >, gt_neg, INT_MIN); 52 52 53 - __msg(": R0=scalar(id=1,smin=umin=0x7fffffff,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 53 + __msg("R{{.}}=scalar(id=1,smin=umin=0x7fffffff,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 54 54 check_assert(s64, >=, ge_pos, INT_MAX); 55 - __msg(": R0=scalar(id=1,smin=0,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 55 + __msg("R{{.}}=scalar(id=1,smin=0,umax=0x7fffffffffffffff,var_off=(0x0; 0x7fffffffffffffff))") 56 56 check_assert(s64, >=, ge_zero, 0); 57 - __msg(": R0=scalar(id=1,smin=0xffffffff80000000") 57 + __msg("R{{.}}=scalar(id=1,smin=0xffffffff80000000") 58 58 check_assert(s64, >=, ge_neg, INT_MIN); 59 59 60 60 SEC("?tc")
+64
tools/testing/selftests/bpf/progs/verifier_linked_scalars.c
··· 363 363 __sink(path[0]); 364 364 } 365 365 366 + void dummy_calls(void) 367 + { 368 + bpf_iter_num_new(0, 0, 0); 369 + bpf_iter_num_next(0); 370 + bpf_iter_num_destroy(0); 371 + } 372 + 373 + SEC("socket") 374 + __success 375 + __flag(BPF_F_TEST_STATE_FREQ) 376 + int spurious_precision_marks(void *ctx) 377 + { 378 + struct bpf_iter_num iter; 379 + 380 + asm volatile( 381 + "r1 = %[iter];" 382 + "r2 = 0;" 383 + "r3 = 10;" 384 + "call %[bpf_iter_num_new];" 385 + "1:" 386 + "r1 = %[iter];" 387 + "call %[bpf_iter_num_next];" 388 + "if r0 == 0 goto 4f;" 389 + "r7 = *(u32 *)(r0 + 0);" 390 + "r8 = *(u32 *)(r0 + 0);" 391 + /* This jump can't be predicted and does not change r7 or r8 state. */ 392 + "if r7 > r8 goto 2f;" 393 + /* Branch explored first ties r2 and r7 as having the same id. */ 394 + "r2 = r7;" 395 + "goto 3f;" 396 + "2:" 397 + /* Branch explored second does not tie r2 and r7 but has a function call. */ 398 + "call %[bpf_get_prandom_u32];" 399 + "3:" 400 + /* 401 + * A checkpoint. 402 + * When first branch is explored, this would inject linked registers 403 + * r2 and r7 into the jump history. 404 + * When second branch is explored, this would be a cache hit point, 405 + * triggering propagate_precision(). 406 + */ 407 + "if r7 <= 42 goto +0;" 408 + /* 409 + * Mark r7 as precise using an if condition that is always true. 410 + * When reached via the second branch, this triggered a bug in the backtrack_insn() 411 + * because r2 (tied to r7) was propagated as precise to a call. 412 + */ 413 + "if r7 <= 0xffffFFFF goto +0;" 414 + "goto 1b;" 415 + "4:" 416 + "r1 = %[iter];" 417 + "call %[bpf_iter_num_destroy];" 418 + : 419 + : __imm_ptr(iter), 420 + __imm(bpf_iter_num_new), 421 + __imm(bpf_iter_num_next), 422 + __imm(bpf_iter_num_destroy), 423 + __imm(bpf_get_prandom_u32) 424 + : __clobber_common, "r7", "r8" 425 + ); 426 + 427 + return 0; 428 + } 429 + 366 430 char _license[] SEC("license") = "GPL";
+42 -14
tools/testing/selftests/bpf/progs/verifier_scalar_ids.c
··· 40 40 */ 41 41 "r3 = r10;" 42 42 "r3 += r0;" 43 + /* Mark r1 and r2 as alive. */ 44 + "r1 = r1;" 45 + "r2 = r2;" 43 46 "r0 = 0;" 44 47 "exit;" 45 48 : ··· 76 73 */ 77 74 "r4 = r10;" 78 75 "r4 += r0;" 76 + /* Mark r1 and r2 as alive. */ 77 + "r1 = r1;" 78 + "r2 = r2;" 79 79 "r0 = 0;" 80 80 "exit;" 81 81 : ··· 112 106 */ 113 107 "r4 = r10;" 114 108 "r4 += r3;" 109 + /* Mark r1 and r2 as alive. */ 110 + "r0 = r0;" 111 + "r1 = r1;" 112 + "r2 = r2;" 115 113 "r0 = 0;" 116 114 "exit;" 117 115 : ··· 153 143 */ 154 144 "r3 = r10;" 155 145 "r3 += r0;" 146 + /* Mark r1 and r2 as alive. */ 147 + "r1 = r1;" 148 + "r2 = r2;" 156 149 "r0 = 0;" 157 150 "exit;" 158 151 : ··· 169 156 */ 170 157 SEC("socket") 171 158 __success __log_level(2) 172 - __msg("12: (0f) r2 += r1") 159 + __msg("17: (0f) r2 += r1") 173 160 /* Current state */ 174 - __msg("frame2: last_idx 12 first_idx 11 subseq_idx -1 ") 175 - __msg("frame2: regs=r1 stack= before 11: (bf) r2 = r10") 161 + __msg("frame2: last_idx 17 first_idx 14 subseq_idx -1 ") 162 + __msg("frame2: regs=r1 stack= before 16: (bf) r2 = r10") 176 163 __msg("frame2: parent state regs=r1 stack=") 177 164 __msg("frame1: parent state regs= stack=") 178 165 __msg("frame0: parent state regs= stack=") 179 166 /* Parent state */ 180 - __msg("frame2: last_idx 10 first_idx 10 subseq_idx 11 ") 181 - __msg("frame2: regs=r1 stack= before 10: (25) if r1 > 0x7 goto pc+0") 167 + __msg("frame2: last_idx 13 first_idx 13 subseq_idx 14 ") 168 + __msg("frame2: regs=r1 stack= before 13: (25) if r1 > 0x7 goto pc+0") 182 169 __msg("frame2: parent state regs=r1 stack=") 183 170 /* frame1.r{6,7} are marked because mark_precise_scalar_ids() 184 171 * looks for all registers with frame2.r1.id in the current state ··· 186 173 __msg("frame1: parent state regs=r6,r7 stack=") 187 174 __msg("frame0: parent state regs=r6 stack=") 188 175 /* Parent state */ 189 - __msg("frame2: last_idx 8 first_idx 8 subseq_idx 10") 190 - __msg("frame2: regs=r1 stack= before 8: (85) call pc+1") 176 + __msg("frame2: last_idx 9 first_idx 9 subseq_idx 13") 177 + __msg("frame2: regs=r1 stack= before 9: (85) call pc+3") 191 178 /* frame1.r1 is marked because of backtracking of call instruction */ 192 179 __msg("frame1: parent state regs=r1,r6,r7 stack=") 193 180 __msg("frame0: parent state regs=r6 stack=") 194 181 /* Parent state */ 195 - __msg("frame1: last_idx 7 first_idx 6 subseq_idx 8") 196 - __msg("frame1: regs=r1,r6,r7 stack= before 7: (bf) r7 = r1") 197 - __msg("frame1: regs=r1,r6 stack= before 6: (bf) r6 = r1") 182 + __msg("frame1: last_idx 8 first_idx 7 subseq_idx 9") 183 + __msg("frame1: regs=r1,r6,r7 stack= before 8: (bf) r7 = r1") 184 + __msg("frame1: regs=r1,r6 stack= before 7: (bf) r6 = r1") 198 185 __msg("frame1: parent state regs=r1 stack=") 199 186 __msg("frame0: parent state regs=r6 stack=") 200 187 /* Parent state */ 201 - __msg("frame1: last_idx 4 first_idx 4 subseq_idx 6") 202 - __msg("frame1: regs=r1 stack= before 4: (85) call pc+1") 188 + __msg("frame1: last_idx 4 first_idx 4 subseq_idx 7") 189 + __msg("frame1: regs=r1 stack= before 4: (85) call pc+2") 203 190 __msg("frame0: parent state regs=r1,r6 stack=") 204 191 /* Parent state */ 205 192 __msg("frame0: last_idx 3 first_idx 1 subseq_idx 4") ··· 217 204 "r1 = r0;" 218 205 "r6 = r0;" 219 206 "call precision_many_frames__foo;" 207 + "r6 = r6;" /* mark r6 as live */ 220 208 "exit;" 221 209 : 222 210 : __imm(bpf_ktime_get_ns) ··· 234 220 "r6 = r1;" 235 221 "r7 = r1;" 236 222 "call precision_many_frames__bar;" 223 + "r6 = r6;" /* mark r6 as live */ 224 + "r7 = r7;" /* mark r7 as live */ 237 225 "exit" 238 226 ::: __clobber_all); 239 227 } ··· 245 229 { 246 230 asm volatile ( 247 231 "if r1 > 7 goto +0;" 232 + "r6 = 0;" /* mark r6 as live */ 233 + "r7 = 0;" /* mark r7 as live */ 248 234 /* force r1 to be precise, this eventually marks: 249 235 * - bar frame r1 250 236 * - foo frame r{1,6,7} ··· 358 340 "r3 += r7;" 359 341 /* force r9 to be precise, this also marks r8 */ 360 342 "r3 += r9;" 343 + "r6 = r6;" /* mark r6 as live */ 344 + "r8 = r8;" /* mark r8 as live */ 361 345 "exit;" 362 346 : 363 347 : __imm(bpf_ktime_get_ns) ··· 373 353 * collect_linked_regs() can't tie more than 6 registers for a single insn. 374 354 */ 375 355 __msg("8: (25) if r0 > 0x7 goto pc+0 ; R0=scalar(id=1") 376 - __msg("9: (bf) r6 = r6 ; R6=scalar(id=2") 356 + __msg("14: (bf) r6 = r6 ; R6=scalar(id=2") 377 357 /* check that r{0-5} are marked precise after 'if' */ 378 358 __msg("frame0: regs=r0 stack= before 8: (25) if r0 > 0x7 goto pc+0") 379 359 __msg("frame0: parent state regs=r0,r1,r2,r3,r4,r5 stack=:") ··· 392 372 "r6 = r0;" 393 373 /* propagate range for r{0-6} */ 394 374 "if r0 > 7 goto +0;" 375 + /* keep r{1-5} live */ 376 + "r1 = r1;" 377 + "r2 = r2;" 378 + "r3 = r3;" 379 + "r4 = r4;" 380 + "r5 = r5;" 395 381 /* make r6 appear in the log */ 396 382 "r6 = r6;" 397 383 /* force r0 to be precise, ··· 543 517 "*(u64*)(r10 - 8) = r1;" 544 518 /* r9 = pointer to stack */ 545 519 "r9 = r10;" 546 - "r9 += -8;" 520 + "r9 += -16;" 547 521 /* r8 = ktime_get_ns() */ 548 522 "call %[bpf_ktime_get_ns];" 549 523 "r8 = r0;" ··· 564 538 "if r7 > 4 goto l2_%=;" 565 539 /* Access memory at r9[r6] */ 566 540 "r9 += r6;" 541 + "r9 += r7;" 542 + "r9 += r8;" 567 543 "r0 = *(u8*)(r9 + 0);" 568 544 "l2_%=:" 569 545 "r0 = 0;"
+4 -4
tools/testing/selftests/bpf/verifier/precise.c
··· 44 44 mark_precise: frame0: regs=r2 stack= before 23\ 45 45 mark_precise: frame0: regs=r2 stack= before 22\ 46 46 mark_precise: frame0: regs=r2 stack= before 20\ 47 - mark_precise: frame0: parent state regs=r2,r9 stack=:\ 47 + mark_precise: frame0: parent state regs=r2 stack=:\ 48 48 mark_precise: frame0: last_idx 19 first_idx 10\ 49 - mark_precise: frame0: regs=r2,r9 stack= before 19\ 49 + mark_precise: frame0: regs=r2 stack= before 19\ 50 50 mark_precise: frame0: regs=r9 stack= before 18\ 51 51 mark_precise: frame0: regs=r8,r9 stack= before 17\ 52 52 mark_precise: frame0: regs=r0,r9 stack= before 15\ ··· 107 107 mark_precise: frame0: parent state regs=r2 stack=:\ 108 108 mark_precise: frame0: last_idx 20 first_idx 20\ 109 109 mark_precise: frame0: regs=r2 stack= before 20\ 110 - mark_precise: frame0: parent state regs=r2,r9 stack=:\ 110 + mark_precise: frame0: parent state regs=r2 stack=:\ 111 111 mark_precise: frame0: last_idx 19 first_idx 17\ 112 - mark_precise: frame0: regs=r2,r9 stack= before 19\ 112 + mark_precise: frame0: regs=r2 stack= before 19\ 113 113 mark_precise: frame0: regs=r9 stack= before 18\ 114 114 mark_precise: frame0: regs=r8,r9 stack= before 17\ 115 115 mark_precise: frame0: parent state regs= stack=:",