1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 int __tcp_bpf_recvmsg(struct sock
*sk
, struct sk_psock
*psock
,
14 struct msghdr
*msg
, int len
, int flags
)
16 struct iov_iter
*iter
= &msg
->msg_iter
;
17 int peek
= flags
& MSG_PEEK
;
18 int i
, ret
, copied
= 0;
19 struct sk_msg
*msg_rx
;
21 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
24 while (copied
!= len
) {
25 struct scatterlist
*sge
;
27 if (unlikely(!msg_rx
))
35 sge
= sk_msg_elem(msg_rx
, i
);
38 if (copied
+ copy
> len
)
40 ret
= copy_page_to_iter(page
, sge
->offset
, copy
, iter
);
50 sk_mem_uncharge(sk
, copy
);
51 msg_rx
->sg
.size
-= copy
;
54 sk_msg_iter_var_next(i
);
59 sk_msg_iter_var_next(i
);
64 } while (i
!= msg_rx
->sg
.end
);
67 msg_rx
= list_next_entry(msg_rx
, list
);
72 if (!sge
->length
&& msg_rx
->sg
.start
== msg_rx
->sg
.end
) {
73 list_del(&msg_rx
->list
);
75 consume_skb(msg_rx
->skb
);
78 msg_rx
= list_first_entry_or_null(&psock
->ingress_msg
,
84 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg
);
86 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
87 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
89 bool apply
= apply_bytes
;
90 struct scatterlist
*sge
;
95 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
100 tmp
->sg
.start
= msg
->sg
.start
;
103 sge
= sk_msg_elem(msg
, i
);
104 size
= (apply
&& apply_bytes
< sge
->length
) ?
105 apply_bytes
: sge
->length
;
106 if (!sk_wmem_schedule(sk
, size
)) {
112 sk_mem_charge(sk
, size
);
113 sk_msg_xfer(tmp
, msg
, i
, size
);
116 get_page(sk_msg_page(tmp
, i
));
117 sk_msg_iter_var_next(i
);
124 } while (i
!= msg
->sg
.end
);
128 msg
->sg
.size
-= apply_bytes
;
129 sk_psock_queue_msg(psock
, tmp
);
130 sk_psock_data_ready(sk
, psock
);
132 sk_msg_free(sk
, tmp
);
140 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
141 int flags
, bool uncharge
)
143 bool apply
= apply_bytes
;
144 struct scatterlist
*sge
;
152 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
153 size
= (apply
&& apply_bytes
< sge
->length
) ?
154 apply_bytes
: sge
->length
;
158 tcp_rate_check_app_limited(sk
);
160 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
162 flags
|= MSG_SENDPAGE_NOPOLICY
;
163 ret
= kernel_sendpage_locked(sk
,
164 page
, off
, size
, flags
);
166 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
177 sk_mem_uncharge(sk
, ret
);
185 sk_msg_iter_next(msg
, start
);
186 sg_init_table(sge
, 1);
187 if (msg
->sg
.start
== msg
->sg
.end
)
190 if (apply
&& !apply_bytes
)
197 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
198 u32 apply_bytes
, int flags
, bool uncharge
)
203 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
208 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
209 u32 bytes
, int flags
)
211 bool ingress
= sk_msg_to_ingress(msg
);
212 struct sk_psock
*psock
= sk_psock_get(sk
);
215 if (unlikely(!psock
)) {
216 sk_msg_free(sk
, msg
);
219 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
220 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
221 sk_psock_put(sk
, psock
);
224 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
226 #ifdef CONFIG_BPF_STREAM_PARSER
227 static bool tcp_bpf_stream_read(const struct sock
*sk
)
229 struct sk_psock
*psock
;
233 psock
= sk_psock(sk
);
235 empty
= list_empty(&psock
->ingress_msg
);
240 static int tcp_bpf_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
241 int flags
, long timeo
, int *err
)
243 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
249 add_wait_queue(sk_sleep(sk
), &wait
);
250 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
251 ret
= sk_wait_event(sk
, &timeo
,
252 !list_empty(&psock
->ingress_msg
) ||
253 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
254 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
255 remove_wait_queue(sk_sleep(sk
), &wait
);
259 static int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
260 int nonblock
, int flags
, int *addr_len
)
262 struct sk_psock
*psock
;
265 psock
= sk_psock_get(sk
);
266 if (unlikely(!psock
))
267 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
268 if (unlikely(flags
& MSG_ERRQUEUE
))
269 return inet_recv_error(sk
, msg
, len
, addr_len
);
270 if (!skb_queue_empty(&sk
->sk_receive_queue
) &&
271 sk_psock_queue_empty(psock
))
272 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
275 copied
= __tcp_bpf_recvmsg(sk
, psock
, msg
, len
, flags
);
280 timeo
= sock_rcvtimeo(sk
, nonblock
);
281 data
= tcp_bpf_wait_data(sk
, psock
, flags
, timeo
, &err
);
283 if (!sk_psock_queue_empty(psock
))
284 goto msg_bytes_ready
;
286 sk_psock_put(sk
, psock
);
287 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
298 sk_psock_put(sk
, psock
);
302 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
303 struct sk_msg
*msg
, int *copied
, int flags
)
305 bool cork
= false, enospc
= sk_msg_full(msg
);
306 struct sock
*sk_redir
;
307 u32 tosend
, delta
= 0;
311 if (psock
->eval
== __SK_NONE
) {
312 /* Track delta in msg size to add/subtract it on SK_DROP from
313 * returned to user copied size. This ensures user doesn't
314 * get a positive return code with msg_cut_data and SK_DROP
317 delta
= msg
->sg
.size
;
318 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
319 delta
-= msg
->sg
.size
;
322 if (msg
->cork_bytes
&&
323 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
324 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
326 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
327 GFP_ATOMIC
| __GFP_NOWARN
);
331 memcpy(psock
->cork
, msg
, sizeof(*msg
));
335 tosend
= msg
->sg
.size
;
336 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
337 tosend
= psock
->apply_bytes
;
339 switch (psock
->eval
) {
341 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
343 *copied
-= sk_msg_free(sk
, msg
);
346 sk_msg_apply_bytes(psock
, tosend
);
349 sk_redir
= psock
->sk_redir
;
350 sk_msg_apply_bytes(psock
, tosend
);
355 sk_msg_return(sk
, msg
, tosend
);
357 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
359 if (unlikely(ret
< 0)) {
360 int free
= sk_msg_free_nocharge(sk
, msg
);
366 sk_msg_free(sk
, msg
);
374 sk_msg_free_partial(sk
, msg
, tosend
);
375 sk_msg_apply_bytes(psock
, tosend
);
376 *copied
-= (tosend
+ delta
);
381 if (!psock
->apply_bytes
) {
382 psock
->eval
= __SK_NONE
;
383 if (psock
->sk_redir
) {
384 sock_put(psock
->sk_redir
);
385 psock
->sk_redir
= NULL
;
389 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
390 msg
->sg
.data
[msg
->sg
.start
].length
)
396 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
398 struct sk_msg tmp
, *msg_tx
= NULL
;
399 int copied
= 0, err
= 0;
400 struct sk_psock
*psock
;
404 /* Don't let internal do_tcp_sendpages() flags through */
405 flags
= (msg
->msg_flags
& ~MSG_SENDPAGE_DECRYPTED
);
406 flags
|= MSG_NO_SHARED_FRAGS
;
408 psock
= sk_psock_get(sk
);
409 if (unlikely(!psock
))
410 return tcp_sendmsg(sk
, msg
, size
);
413 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
414 while (msg_data_left(msg
)) {
423 copy
= msg_data_left(msg
);
424 if (!sk_stream_memory_free(sk
))
425 goto wait_for_sndbuf
;
427 msg_tx
= psock
->cork
;
433 osize
= msg_tx
->sg
.size
;
434 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
437 goto wait_for_memory
;
439 copy
= msg_tx
->sg
.size
- osize
;
442 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
445 sk_msg_trim(sk
, msg_tx
, osize
);
450 if (psock
->cork_bytes
) {
451 if (size
> psock
->cork_bytes
)
452 psock
->cork_bytes
= 0;
454 psock
->cork_bytes
-= size
;
455 if (psock
->cork_bytes
&& !enospc
)
457 /* All cork bytes are accounted, rerun the prog. */
458 psock
->eval
= __SK_NONE
;
459 psock
->cork_bytes
= 0;
462 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
463 if (unlikely(err
< 0))
467 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
469 err
= sk_stream_wait_memory(sk
, &timeo
);
471 if (msg_tx
&& msg_tx
!= psock
->cork
)
472 sk_msg_free(sk
, msg_tx
);
478 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
480 sk_psock_put(sk
, psock
);
481 return copied
? copied
: err
;
484 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
485 size_t size
, int flags
)
487 struct sk_msg tmp
, *msg
= NULL
;
488 int err
= 0, copied
= 0;
489 struct sk_psock
*psock
;
492 psock
= sk_psock_get(sk
);
493 if (unlikely(!psock
))
494 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
504 /* Catch case where ring is full and sendpage is stalled. */
505 if (unlikely(sk_msg_full(msg
)))
508 sk_msg_page_add(msg
, page
, size
, offset
);
509 sk_mem_charge(sk
, size
);
511 if (sk_msg_full(msg
))
513 if (psock
->cork_bytes
) {
514 if (size
> psock
->cork_bytes
)
515 psock
->cork_bytes
= 0;
517 psock
->cork_bytes
-= size
;
518 if (psock
->cork_bytes
&& !enospc
)
520 /* All cork bytes are accounted, rerun the prog. */
521 psock
->eval
= __SK_NONE
;
522 psock
->cork_bytes
= 0;
525 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
528 sk_psock_put(sk
, psock
);
529 return copied
? copied
: err
;
544 static struct proto
*tcpv6_prot_saved __read_mostly
;
545 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
546 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
548 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
551 prot
[TCP_BPF_BASE
] = *base
;
552 prot
[TCP_BPF_BASE
].unhash
= sock_map_unhash
;
553 prot
[TCP_BPF_BASE
].close
= sock_map_close
;
554 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
555 prot
[TCP_BPF_BASE
].stream_memory_read
= tcp_bpf_stream_read
;
557 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
558 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
559 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
562 static void tcp_bpf_check_v6_needs_rebuild(struct sock
*sk
, struct proto
*ops
)
564 if (sk
->sk_family
== AF_INET6
&&
565 unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
566 spin_lock_bh(&tcpv6_prot_lock
);
567 if (likely(ops
!= tcpv6_prot_saved
)) {
568 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
569 smp_store_release(&tcpv6_prot_saved
, ops
);
571 spin_unlock_bh(&tcpv6_prot_lock
);
575 static int __init
tcp_bpf_v4_build_proto(void)
577 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
580 core_initcall(tcp_bpf_v4_build_proto
);
582 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
584 /* In order to avoid retpoline, we make assumptions when we call
585 * into ops if e.g. a psock is not present. Make sure they are
586 * indeed valid assumptions.
588 return ops
->recvmsg
== tcp_recvmsg
&&
589 ops
->sendmsg
== tcp_sendmsg
&&
590 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
593 struct proto
*tcp_bpf_get_proto(struct sock
*sk
, struct sk_psock
*psock
)
595 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
596 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
598 if (!psock
->sk_proto
) {
599 struct proto
*ops
= READ_ONCE(sk
->sk_prot
);
601 if (tcp_bpf_assert_proto_ops(ops
))
602 return ERR_PTR(-EINVAL
);
604 tcp_bpf_check_v6_needs_rebuild(sk
, ops
);
607 return &tcp_bpf_prots
[family
][config
];
610 /* If a child got cloned from a listening socket that had tcp_bpf
611 * protocol callbacks installed, we need to restore the callbacks to
612 * the default ones because the child does not inherit the psock state
613 * that tcp_bpf callbacks expect.
615 void tcp_bpf_clone(const struct sock
*sk
, struct sock
*newsk
)
617 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
618 struct proto
*prot
= newsk
->sk_prot
;
620 if (prot
== &tcp_bpf_prots
[family
][TCP_BPF_BASE
])
621 newsk
->sk_prot
= sk
->sk_prot_creator
;
623 #endif /* CONFIG_BPF_STREAM_PARSER */