1 /* SPDX-License-Identifier: GPL-2.0 */
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
8 #include <linux/filter.h>
9 #include <linux/scatterlist.h>
10 #include <linux/skbuff.h>
14 #include <net/strparser.h>
16 #define MAX_MSG_FRAGS MAX_SKB_FRAGS
17 #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
33 /* The extra two elements:
34 * 1) used for chaining the front and sections when the list becomes
35 * partitioned (e.g. end < start). The crypto APIs require the
37 * 2) to chain tailer SG entries after the message.
39 struct scatterlist data
[MAX_MSG_FRAGS
+ 2];
41 static_assert(BITS_PER_LONG
>= NR_MSG_FRAG_IDS
);
43 /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
52 struct sock
*sk_redir
;
54 struct list_head list
;
57 struct sk_psock_progs
{
58 struct bpf_prog
*msg_parser
;
59 struct bpf_prog
*skb_parser
;
60 struct bpf_prog
*skb_verdict
;
63 enum sk_psock_state_bits
{
67 struct sk_psock_link
{
68 struct list_head list
;
73 struct sk_psock_parser
{
74 struct strparser strp
;
76 void (*saved_data_ready
)(struct sock
*sk
);
79 struct sk_psock_work_state
{
87 struct sock
*sk_redir
;
92 struct sk_psock_progs progs
;
93 struct sk_psock_parser parser
;
94 struct sk_buff_head ingress_skb
;
95 struct list_head ingress_msg
;
97 struct list_head link
;
100 void (*saved_unhash
)(struct sock
*sk
);
101 void (*saved_close
)(struct sock
*sk
, long timeout
);
102 void (*saved_write_space
)(struct sock
*sk
);
103 struct proto
*sk_proto
;
104 struct sk_psock_work_state work_state
;
105 struct work_struct work
;
108 struct work_struct gc
;
112 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
113 int elem_first_coalesce
);
114 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
116 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
);
117 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
);
118 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
);
119 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
);
120 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
123 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
124 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
126 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
127 struct sk_msg
*msg
, u32 bytes
);
128 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
129 struct sk_msg
*msg
, u32 bytes
);
131 static inline void sk_msg_check_to_free(struct sk_msg
*msg
, u32 i
, u32 bytes
)
133 WARN_ON(i
== msg
->sg
.end
&& bytes
);
136 static inline void sk_msg_apply_bytes(struct sk_psock
*psock
, u32 bytes
)
138 if (psock
->apply_bytes
) {
139 if (psock
->apply_bytes
< bytes
)
140 psock
->apply_bytes
= 0;
142 psock
->apply_bytes
-= bytes
;
146 static inline u32
sk_msg_iter_dist(u32 start
, u32 end
)
148 return end
>= start
? end
- start
: end
+ (NR_MSG_FRAG_IDS
- start
);
151 #define sk_msg_iter_var_prev(var) \
154 var = NR_MSG_FRAG_IDS - 1; \
159 #define sk_msg_iter_var_next(var) \
162 if (var == NR_MSG_FRAG_IDS) \
166 #define sk_msg_iter_prev(msg, which) \
167 sk_msg_iter_var_prev(msg->sg.which)
169 #define sk_msg_iter_next(msg, which) \
170 sk_msg_iter_var_next(msg->sg.which)
172 static inline void sk_msg_clear_meta(struct sk_msg
*msg
)
174 memset(&msg
->sg
, 0, offsetofend(struct sk_msg_sg
, copy
));
177 static inline void sk_msg_init(struct sk_msg
*msg
)
179 BUILD_BUG_ON(ARRAY_SIZE(msg
->sg
.data
) - 1 != NR_MSG_FRAG_IDS
);
180 memset(msg
, 0, sizeof(*msg
));
181 sg_init_marker(msg
->sg
.data
, NR_MSG_FRAG_IDS
);
184 static inline void sk_msg_xfer(struct sk_msg
*dst
, struct sk_msg
*src
,
187 dst
->sg
.data
[which
] = src
->sg
.data
[which
];
188 dst
->sg
.data
[which
].length
= size
;
189 dst
->sg
.size
+= size
;
190 src
->sg
.data
[which
].length
-= size
;
191 src
->sg
.data
[which
].offset
+= size
;
194 static inline void sk_msg_xfer_full(struct sk_msg
*dst
, struct sk_msg
*src
)
196 memcpy(dst
, src
, sizeof(*src
));
200 static inline bool sk_msg_full(const struct sk_msg
*msg
)
202 return sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
) == MAX_MSG_FRAGS
;
205 static inline u32
sk_msg_elem_used(const struct sk_msg
*msg
)
207 return sk_msg_iter_dist(msg
->sg
.start
, msg
->sg
.end
);
210 static inline struct scatterlist
*sk_msg_elem(struct sk_msg
*msg
, int which
)
212 return &msg
->sg
.data
[which
];
215 static inline struct scatterlist
sk_msg_elem_cpy(struct sk_msg
*msg
, int which
)
217 return msg
->sg
.data
[which
];
220 static inline struct page
*sk_msg_page(struct sk_msg
*msg
, int which
)
222 return sg_page(sk_msg_elem(msg
, which
));
225 static inline bool sk_msg_to_ingress(const struct sk_msg
*msg
)
227 return msg
->flags
& BPF_F_INGRESS
;
230 static inline void sk_msg_compute_data_pointers(struct sk_msg
*msg
)
232 struct scatterlist
*sge
= sk_msg_elem(msg
, msg
->sg
.start
);
234 if (test_bit(msg
->sg
.start
, &msg
->sg
.copy
)) {
236 msg
->data_end
= NULL
;
238 msg
->data
= sg_virt(sge
);
239 msg
->data_end
= msg
->data
+ sge
->length
;
243 static inline void sk_msg_page_add(struct sk_msg
*msg
, struct page
*page
,
246 struct scatterlist
*sge
;
249 sge
= sk_msg_elem(msg
, msg
->sg
.end
);
250 sg_set_page(sge
, page
, len
, offset
);
253 __set_bit(msg
->sg
.end
, &msg
->sg
.copy
);
255 sk_msg_iter_next(msg
, end
);
258 static inline void sk_msg_sg_copy(struct sk_msg
*msg
, u32 i
, bool copy_state
)
262 __set_bit(i
, &msg
->sg
.copy
);
264 __clear_bit(i
, &msg
->sg
.copy
);
265 sk_msg_iter_var_next(i
);
266 if (i
== msg
->sg
.end
)
271 static inline void sk_msg_sg_copy_set(struct sk_msg
*msg
, u32 start
)
273 sk_msg_sg_copy(msg
, start
, true);
276 static inline void sk_msg_sg_copy_clear(struct sk_msg
*msg
, u32 start
)
278 sk_msg_sg_copy(msg
, start
, false);
281 static inline struct sk_psock
*sk_psock(const struct sock
*sk
)
283 return rcu_dereference_sk_user_data(sk
);
286 static inline void sk_psock_queue_msg(struct sk_psock
*psock
,
289 list_add_tail(&msg
->list
, &psock
->ingress_msg
);
292 static inline bool sk_psock_queue_empty(const struct sk_psock
*psock
)
294 return psock
? list_empty(&psock
->ingress_msg
) : true;
297 static inline void sk_psock_report_error(struct sk_psock
*psock
, int err
)
299 struct sock
*sk
= psock
->sk
;
302 sk
->sk_error_report(sk
);
305 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
);
307 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
);
308 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
);
309 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
);
311 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
314 static inline struct sk_psock_link
*sk_psock_init_link(void)
316 return kzalloc(sizeof(struct sk_psock_link
),
317 GFP_ATOMIC
| __GFP_NOWARN
);
320 static inline void sk_psock_free_link(struct sk_psock_link
*link
)
325 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
);
327 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
);
329 static inline void sk_psock_cork_free(struct sk_psock
*psock
)
332 sk_msg_free(psock
->sk
, psock
->cork
);
338 static inline void sk_psock_update_proto(struct sock
*sk
,
339 struct sk_psock
*psock
,
342 /* Initialize saved callbacks and original proto only once, since this
343 * function may be called multiple times for a psock, e.g. when
344 * psock->progs.msg_parser is updated.
346 * Since we've not installed the new proto, psock is not yet in use and
347 * we can initialize it without synchronization.
349 if (!psock
->sk_proto
) {
350 struct proto
*orig
= READ_ONCE(sk
->sk_prot
);
352 psock
->saved_unhash
= orig
->unhash
;
353 psock
->saved_close
= orig
->close
;
354 psock
->saved_write_space
= sk
->sk_write_space
;
356 psock
->sk_proto
= orig
;
359 /* Pairs with lockless read in sk_clone_lock() */
360 WRITE_ONCE(sk
->sk_prot
, ops
);
363 static inline void sk_psock_restore_proto(struct sock
*sk
,
364 struct sk_psock
*psock
)
366 sk
->sk_prot
->unhash
= psock
->saved_unhash
;
367 if (inet_csk_has_ulp(sk
)) {
368 tcp_update_ulp(sk
, psock
->sk_proto
, psock
->saved_write_space
);
370 sk
->sk_write_space
= psock
->saved_write_space
;
371 /* Pairs with lockless read in sk_clone_lock() */
372 WRITE_ONCE(sk
->sk_prot
, psock
->sk_proto
);
376 static inline void sk_psock_set_state(struct sk_psock
*psock
,
377 enum sk_psock_state_bits bit
)
379 set_bit(bit
, &psock
->state
);
382 static inline void sk_psock_clear_state(struct sk_psock
*psock
,
383 enum sk_psock_state_bits bit
)
385 clear_bit(bit
, &psock
->state
);
388 static inline bool sk_psock_test_state(const struct sk_psock
*psock
,
389 enum sk_psock_state_bits bit
)
391 return test_bit(bit
, &psock
->state
);
394 static inline struct sk_psock
*sk_psock_get(struct sock
*sk
)
396 struct sk_psock
*psock
;
399 psock
= sk_psock(sk
);
400 if (psock
&& !refcount_inc_not_zero(&psock
->refcnt
))
406 void sk_psock_stop(struct sock
*sk
, struct sk_psock
*psock
);
407 void sk_psock_destroy(struct rcu_head
*rcu
);
408 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
);
410 static inline void sk_psock_put(struct sock
*sk
, struct sk_psock
*psock
)
412 if (refcount_dec_and_test(&psock
->refcnt
))
413 sk_psock_drop(sk
, psock
);
416 static inline void sk_psock_data_ready(struct sock
*sk
, struct sk_psock
*psock
)
418 if (psock
->parser
.enabled
)
419 psock
->parser
.saved_data_ready(sk
);
421 sk
->sk_data_ready(sk
);
424 static inline void psock_set_prog(struct bpf_prog
**pprog
,
425 struct bpf_prog
*prog
)
427 prog
= xchg(pprog
, prog
);
432 static inline void psock_progs_drop(struct sk_psock_progs
*progs
)
434 psock_set_prog(&progs
->msg_parser
, NULL
);
435 psock_set_prog(&progs
->skb_parser
, NULL
);
436 psock_set_prog(&progs
->skb_verdict
, NULL
);
439 #endif /* _LINUX_SKMSG_H */