]> git.ipfire.org Git - thirdparty/linux.git/blame - net/xfrm/espintcp.c
Merge tag 'x86-fpu-2020-06-01' of git://git.kernel.org/pub/scm/linux/kernel/git/tip/tip
[thirdparty/linux.git] / net / xfrm / espintcp.c
CommitLineData
e27cca96
SD
1// SPDX-License-Identifier: GPL-2.0
2#include <net/tcp.h>
3#include <net/strparser.h>
4#include <net/xfrm.h>
5#include <net/esp.h>
6#include <net/espintcp.h>
7#include <linux/skmsg.h>
8#include <net/inet_common.h>
9
10static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
11 struct sock *sk)
12{
13 if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
14 !sk_rmem_schedule(sk, skb, skb->truesize)) {
15 kfree_skb(skb);
16 return;
17 }
18
19 skb_set_owner_r(skb, sk);
20
21 memset(skb->cb, 0, sizeof(skb->cb));
22 skb_queue_tail(&ctx->ike_queue, skb);
23 ctx->saved_data_ready(sk);
24}
25
26static void handle_esp(struct sk_buff *skb, struct sock *sk)
27{
28 skb_reset_transport_header(skb);
29 memset(skb->cb, 0, sizeof(skb->cb));
30
31 rcu_read_lock();
32 skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
33 local_bh_disable();
34 xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
35 local_bh_enable();
36 rcu_read_unlock();
37}
38
39static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
40{
41 struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
42 strp);
43 struct strp_msg *rxm = strp_msg(skb);
44 u32 nonesp_marker;
45 int err;
46
47 err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
48 sizeof(nonesp_marker));
49 if (err < 0) {
50 kfree_skb(skb);
51 return;
52 }
53
54 /* remove header, leave non-ESP marker/SPI */
55 if (!__pskb_pull(skb, rxm->offset + 2)) {
56 kfree_skb(skb);
57 return;
58 }
59
60 if (pskb_trim(skb, rxm->full_len - 2) != 0) {
61 kfree_skb(skb);
62 return;
63 }
64
65 if (nonesp_marker == 0)
66 handle_nonesp(ctx, skb, strp->sk);
67 else
68 handle_esp(skb, strp->sk);
69}
70
71static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
72{
73 struct strp_msg *rxm = strp_msg(skb);
74 __be16 blen;
75 u16 len;
76 int err;
77
78 if (skb->len < rxm->offset + 2)
79 return 0;
80
81 err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
82 if (err < 0)
83 return err;
84
85 len = be16_to_cpu(blen);
86 if (len < 6)
87 return -EINVAL;
88
89 return len;
90}
91
92static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
93 int nonblock, int flags, int *addr_len)
94{
95 struct espintcp_ctx *ctx = espintcp_getctx(sk);
96 struct sk_buff *skb;
97 int err = 0;
98 int copied;
99 int off = 0;
100
101 flags |= nonblock ? MSG_DONTWAIT : 0;
102
e427cad6 103 skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
e27cca96
SD
104 if (!skb)
105 return err;
106
107 copied = len;
108 if (copied > skb->len)
109 copied = skb->len;
110 else if (copied < skb->len)
111 msg->msg_flags |= MSG_TRUNC;
112
113 err = skb_copy_datagram_msg(skb, 0, msg, copied);
114 if (unlikely(err)) {
115 kfree_skb(skb);
116 return err;
117 }
118
119 if (flags & MSG_TRUNC)
120 copied = skb->len;
121 kfree_skb(skb);
122 return copied;
123}
124
125int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
126{
127 struct espintcp_ctx *ctx = espintcp_getctx(sk);
128
129 if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
130 return -ENOBUFS;
131
132 __skb_queue_tail(&ctx->out_queue, skb);
133
134 return 0;
135}
136EXPORT_SYMBOL_GPL(espintcp_queue_out);
137
138/* espintcp length field is 2B and length includes the length field's size */
139#define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
140
141static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
142 int flags)
143{
144 do {
145 int ret;
146
147 ret = skb_send_sock_locked(sk, emsg->skb,
148 emsg->offset, emsg->len);
149 if (ret < 0)
150 return ret;
151
152 emsg->len -= ret;
153 emsg->offset += ret;
154 } while (emsg->len > 0);
155
156 kfree_skb(emsg->skb);
157 memset(emsg, 0, sizeof(*emsg));
158
159 return 0;
160}
161
162static int espintcp_sendskmsg_locked(struct sock *sk,
163 struct espintcp_msg *emsg, int flags)
164{
165 struct sk_msg *skmsg = &emsg->skmsg;
166 struct scatterlist *sg;
167 int done = 0;
168 int ret;
169
170 flags |= MSG_SENDPAGE_NOTLAST;
171 sg = &skmsg->sg.data[skmsg->sg.start];
172 do {
173 size_t size = sg->length - emsg->offset;
174 int offset = sg->offset + emsg->offset;
175 struct page *p;
176
177 emsg->offset = 0;
178
179 if (sg_is_last(sg))
180 flags &= ~MSG_SENDPAGE_NOTLAST;
181
182 p = sg_page(sg);
183retry:
184 ret = do_tcp_sendpages(sk, p, offset, size, flags);
185 if (ret < 0) {
186 emsg->offset = offset - sg->offset;
187 skmsg->sg.start += done;
188 return ret;
189 }
190
191 if (ret != size) {
192 offset += ret;
193 size -= ret;
194 goto retry;
195 }
196
197 done++;
198 put_page(p);
199 sk_mem_uncharge(sk, sg->length);
200 sg = sg_next(sg);
201 } while (sg);
202
203 memset(emsg, 0, sizeof(*emsg));
204
205 return 0;
206}
207
208static int espintcp_push_msgs(struct sock *sk)
209{
210 struct espintcp_ctx *ctx = espintcp_getctx(sk);
211 struct espintcp_msg *emsg = &ctx->partial;
212 int err;
213
214 if (!emsg->len)
215 return 0;
216
217 if (ctx->tx_running)
218 return -EAGAIN;
219 ctx->tx_running = 1;
220
221 if (emsg->skb)
222 err = espintcp_sendskb_locked(sk, emsg, 0);
223 else
224 err = espintcp_sendskmsg_locked(sk, emsg, 0);
225 if (err == -EAGAIN) {
226 ctx->tx_running = 0;
227 return 0;
228 }
229 if (!err)
230 memset(emsg, 0, sizeof(*emsg));
231
232 ctx->tx_running = 0;
233
234 return err;
235}
236
237int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
238{
239 struct espintcp_ctx *ctx = espintcp_getctx(sk);
240 struct espintcp_msg *emsg = &ctx->partial;
241 unsigned int len;
242 int offset;
243
244 if (sk->sk_state != TCP_ESTABLISHED) {
245 kfree_skb(skb);
246 return -ECONNRESET;
247 }
248
249 offset = skb_transport_offset(skb);
250 len = skb->len - offset;
251
252 espintcp_push_msgs(sk);
253
254 if (emsg->len) {
255 kfree_skb(skb);
256 return -ENOBUFS;
257 }
258
259 skb_set_owner_w(skb, sk);
260
261 emsg->offset = offset;
262 emsg->len = len;
263 emsg->skb = skb;
264
265 espintcp_push_msgs(sk);
266
267 return 0;
268}
269EXPORT_SYMBOL_GPL(espintcp_push_skb);
270
271static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
272{
273 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
274 struct espintcp_ctx *ctx = espintcp_getctx(sk);
275 struct espintcp_msg *emsg = &ctx->partial;
276 struct iov_iter pfx_iter;
277 struct kvec pfx_iov = {};
278 size_t msglen = size + 2;
279 char buf[2] = {0};
280 int err, end;
281
282 if (msg->msg_flags)
283 return -EOPNOTSUPP;
284
285 if (size > MAX_ESPINTCP_MSG)
286 return -EMSGSIZE;
287
288 if (msg->msg_controllen)
289 return -EOPNOTSUPP;
290
291 lock_sock(sk);
292
293 err = espintcp_push_msgs(sk);
294 if (err < 0) {
295 err = -ENOBUFS;
296 goto unlock;
297 }
298
299 sk_msg_init(&emsg->skmsg);
300 while (1) {
301 /* only -ENOMEM is possible since we don't coalesce */
302 err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
303 if (!err)
304 break;
305
306 err = sk_stream_wait_memory(sk, &timeo);
307 if (err)
308 goto fail;
309 }
310
311 *((__be16 *)buf) = cpu_to_be16(msglen);
312 pfx_iov.iov_base = buf;
313 pfx_iov.iov_len = sizeof(buf);
314 iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
315
316 err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
317 pfx_iov.iov_len);
318 if (err < 0)
319 goto fail;
320
321 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
322 if (err < 0)
323 goto fail;
324
325 end = emsg->skmsg.sg.end;
326 emsg->len = size;
327 sk_msg_iter_var_prev(end);
328 sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
329
330 tcp_rate_check_app_limited(sk);
331
332 err = espintcp_push_msgs(sk);
333 /* this message could be partially sent, keep it */
334 if (err < 0)
335 goto unlock;
336 release_sock(sk);
337
338 return size;
339
340fail:
341 sk_msg_free(sk, &emsg->skmsg);
342 memset(emsg, 0, sizeof(*emsg));
343unlock:
344 release_sock(sk);
345 return err;
346}
347
348static struct proto espintcp_prot __ro_after_init;
349static struct proto_ops espintcp_ops __ro_after_init;
350
351static void espintcp_data_ready(struct sock *sk)
352{
353 struct espintcp_ctx *ctx = espintcp_getctx(sk);
354
355 strp_data_ready(&ctx->strp);
356}
357
358static void espintcp_tx_work(struct work_struct *work)
359{
360 struct espintcp_ctx *ctx = container_of(work,
361 struct espintcp_ctx, work);
362 struct sock *sk = ctx->strp.sk;
363
364 lock_sock(sk);
365 if (!ctx->tx_running)
366 espintcp_push_msgs(sk);
367 release_sock(sk);
368}
369
370static void espintcp_write_space(struct sock *sk)
371{
372 struct espintcp_ctx *ctx = espintcp_getctx(sk);
373
374 schedule_work(&ctx->work);
375 ctx->saved_write_space(sk);
376}
377
378static void espintcp_destruct(struct sock *sk)
379{
380 struct espintcp_ctx *ctx = espintcp_getctx(sk);
381
9f0cadc3 382 ctx->saved_destruct(sk);
e27cca96
SD
383 kfree(ctx);
384}
385
386bool tcp_is_ulp_esp(struct sock *sk)
387{
388 return sk->sk_prot == &espintcp_prot;
389}
390EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
391
392static int espintcp_init_sk(struct sock *sk)
393{
394 struct inet_connection_sock *icsk = inet_csk(sk);
395 struct strp_callbacks cb = {
396 .rcv_msg = espintcp_rcv,
397 .parse_msg = espintcp_parse,
398 };
399 struct espintcp_ctx *ctx;
400 int err;
401
402 /* sockmap is not compatible with espintcp */
403 if (sk->sk_user_data)
404 return -EBUSY;
405
406 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
407 if (!ctx)
408 return -ENOMEM;
409
410 err = strp_init(&ctx->strp, sk, &cb);
411 if (err)
412 goto free;
413
414 __sk_dst_reset(sk);
415
416 strp_check_rcv(&ctx->strp);
417 skb_queue_head_init(&ctx->ike_queue);
418 skb_queue_head_init(&ctx->out_queue);
419 sk->sk_prot = &espintcp_prot;
420 sk->sk_socket->ops = &espintcp_ops;
421 ctx->saved_data_ready = sk->sk_data_ready;
422 ctx->saved_write_space = sk->sk_write_space;
9f0cadc3 423 ctx->saved_destruct = sk->sk_destruct;
e27cca96
SD
424 sk->sk_data_ready = espintcp_data_ready;
425 sk->sk_write_space = espintcp_write_space;
426 sk->sk_destruct = espintcp_destruct;
427 rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
428 INIT_WORK(&ctx->work, espintcp_tx_work);
429
430 /* avoid using task_frag */
431 sk->sk_allocation = GFP_ATOMIC;
432
433 return 0;
434
435free:
436 kfree(ctx);
437 return err;
438}
439
440static void espintcp_release(struct sock *sk)
441{
442 struct espintcp_ctx *ctx = espintcp_getctx(sk);
443 struct sk_buff_head queue;
444 struct sk_buff *skb;
445
446 __skb_queue_head_init(&queue);
447 skb_queue_splice_init(&ctx->out_queue, &queue);
448
449 while ((skb = __skb_dequeue(&queue)))
450 espintcp_push_skb(sk, skb);
451
452 tcp_release_cb(sk);
453}
454
455static void espintcp_close(struct sock *sk, long timeout)
456{
457 struct espintcp_ctx *ctx = espintcp_getctx(sk);
458 struct espintcp_msg *emsg = &ctx->partial;
459
460 strp_stop(&ctx->strp);
461
462 sk->sk_prot = &tcp_prot;
463 barrier();
464
465 cancel_work_sync(&ctx->work);
466 strp_done(&ctx->strp);
467
468 skb_queue_purge(&ctx->out_queue);
469 skb_queue_purge(&ctx->ike_queue);
470
471 if (emsg->len) {
472 if (emsg->skb)
473 kfree_skb(emsg->skb);
474 else
475 sk_msg_free(sk, &emsg->skmsg);
476 }
477
478 tcp_close(sk, timeout);
479}
480
481static __poll_t espintcp_poll(struct file *file, struct socket *sock,
482 poll_table *wait)
483{
484 __poll_t mask = datagram_poll(file, sock, wait);
485 struct sock *sk = sock->sk;
486 struct espintcp_ctx *ctx = espintcp_getctx(sk);
487
488 if (!skb_queue_empty(&ctx->ike_queue))
489 mask |= EPOLLIN | EPOLLRDNORM;
490
491 return mask;
492}
493
494static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
495 .name = "espintcp",
496 .owner = THIS_MODULE,
497 .init = espintcp_init_sk,
498};
499
500void __init espintcp_init(void)
501{
502 memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
503 memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
504 espintcp_prot.sendmsg = espintcp_sendmsg;
505 espintcp_prot.recvmsg = espintcp_recvmsg;
506 espintcp_prot.close = espintcp_close;
507 espintcp_prot.release_cb = espintcp_release;
508 espintcp_ops.poll = espintcp_poll;
509
510 tcp_register_ulp(&espintcp_ulp);
511}