1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 static bool tcp_bpf_stream_read(const struct sock
*sk
)
15 struct sk_psock
*psock
;
21 empty
= list_empty(&psock
->ingress_msg
);
26 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
27 int flags
, long timeo
, int *err
)
29 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
32 add_wait_queue(sk_sleep(sk
), &wait
);
33 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
34 ret
= sk_wait_event(sk
, &timeo
,
35 !list_empty(&psock
->ingress_msg
) ||
36 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
37 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
38 remove_wait_queue(sk_sleep(sk
), &wait
);
42 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
43 struct msghdr
*msg
, int len
, int flags
)
45 struct iov_iter
*iter
= &msg
->msg_iter
;
46 int peek
= flags
& MSG_PEEK
;
47 int i
, ret
, copied
= 0;
48 struct sk_msg
*msg_rx
;
50 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
53 while (copied
!= len
) {
54 struct scatterlist
*sge
;
56 if (unlikely(!msg_rx
))
64 sge
= sk_msg_elem(msg_rx
, i
);
67 if (copied
+ copy
> len
)
69 ret
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
79 sk_mem_uncharge(sk
, copy
);
80 msg_rx
->sg
.size
-= copy
;
83 sk_msg_iter_var_next(i
);
88 sk_msg_iter_var_next(i
);
93 } while (i
!= msg_rx
->sg
.end
);
96 msg_rx
= list_next_entry(msg_rx
, list
);
100 msg_rx
->sg
.start
= i
;
101 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
102 list_del(&msg_rx
->list
);
104 consume_skb(msg_rx
->skb
);
107 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
108 struct sk_msg
, list
);
113 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
115 int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
116 int nonblock
, int flags
, int *addr_len
)
118 struct sk_psock
*psock
;
121 if (unlikely(flags
& MSG_ERRQUEUE
))
122 return inet_recv_error(sk
, msg
, len
, addr_len
);
123 if (!skb_queue_empty(&sk
->sk_receive_queue
))
124 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
126 psock
= sk_psock_get(sk
);
127 if (unlikely(!psock
))
128 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
131 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
136 timeo
= sock_rcvtimeo(sk
, nonblock
);
137 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
139 if (skb_queue_empty(&sk
->sk_receive_queue
))
140 goto msg_bytes_ready
;
142 sk_psock_put(sk
, psock
);
143 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
154 sk_psock_put(sk
, psock
);
158 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
159 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
161 bool apply
= apply_bytes
;
162 struct scatterlist
*sge
;
163 u32 size
, copied
= 0;
167 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
172 tmp
->sg
.start
= msg
->sg
.start
;
175 sge
= sk_msg_elem(msg
, i
);
176 size
= (apply
&& apply_bytes
< sge
->length
) ?
177 apply_bytes
: sge
->length
;
178 if (!sk_wmem_schedule(sk
, size
)) {
184 sk_mem_charge(sk
, size
);
185 sk_msg_xfer(tmp
, msg
, i
, size
);
188 get_page(sk_msg_page(tmp
, i
));
189 sk_msg_iter_var_next(i
);
196 } while (i
!= msg
->sg
.end
);
200 msg
->sg
.size
-= apply_bytes
;
201 sk_psock_queue_msg(psock
, tmp
);
202 sk_psock_data_ready(sk
, psock
);
204 sk_msg_free(sk
, tmp
);
212 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
213 int flags
, bool uncharge
)
215 bool apply
= apply_bytes
;
216 struct scatterlist
*sge
;
224 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
225 size
= (apply
&& apply_bytes
< sge
->length
) ?
226 apply_bytes
: sge
->length
;
230 tcp_rate_check_app_limited(sk
);
232 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
234 flags
|= MSG_SENDPAGE_NOPOLICY
;
235 ret
= kernel_sendpage_locked(sk
,
236 page
, off
, size
, flags
);
238 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
249 sk_mem_uncharge(sk
, ret
);
257 sk_msg_iter_next(msg
, start
);
258 sg_init_table(sge
, 1);
259 if (msg
->sg
.start
== msg
->sg
.end
)
262 if (apply
&& !apply_bytes
)
269 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
270 u32 apply_bytes
, int flags
, bool uncharge
)
275 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
280 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
281 u32 bytes
, int flags
)
283 bool ingress
= sk_msg_to_ingress(msg
);
284 struct sk_psock
*psock
= sk_psock_get(sk
);
287 if (unlikely(!psock
)) {
288 sk_msg_free(sk
, msg
);
291 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
292 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
293 sk_psock_put(sk
, psock
);
296 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
298 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
299 struct sk_msg
*msg
, int *copied
, int flags
)
301 bool cork
= false, enospc
= msg
->sg
.start
== msg
->sg
.end
;
302 struct sock
*sk_redir
;
303 u32 tosend
, delta
= 0;
307 if (psock
->eval
== __SK_NONE
) {
308 /* Track delta in msg size to add/subtract it on SK_DROP from
309 * returned to user copied size. This ensures user doesn't
310 * get a positive return code with msg_cut_data and SK_DROP
313 delta
= msg
->sg
.size
;
314 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
315 if (msg
->sg
.size
< delta
)
316 delta
-= msg
->sg
.size
;
321 if (msg
->cork_bytes
&&
322 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
323 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
325 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
326 GFP_ATOMIC
| __GFP_NOWARN
);
330 memcpy(psock
->cork
, msg
, sizeof(*msg
));
334 tosend
= msg
->sg
.size
;
335 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
336 tosend
= psock
->apply_bytes
;
338 switch (psock
->eval
) {
340 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
342 *copied
-= sk_msg_free(sk
, msg
);
345 sk_msg_apply_bytes(psock
, tosend
);
348 sk_redir
= psock
->sk_redir
;
349 sk_msg_apply_bytes(psock
, tosend
);
354 sk_msg_return(sk
, msg
, tosend
);
356 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
358 if (unlikely(ret
< 0)) {
359 int free
= sk_msg_free_nocharge(sk
, msg
);
365 sk_msg_free(sk
, msg
);
373 sk_msg_free_partial(sk
, msg
, tosend
);
374 sk_msg_apply_bytes(psock
, tosend
);
375 *copied
-= (tosend
+ delta
);
380 if (!psock
->apply_bytes
) {
381 psock
->eval
= __SK_NONE
;
382 if (psock
->sk_redir
) {
383 sock_put(psock
->sk_redir
);
384 psock
->sk_redir
= NULL
;
388 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
389 msg
->sg
.data
[msg
->sg
.start
].length
)
395 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
397 struct sk_msg tmp
, *msg_tx
= NULL
;
398 int flags
= msg
->msg_flags
| MSG_NO_SHARED_FRAGS
;
399 int copied
= 0, err
= 0;
400 struct sk_psock
*psock
;
403 psock
= sk_psock_get(sk
);
404 if (unlikely(!psock
))
405 return tcp_sendmsg(sk
, msg
, size
);
408 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
409 while (msg_data_left(msg
)) {
418 copy
= msg_data_left(msg
);
419 if (!sk_stream_memory_free(sk
))
420 goto wait_for_sndbuf
;
422 msg_tx
= psock
->cork
;
428 osize
= msg_tx
->sg
.size
;
429 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
432 goto wait_for_memory
;
434 copy
= msg_tx
->sg
.size
- osize
;
437 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
440 sk_msg_trim(sk
, msg_tx
, osize
);
445 if (psock
->cork_bytes
) {
446 if (size
> psock
->cork_bytes
)
447 psock
->cork_bytes
= 0;
449 psock
->cork_bytes
-= size
;
450 if (psock
->cork_bytes
&& !enospc
)
452 /* All cork bytes are accounted, rerun the prog. */
453 psock
->eval
= __SK_NONE
;
454 psock
->cork_bytes
= 0;
457 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
458 if (unlikely(err
< 0))
462 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
464 err
= sk_stream_wait_memory(sk
, &timeo
);
466 if (msg_tx
&& msg_tx
!= psock
->cork
)
467 sk_msg_free(sk
, msg_tx
);
473 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
475 sk_psock_put(sk
, psock
);
476 return copied
? copied
: err
;
479 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
480 size_t size
, int flags
)
482 struct sk_msg tmp
, *msg
= NULL
;
483 int err
= 0, copied
= 0;
484 struct sk_psock
*psock
;
487 psock
= sk_psock_get(sk
);
488 if (unlikely(!psock
))
489 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
499 /* Catch case where ring is full and sendpage is stalled. */
500 if (unlikely(sk_msg_full(msg
)))
503 sk_msg_page_add(msg
, page
, size
, offset
);
504 sk_mem_charge(sk
, size
);
506 if (sk_msg_full(msg
))
508 if (psock
->cork_bytes
) {
509 if (size
> psock
->cork_bytes
)
510 psock
->cork_bytes
= 0;
512 psock
->cork_bytes
-= size
;
513 if (psock
->cork_bytes
&& !enospc
)
515 /* All cork bytes are accounted, rerun the prog. */
516 psock
->eval
= __SK_NONE
;
517 psock
->cork_bytes
= 0;
520 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
523 sk_psock_put(sk
, psock
);
524 return copied
? copied
: err
;
527 static void tcp_bpf_remove(struct sock
*sk
, struct sk_psock
*psock
)
529 struct sk_psock_link
*link
;
531 while ((link
= sk_psock_link_pop(psock
))) {
532 sk_psock_unlink(sk
, link
);
533 sk_psock_free_link(link
);
537 static void tcp_bpf_unhash(struct sock
*sk
)
539 void (*saved_unhash
)(struct sock
*sk
);
540 struct sk_psock
*psock
;
543 psock
= sk_psock(sk
);
544 if (unlikely(!psock
)) {
546 if (sk
->sk_prot
->unhash
)
547 sk
->sk_prot
->unhash(sk
);
551 saved_unhash
= psock
->saved_unhash
;
552 tcp_bpf_remove(sk
, psock
);
557 static void tcp_bpf_close(struct sock
*sk
, long timeout
)
559 void (*saved_close
)(struct sock
*sk
, long timeout
);
560 struct sk_psock
*psock
;
564 psock
= sk_psock(sk
);
565 if (unlikely(!psock
)) {
568 return sk
->sk_prot
->close(sk
, timeout
);
571 saved_close
= psock
->saved_close
;
572 tcp_bpf_remove(sk
, psock
);
575 saved_close(sk
, timeout
);
590 static struct proto
*tcpv6_prot_saved __read_mostly
;
591 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
592 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
594 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
597 prot
[TCP_BPF_BASE
] = *base
;
598 prot
[TCP_BPF_BASE
].unhash
= tcp_bpf_unhash
;
599 prot
[TCP_BPF_BASE
].close
= tcp_bpf_close
;
600 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
601 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
603 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
604 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
605 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
608 static void tcp_bpf_check_v6_needs_rebuild(struct sock
*sk
, struct proto
*ops
)
610 if (sk
->sk_family
== AF_INET6
&&
611 unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
612 spin_lock_bh(&tcpv6_prot_lock
);
613 if (likely(ops
!= tcpv6_prot_saved
)) {
614 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
615 smp_store_release(&tcpv6_prot_saved
, ops
);
617 spin_unlock_bh(&tcpv6_prot_lock
);
621 static int __init
tcp_bpf_v4_build_proto(void)
623 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
626 core_initcall(tcp_bpf_v4_build_proto
);
628 static void tcp_bpf_update_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
630 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
631 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
633 sk_psock_update_proto(sk
, psock
, &tcp_bpf_prots
[family
][config
]);
636 static void tcp_bpf_reinit_sk_prot(struct sock
*sk
, struct sk_psock
*psock
)
638 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
639 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
641 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
642 * or added requiring sk_prot hook updates. We keep original saved
643 * hooks in this case.
645 sk
->sk_prot
= &tcp_bpf_prots
[family
][config
];
648 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
650 /* In order to avoid retpoline, we make assumptions when we call
651 * into ops if e.g. a psock is not present. Make sure they are
652 * indeed valid assumptions.
654 return ops
->recvmsg
== tcp_recvmsg
&&
655 ops
->sendmsg
== tcp_sendmsg
&&
656 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
659 void tcp_bpf_reinit(struct sock
*sk
)
661 struct sk_psock
*psock
;
663 sock_owned_by_me(sk
);
666 psock
= sk_psock(sk
);
667 tcp_bpf_reinit_sk_prot(sk
, psock
);
671 int tcp_bpf_init(struct sock
*sk
)
673 struct proto
*ops
= READ_ONCE(sk
->sk_prot
);
674 struct sk_psock
*psock
;
676 sock_owned_by_me(sk
);
679 psock
= sk_psock(sk
);
680 if (unlikely(!psock
|| psock
->sk_proto
||
681 tcp_bpf_assert_proto_ops(ops
))) {
685 tcp_bpf_check_v6_needs_rebuild(sk
, ops
);
686 tcp_bpf_update_sk_prot(sk
, psock
);