1 // SPDX-License-Identifier: GPL-2.0
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
15 #include <linux/ipv6.h>
16 #include <linux/udp.h>
17 #include <net/ip_tunnels.h>
19 /* Must be called with bh disabled. */
20 static void update_rx_stats(struct wg_peer
*peer
, size_t len
)
22 struct pcpu_sw_netstats
*tstats
=
23 get_cpu_ptr(peer
->device
->dev
->tstats
);
25 u64_stats_update_begin(&tstats
->syncp
);
27 tstats
->rx_bytes
+= len
;
28 peer
->rx_bytes
+= len
;
29 u64_stats_update_end(&tstats
->syncp
);
33 #define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
35 static size_t validate_header_len(struct sk_buff
*skb
)
37 if (unlikely(skb
->len
< sizeof(struct message_header
)))
39 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_DATA
) &&
40 skb
->len
>= MESSAGE_MINIMUM_LENGTH
)
41 return sizeof(struct message_data
);
42 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
) &&
43 skb
->len
== sizeof(struct message_handshake_initiation
))
44 return sizeof(struct message_handshake_initiation
);
45 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
) &&
46 skb
->len
== sizeof(struct message_handshake_response
))
47 return sizeof(struct message_handshake_response
);
48 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
) &&
49 skb
->len
== sizeof(struct message_handshake_cookie
))
50 return sizeof(struct message_handshake_cookie
);
54 static int prepare_skb_header(struct sk_buff
*skb
, struct wg_device
*wg
)
56 size_t data_offset
, data_len
, header_len
;
59 if (unlikely(!wg_check_packet_protocol(skb
) ||
60 skb_transport_header(skb
) < skb
->head
||
61 (skb_transport_header(skb
) + sizeof(struct udphdr
)) >
62 skb_tail_pointer(skb
)))
63 return -EINVAL
; /* Bogus IP header */
65 data_offset
= (u8
*)udp
- skb
->data
;
66 if (unlikely(data_offset
> U16_MAX
||
67 data_offset
+ sizeof(struct udphdr
) > skb
->len
))
68 /* Packet has offset at impossible location or isn't big enough
72 data_len
= ntohs(udp
->len
);
73 if (unlikely(data_len
< sizeof(struct udphdr
) ||
74 data_len
> skb
->len
- data_offset
))
75 /* UDP packet is reporting too small of a size or lying about
79 data_len
-= sizeof(struct udphdr
);
80 data_offset
= (u8
*)udp
+ sizeof(struct udphdr
) - skb
->data
;
81 if (unlikely(!pskb_may_pull(skb
,
82 data_offset
+ sizeof(struct message_header
)) ||
83 pskb_trim(skb
, data_len
+ data_offset
) < 0))
85 skb_pull(skb
, data_offset
);
86 if (unlikely(skb
->len
!= data_len
))
87 /* Final len does not agree with calculated len */
89 header_len
= validate_header_len(skb
);
90 if (unlikely(!header_len
))
92 __skb_push(skb
, data_offset
);
93 if (unlikely(!pskb_may_pull(skb
, data_offset
+ header_len
)))
95 __skb_pull(skb
, data_offset
);
99 static void wg_receive_handshake_packet(struct wg_device
*wg
,
102 enum cookie_mac_state mac_state
;
103 struct wg_peer
*peer
= NULL
;
104 /* This is global, so that our load calculation applies to the whole
105 * system. We don't care about races with it at all.
107 static u64 last_under_load
;
108 bool packet_needs_cookie
;
111 if (SKB_TYPE_LE32(skb
) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
)) {
112 net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
114 wg_cookie_message_consume(
115 (struct message_handshake_cookie
*)skb
->data
, wg
);
119 under_load
= skb_queue_len(&wg
->incoming_handshakes
) >=
120 MAX_QUEUED_INCOMING_HANDSHAKES
/ 8;
122 last_under_load
= ktime_get_coarse_boottime_ns();
123 } else if (last_under_load
) {
124 under_load
= !wg_birthdate_has_expired(last_under_load
, 1);
128 mac_state
= wg_cookie_validate_packet(&wg
->cookie_checker
, skb
,
130 if ((under_load
&& mac_state
== VALID_MAC_WITH_COOKIE
) ||
131 (!under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
)) {
132 packet_needs_cookie
= false;
133 } else if (under_load
&& mac_state
== VALID_MAC_BUT_NO_COOKIE
) {
134 packet_needs_cookie
= true;
136 net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
141 switch (SKB_TYPE_LE32(skb
)) {
142 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
): {
143 struct message_handshake_initiation
*message
=
144 (struct message_handshake_initiation
*)skb
->data
;
146 if (packet_needs_cookie
) {
147 wg_packet_send_handshake_cookie(wg
, skb
,
148 message
->sender_index
);
151 peer
= wg_noise_handshake_consume_initiation(message
, wg
);
152 if (unlikely(!peer
)) {
153 net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
157 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
158 net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
159 wg
->dev
->name
, peer
->internal_id
,
160 &peer
->endpoint
.addr
);
161 wg_packet_send_handshake_response(peer
);
164 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
): {
165 struct message_handshake_response
*message
=
166 (struct message_handshake_response
*)skb
->data
;
168 if (packet_needs_cookie
) {
169 wg_packet_send_handshake_cookie(wg
, skb
,
170 message
->sender_index
);
173 peer
= wg_noise_handshake_consume_response(message
, wg
);
174 if (unlikely(!peer
)) {
175 net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
179 wg_socket_set_peer_endpoint_from_skb(peer
, skb
);
180 net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
181 wg
->dev
->name
, peer
->internal_id
,
182 &peer
->endpoint
.addr
);
183 if (wg_noise_handshake_begin_session(&peer
->handshake
,
185 wg_timers_session_derived(peer
);
186 wg_timers_handshake_complete(peer
);
187 /* Calling this function will either send any existing
188 * packets in the queue and not send a keepalive, which
189 * is the best case, Or, if there's nothing in the
190 * queue, it will send a keepalive, in order to give
191 * immediate confirmation of the session.
193 wg_packet_send_keepalive(peer
);
199 if (unlikely(!peer
)) {
200 WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
205 update_rx_stats(peer
, skb
->len
);
208 wg_timers_any_authenticated_packet_received(peer
);
209 wg_timers_any_authenticated_packet_traversal(peer
);
213 void wg_packet_handshake_receive_worker(struct work_struct
*work
)
215 struct wg_device
*wg
= container_of(work
, struct multicore_worker
,
219 while ((skb
= skb_dequeue(&wg
->incoming_handshakes
)) != NULL
) {
220 wg_receive_handshake_packet(wg
, skb
);
226 static void keep_key_fresh(struct wg_peer
*peer
)
228 struct noise_keypair
*keypair
;
231 if (peer
->sent_lastminute_handshake
)
235 keypair
= rcu_dereference_bh(peer
->keypairs
.current_keypair
);
236 if (likely(keypair
&& READ_ONCE(keypair
->sending
.is_valid
)) &&
237 keypair
->i_am_the_initiator
&&
238 unlikely(wg_birthdate_has_expired(keypair
->sending
.birthdate
,
239 REJECT_AFTER_TIME
- KEEPALIVE_TIMEOUT
- REKEY_TIMEOUT
)))
241 rcu_read_unlock_bh();
244 peer
->sent_lastminute_handshake
= true;
245 wg_packet_send_queued_handshake_initiation(peer
, false);
249 static bool decrypt_packet(struct sk_buff
*skb
, struct noise_symmetric_key
*key
)
251 struct scatterlist sg
[MAX_SKB_FRAGS
+ 8];
252 struct sk_buff
*trailer
;
259 if (unlikely(!READ_ONCE(key
->is_valid
) ||
260 wg_birthdate_has_expired(key
->birthdate
, REJECT_AFTER_TIME
) ||
261 key
->counter
.receive
.counter
>= REJECT_AFTER_MESSAGES
)) {
262 WRITE_ONCE(key
->is_valid
, false);
266 PACKET_CB(skb
)->nonce
=
267 le64_to_cpu(((struct message_data
*)skb
->data
)->counter
);
269 /* We ensure that the network header is part of the packet before we
270 * call skb_cow_data, so that there's no chance that data is removed
271 * from the skb, so that later we can extract the original endpoint.
273 offset
= skb
->data
- skb_network_header(skb
);
274 skb_push(skb
, offset
);
275 num_frags
= skb_cow_data(skb
, 0, &trailer
);
276 offset
+= sizeof(struct message_data
);
277 skb_pull(skb
, offset
);
278 if (unlikely(num_frags
< 0 || num_frags
> ARRAY_SIZE(sg
)))
281 sg_init_table(sg
, num_frags
);
282 if (skb_to_sgvec(skb
, sg
, 0, skb
->len
) <= 0)
285 if (!chacha20poly1305_decrypt_sg_inplace(sg
, skb
->len
, NULL
, 0,
286 PACKET_CB(skb
)->nonce
,
290 /* Another ugly situation of pushing and pulling the header so as to
291 * keep endpoint information intact.
293 skb_push(skb
, offset
);
294 if (pskb_trim(skb
, skb
->len
- noise_encrypted_len(0)))
296 skb_pull(skb
, offset
);
301 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
302 static bool counter_validate(union noise_counter
*counter
, u64 their_counter
)
304 unsigned long index
, index_current
, top
, i
;
307 spin_lock_bh(&counter
->receive
.lock
);
309 if (unlikely(counter
->receive
.counter
>= REJECT_AFTER_MESSAGES
+ 1 ||
310 their_counter
>= REJECT_AFTER_MESSAGES
))
315 if (unlikely((COUNTER_WINDOW_SIZE
+ their_counter
) <
316 counter
->receive
.counter
))
319 index
= their_counter
>> ilog2(BITS_PER_LONG
);
321 if (likely(their_counter
> counter
->receive
.counter
)) {
322 index_current
= counter
->receive
.counter
>> ilog2(BITS_PER_LONG
);
323 top
= min_t(unsigned long, index
- index_current
,
324 COUNTER_BITS_TOTAL
/ BITS_PER_LONG
);
325 for (i
= 1; i
<= top
; ++i
)
326 counter
->receive
.backtrack
[(i
+ index_current
) &
327 ((COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1)] = 0;
328 counter
->receive
.counter
= their_counter
;
331 index
&= (COUNTER_BITS_TOTAL
/ BITS_PER_LONG
) - 1;
332 ret
= !test_and_set_bit(their_counter
& (BITS_PER_LONG
- 1),
333 &counter
->receive
.backtrack
[index
]);
336 spin_unlock_bh(&counter
->receive
.lock
);
340 #include "selftest/counter.c"
342 static void wg_packet_consume_data_done(struct wg_peer
*peer
,
344 struct endpoint
*endpoint
)
346 struct net_device
*dev
= peer
->device
->dev
;
347 unsigned int len
, len_before_trim
;
348 struct wg_peer
*routed_peer
;
350 wg_socket_set_peer_endpoint(peer
, endpoint
);
352 if (unlikely(wg_noise_received_with_keypair(&peer
->keypairs
,
353 PACKET_CB(skb
)->keypair
))) {
354 wg_timers_handshake_complete(peer
);
355 wg_packet_send_staged_packets(peer
);
358 keep_key_fresh(peer
);
360 wg_timers_any_authenticated_packet_received(peer
);
361 wg_timers_any_authenticated_packet_traversal(peer
);
363 /* A packet with length 0 is a keepalive packet */
364 if (unlikely(!skb
->len
)) {
365 update_rx_stats(peer
, message_data_len(0));
366 net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
367 dev
->name
, peer
->internal_id
,
368 &peer
->endpoint
.addr
);
369 goto packet_processed
;
372 wg_timers_data_received(peer
);
374 if (unlikely(skb_network_header(skb
) < skb
->head
))
375 goto dishonest_packet_size
;
376 if (unlikely(!(pskb_network_may_pull(skb
, sizeof(struct iphdr
)) &&
377 (ip_hdr(skb
)->version
== 4 ||
378 (ip_hdr(skb
)->version
== 6 &&
379 pskb_network_may_pull(skb
, sizeof(struct ipv6hdr
)))))))
380 goto dishonest_packet_type
;
383 /* We've already verified the Poly1305 auth tag, which means this packet
384 * was not modified in transit. We can therefore tell the networking
385 * stack that all checksums of every layer of encapsulation have already
386 * been checked "by the hardware" and therefore is unnecessary to check
389 skb
->ip_summed
= CHECKSUM_UNNECESSARY
;
390 skb
->csum_level
= ~0; /* All levels */
391 skb
->protocol
= wg_examine_packet_protocol(skb
);
392 if (skb
->protocol
== htons(ETH_P_IP
)) {
393 len
= ntohs(ip_hdr(skb
)->tot_len
);
394 if (unlikely(len
< sizeof(struct iphdr
)))
395 goto dishonest_packet_size
;
396 if (INET_ECN_is_ce(PACKET_CB(skb
)->ds
))
397 IP_ECN_set_ce(ip_hdr(skb
));
398 } else if (skb
->protocol
== htons(ETH_P_IPV6
)) {
399 len
= ntohs(ipv6_hdr(skb
)->payload_len
) +
400 sizeof(struct ipv6hdr
);
401 if (INET_ECN_is_ce(PACKET_CB(skb
)->ds
))
402 IP6_ECN_set_ce(skb
, ipv6_hdr(skb
));
404 goto dishonest_packet_type
;
407 if (unlikely(len
> skb
->len
))
408 goto dishonest_packet_size
;
409 len_before_trim
= skb
->len
;
410 if (unlikely(pskb_trim(skb
, len
)))
411 goto packet_processed
;
413 routed_peer
= wg_allowedips_lookup_src(&peer
->device
->peer_allowedips
,
415 wg_peer_put(routed_peer
); /* We don't need the extra reference. */
417 if (unlikely(routed_peer
!= peer
))
418 goto dishonest_packet_peer
;
420 if (unlikely(napi_gro_receive(&peer
->napi
, skb
) == GRO_DROP
)) {
421 ++dev
->stats
.rx_dropped
;
422 net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n",
423 dev
->name
, peer
->internal_id
,
424 &peer
->endpoint
.addr
);
426 update_rx_stats(peer
, message_data_len(len_before_trim
));
430 dishonest_packet_peer
:
431 net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
432 dev
->name
, skb
, peer
->internal_id
,
433 &peer
->endpoint
.addr
);
434 ++dev
->stats
.rx_errors
;
435 ++dev
->stats
.rx_frame_errors
;
436 goto packet_processed
;
437 dishonest_packet_type
:
438 net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
439 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
440 ++dev
->stats
.rx_errors
;
441 ++dev
->stats
.rx_frame_errors
;
442 goto packet_processed
;
443 dishonest_packet_size
:
444 net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
445 dev
->name
, peer
->internal_id
, &peer
->endpoint
.addr
);
446 ++dev
->stats
.rx_errors
;
447 ++dev
->stats
.rx_length_errors
;
448 goto packet_processed
;
453 int wg_packet_rx_poll(struct napi_struct
*napi
, int budget
)
455 struct wg_peer
*peer
= container_of(napi
, struct wg_peer
, napi
);
456 struct crypt_queue
*queue
= &peer
->rx_queue
;
457 struct noise_keypair
*keypair
;
458 struct endpoint endpoint
;
459 enum packet_state state
;
464 if (unlikely(budget
<= 0))
467 while ((skb
= __ptr_ring_peek(&queue
->ring
)) != NULL
&&
468 (state
= atomic_read_acquire(&PACKET_CB(skb
)->state
)) !=
469 PACKET_STATE_UNCRYPTED
) {
470 __ptr_ring_discard_one(&queue
->ring
);
471 peer
= PACKET_PEER(skb
);
472 keypair
= PACKET_CB(skb
)->keypair
;
475 if (unlikely(state
!= PACKET_STATE_CRYPTED
))
478 if (unlikely(!counter_validate(&keypair
->receiving
.counter
,
479 PACKET_CB(skb
)->nonce
))) {
480 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
481 peer
->device
->dev
->name
,
482 PACKET_CB(skb
)->nonce
,
483 keypair
->receiving
.counter
.receive
.counter
);
487 if (unlikely(wg_socket_endpoint_from_skb(&endpoint
, skb
)))
490 wg_reset_packet(skb
);
491 wg_packet_consume_data_done(peer
, skb
, &endpoint
);
495 wg_noise_keypair_put(keypair
, false);
500 if (++work_done
>= budget
)
504 if (work_done
< budget
)
505 napi_complete_done(napi
, work_done
);
510 void wg_packet_decrypt_worker(struct work_struct
*work
)
512 struct crypt_queue
*queue
= container_of(work
, struct multicore_worker
,
516 while ((skb
= ptr_ring_consume_bh(&queue
->ring
)) != NULL
) {
517 enum packet_state state
= likely(decrypt_packet(skb
,
518 &PACKET_CB(skb
)->keypair
->receiving
)) ?
519 PACKET_STATE_CRYPTED
: PACKET_STATE_DEAD
;
520 wg_queue_enqueue_per_peer_napi(skb
, state
);
524 static void wg_packet_consume_data(struct wg_device
*wg
, struct sk_buff
*skb
)
526 __le32 idx
= ((struct message_data
*)skb
->data
)->key_idx
;
527 struct wg_peer
*peer
= NULL
;
531 PACKET_CB(skb
)->keypair
=
532 (struct noise_keypair
*)wg_index_hashtable_lookup(
533 wg
->index_hashtable
, INDEX_HASHTABLE_KEYPAIR
, idx
,
535 if (unlikely(!wg_noise_keypair_get(PACKET_CB(skb
)->keypair
)))
538 if (unlikely(READ_ONCE(peer
->is_dead
)))
541 ret
= wg_queue_enqueue_per_device_and_peer(&wg
->decrypt_queue
,
542 &peer
->rx_queue
, skb
,
544 &wg
->decrypt_queue
.last_cpu
);
545 if (unlikely(ret
== -EPIPE
))
546 wg_queue_enqueue_per_peer_napi(skb
, PACKET_STATE_DEAD
);
547 if (likely(!ret
|| ret
== -EPIPE
)) {
548 rcu_read_unlock_bh();
552 wg_noise_keypair_put(PACKET_CB(skb
)->keypair
, false);
554 rcu_read_unlock_bh();
559 void wg_packet_receive(struct wg_device
*wg
, struct sk_buff
*skb
)
561 if (unlikely(prepare_skb_header(skb
, wg
) < 0))
563 switch (SKB_TYPE_LE32(skb
)) {
564 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION
):
565 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE
):
566 case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE
): {
569 if (skb_queue_len(&wg
->incoming_handshakes
) >
570 MAX_QUEUED_INCOMING_HANDSHAKES
||
571 unlikely(!rng_is_initialized())) {
572 net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
576 skb_queue_tail(&wg
->incoming_handshakes
, skb
);
577 /* Queues up a call to packet_process_queued_handshake_
580 cpu
= wg_cpumask_next_online(&wg
->incoming_handshake_cpu
);
581 queue_work_on(cpu
, wg
->handshake_receive_wq
,
582 &per_cpu_ptr(wg
->incoming_handshakes_worker
, cpu
)->work
);
585 case cpu_to_le32(MESSAGE_DATA
):
586 PACKET_CB(skb
)->ds
= ip_tunnel_get_dsfield(ip_hdr(skb
), skb
);
587 wg_packet_consume_data(wg
, skb
);
590 WARN(1, "Non-exhaustive parsing of packet header lead to unknown packet type!\n");