1 /* SPDX-License-Identifier: GPL-2.0 */
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
8 #include <linux/filter.h>
9 #include <linux/scatterlist.h>
10 #include <linux/skbuff.h>
14 #include <net/strparser.h>
16 #define MAX_MSG_FRAGS MAX_SKB_FRAGS
31 bool copy
[MAX_MSG_FRAGS
];
32 /* The extra element is used for chaining the front and sections when
33 * the list becomes partitioned (e.g. end < start). The crypto APIs
34 * require the chaining.
36 struct scatterlist data
[MAX_MSG_FRAGS
+ 1];
39 /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
48 struct sock
*sk_redir
;
50 struct list_head list
;
53 struct sk_psock_progs
{
54 struct bpf_prog
*msg_parser
;
55 struct bpf_prog
*skb_parser
;
56 struct bpf_prog
*skb_verdict
;
59 enum sk_psock_state_bits
{
63 struct sk_psock_link
{
64 struct list_head list
;
69 struct sk_psock_parser
{
70 struct strparser strp
;
72 void (*saved_data_ready
)(struct sock
*sk
);
75 struct sk_psock_work_state
{
83 struct sock
*sk_redir
;
88 struct sk_psock_progs progs
;
89 struct sk_psock_parser parser
;
90 struct sk_buff_head ingress_skb
;
91 struct list_head ingress_msg
;
93 struct list_head link
;
96 void (*saved_unhash
)(struct sock
*sk
);
97 void (*saved_close
)(struct sock
*sk
, long timeout
);
98 void (*saved_write_space
)(struct sock
*sk
);
99 struct proto
*sk_proto
;
100 struct sk_psock_work_state work_state
;
101 struct work_struct work
;
104 struct work_struct gc
;
108 int sk_msg_alloc(struct sock
*sk
, struct sk_msg
*msg
, int len
,
109 int elem_first_coalesce
);
110 int sk_msg_clone(struct sock
*sk
, struct sk_msg
*dst
, struct sk_msg
*src
,
112 void sk_msg_trim(struct sock
*sk
, struct sk_msg
*msg
, int len
);
113 int sk_msg_free(struct sock
*sk
, struct sk_msg
*msg
);
114 int sk_msg_free_nocharge(struct sock
*sk
, struct sk_msg
*msg
);
115 void sk_msg_free_partial(struct sock
*sk
, struct sk_msg
*msg
, u32 bytes
);
116 void sk_msg_free_partial_nocharge(struct sock
*sk
, struct sk_msg
*msg
,
119 void sk_msg_return(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
120 void sk_msg_return_zero(struct sock
*sk
, struct sk_msg
*msg
, int bytes
);
122 int sk_msg_zerocopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
123 struct sk_msg
*msg
, u32 bytes
);
124 int sk_msg_memcopy_from_iter(struct sock
*sk
, struct iov_iter
*from
,
125 struct sk_msg
*msg
, u32 bytes
);
127 static inline void sk_msg_check_to_free(struct sk_msg
*msg
, u32 i
, u32 bytes
)
129 WARN_ON(i
== msg
->sg
.end
&& bytes
);
132 static inline void sk_msg_apply_bytes(struct sk_psock
*psock
, u32 bytes
)
134 if (psock
->apply_bytes
) {
135 if (psock
->apply_bytes
< bytes
)
136 psock
->apply_bytes
= 0;
138 psock
->apply_bytes
-= bytes
;
142 #define sk_msg_iter_var_prev(var) \
145 var = MAX_MSG_FRAGS - 1; \
150 #define sk_msg_iter_var_next(var) \
153 if (var == MAX_MSG_FRAGS) \
157 #define sk_msg_iter_prev(msg, which) \
158 sk_msg_iter_var_prev(msg->sg.which)
160 #define sk_msg_iter_next(msg, which) \
161 sk_msg_iter_var_next(msg->sg.which)
163 static inline void sk_msg_clear_meta(struct sk_msg
*msg
)
165 memset(&msg
->sg
, 0, offsetofend(struct sk_msg_sg
, copy
));
168 static inline void sk_msg_init(struct sk_msg
*msg
)
170 BUILD_BUG_ON(ARRAY_SIZE(msg
->sg
.data
) - 1 != MAX_MSG_FRAGS
);
171 memset(msg
, 0, sizeof(*msg
));
172 sg_init_marker(msg
->sg
.data
, MAX_MSG_FRAGS
);
175 static inline void sk_msg_xfer(struct sk_msg
*dst
, struct sk_msg
*src
,
178 dst
->sg
.data
[which
] = src
->sg
.data
[which
];
179 dst
->sg
.data
[which
].length
= size
;
180 dst
->sg
.size
+= size
;
181 src
->sg
.data
[which
].length
-= size
;
182 src
->sg
.data
[which
].offset
+= size
;
185 static inline void sk_msg_xfer_full(struct sk_msg
*dst
, struct sk_msg
*src
)
187 memcpy(dst
, src
, sizeof(*src
));
191 static inline bool sk_msg_full(const struct sk_msg
*msg
)
193 return (msg
->sg
.end
== msg
->sg
.start
) && msg
->sg
.size
;
196 static inline u32
sk_msg_elem_used(const struct sk_msg
*msg
)
198 if (sk_msg_full(msg
))
199 return MAX_MSG_FRAGS
;
201 return msg
->sg
.end
>= msg
->sg
.start
?
202 msg
->sg
.end
- msg
->sg
.start
:
203 msg
->sg
.end
+ (MAX_MSG_FRAGS
- msg
->sg
.start
);
206 static inline struct scatterlist
*sk_msg_elem(struct sk_msg
*msg
, int which
)
208 return &msg
->sg
.data
[which
];
211 static inline struct scatterlist
sk_msg_elem_cpy(struct sk_msg
*msg
, int which
)
213 return msg
->sg
.data
[which
];
216 static inline struct page
*sk_msg_page(struct sk_msg
*msg
, int which
)
218 return sg_page(sk_msg_elem(msg
, which
));
221 static inline bool sk_msg_to_ingress(const struct sk_msg
*msg
)
223 return msg
->flags
& BPF_F_INGRESS
;
226 static inline void sk_msg_compute_data_pointers(struct sk_msg
*msg
)
228 struct scatterlist
*sge
= sk_msg_elem(msg
, msg
->sg
.start
);
230 if (msg
->sg
.copy
[msg
->sg
.start
]) {
232 msg
->data_end
= NULL
;
234 msg
->data
= sg_virt(sge
);
235 msg
->data_end
= msg
->data
+ sge
->length
;
239 static inline void sk_msg_page_add(struct sk_msg
*msg
, struct page
*page
,
242 struct scatterlist
*sge
;
245 sge
= sk_msg_elem(msg
, msg
->sg
.end
);
246 sg_set_page(sge
, page
, len
, offset
);
249 msg
->sg
.copy
[msg
->sg
.end
] = true;
251 sk_msg_iter_next(msg
, end
);
254 static inline void sk_msg_sg_copy(struct sk_msg
*msg
, u32 i
, bool copy_state
)
257 msg
->sg
.copy
[i
] = copy_state
;
258 sk_msg_iter_var_next(i
);
259 if (i
== msg
->sg
.end
)
264 static inline void sk_msg_sg_copy_set(struct sk_msg
*msg
, u32 start
)
266 sk_msg_sg_copy(msg
, start
, true);
269 static inline void sk_msg_sg_copy_clear(struct sk_msg
*msg
, u32 start
)
271 sk_msg_sg_copy(msg
, start
, false);
274 static inline struct sk_psock
*sk_psock(const struct sock
*sk
)
276 return rcu_dereference_sk_user_data(sk
);
279 static inline void sk_psock_queue_msg(struct sk_psock
*psock
,
282 list_add_tail(&msg
->list
, &psock
->ingress_msg
);
285 static inline bool sk_psock_queue_empty(const struct sk_psock
*psock
)
287 return psock
? list_empty(&psock
->ingress_msg
) : true;
290 static inline void sk_psock_report_error(struct sk_psock
*psock
, int err
)
292 struct sock
*sk
= psock
->sk
;
295 sk
->sk_error_report(sk
);
298 struct sk_psock
*sk_psock_init(struct sock
*sk
, int node
);
300 int sk_psock_init_strp(struct sock
*sk
, struct sk_psock
*psock
);
301 void sk_psock_start_strp(struct sock
*sk
, struct sk_psock
*psock
);
302 void sk_psock_stop_strp(struct sock
*sk
, struct sk_psock
*psock
);
304 int sk_psock_msg_verdict(struct sock
*sk
, struct sk_psock
*psock
,
307 static inline struct sk_psock_link
*sk_psock_init_link(void)
309 return kzalloc(sizeof(struct sk_psock_link
),
310 GFP_ATOMIC
| __GFP_NOWARN
);
313 static inline void sk_psock_free_link(struct sk_psock_link
*link
)
318 struct sk_psock_link
*sk_psock_link_pop(struct sk_psock
*psock
);
319 #if defined(CONFIG_BPF_STREAM_PARSER)
320 void sk_psock_unlink(struct sock
*sk
, struct sk_psock_link
*link
);
322 static inline void sk_psock_unlink(struct sock
*sk
,
323 struct sk_psock_link
*link
)
328 void __sk_psock_purge_ingress_msg(struct sk_psock
*psock
);
330 static inline void sk_psock_cork_free(struct sk_psock
*psock
)
333 sk_msg_free(psock
->sk
, psock
->cork
);
339 static inline void sk_psock_update_proto(struct sock
*sk
,
340 struct sk_psock
*psock
,
343 psock
->saved_unhash
= sk
->sk_prot
->unhash
;
344 psock
->saved_close
= sk
->sk_prot
->close
;
345 psock
->saved_write_space
= sk
->sk_write_space
;
347 psock
->sk_proto
= sk
->sk_prot
;
351 static inline void sk_psock_restore_proto(struct sock
*sk
,
352 struct sk_psock
*psock
)
354 sk
->sk_write_space
= psock
->saved_write_space
;
356 if (psock
->sk_proto
) {
357 struct inet_connection_sock
*icsk
= inet_csk(sk
);
358 bool has_ulp
= !!icsk
->icsk_ulp_data
;
361 tcp_update_ulp(sk
, psock
->sk_proto
);
363 sk
->sk_prot
= psock
->sk_proto
;
364 psock
->sk_proto
= NULL
;
368 static inline void sk_psock_set_state(struct sk_psock
*psock
,
369 enum sk_psock_state_bits bit
)
371 set_bit(bit
, &psock
->state
);
374 static inline void sk_psock_clear_state(struct sk_psock
*psock
,
375 enum sk_psock_state_bits bit
)
377 clear_bit(bit
, &psock
->state
);
380 static inline bool sk_psock_test_state(const struct sk_psock
*psock
,
381 enum sk_psock_state_bits bit
)
383 return test_bit(bit
, &psock
->state
);
386 static inline struct sk_psock
*sk_psock_get_checked(struct sock
*sk
)
388 struct sk_psock
*psock
;
391 psock
= sk_psock(sk
);
393 if (sk
->sk_prot
->recvmsg
!= tcp_bpf_recvmsg
) {
394 psock
= ERR_PTR(-EBUSY
);
398 if (!refcount_inc_not_zero(&psock
->refcnt
))
399 psock
= ERR_PTR(-EBUSY
);
406 static inline struct sk_psock
*sk_psock_get(struct sock
*sk
)
408 struct sk_psock
*psock
;
411 psock
= sk_psock(sk
);
412 if (psock
&& !refcount_inc_not_zero(&psock
->refcnt
))
418 void sk_psock_stop(struct sock
*sk
, struct sk_psock
*psock
);
419 void sk_psock_destroy(struct rcu_head
*rcu
);
420 void sk_psock_drop(struct sock
*sk
, struct sk_psock
*psock
);
422 static inline void sk_psock_put(struct sock
*sk
, struct sk_psock
*psock
)
424 if (refcount_dec_and_test(&psock
->refcnt
))
425 sk_psock_drop(sk
, psock
);
428 static inline void sk_psock_data_ready(struct sock
*sk
, struct sk_psock
*psock
)
430 if (psock
->parser
.enabled
)
431 psock
->parser
.saved_data_ready(sk
);
433 sk
->sk_data_ready(sk
);
436 static inline void psock_set_prog(struct bpf_prog
**pprog
,
437 struct bpf_prog
*prog
)
439 prog
= xchg(pprog
, prog
);
444 static inline void psock_progs_drop(struct sk_psock_progs
*progs
)
446 psock_set_prog(&progs
->msg_parser
, NULL
);
447 psock_set_prog(&progs
->skb_parser
, NULL
);
448 psock_set_prog(&progs
->skb_verdict
, NULL
);
451 #endif /* _LINUX_SKMSG_H */