1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
11 #include <trace/events/sock.h>
13 static bool sk_msg_try_coalesce_ok(struct sk_msg
*msg
, int elem_first_coalesce
)
15 if (msg
->sg
.end
> msg
->sg
.start
&&
16 elem_first_coalesce
< msg
->sg
.end
)
19 if (msg
->sg
.end
< msg
->sg
.start
&&
20 (elem_first_coalesce
> msg
->sg
.start
||
21 elem_first_coalesce
< msg
->sg
.end
))
27 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
28 int elem_first_coalesce
)
30 struct page_frag
*pfrag
= sk_page_frag(sk
);
31 u32 osize
= msg
->sg
.size
;
36 struct scatterlist
*sge
;
40 if (!sk_page_frag_refill(sk
, pfrag
)) {
45 orig_offset
= pfrag
->offset
;
46 use
= min_t(int, len
, pfrag
->size
- orig_offset
);
47 if (!sk_wmem_schedule(sk
, use
)) {
53 sk_msg_iter_var_prev(i
);
54 sge
= &msg
->sg
.data
[i
];
56 if (sk_msg_try_coalesce_ok(msg
, elem_first_coalesce
) &&
57 sg_page(sge
) == pfrag
->page
&&
58 sge
->offset
+ sge
->length
== orig_offset
) {
61 if (sk_msg_full(msg
)) {
66 sge
= &msg
->sg
.data
[msg
->sg
.end
];
68 sg_set_page(sge
, pfrag
->page
, use
, orig_offset
);
69 get_page(pfrag
->page
);
70 sk_msg_iter_next(msg
, end
);
73 sk_mem_charge(sk
, use
);
82 sk_msg_trim(sk
, msg
, osize
);
85 EXPORT_SYMBOL_GPL(sk_msg_alloc
);
87 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
90 int i
= src
->sg
.start
;
91 struct scatterlist
*sge
= sk_msg_elem(src
, i
);
92 struct scatterlist
*sgd
= NULL
;
96 if (sge
->length
> off
)
99 sk_msg_iter_var_next(i
);
100 if (i
== src
->sg
.end
&& off
)
102 sge
= sk_msg_elem(src
, i
);
106 sge_len
= sge
->length
- off
;
111 sgd
= sk_msg_elem(dst
, dst
->sg
.end
- 1);
114 (sg_page(sge
) == sg_page(sgd
)) &&
115 (sg_virt(sge
) + off
== sg_virt(sgd
) + sgd
->length
)) {
116 sgd
->length
+= sge_len
;
117 dst
->sg
.size
+= sge_len
;
118 } else if (!sk_msg_full(dst
)) {
119 sge_off
= sge
->offset
+ off
;
120 sk_msg_page_add(dst
, sg_page(sge
), sge_len
, sge_off
);
127 sk_mem_charge(sk
, sge_len
);
128 sk_msg_iter_var_next(i
);
129 if (i
== src
->sg
.end
&& len
)
131 sge
= sk_msg_elem(src
, i
);
136 EXPORT_SYMBOL_GPL(sk_msg_clone
);
138 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
140 int i
= msg
->sg
.start
;
143 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
145 if (bytes
< sge
->length
) {
146 sge
->length
-= bytes
;
147 sge
->offset
+= bytes
;
148 sk_mem_uncharge(sk
, bytes
);
152 sk_mem_uncharge(sk
, sge
->length
);
153 bytes
-= sge
->length
;
156 sk_msg_iter_var_next(i
);
157 } while (bytes
&& i
!= msg
->sg
.end
);
160 EXPORT_SYMBOL_GPL(sk_msg_return_zero
);
162 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
)
164 int i
= msg
->sg
.start
;
167 struct scatterlist
*sge
= &msg
->sg
.data
[i
];
168 int uncharge
= (bytes
< sge
->length
) ? bytes
: sge
->length
;
170 sk_mem_uncharge(sk
, uncharge
);
172 sk_msg_iter_var_next(i
);
173 } while (i
!= msg
->sg
.end
);
175 EXPORT_SYMBOL_GPL(sk_msg_return
);
177 static int sk_msg_free_elem(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
180 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
181 u32 len
= sge
->length
;
183 /* When the skb owns the memory we free it from consume_skb path. */
186 sk_mem_uncharge(sk
, len
);
187 put_page(sg_page(sge
));
189 memset(sge
, 0, sizeof(*sge
));
193 static int __sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
, u32 i
,
196 struct scatterlist
*sge
= sk_msg_elem(msg
, i
);
199 while (msg
->sg
.size
) {
200 msg
->sg
.size
-= sge
->length
;
201 freed
+= sk_msg_free_elem(sk
, msg
, i
, charge
);
202 sk_msg_iter_var_next(i
);
203 sk_msg_check_to_free(msg
, i
, msg
->sg
.size
);
204 sge
= sk_msg_elem(msg
, i
);
206 consume_skb(msg
->skb
);
211 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
)
213 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, false);
215 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge
);
217 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
)
219 return __sk_msg_free(sk
, msg
, msg
->sg
.start
, true);
221 EXPORT_SYMBOL_GPL(sk_msg_free
);
223 static void __sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
,
224 u32 bytes
, bool charge
)
226 struct scatterlist
*sge
;
227 u32 i
= msg
->sg
.start
;
230 sge
= sk_msg_elem(msg
, i
);
233 if (bytes
< sge
->length
) {
235 sk_mem_uncharge(sk
, bytes
);
236 sge
->length
-= bytes
;
237 sge
->offset
+= bytes
;
238 msg
->sg
.size
-= bytes
;
242 msg
->sg
.size
-= sge
->length
;
243 bytes
-= sge
->length
;
244 sk_msg_free_elem(sk
, msg
, i
, charge
);
245 sk_msg_iter_var_next(i
);
246 sk_msg_check_to_free(msg
, i
, bytes
);
251 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
)
253 __sk_msg_free_partial(sk
, msg
, bytes
, true);
255 EXPORT_SYMBOL_GPL(sk_msg_free_partial
);
257 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
260 __sk_msg_free_partial(sk
, msg
, bytes
, false);
263 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
)
265 int trim
= msg
->sg
.size
- len
;
273 sk_msg_iter_var_prev(i
);
275 while (msg
->sg
.data
[i
].length
&&
276 trim
>= msg
->sg
.data
[i
].length
) {
277 trim
-= msg
->sg
.data
[i
].length
;
278 sk_msg_free_elem(sk
, msg
, i
, true);
279 sk_msg_iter_var_prev(i
);
284 msg
->sg
.data
[i
].length
-= trim
;
285 sk_mem_uncharge(sk
, trim
);
286 /* Adjust copybreak if it falls into the trimmed part of last buf */
287 if (msg
->sg
.curr
== i
&& msg
->sg
.copybreak
> msg
->sg
.data
[i
].length
)
288 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
290 sk_msg_iter_var_next(i
);
293 /* If we trim data a full sg elem before curr pointer update
294 * copybreak and current so that any future copy operations
295 * start at new copy location.
296 * However trimed data that has not yet been used in a copy op
297 * does not require an update.
300 msg
->sg
.curr
= msg
->sg
.start
;
301 msg
->sg
.copybreak
= 0;
302 } else if (sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.curr
) >=
303 sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
)) {
304 sk_msg_iter_var_prev(i
);
306 msg
->sg
.copybreak
= msg
->sg
.data
[i
].length
;
309 EXPORT_SYMBOL_GPL(sk_msg_trim
);
311 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
312 struct sk_msg
*msg
, u32 bytes
)
314 int i
, maxpages
, ret
= 0, num_elems
= sk_msg_elem_used(msg
);
315 const int to_max_pages
= MAX_MSG_FRAGS
;
316 struct page
*pages
[MAX_MSG_FRAGS
];
317 ssize_t orig
, copied
, use
, offset
;
322 maxpages
= to_max_pages
- num_elems
;
328 copied
= iov_iter_get_pages2(from
, pages
, bytes
, maxpages
,
336 msg
->sg
.size
+= copied
;
339 use
= min_t(int, copied
, PAGE_SIZE
- offset
);
340 sg_set_page(&msg
->sg
.data
[msg
->sg
.end
],
341 pages
[i
], use
, offset
);
342 sg_unmark_end(&msg
->sg
.data
[msg
->sg
.end
]);
343 sk_mem_charge(sk
, use
);
347 sk_msg_iter_next(msg
, end
);
351 /* When zerocopy is mixed with sk_msg_*copy* operations we
352 * may have a copybreak set in this case clear and prefer
353 * zerocopy remainder when possible.
355 msg
->sg
.copybreak
= 0;
356 msg
->sg
.curr
= msg
->sg
.end
;
359 /* Revert iov_iter updates, msg will need to use 'trim' later if it
360 * also needs to be cleared.
363 iov_iter_revert(from
, msg
->sg
.size
- orig
);
366 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter
);
368 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
369 struct sk_msg
*msg
, u32 bytes
)
371 int ret
= -ENOSPC
, i
= msg
->sg
.curr
;
372 struct scatterlist
*sge
;
377 sge
= sk_msg_elem(msg
, i
);
378 /* This is possible if a trim operation shrunk the buffer */
379 if (msg
->sg
.copybreak
>= sge
->length
) {
380 msg
->sg
.copybreak
= 0;
381 sk_msg_iter_var_next(i
);
382 if (i
== msg
->sg
.end
)
384 sge
= sk_msg_elem(msg
, i
);
387 buf_size
= sge
->length
- msg
->sg
.copybreak
;
388 copy
= (buf_size
> bytes
) ? bytes
: buf_size
;
389 to
= sg_virt(sge
) + msg
->sg
.copybreak
;
390 msg
->sg
.copybreak
+= copy
;
391 if (sk
->sk_route_caps
& NETIF_F_NOCACHE_COPY
)
392 ret
= copy_from_iter_nocache(to
, copy
, from
);
394 ret
= copy_from_iter(to
, copy
, from
);
402 msg
->sg
.copybreak
= 0;
403 sk_msg_iter_var_next(i
);
404 } while (i
!= msg
->sg
.end
);
409 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter
);
411 /* Receive sk_msg from psock->ingress_msg to @msg. */
412 int sk_msg_recvmsg(struct sock
*sk
, struct sk_psock
*psock
, struct msghdr
*msg
,
415 struct iov_iter
*iter
= &msg
->msg_iter
;
416 int peek
= flags
& MSG_PEEK
;
417 struct sk_msg
*msg_rx
;
420 msg_rx
= sk_psock_peek_msg(psock
);
421 while (copied
!= len
) {
422 struct scatterlist
*sge
;
424 if (unlikely(!msg_rx
))
427 i
= msg_rx
->sg
.start
;
432 sge
= sk_msg_elem(msg_rx
, i
);
435 if (copied
+ copy
> len
)
437 copy
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
439 copied
= copied
? copied
: -EFAULT
;
448 sk_mem_uncharge(sk
, copy
);
449 msg_rx
->sg
.size
-= copy
;
452 sk_msg_iter_var_next(i
);
457 /* Lets not optimize peek case if copy_page_to_iter
458 * didn't copy the entire length lets just break.
460 if (copy
!= sge
->length
)
462 sk_msg_iter_var_next(i
);
467 } while ((i
!= msg_rx
->sg
.end
) && !sg_is_last(sge
));
469 if (unlikely(peek
)) {
470 msg_rx
= sk_psock_next_msg(psock
, msg_rx
);
476 msg_rx
->sg
.start
= i
;
477 if (!sge
->length
&& (i
== msg_rx
->sg
.end
|| sg_is_last(sge
))) {
478 msg_rx
= sk_psock_dequeue_msg(psock
);
479 kfree_sk_msg(msg_rx
);
481 msg_rx
= sk_psock_peek_msg(psock
);
486 EXPORT_SYMBOL_GPL(sk_msg_recvmsg
);
488 bool sk_msg_is_readable(struct sock
*sk
)
490 struct sk_psock
*psock
;
494 psock
= sk_psock(sk
);
496 empty
= list_empty(&psock
->ingress_msg
);
500 EXPORT_SYMBOL_GPL(sk_msg_is_readable
);
502 static struct sk_msg
*alloc_sk_msg(gfp_t gfp
)
506 msg
= kzalloc(sizeof(*msg
), gfp
| __GFP_NOWARN
);
509 sg_init_marker(msg
->sg
.data
, NR_MSG_FRAG_IDS
);
513 static struct sk_msg
*sk_psock_create_ingress_msg(struct sock
*sk
,
516 if (atomic_read(&sk
->sk_rmem_alloc
) > sk
->sk_rcvbuf
)
519 if (!sk_rmem_schedule(sk
, skb
, skb
->truesize
))
522 return alloc_sk_msg(GFP_KERNEL
);
525 static int sk_psock_skb_ingress_enqueue(struct sk_buff
*skb
,
527 struct sk_psock
*psock
,
533 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, off
, len
);
535 /* skb linearize may fail with ENOMEM, but lets simply try again
536 * later if this happens. Under memory pressure we don't want to
537 * drop the skb. We need to linearize the skb so that the mapping
538 * in skb_to_sgvec can not error.
540 if (skb_linearize(skb
))
543 num_sge
= skb_to_sgvec(skb
, msg
->sg
.data
, off
, len
);
544 if (unlikely(num_sge
< 0))
550 msg
->sg
.size
= copied
;
551 msg
->sg
.end
= num_sge
;
554 sk_psock_queue_msg(psock
, msg
);
555 sk_psock_data_ready(sk
, psock
);
559 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
,
562 static int sk_psock_skb_ingress(struct sk_psock
*psock
, struct sk_buff
*skb
,
565 struct sock
*sk
= psock
->sk
;
569 /* If we are receiving on the same sock skb->sk is already assigned,
570 * skip memory accounting and owner transition seeing it already set
573 if (unlikely(skb
->sk
== sk
))
574 return sk_psock_skb_ingress_self(psock
, skb
, off
, len
);
575 msg
= sk_psock_create_ingress_msg(sk
, skb
);
579 /* This will transition ownership of the data from the socket where
580 * the BPF program was run initiating the redirect to the socket
581 * we will eventually receive this data on. The data will be released
582 * from skb_consume found in __tcp_bpf_recvmsg() after its been copied
585 skb_set_owner_r(skb
, sk
);
586 err
= sk_psock_skb_ingress_enqueue(skb
, off
, len
, psock
, sk
, msg
);
592 /* Puts an skb on the ingress queue of the socket already assigned to the
593 * skb. In this case we do not need to check memory limits or skb_set_owner_r
594 * because the skb is already accounted for here.
596 static int sk_psock_skb_ingress_self(struct sk_psock
*psock
, struct sk_buff
*skb
,
599 struct sk_msg
*msg
= alloc_sk_msg(GFP_ATOMIC
);
600 struct sock
*sk
= psock
->sk
;
605 skb_set_owner_r(skb
, sk
);
606 err
= sk_psock_skb_ingress_enqueue(skb
, off
, len
, psock
, sk
, msg
);
612 static int sk_psock_handle_skb(struct sk_psock
*psock
, struct sk_buff
*skb
,
613 u32 off
, u32 len
, bool ingress
)
618 if (!sock_writeable(psock
->sk
))
620 return skb_send_sock(psock
->sk
, skb
, off
, len
);
623 err
= sk_psock_skb_ingress(psock
, skb
, off
, len
);
629 static void sk_psock_skb_state(struct sk_psock
*psock
,
630 struct sk_psock_work_state
*state
,
633 spin_lock_bh(&psock
->ingress_lock
);
634 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
638 spin_unlock_bh(&psock
->ingress_lock
);
641 static void sk_psock_backlog(struct work_struct
*work
)
643 struct delayed_work
*dwork
= to_delayed_work(work
);
644 struct sk_psock
*psock
= container_of(dwork
, struct sk_psock
, work
);
645 struct sk_psock_work_state
*state
= &psock
->work_state
;
646 struct sk_buff
*skb
= NULL
;
647 u32 len
= 0, off
= 0;
651 mutex_lock(&psock
->work_mutex
);
652 if (unlikely(state
->len
)) {
657 while ((skb
= skb_peek(&psock
->ingress_skb
))) {
660 if (skb_bpf_strparser(skb
)) {
661 struct strp_msg
*stm
= strp_msg(skb
);
666 ingress
= skb_bpf_ingress(skb
);
667 skb_bpf_redirect_clear(skb
);
670 if (!sock_flag(psock
->sk
, SOCK_DEAD
))
671 ret
= sk_psock_handle_skb(psock
, skb
, off
,
674 if (ret
== -EAGAIN
) {
675 sk_psock_skb_state(psock
, state
, len
, off
);
677 /* Delay slightly to prioritize any
678 * other work that might be here.
680 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
681 schedule_delayed_work(&psock
->work
, 1);
684 /* Hard errors break pipe and stop xmit. */
685 sk_psock_report_error(psock
, ret
? -ret
: EPIPE
);
686 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
693 skb
= skb_dequeue(&psock
->ingress_skb
);
697 mutex_unlock(&psock
->work_mutex
);
700 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
)
702 struct sk_psock
*psock
;
705 write_lock_bh(&sk
->sk_callback_lock
);
707 if (sk_is_inet(sk
) && inet_csk_has_ulp(sk
)) {
708 psock
= ERR_PTR(-EINVAL
);
712 if (sk
->sk_user_data
) {
713 psock
= ERR_PTR(-EBUSY
);
717 psock
= kzalloc_node(sizeof(*psock
), GFP_ATOMIC
| __GFP_NOWARN
, node
);
719 psock
= ERR_PTR(-ENOMEM
);
723 prot
= READ_ONCE(sk
->sk_prot
);
725 psock
->eval
= __SK_NONE
;
726 psock
->sk_proto
= prot
;
727 psock
->saved_unhash
= prot
->unhash
;
728 psock
->saved_destroy
= prot
->destroy
;
729 psock
->saved_close
= prot
->close
;
730 psock
->saved_write_space
= sk
->sk_write_space
;
732 INIT_LIST_HEAD(&psock
->link
);
733 spin_lock_init(&psock
->link_lock
);
735 INIT_DELAYED_WORK(&psock
->work
, sk_psock_backlog
);
736 mutex_init(&psock
->work_mutex
);
737 INIT_LIST_HEAD(&psock
->ingress_msg
);
738 spin_lock_init(&psock
->ingress_lock
);
739 skb_queue_head_init(&psock
->ingress_skb
);
741 sk_psock_set_state(psock
, SK_PSOCK_TX_ENABLED
);
742 refcount_set(&psock
->refcnt
, 1);
744 __rcu_assign_sk_user_data_with_flags(sk
, psock
,
745 SK_USER_DATA_NOCOPY
|
750 write_unlock_bh(&sk
->sk_callback_lock
);
753 EXPORT_SYMBOL_GPL(sk_psock_init
);
755 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
)
757 struct sk_psock_link
*link
;
759 spin_lock_bh(&psock
->link_lock
);
760 link
= list_first_entry_or_null(&psock
->link
, struct sk_psock_link
,
763 list_del(&link
->list
);
764 spin_unlock_bh(&psock
->link_lock
);
768 static void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
)
770 struct sk_msg
*msg
, *tmp
;
772 list_for_each_entry_safe(msg
, tmp
, &psock
->ingress_msg
, list
) {
773 list_del(&msg
->list
);
774 sk_msg_free(psock
->sk
, msg
);
779 static void __sk_psock_zap_ingress(struct sk_psock
*psock
)
783 while ((skb
= skb_dequeue(&psock
->ingress_skb
)) != NULL
) {
784 skb_bpf_redirect_clear(skb
);
785 sock_drop(psock
->sk
, skb
);
787 __sk_psock_purge_ingress_msg(psock
);
790 static void sk_psock_link_destroy(struct sk_psock
*psock
)
792 struct sk_psock_link
*link
, *tmp
;
794 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
795 list_del(&link
->list
);
796 sk_psock_free_link(link
);
800 void sk_psock_stop(struct sk_psock
*psock
)
802 spin_lock_bh(&psock
->ingress_lock
);
803 sk_psock_clear_state(psock
, SK_PSOCK_TX_ENABLED
);
804 sk_psock_cork_free(psock
);
805 spin_unlock_bh(&psock
->ingress_lock
);
808 static void sk_psock_done_strp(struct sk_psock
*psock
);
810 static void sk_psock_destroy(struct work_struct
*work
)
812 struct sk_psock
*psock
= container_of(to_rcu_work(work
),
813 struct sk_psock
, rwork
);
814 /* No sk_callback_lock since already detached. */
816 sk_psock_done_strp(psock
);
818 cancel_delayed_work_sync(&psock
->work
);
819 __sk_psock_zap_ingress(psock
);
820 mutex_destroy(&psock
->work_mutex
);
822 psock_progs_drop(&psock
->progs
);
824 sk_psock_link_destroy(psock
);
825 sk_psock_cork_free(psock
);
828 sock_put(psock
->sk_redir
);
833 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
)
835 write_lock_bh(&sk
->sk_callback_lock
);
836 sk_psock_restore_proto(sk
, psock
);
837 rcu_assign_sk_user_data(sk
, NULL
);
838 if (psock
->progs
.stream_parser
)
839 sk_psock_stop_strp(sk
, psock
);
840 else if (psock
->progs
.stream_verdict
|| psock
->progs
.skb_verdict
)
841 sk_psock_stop_verdict(sk
, psock
);
842 write_unlock_bh(&sk
->sk_callback_lock
);
844 sk_psock_stop(psock
);
846 INIT_RCU_WORK(&psock
->rwork
, sk_psock_destroy
);
847 queue_rcu_work(system_wq
, &psock
->rwork
);
849 EXPORT_SYMBOL_GPL(sk_psock_drop
);
851 static int sk_psock_map_verd(int verdict
, bool redir
)
855 return redir
? __SK_REDIRECT
: __SK_PASS
;
864 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
867 struct bpf_prog
*prog
;
871 prog
= READ_ONCE(psock
->progs
.msg_parser
);
872 if (unlikely(!prog
)) {
877 sk_msg_compute_data_pointers(msg
);
879 ret
= bpf_prog_run_pin_on_cpu(prog
, msg
);
880 ret
= sk_psock_map_verd(ret
, msg
->sk_redir
);
881 psock
->apply_bytes
= msg
->apply_bytes
;
882 if (ret
== __SK_REDIRECT
) {
883 if (psock
->sk_redir
) {
884 sock_put(psock
->sk_redir
);
885 psock
->sk_redir
= NULL
;
887 if (!msg
->sk_redir
) {
891 psock
->redir_ingress
= sk_msg_to_ingress(msg
);
892 psock
->sk_redir
= msg
->sk_redir
;
893 sock_hold(psock
->sk_redir
);
899 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict
);
901 static int sk_psock_skb_redirect(struct sk_psock
*from
, struct sk_buff
*skb
)
903 struct sk_psock
*psock_other
;
904 struct sock
*sk_other
;
906 sk_other
= skb_bpf_redirect_fetch(skb
);
907 /* This error is a buggy BPF program, it returned a redirect
908 * return code, but then didn't set a redirect interface.
910 if (unlikely(!sk_other
)) {
911 skb_bpf_redirect_clear(skb
);
912 sock_drop(from
->sk
, skb
);
915 psock_other
= sk_psock(sk_other
);
916 /* This error indicates the socket is being torn down or had another
917 * error that caused the pipe to break. We can't send a packet on
918 * a socket that is in this state so we drop the skb.
920 if (!psock_other
|| sock_flag(sk_other
, SOCK_DEAD
)) {
921 skb_bpf_redirect_clear(skb
);
922 sock_drop(from
->sk
, skb
);
925 spin_lock_bh(&psock_other
->ingress_lock
);
926 if (!sk_psock_test_state(psock_other
, SK_PSOCK_TX_ENABLED
)) {
927 spin_unlock_bh(&psock_other
->ingress_lock
);
928 skb_bpf_redirect_clear(skb
);
929 sock_drop(from
->sk
, skb
);
933 skb_queue_tail(&psock_other
->ingress_skb
, skb
);
934 schedule_delayed_work(&psock_other
->work
, 0);
935 spin_unlock_bh(&psock_other
->ingress_lock
);
939 static void sk_psock_tls_verdict_apply(struct sk_buff
*skb
,
940 struct sk_psock
*from
, int verdict
)
944 sk_psock_skb_redirect(from
, skb
);
953 int sk_psock_tls_strp_read(struct sk_psock
*psock
, struct sk_buff
*skb
)
955 struct bpf_prog
*prog
;
959 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
963 skb_bpf_redirect_clear(skb
);
964 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
965 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
968 sk_psock_tls_verdict_apply(skb
, psock
, ret
);
972 EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read
);
974 static int sk_psock_verdict_apply(struct sk_psock
*psock
, struct sk_buff
*skb
,
977 struct sock
*sk_other
;
984 sk_other
= psock
->sk
;
985 if (sock_flag(sk_other
, SOCK_DEAD
) ||
986 !sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
989 skb_bpf_set_ingress(skb
);
991 /* If the queue is empty then we can submit directly
992 * into the msg queue. If its not empty we have to
993 * queue work otherwise we may get OOO data. Otherwise,
994 * if sk_psock_skb_ingress errors will be handled by
995 * retrying later from workqueue.
997 if (skb_queue_empty(&psock
->ingress_skb
)) {
1000 if (skb_bpf_strparser(skb
)) {
1001 struct strp_msg
*stm
= strp_msg(skb
);
1004 len
= stm
->full_len
;
1006 err
= sk_psock_skb_ingress_self(psock
, skb
, off
, len
);
1009 spin_lock_bh(&psock
->ingress_lock
);
1010 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
1011 skb_queue_tail(&psock
->ingress_skb
, skb
);
1012 schedule_delayed_work(&psock
->work
, 0);
1015 spin_unlock_bh(&psock
->ingress_lock
);
1021 tcp_eat_skb(psock
->sk
, skb
);
1022 err
= sk_psock_skb_redirect(psock
, skb
);
1027 skb_bpf_redirect_clear(skb
);
1028 tcp_eat_skb(psock
->sk
, skb
);
1029 sock_drop(psock
->sk
, skb
);
1035 static void sk_psock_write_space(struct sock
*sk
)
1037 struct sk_psock
*psock
;
1038 void (*write_space
)(struct sock
*sk
) = NULL
;
1041 psock
= sk_psock(sk
);
1042 if (likely(psock
)) {
1043 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
))
1044 schedule_delayed_work(&psock
->work
, 0);
1045 write_space
= psock
->saved_write_space
;
1052 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1053 static void sk_psock_strp_read(struct strparser
*strp
, struct sk_buff
*skb
)
1055 struct sk_psock
*psock
;
1056 struct bpf_prog
*prog
;
1057 int ret
= __SK_DROP
;
1062 psock
= sk_psock(sk
);
1063 if (unlikely(!psock
)) {
1067 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
1071 skb_bpf_redirect_clear(skb
);
1072 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1073 skb_bpf_set_strparser(skb
);
1074 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
1077 sk_psock_verdict_apply(psock
, skb
, ret
);
1082 static int sk_psock_strp_read_done(struct strparser
*strp
, int err
)
1087 static int sk_psock_strp_parse(struct strparser
*strp
, struct sk_buff
*skb
)
1089 struct sk_psock
*psock
= container_of(strp
, struct sk_psock
, strp
);
1090 struct bpf_prog
*prog
;
1094 prog
= READ_ONCE(psock
->progs
.stream_parser
);
1096 skb
->sk
= psock
->sk
;
1097 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1104 /* Called with socket lock held. */
1105 static void sk_psock_strp_data_ready(struct sock
*sk
)
1107 struct sk_psock
*psock
;
1109 trace_sk_data_ready(sk
);
1112 psock
= sk_psock(sk
);
1113 if (likely(psock
)) {
1114 if (tls_sw_has_ctx_rx(sk
)) {
1115 psock
->saved_data_ready(sk
);
1117 write_lock_bh(&sk
->sk_callback_lock
);
1118 strp_data_ready(&psock
->strp
);
1119 write_unlock_bh(&sk
->sk_callback_lock
);
1125 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
1129 static const struct strp_callbacks cb
= {
1130 .rcv_msg
= sk_psock_strp_read
,
1131 .read_sock_done
= sk_psock_strp_read_done
,
1132 .parse_msg
= sk_psock_strp_parse
,
1135 ret
= strp_init(&psock
->strp
, sk
, &cb
);
1137 sk_psock_set_state(psock
, SK_PSOCK_RX_STRP_ENABLED
);
1142 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
1144 if (psock
->saved_data_ready
)
1147 psock
->saved_data_ready
= sk
->sk_data_ready
;
1148 sk
->sk_data_ready
= sk_psock_strp_data_ready
;
1149 sk
->sk_write_space
= sk_psock_write_space
;
1152 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
1154 psock_set_prog(&psock
->progs
.stream_parser
, NULL
);
1156 if (!psock
->saved_data_ready
)
1159 sk
->sk_data_ready
= psock
->saved_data_ready
;
1160 psock
->saved_data_ready
= NULL
;
1161 strp_stop(&psock
->strp
);
1164 static void sk_psock_done_strp(struct sk_psock
*psock
)
1166 /* Parser has been stopped */
1167 if (sk_psock_test_state(psock
, SK_PSOCK_RX_STRP_ENABLED
))
1168 strp_done(&psock
->strp
);
1171 static void sk_psock_done_strp(struct sk_psock
*psock
)
1174 #endif /* CONFIG_BPF_STREAM_PARSER */
1176 static int sk_psock_verdict_recv(struct sock
*sk
, struct sk_buff
*skb
)
1178 struct sk_psock
*psock
;
1179 struct bpf_prog
*prog
;
1180 int ret
= __SK_DROP
;
1184 psock
= sk_psock(sk
);
1185 if (unlikely(!psock
)) {
1187 tcp_eat_skb(sk
, skb
);
1191 prog
= READ_ONCE(psock
->progs
.stream_verdict
);
1193 prog
= READ_ONCE(psock
->progs
.skb_verdict
);
1196 skb_bpf_redirect_clear(skb
);
1197 ret
= bpf_prog_run_pin_on_cpu(prog
, skb
);
1198 ret
= sk_psock_map_verd(ret
, skb_bpf_redirect_fetch(skb
));
1200 ret
= sk_psock_verdict_apply(psock
, skb
, ret
);
1208 static void sk_psock_verdict_data_ready(struct sock
*sk
)
1210 struct socket
*sock
= sk
->sk_socket
;
1211 const struct proto_ops
*ops
;
1214 trace_sk_data_ready(sk
);
1216 if (unlikely(!sock
))
1218 ops
= READ_ONCE(sock
->ops
);
1219 if (!ops
|| !ops
->read_skb
)
1221 copied
= ops
->read_skb(sk
, sk_psock_verdict_recv
);
1223 struct sk_psock
*psock
;
1226 psock
= sk_psock(sk
);
1228 psock
->saved_data_ready(sk
);
1233 void sk_psock_start_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1235 if (psock
->saved_data_ready
)
1238 psock
->saved_data_ready
= sk
->sk_data_ready
;
1239 sk
->sk_data_ready
= sk_psock_verdict_data_ready
;
1240 sk
->sk_write_space
= sk_psock_write_space
;
1243 void sk_psock_stop_verdict(struct sock
*sk
, struct sk_psock
*psock
)
1245 psock_set_prog(&psock
->progs
.stream_verdict
, NULL
);
1246 psock_set_prog(&psock
->progs
.skb_verdict
, NULL
);
1248 if (!psock
->saved_data_ready
)
1251 sk
->sk_data_ready
= psock
->saved_data_ready
;
1252 psock
->saved_data_ready
= NULL
;