1 /* SPDX-License-Identifier: GPL-2.0 */
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
8 #include <linux/filter.h>
9 #include <linux/scatterlist.h>
10 #include <linux/skbuff.h>
14 #include <net/strparser.h>
16 #define MAX_MSG_FRAGS MAX_SKB_FRAGS
17 #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
32 DECLARE_BITMAP(copy
, MAX_MSG_FRAGS
+ 2);
33 /* The extra two elements:
34 * 1) used for chaining the front and sections when the list becomes
35 * partitioned (e.g. end < start). The crypto APIs require the
37 * 2) to chain tailer SG entries after the message.
39 struct scatterlist data
[MAX_MSG_FRAGS
+ 2];
42 /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
51 struct sock
*sk_redir
;
53 struct list_head list
;
56 struct sk_psock_progs
{
57 struct bpf_prog
*msg_parser
;
58 struct bpf_prog
*stream_parser
;
59 struct bpf_prog
*stream_verdict
;
60 struct bpf_prog
*skb_verdict
;
61 struct bpf_link
*msg_parser_link
;
62 struct bpf_link
*stream_parser_link
;
63 struct bpf_link
*stream_verdict_link
;
64 struct bpf_link
*skb_verdict_link
;
67 enum sk_psock_state_bits
{
69 SK_PSOCK_RX_STRP_ENABLED
,
72 struct sk_psock_link
{
73 struct list_head list
;
78 struct sk_psock_work_state
{
85 struct sock
*sk_redir
;
89 bool redir_ingress
; /* undefined if sk_redir is null */
91 struct sk_psock_progs progs
;
92 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
93 struct strparser strp
;
97 struct sk_buff_head ingress_skb
;
98 struct list_head ingress_msg
;
99 spinlock_t ingress_lock
;
101 struct list_head link
;
102 spinlock_t link_lock
;
104 void (*saved_unhash
)(struct sock
*sk
);
105 void (*saved_destroy
)(struct sock
*sk
);
106 void (*saved_close
)(struct sock
*sk
, long timeout
);
107 void (*saved_write_space
)(struct sock
*sk
);
108 void (*saved_data_ready
)(struct sock
*sk
);
109 /* psock_update_sk_prot may be called with restore=false many times
110 * so the handler must be safe for this case. It will be called
111 * exactly once with restore=true when the psock is being destroyed
112 * and psock refcnt is zero, but before an RCU grace period.
114 int (*psock_update_sk_prot
)(struct sock
*sk
, struct sk_psock
*psock
,
116 struct proto
*sk_proto
;
117 struct mutex work_mutex
;
118 struct sk_psock_work_state work_state
;
119 struct delayed_work work
;
120 struct sock
*sk_pair
;
121 struct rcu_work rwork
;
124 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
125 int elem_first_coalesce
);
126 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
128 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
);
129 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
);
130 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
);
131 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
);
132 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
135 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
136 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
138 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
139 struct sk_msg
*msg
, u32 bytes
);
140 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
141 struct sk_msg
*msg
, u32 bytes
);
142 int sk_msg_recvmsg(struct sock
*sk
, struct sk_psock
*psock
, struct msghdr
*msg
,
144 bool sk_msg_is_readable(struct sock
*sk
);
146 static inline void sk_msg_check_to_free(struct sk_msg
*msg
, u32 i
, u32 bytes
)
148 WARN_ON(i
== msg
->sg
.end
&& bytes
);
151 static inline void sk_msg_apply_bytes(struct sk_psock
*psock
, u32 bytes
)
153 if (psock
->apply_bytes
) {
154 if (psock
->apply_bytes
< bytes
)
155 psock
->apply_bytes
= 0;
157 psock
->apply_bytes
-= bytes
;
161 static inline u32
sk_msg_iter_dist(u32 start
, u32 end
)
163 return end
>= start
? end
- start
: end
+ (NR_MSG_FRAG_IDS
- start
);
166 #define sk_msg_iter_var_prev(var) \
169 var = NR_MSG_FRAG_IDS - 1; \
174 #define sk_msg_iter_var_next(var) \
177 if (var == NR_MSG_FRAG_IDS) \
181 #define sk_msg_iter_prev(msg, which) \
182 sk_msg_iter_var_prev(msg->sg.which)
184 #define sk_msg_iter_next(msg, which) \
185 sk_msg_iter_var_next(msg->sg.which)
187 static inline void sk_msg_init(struct sk_msg
*msg
)
189 BUILD_BUG_ON(ARRAY_SIZE(msg
->sg
.data
) - 1 != NR_MSG_FRAG_IDS
);
190 memset(msg
, 0, sizeof(*msg
));
191 sg_init_marker(msg
->sg
.data
, NR_MSG_FRAG_IDS
);
194 static inline void sk_msg_xfer(struct sk_msg
*dst
, struct sk_msg
*src
,
197 dst
->sg
.data
[which
] = src
->sg
.data
[which
];
198 dst
->sg
.data
[which
].length
= size
;
199 dst
->sg
.size
+= size
;
200 src
->sg
.size
-= size
;
201 src
->sg
.data
[which
].length
-= size
;
202 src
->sg
.data
[which
].offset
+= size
;
205 static inline void sk_msg_xfer_full(struct sk_msg
*dst
, struct sk_msg
*src
)
207 memcpy(dst
, src
, sizeof(*src
));
211 static inline bool sk_msg_full(const struct sk_msg
*msg
)
213 return sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
) == MAX_MSG_FRAGS
;
216 static inline u32
sk_msg_elem_used(const struct sk_msg
*msg
)
218 return sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
);
221 static inline struct scatterlist
*sk_msg_elem(struct sk_msg
*msg
, int which
)
223 return &msg
->sg
.data
[which
];
226 static inline struct scatterlist
sk_msg_elem_cpy(struct sk_msg
*msg
, int which
)
228 return msg
->sg
.data
[which
];
231 static inline struct page
*sk_msg_page(struct sk_msg
*msg
, int which
)
233 return sg_page(sk_msg_elem(msg
, which
));
236 static inline bool sk_msg_to_ingress(const struct sk_msg
*msg
)
238 return msg
->flags
& BPF_F_INGRESS
;
241 static inline void sk_msg_compute_data_pointers(struct sk_msg
*msg
)
243 struct scatterlist
*sge
= sk_msg_elem(msg
, msg
->sg
.start
);
245 if (test_bit(msg
->sg
.start
, msg
->sg
.copy
)) {
247 msg
->data_end
= NULL
;
249 msg
->data
= sg_virt(sge
);
250 msg
->data_end
= msg
->data
+ sge
->length
;
254 static inline void sk_msg_page_add(struct sk_msg
*msg
, struct page
*page
,
257 struct scatterlist
*sge
;
260 sge
= sk_msg_elem(msg
, msg
->sg
.end
);
261 sg_set_page(sge
, page
, len
, offset
);
264 __set_bit(msg
->sg
.end
, msg
->sg
.copy
);
266 sk_msg_iter_next(msg
, end
);
269 static inline void sk_msg_sg_copy(struct sk_msg
*msg
, u32 i
, bool copy_state
)
273 __set_bit(i
, msg
->sg
.copy
);
275 __clear_bit(i
, msg
->sg
.copy
);
276 sk_msg_iter_var_next(i
);
277 if (i
== msg
->sg
.end
)
282 static inline void sk_msg_sg_copy_set(struct sk_msg
*msg
, u32 start
)
284 sk_msg_sg_copy(msg
, start
, true);
287 static inline void sk_msg_sg_copy_clear(struct sk_msg
*msg
, u32 start
)
289 sk_msg_sg_copy(msg
, start
, false);
292 static inline struct sk_psock
*sk_psock(const struct sock
*sk
)
294 return __rcu_dereference_sk_user_data_with_flags(sk
,
298 static inline void sk_psock_set_state(struct sk_psock
*psock
,
299 enum sk_psock_state_bits bit
)
301 set_bit(bit
, &psock
->state
);
304 static inline void sk_psock_clear_state(struct sk_psock
*psock
,
305 enum sk_psock_state_bits bit
)
307 clear_bit(bit
, &psock
->state
);
310 static inline bool sk_psock_test_state(const struct sk_psock
*psock
,
311 enum sk_psock_state_bits bit
)
313 return test_bit(bit
, &psock
->state
);
316 static inline void sock_drop(struct sock
*sk
, struct sk_buff
*skb
)
318 sk_drops_add(sk
, skb
);
322 static inline bool sk_psock_queue_msg(struct sk_psock
*psock
,
327 spin_lock_bh(&psock
->ingress_lock
);
328 if (sk_psock_test_state(psock
, SK_PSOCK_TX_ENABLED
)) {
329 list_add_tail(&msg
->list
, &psock
->ingress_msg
);
332 sk_msg_free(psock
->sk
, msg
);
336 spin_unlock_bh(&psock
->ingress_lock
);
340 static inline struct sk_msg
*sk_psock_dequeue_msg(struct sk_psock
*psock
)
344 spin_lock_bh(&psock
->ingress_lock
);
345 msg
= list_first_entry_or_null(&psock
->ingress_msg
, struct sk_msg
, list
);
347 list_del(&msg
->list
);
348 spin_unlock_bh(&psock
->ingress_lock
);
352 static inline struct sk_msg
*sk_psock_peek_msg(struct sk_psock
*psock
)
356 spin_lock_bh(&psock
->ingress_lock
);
357 msg
= list_first_entry_or_null(&psock
->ingress_msg
, struct sk_msg
, list
);
358 spin_unlock_bh(&psock
->ingress_lock
);
362 static inline struct sk_msg
*sk_psock_next_msg(struct sk_psock
*psock
,
367 spin_lock_bh(&psock
->ingress_lock
);
368 if (list_is_last(&msg
->list
, &psock
->ingress_msg
))
371 ret
= list_next_entry(msg
, list
);
372 spin_unlock_bh(&psock
->ingress_lock
);
376 static inline bool sk_psock_queue_empty(const struct sk_psock
*psock
)
378 return psock
? list_empty(&psock
->ingress_msg
) : true;
381 static inline void kfree_sk_msg(struct sk_msg
*msg
)
384 consume_skb(msg
->skb
);
388 static inline void sk_psock_report_error(struct sk_psock
*psock
, int err
)
390 struct sock
*sk
= psock
->sk
;
396 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
);
397 void sk_psock_stop(struct sk_psock
*psock
);
399 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
400 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
);
401 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
);
402 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
);
404 static inline int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
)
409 static inline void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
)
413 static inline void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
)
418 void sk_psock_start_verdict(struct sock
*sk
, struct sk_psock
*psock
);
419 void sk_psock_stop_verdict(struct sock
*sk
, struct sk_psock
*psock
);
421 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
425 * This specialized allocator has to be a macro for its allocations to be
426 * accounted separately (to have a separate alloc_tag). The typecast is
427 * intentional to enforce typesafety.
429 #define sk_psock_init_link() \
430 ((struct sk_psock_link *)kzalloc(sizeof(struct sk_psock_link), \
431 GFP_ATOMIC | __GFP_NOWARN))
433 static inline void sk_psock_free_link(struct sk_psock_link
*link
)
438 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
);
440 static inline void sk_psock_cork_free(struct sk_psock
*psock
)
443 sk_msg_free(psock
->sk
, psock
->cork
);
449 static inline void sk_psock_restore_proto(struct sock
*sk
,
450 struct sk_psock
*psock
)
452 if (psock
->psock_update_sk_prot
)
453 psock
->psock_update_sk_prot(sk
, psock
, true);
456 static inline struct sk_psock
*sk_psock_get(struct sock
*sk
)
458 struct sk_psock
*psock
;
461 psock
= sk_psock(sk
);
462 if (psock
&& !refcount_inc_not_zero(&psock
->refcnt
))
468 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
);
470 static inline void sk_psock_put(struct sock
*sk
, struct sk_psock
*psock
)
472 if (refcount_dec_and_test(&psock
->refcnt
))
473 sk_psock_drop(sk
, psock
);
476 static inline void sk_psock_data_ready(struct sock
*sk
, struct sk_psock
*psock
)
478 read_lock_bh(&sk
->sk_callback_lock
);
479 if (psock
->saved_data_ready
)
480 psock
->saved_data_ready(sk
);
482 sk
->sk_data_ready(sk
);
483 read_unlock_bh(&sk
->sk_callback_lock
);
486 static inline void psock_set_prog(struct bpf_prog
**pprog
,
487 struct bpf_prog
*prog
)
489 prog
= xchg(pprog
, prog
);
494 static inline int psock_replace_prog(struct bpf_prog
**pprog
,
495 struct bpf_prog
*prog
,
496 struct bpf_prog
*old
)
498 if (cmpxchg(pprog
, old
, prog
) != old
)
507 static inline void psock_progs_drop(struct sk_psock_progs
*progs
)
509 psock_set_prog(&progs
->msg_parser
, NULL
);
510 psock_set_prog(&progs
->stream_parser
, NULL
);
511 psock_set_prog(&progs
->stream_verdict
, NULL
);
512 psock_set_prog(&progs
->skb_verdict
, NULL
);
515 int sk_psock_tls_strp_read(struct sk_psock
*psock
, struct sk_buff
*skb
);
517 static inline bool sk_psock_strp_enabled(struct sk_psock
*psock
)
521 return !!psock
->saved_data_ready
;
524 #if IS_ENABLED(CONFIG_NET_SOCK_MSG)
526 #define BPF_F_STRPARSER (1UL << 1)
528 /* We only have two bits so far. */
529 #define BPF_F_PTR_MASK ~(BPF_F_INGRESS | BPF_F_STRPARSER)
531 static inline bool skb_bpf_strparser(const struct sk_buff
*skb
)
533 unsigned long sk_redir
= skb
->_sk_redir
;
535 return sk_redir
& BPF_F_STRPARSER
;
538 static inline void skb_bpf_set_strparser(struct sk_buff
*skb
)
540 skb
->_sk_redir
|= BPF_F_STRPARSER
;
543 static inline bool skb_bpf_ingress(const struct sk_buff
*skb
)
545 unsigned long sk_redir
= skb
->_sk_redir
;
547 return sk_redir
& BPF_F_INGRESS
;
550 static inline void skb_bpf_set_ingress(struct sk_buff
*skb
)
552 skb
->_sk_redir
|= BPF_F_INGRESS
;
555 static inline void skb_bpf_set_redir(struct sk_buff
*skb
, struct sock
*sk_redir
,
558 skb
->_sk_redir
= (unsigned long)sk_redir
;
560 skb
->_sk_redir
|= BPF_F_INGRESS
;
563 static inline struct sock
*skb_bpf_redirect_fetch(const struct sk_buff
*skb
)
565 unsigned long sk_redir
= skb
->_sk_redir
;
567 return (struct sock
*)(sk_redir
& BPF_F_PTR_MASK
);
570 static inline void skb_bpf_redirect_clear(struct sk_buff
*skb
)
574 #endif /* CONFIG_NET_SOCK_MSG */
575 #endif /* _LINUX_SKMSG_H */