]>
git.ipfire.org Git - thirdparty/linux.git/blob - drivers/net/wireguard/socket.c
1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
12 #include <linux/ctype.h>
13 #include <linux/net.h>
14 #include <linux/if_vlan.h>
15 #include <linux/if_ether.h>
16 #include <linux/inetdevice.h>
17 #include <net/udp_tunnel.h>
20 static int send4(struct wg_device
*wg
, struct sk_buff
*skb
,
21 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
24 .saddr
= endpoint
->src4
.s_addr
,
25 .daddr
= endpoint
->addr4
.sin_addr
.s_addr
,
26 .fl4_dport
= endpoint
->addr4
.sin_port
,
27 .flowi4_mark
= wg
->fwmark
,
28 .flowi4_proto
= IPPROTO_UDP
30 struct rtable
*rt
= NULL
;
34 skb_mark_not_on_list(skb
);
36 skb
->mark
= wg
->fwmark
;
39 sock
= rcu_dereference_bh(wg
->sock4
);
41 if (unlikely(!sock
)) {
46 fl
.fl4_sport
= inet_sk(sock
)->inet_sport
;
49 rt
= dst_cache_get_ip4(cache
, &fl
.saddr
);
52 security_sk_classify_flow(sock
, flowi4_to_flowi(&fl
));
53 if (unlikely(!inet_confirm_addr(sock_net(sock
), NULL
, 0,
54 fl
.saddr
, RT_SCOPE_HOST
))) {
55 endpoint
->src4
.s_addr
= 0;
56 *(__force __be32
*)&endpoint
->src_if4
= 0;
59 dst_cache_reset(cache
);
61 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
62 if (unlikely(endpoint
->src_if4
&& ((IS_ERR(rt
) &&
63 PTR_ERR(rt
) == -EINVAL
) || (!IS_ERR(rt
) &&
64 rt
->dst
.dev
->ifindex
!= endpoint
->src_if4
)))) {
65 endpoint
->src4
.s_addr
= 0;
66 *(__force __be32
*)&endpoint
->src_if4
= 0;
69 dst_cache_reset(cache
);
72 rt
= ip_route_output_flow(sock_net(sock
), &fl
, sock
);
74 if (unlikely(IS_ERR(rt
))) {
76 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
77 wg
->dev
->name
, &endpoint
->addr
, ret
);
81 dst_cache_set_ip4(cache
, &rt
->dst
, fl
.saddr
);
85 udp_tunnel_xmit_skb(rt
, sock
, skb
, fl
.saddr
, fl
.daddr
, ds
,
86 ip4_dst_hoplimit(&rt
->dst
), 0, fl
.fl4_sport
,
87 fl
.fl4_dport
, false, false);
97 static int send6(struct wg_device
*wg
, struct sk_buff
*skb
,
98 struct endpoint
*endpoint
, u8 ds
, struct dst_cache
*cache
)
100 #if IS_ENABLED(CONFIG_IPV6)
102 .saddr
= endpoint
->src6
,
103 .daddr
= endpoint
->addr6
.sin6_addr
,
104 .fl6_dport
= endpoint
->addr6
.sin6_port
,
105 .flowi6_mark
= wg
->fwmark
,
106 .flowi6_oif
= endpoint
->addr6
.sin6_scope_id
,
107 .flowi6_proto
= IPPROTO_UDP
108 /* TODO: addr->sin6_flowinfo */
110 struct dst_entry
*dst
= NULL
;
114 skb_mark_not_on_list(skb
);
116 skb
->mark
= wg
->fwmark
;
119 sock
= rcu_dereference_bh(wg
->sock6
);
121 if (unlikely(!sock
)) {
126 fl
.fl6_sport
= inet_sk(sock
)->inet_sport
;
129 dst
= dst_cache_get_ip6(cache
, &fl
.saddr
);
132 security_sk_classify_flow(sock
, flowi6_to_flowi(&fl
));
133 if (unlikely(!ipv6_addr_any(&fl
.saddr
) &&
134 !ipv6_chk_addr(sock_net(sock
), &fl
.saddr
, NULL
, 0))) {
135 endpoint
->src6
= fl
.saddr
= in6addr_any
;
137 dst_cache_reset(cache
);
139 dst
= ipv6_stub
->ipv6_dst_lookup_flow(sock_net(sock
), sock
, &fl
,
141 if (unlikely(IS_ERR(dst
))) {
143 net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
144 wg
->dev
->name
, &endpoint
->addr
, ret
);
148 dst_cache_set_ip6(cache
, dst
, &fl
.saddr
);
152 udp_tunnel6_xmit_skb(dst
, sock
, skb
, skb
->dev
, &fl
.saddr
, &fl
.daddr
, ds
,
153 ip6_dst_hoplimit(dst
), 0, fl
.fl6_sport
,
154 fl
.fl6_dport
, false);
160 rcu_read_unlock_bh();
163 return -EAFNOSUPPORT
;
167 int wg_socket_send_skb_to_peer(struct wg_peer
*peer
, struct sk_buff
*skb
, u8 ds
)
169 size_t skb_len
= skb
->len
;
170 int ret
= -EAFNOSUPPORT
;
172 read_lock_bh(&peer
->endpoint_lock
);
173 if (peer
->endpoint
.addr
.sa_family
== AF_INET
)
174 ret
= send4(peer
->device
, skb
, &peer
->endpoint
, ds
,
175 &peer
->endpoint_cache
);
176 else if (peer
->endpoint
.addr
.sa_family
== AF_INET6
)
177 ret
= send6(peer
->device
, skb
, &peer
->endpoint
, ds
,
178 &peer
->endpoint_cache
);
182 peer
->tx_bytes
+= skb_len
;
183 read_unlock_bh(&peer
->endpoint_lock
);
188 int wg_socket_send_buffer_to_peer(struct wg_peer
*peer
, void *buffer
,
191 struct sk_buff
*skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
196 skb_reserve(skb
, SKB_HEADER_LEN
);
197 skb_set_inner_network_header(skb
, 0);
198 skb_put_data(skb
, buffer
, len
);
199 return wg_socket_send_skb_to_peer(peer
, skb
, ds
);
202 int wg_socket_send_buffer_as_reply_to_skb(struct wg_device
*wg
,
203 struct sk_buff
*in_skb
, void *buffer
,
208 struct endpoint endpoint
;
210 if (unlikely(!in_skb
))
212 ret
= wg_socket_endpoint_from_skb(&endpoint
, in_skb
);
213 if (unlikely(ret
< 0))
216 skb
= alloc_skb(len
+ SKB_HEADER_LEN
, GFP_ATOMIC
);
219 skb_reserve(skb
, SKB_HEADER_LEN
);
220 skb_set_inner_network_header(skb
, 0);
221 skb_put_data(skb
, buffer
, len
);
223 if (endpoint
.addr
.sa_family
== AF_INET
)
224 ret
= send4(wg
, skb
, &endpoint
, 0, NULL
);
225 else if (endpoint
.addr
.sa_family
== AF_INET6
)
226 ret
= send6(wg
, skb
, &endpoint
, 0, NULL
);
227 /* No other possibilities if the endpoint is valid, which it is,
228 * as we checked above.
234 int wg_socket_endpoint_from_skb(struct endpoint
*endpoint
,
235 const struct sk_buff
*skb
)
237 memset(endpoint
, 0, sizeof(*endpoint
));
238 if (skb
->protocol
== htons(ETH_P_IP
)) {
239 endpoint
->addr4
.sin_family
= AF_INET
;
240 endpoint
->addr4
.sin_port
= udp_hdr(skb
)->source
;
241 endpoint
->addr4
.sin_addr
.s_addr
= ip_hdr(skb
)->saddr
;
242 endpoint
->src4
.s_addr
= ip_hdr(skb
)->daddr
;
243 endpoint
->src_if4
= skb
->skb_iif
;
244 } else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
245 endpoint
->addr6
.sin6_family
= AF_INET6
;
246 endpoint
->addr6
.sin6_port
= udp_hdr(skb
)->source
;
247 endpoint
->addr6
.sin6_addr
= ipv6_hdr(skb
)->saddr
;
248 endpoint
->addr6
.sin6_scope_id
= ipv6_iface_scope_id(
249 &ipv6_hdr(skb
)->saddr
, skb
->skb_iif
);
250 endpoint
->src6
= ipv6_hdr(skb
)->daddr
;
257 static bool endpoint_eq(const struct endpoint
*a
, const struct endpoint
*b
)
259 return (a
->addr
.sa_family
== AF_INET
&& b
->addr
.sa_family
== AF_INET
&&
260 a
->addr4
.sin_port
== b
->addr4
.sin_port
&&
261 a
->addr4
.sin_addr
.s_addr
== b
->addr4
.sin_addr
.s_addr
&&
262 a
->src4
.s_addr
== b
->src4
.s_addr
&& a
->src_if4
== b
->src_if4
) ||
263 (a
->addr
.sa_family
== AF_INET6
&&
264 b
->addr
.sa_family
== AF_INET6
&&
265 a
->addr6
.sin6_port
== b
->addr6
.sin6_port
&&
266 ipv6_addr_equal(&a
->addr6
.sin6_addr
, &b
->addr6
.sin6_addr
) &&
267 a
->addr6
.sin6_scope_id
== b
->addr6
.sin6_scope_id
&&
268 ipv6_addr_equal(&a
->src6
, &b
->src6
)) ||
269 unlikely(!a
->addr
.sa_family
&& !b
->addr
.sa_family
);
272 void wg_socket_set_peer_endpoint(struct wg_peer
*peer
,
273 const struct endpoint
*endpoint
)
275 /* First we check unlocked, in order to optimize, since it's pretty rare
276 * that an endpoint will change. If we happen to be mid-write, and two
277 * CPUs wind up writing the same thing or something slightly different,
278 * it doesn't really matter much either.
280 if (endpoint_eq(endpoint
, &peer
->endpoint
))
282 write_lock_bh(&peer
->endpoint_lock
);
283 if (endpoint
->addr
.sa_family
== AF_INET
) {
284 peer
->endpoint
.addr4
= endpoint
->addr4
;
285 peer
->endpoint
.src4
= endpoint
->src4
;
286 peer
->endpoint
.src_if4
= endpoint
->src_if4
;
287 } else if (endpoint
->addr
.sa_family
== AF_INET6
) {
288 peer
->endpoint
.addr6
= endpoint
->addr6
;
289 peer
->endpoint
.src6
= endpoint
->src6
;
293 dst_cache_reset(&peer
->endpoint_cache
);
295 write_unlock_bh(&peer
->endpoint_lock
);
298 void wg_socket_set_peer_endpoint_from_skb(struct wg_peer
*peer
,
299 const struct sk_buff
*skb
)
301 struct endpoint endpoint
;
303 if (!wg_socket_endpoint_from_skb(&endpoint
, skb
))
304 wg_socket_set_peer_endpoint(peer
, &endpoint
);
307 void wg_socket_clear_peer_endpoint_src(struct wg_peer
*peer
)
309 write_lock_bh(&peer
->endpoint_lock
);
310 memset(&peer
->endpoint
.src6
, 0, sizeof(peer
->endpoint
.src6
));
311 dst_cache_reset(&peer
->endpoint_cache
);
312 write_unlock_bh(&peer
->endpoint_lock
);
315 static int wg_receive(struct sock
*sk
, struct sk_buff
*skb
)
317 struct wg_device
*wg
;
321 wg
= sk
->sk_user_data
;
324 skb_mark_not_on_list(skb
);
325 wg_packet_receive(wg
, skb
);
333 static void sock_free(struct sock
*sock
)
337 sk_clear_memalloc(sock
);
338 udp_tunnel_sock_release(sock
->sk_socket
);
341 static void set_sock_opts(struct socket
*sock
)
343 sock
->sk
->sk_allocation
= GFP_ATOMIC
;
344 sock
->sk
->sk_sndbuf
= INT_MAX
;
345 sk_set_memalloc(sock
->sk
);
348 int wg_socket_init(struct wg_device
*wg
, u16 port
)
351 struct udp_tunnel_sock_cfg cfg
= {
354 .encap_rcv
= wg_receive
356 struct socket
*new4
= NULL
, *new6
= NULL
;
357 struct udp_port_cfg port4
= {
359 .local_ip
.s_addr
= htonl(INADDR_ANY
),
360 .local_udp_port
= htons(port
),
361 .use_udp_checksums
= true
363 #if IS_ENABLED(CONFIG_IPV6)
365 struct udp_port_cfg port6
= {
367 .local_ip6
= IN6ADDR_ANY_INIT
,
368 .use_udp6_tx_checksums
= true,
369 .use_udp6_rx_checksums
= true,
374 #if IS_ENABLED(CONFIG_IPV6)
378 ret
= udp_sock_create(wg
->creating_net
, &port4
, &new4
);
380 pr_err("%s: Could not create IPv4 socket\n", wg
->dev
->name
);
384 setup_udp_tunnel_sock(wg
->creating_net
, new4
, &cfg
);
386 #if IS_ENABLED(CONFIG_IPV6)
387 if (ipv6_mod_enabled()) {
388 port6
.local_udp_port
= inet_sk(new4
->sk
)->inet_sport
;
389 ret
= udp_sock_create(wg
->creating_net
, &port6
, &new6
);
391 udp_tunnel_sock_release(new4
);
392 if (ret
== -EADDRINUSE
&& !port
&& retries
++ < 100)
394 pr_err("%s: Could not create IPv6 socket\n",
399 setup_udp_tunnel_sock(wg
->creating_net
, new6
, &cfg
);
403 wg_socket_reinit(wg
, new4
->sk
, new6
? new6
->sk
: NULL
);
407 void wg_socket_reinit(struct wg_device
*wg
, struct sock
*new4
,
410 struct sock
*old4
, *old6
;
412 mutex_lock(&wg
->socket_update_lock
);
413 old4
= rcu_dereference_protected(wg
->sock4
,
414 lockdep_is_held(&wg
->socket_update_lock
));
415 old6
= rcu_dereference_protected(wg
->sock6
,
416 lockdep_is_held(&wg
->socket_update_lock
));
417 rcu_assign_pointer(wg
->sock4
, new4
);
418 rcu_assign_pointer(wg
->sock6
, new6
);
420 wg
->incoming_port
= ntohs(inet_sk(new4
)->inet_sport
);
421 mutex_unlock(&wg
->socket_update_lock
);