]>
git.ipfire.org Git - thirdparty/linux.git/blob - drivers/net/wireguard/socket.c
1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
12 #include <linux/ctype.h>
13 #include <linux/net.h>
14 #include <linux/if_vlan.h>
15 #include <linux/if_ether.h>
16 #include <linux/inetdevice.h>
17 #include <net/udp_tunnel.h>
20 static int send4(struct wg_device
*wg
, struct sk_buff
*skb
,
21 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
24 .saddr
= endpoint
->src4
.s_addr
,
25 .daddr
= endpoint
->addr4
.sin_addr
.s_addr
,
26 .fl4_dport
= endpoint
->addr4
.sin_port
,
27 .flowi4_mark
= wg
->fwmark
,
28 .flowi4_proto
= IPPROTO_UDP
30 struct rtable
*rt
= NULL
;
34 skb_mark_not_on_list(skb
);
36 skb
->mark
= wg
->fwmark
;
39 sock
= rcu_dereference_bh(wg
->sock4
);
41 if (unlikely(!sock
)) {
46 fl
.fl4_sport
= inet_sk(sock
)->inet_sport
;
49 rt
= dst_cache_get_ip4(cache
, &fl
.saddr
);
52 security_sk_classify_flow(sock
, flowi4_to_flowi(&fl
));
53 if (unlikely(!inet_confirm_addr(sock_net(sock
), NULL
, 0,
54 fl
.saddr
, RT_SCOPE_HOST
))) {
55 endpoint
->src4
.s_addr
= 0;
56 *(__force __be32
*)&endpoint
->src_if4
= 0;
59 dst_cache_reset(cache
);
61 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
62 if (unlikely(endpoint
->src_if4
&& ((IS_ERR(rt
) &&
63 PTR_ERR(rt
) == -EINVAL
) || (!IS_ERR(rt
) &&
64 rt
->dst
.dev
->ifindex
!= endpoint
->src_if4
)))) {
65 endpoint
->src4
.s_addr
= 0;
66 *(__force __be32
*)&endpoint
->src_if4
= 0;
69 dst_cache_reset(cache
);
72 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
74 if (unlikely(IS_ERR(rt
))) {
76 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
77 wg
->dev
->name
, &endpoint
->addr
, ret
);
79 } else if (unlikely(rt
->dst
.dev
== skb
->dev
)) {
82 net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
83 wg
->dev
->name
, &endpoint
->addr
);
87 dst_cache_set_ip4(cache
, &rt
->dst
, fl
.saddr
);
91 udp_tunnel_xmit_skb(rt
, sock
, skb
, fl
.saddr
, fl
.daddr
, ds
,
92 ip4_dst_hoplimit(&rt
->dst
), 0, fl
.fl4_sport
,
93 fl
.fl4_dport
, false, false);
103 static int send6(struct wg_device
*wg
, struct sk_buff
*skb
,
104 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
106 #if IS_ENABLED(CONFIG_IPV6)
108 .saddr
= endpoint
->src6
,
109 .daddr
= endpoint
->addr6
.sin6_addr
,
110 .fl6_dport
= endpoint
->addr6
.sin6_port
,
111 .flowi6_mark
= wg
->fwmark
,
112 .flowi6_oif
= endpoint
->addr6
.sin6_scope_id
,
113 .flowi6_proto
= IPPROTO_UDP
114 /* TODO: addr->sin6_flowinfo */
116 struct dst_entry
*dst
= NULL
;
120 skb_mark_not_on_list(skb
);
122 skb
->mark
= wg
->fwmark
;
125 sock
= rcu_dereference_bh(wg
->sock6
);
127 if (unlikely(!sock
)) {
132 fl
.fl6_sport
= inet_sk(sock
)->inet_sport
;
135 dst
= dst_cache_get_ip6(cache
, &fl
.saddr
);
138 security_sk_classify_flow(sock
, flowi6_to_flowi(&fl
));
139 if (unlikely(!ipv6_addr_any(&fl
.saddr
) &&
140 !ipv6_chk_addr(sock_net(sock
), &fl
.saddr
, NULL
, 0))) {
141 endpoint
->src6
= fl
.saddr
= in6addr_any
;
143 dst_cache_reset(cache
);
145 dst
= ipv6_stub
->ipv6_dst_lookup_flow(sock_net(sock
), sock
, &fl
,
147 if (unlikely(IS_ERR(dst
))) {
149 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
150 wg
->dev
->name
, &endpoint
->addr
, ret
);
152 } else if (unlikely(dst
->dev
== skb
->dev
)) {
155 net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
156 wg
->dev
->name
, &endpoint
->addr
);
160 dst_cache_set_ip6(cache
, dst
, &fl
.saddr
);
164 udp_tunnel6_xmit_skb(dst
, sock
, skb
, skb
->dev
, &fl
.saddr
, &fl
.daddr
, ds
,
165 ip6_dst_hoplimit(dst
), 0, fl
.fl6_sport
,
166 fl
.fl6_dport
, false);
172 rcu_read_unlock_bh();
175 return -EAFNOSUPPORT
;
179 int wg_socket_send_skb_to_peer(struct wg_peer
*peer
, struct sk_buff
*skb
, u8 ds
)
181 size_t skb_len
= skb
->len
;
182 int ret
= -EAFNOSUPPORT
;
184 read_lock_bh(&peer
->endpoint_lock
);
185 if (peer
->endpoint
.addr
.sa_family
== AF_INET
)
186 ret
= send4(peer
->device
, skb
, &peer
->endpoint
, ds
,
187 &peer
->endpoint_cache
);
188 else if (peer
->endpoint
.addr
.sa_family
== AF_INET6
)
189 ret
= send6(peer
->device
, skb
, &peer
->endpoint
, ds
,
190 &peer
->endpoint_cache
);
194 peer
->tx_bytes
+= skb_len
;
195 read_unlock_bh(&peer
->endpoint_lock
);
200 int wg_socket_send_buffer_to_peer(struct wg_peer
*peer
, void *buffer
,
203 struct sk_buff
*skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
208 skb_reserve(skb
, SKB_HEADER_LEN
);
209 skb_set_inner_network_header(skb
, 0);
210 skb_put_data(skb
, buffer
, len
);
211 return wg_socket_send_skb_to_peer(peer
, skb
, ds
);
214 int wg_socket_send_buffer_as_reply_to_skb(struct wg_device
*wg
,
215 struct sk_buff
*in_skb
, void *buffer
,
220 struct endpoint endpoint
;
222 if (unlikely(!in_skb
))
224 ret
= wg_socket_endpoint_from_skb(&endpoint
, in_skb
);
225 if (unlikely(ret
< 0))
228 skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
231 skb_reserve(skb
, SKB_HEADER_LEN
);
232 skb_set_inner_network_header(skb
, 0);
233 skb_put_data(skb
, buffer
, len
);
235 if (endpoint
.addr
.sa_family
== AF_INET
)
236 ret
= send4(wg
, skb
, &endpoint
, 0, NULL
);
237 else if (endpoint
.addr
.sa_family
== AF_INET6
)
238 ret
= send6(wg
, skb
, &endpoint
, 0, NULL
);
239 /* No other possibilities if the endpoint is valid, which it is,
240 * as we checked above.
246 int wg_socket_endpoint_from_skb(struct endpoint
*endpoint
,
247 const struct sk_buff
*skb
)
249 memset(endpoint
, 0, sizeof(*endpoint
));
250 if (skb
->protocol
== htons(ETH_P_IP
)) {
251 endpoint
->addr4
.sin_family
= AF_INET
;
252 endpoint
->addr4
.sin_port
= udp_hdr(skb
)->source
;
253 endpoint
->addr4
.sin_addr
.s_addr
= ip_hdr(skb
)->saddr
;
254 endpoint
->src4
.s_addr
= ip_hdr(skb
)->daddr
;
255 endpoint
->src_if4
= skb
->skb_iif
;
256 } else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
257 endpoint
->addr6
.sin6_family
= AF_INET6
;
258 endpoint
->addr6
.sin6_port
= udp_hdr(skb
)->source
;
259 endpoint
->addr6
.sin6_addr
= ipv6_hdr(skb
)->saddr
;
260 endpoint
->addr6
.sin6_scope_id
= ipv6_iface_scope_id(
261 &ipv6_hdr(skb
)->saddr
, skb
->skb_iif
);
262 endpoint
->src6
= ipv6_hdr(skb
)->daddr
;
269 static bool endpoint_eq(const struct endpoint
*a
, const struct endpoint
*b
)
271 return (a
->addr
.sa_family
== AF_INET
&& b
->addr
.sa_family
== AF_INET
&&
272 a
->addr4
.sin_port
== b
->addr4
.sin_port
&&
273 a
->addr4
.sin_addr
.s_addr
== b
->addr4
.sin_addr
.s_addr
&&
274 a
->src4
.s_addr
== b
->src4
.s_addr
&& a
->src_if4
== b
->src_if4
) ||
275 (a
->addr
.sa_family
== AF_INET6
&&
276 b
->addr
.sa_family
== AF_INET6
&&
277 a
->addr6
.sin6_port
== b
->addr6
.sin6_port
&&
278 ipv6_addr_equal(&a
->addr6
.sin6_addr
, &b
->addr6
.sin6_addr
) &&
279 a
->addr6
.sin6_scope_id
== b
->addr6
.sin6_scope_id
&&
280 ipv6_addr_equal(&a
->src6
, &b
->src6
)) ||
281 unlikely(!a
->addr
.sa_family
&& !b
->addr
.sa_family
);
284 void wg_socket_set_peer_endpoint(struct wg_peer
*peer
,
285 const struct endpoint
*endpoint
)
287 /* First we check unlocked, in order to optimize, since it's pretty rare
288 * that an endpoint will change. If we happen to be mid-write, and two
289 * CPUs wind up writing the same thing or something slightly different,
290 * it doesn't really matter much either.
292 if (endpoint_eq(endpoint
, &peer
->endpoint
))
294 write_lock_bh(&peer
->endpoint_lock
);
295 if (endpoint
->addr
.sa_family
== AF_INET
) {
296 peer
->endpoint
.addr4
= endpoint
->addr4
;
297 peer
->endpoint
.src4
= endpoint
->src4
;
298 peer
->endpoint
.src_if4
= endpoint
->src_if4
;
299 } else if (endpoint
->addr
.sa_family
== AF_INET6
) {
300 peer
->endpoint
.addr6
= endpoint
->addr6
;
301 peer
->endpoint
.src6
= endpoint
->src6
;
305 dst_cache_reset(&peer
->endpoint_cache
);
307 write_unlock_bh(&peer
->endpoint_lock
);
310 void wg_socket_set_peer_endpoint_from_skb(struct wg_peer
*peer
,
311 const struct sk_buff
*skb
)
313 struct endpoint endpoint
;
315 if (!wg_socket_endpoint_from_skb(&endpoint
, skb
))
316 wg_socket_set_peer_endpoint(peer
, &endpoint
);
319 void wg_socket_clear_peer_endpoint_src(struct wg_peer
*peer
)
321 write_lock_bh(&peer
->endpoint_lock
);
322 memset(&peer
->endpoint
.src6
, 0, sizeof(peer
->endpoint
.src6
));
323 dst_cache_reset(&peer
->endpoint_cache
);
324 write_unlock_bh(&peer
->endpoint_lock
);
327 static int wg_receive(struct sock
*sk
, struct sk_buff
*skb
)
329 struct wg_device
*wg
;
333 wg
= sk
->sk_user_data
;
336 skb_mark_not_on_list(skb
);
337 wg_packet_receive(wg
, skb
);
345 static void sock_free(struct sock
*sock
)
349 sk_clear_memalloc(sock
);
350 udp_tunnel_sock_release(sock
->sk_socket
);
353 static void set_sock_opts(struct socket
*sock
)
355 sock
->sk
->sk_allocation
= GFP_ATOMIC
;
356 sock
->sk
->sk_sndbuf
= INT_MAX
;
357 sk_set_memalloc(sock
->sk
);
360 int wg_socket_init(struct wg_device
*wg
, u16 port
)
363 struct udp_tunnel_sock_cfg cfg
= {
366 .encap_rcv
= wg_receive
368 struct socket
*new4
= NULL
, *new6
= NULL
;
369 struct udp_port_cfg port4
= {
371 .local_ip
.s_addr
= htonl(INADDR_ANY
),
372 .local_udp_port
= htons(port
),
373 .use_udp_checksums
= true
375 #if IS_ENABLED(CONFIG_IPV6)
377 struct udp_port_cfg port6
= {
379 .local_ip6
= IN6ADDR_ANY_INIT
,
380 .use_udp6_tx_checksums
= true,
381 .use_udp6_rx_checksums
= true,
386 #if IS_ENABLED(CONFIG_IPV6)
390 ret
= udp_sock_create(wg
->creating_net
, &port4
, &new4
);
392 pr_err("%s: Could not create IPv4 socket\n", wg
->dev
->name
);
396 setup_udp_tunnel_sock(wg
->creating_net
, new4
, &cfg
);
398 #if IS_ENABLED(CONFIG_IPV6)
399 if (ipv6_mod_enabled()) {
400 port6
.local_udp_port
= inet_sk(new4
->sk
)->inet_sport
;
401 ret
= udp_sock_create(wg
->creating_net
, &port6
, &new6
);
403 udp_tunnel_sock_release(new4
);
404 if (ret
== -EADDRINUSE
&& !port
&& retries
++ < 100)
406 pr_err("%s: Could not create IPv6 socket\n",
411 setup_udp_tunnel_sock(wg
->creating_net
, new6
, &cfg
);
415 wg_socket_reinit(wg
, new4
->sk
, new6
? new6
->sk
: NULL
);
419 void wg_socket_reinit(struct wg_device
*wg
, struct sock
*new4
,
422 struct sock
*old4
, *old6
;
424 mutex_lock(&wg
->socket_update_lock
);
425 old4
= rcu_dereference_protected(wg
->sock4
,
426 lockdep_is_held(&wg
->socket_update_lock
));
427 old6
= rcu_dereference_protected(wg
->sock6
,
428 lockdep_is_held(&wg
->socket_update_lock
));
429 rcu_assign_pointer(wg
->sock4
, new4
);
430 rcu_assign_pointer(wg
->sock6
, new6
);
432 wg
->incoming_port
= ntohs(inet_sk(new4
)->inet_sport
);
433 mutex_unlock(&wg
->socket_update_lock
);