"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
queue struct {
encryption *outboundQueue
decryption *inboundQueue
- handshake chan QueueHandshakeElement
- }
-
- signals struct {
- stop chan struct{}
+ handshake *handshakeQueue
}
tun struct {
}
ipcMutex sync.RWMutex
+ closed chan struct{}
}
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
return q
}
+// A handshakeQueue is similar to an outboundQueue; see those docs.
+type handshakeQueue struct {
+ c chan QueueHandshakeElement
+ wg sync.WaitGroup
+}
+
+func newHandshakeQueue() *handshakeQueue {
+ q := &handshakeQueue{
+ c: make(chan QueueHandshakeElement, QueueHandshakeSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
// check if currently under load
now := time.Now()
- underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+ underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
return true
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device)
+ device.closed = make(chan struct{})
device.log = logger
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
// create queues
- device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
+ device.queue.handshake = newHandshakeQueue()
device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue()
- // prepare signals
-
- device.signals.stop = make(chan struct{})
-
// prepare net
device.net.port = 0
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
-func (device *Device) FlushPacketQueues() {
- for {
- select {
- case elem := <-device.queue.handshake:
- device.PutMessageBuffer(elem.buffer)
- default:
- return
- }
- }
-
-}
-
func (device *Device) Close() {
if device.isClosed.Swap(true) {
return
// No new peers are coming; we are done with these queues.
device.queue.encryption.wg.Done()
device.queue.decryption.wg.Done()
- close(device.signals.stop)
+ device.queue.handshake.wg.Done()
device.state.stopping.Wait()
device.RemoveAllPeers()
- device.FlushPacketQueues()
-
device.rate.limiter.Close()
device.state.changing.Set(false)
device.log.Verbosef("Interface closed")
+ close(device.closed)
}
func (device *Device) Wait() chan struct{} {
- return device.signals.stop
+ return device.closed
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.net.stopping.Add(2)
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
elem.endpoint = nil
}
-func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool {
- select {
- case queue <- elem:
- return true
- default:
- return false
- }
-}
-
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
device.queue.decryption.wg.Done()
+ device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
}
if okay {
- if (device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- endpoint: endpoint,
- },
- )) {
+ select {
+ case device.queue.handshake.c <- QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ endpoint: endpoint,
+ }:
buffer = device.GetMessageBuffer()
+ default:
}
}
}
/* Handles incoming packets related to handshake
*/
func (device *Device) RoutineHandshake() {
- var elem QueueHandshakeElement
- var ok bool
-
defer func() {
device.log.Verbosef("Routine: handshake worker - stopped")
device.state.stopping.Done()
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- }
}()
-
device.log.Verbosef("Routine: handshake worker - started")
- for {
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- elem.buffer = nil
- }
-
- select {
- case elem, ok = <-device.queue.handshake:
- case <-device.signals.stop:
- return
- }
-
- if !ok {
- return
- }
+ for elem := range device.queue.handshake.c {
// handle cookie fields and ratelimiting
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
device.log.Verbosef("Failed to decode cookie reply")
- return
+ goto skip
}
// lookup peer from index
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
- continue
+ goto skip
}
// consume reply
}
}
- continue
+ goto skip
case MessageInitiationType, MessageResponseType:
if !device.cookieChecker.CheckMAC1(elem.packet) {
device.log.Verbosef("Received packet with invalid mac1")
- continue
+ goto skip
}
// endpoints destination address is the source of the datagram
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
- continue
+ goto skip
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
- continue
+ goto skip
}
}
default:
device.log.Errorf("Invalid packet ended up in the handshake queue")
- continue
+ goto skip
}
// handle handshake initiation/response content
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode initiation message")
- continue
+ goto skip
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
- continue
+ goto skip
}
// update timers
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
device.log.Errorf("Failed to decode response message")
- continue
+ goto skip
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
- continue
+ goto skip
}
// update endpoint
if err != nil {
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
- continue
+ goto skip
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
}
+ skip:
+ device.PutMessageBuffer(elem.buffer)
}
}