]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Improved readability of send/receive code
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 9 Sep 2017 13:03:01 +0000 (15:03 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 9 Sep 2017 13:03:01 +0000 (15:03 +0200)
src/receive.go
src/send.go

index 97646d88712594148f34c5234ffcef92c1a0b3ba..09fca77ec5c4351298c23885b5d8c6006b7b0df6 100644 (file)
@@ -128,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() {
 
                                // read next datagram
 
-                               size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes
+                               size, raddr, err := conn.ReadFromUDP(buffer[:])
 
                                if err != nil {
                                        break
@@ -222,7 +222,7 @@ func (device *Device) RoutineReceiveIncomming() {
 }
 
 func (device *Device) RoutineDecryption() {
-       var elem *QueueInboundElement
+
        var nonce [chacha20poly1305.NonceSize]byte
 
        logDebug := device.log.Debug
@@ -230,50 +230,51 @@ func (device *Device) RoutineDecryption() {
 
        for {
                select {
-               case elem = <-device.queue.decryption:
                case <-device.signal.stop:
+                       logDebug.Println("Routine, decryption worker, stopped")
                        return
-               }
 
-               // check if dropped
+               case elem := <-device.queue.decryption:
 
-               if elem.IsDropped() {
-                       continue
-               }
+                       // check if dropped
 
-               // split message into fields
-
-               counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
-               content := elem.packet[MessageTransportOffsetContent:]
-
-               // decrypt with key-pair
-
-               var err error
-               copy(nonce[4:], counter)
-               elem.counter = binary.LittleEndian.Uint64(counter)
-               elem.keyPair.receive.mutex.RLock()
-               if elem.keyPair.receive.aead == nil {
-                       // very unlikely (the key was deleted during queuing)
-                       elem.Drop()
-               } else {
-                       elem.packet, err = elem.keyPair.receive.aead.Open(
-                               elem.buffer[:0],
-                               nonce[:],
-                               content,
-                               nil,
-                       )
-                       if err != nil {
+                       if elem.IsDropped() {
+                               continue
+                       }
+
+                       // split message into fields
+
+                       counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+                       content := elem.packet[MessageTransportOffsetContent:]
+
+                       // decrypt with key-pair
+
+                       copy(nonce[4:], counter)
+                       elem.counter = binary.LittleEndian.Uint64(counter)
+                       elem.keyPair.receive.mutex.RLock()
+                       if elem.keyPair.receive.aead == nil {
+                               // very unlikely (the key was deleted during queuing)
                                elem.Drop()
+                       } else {
+                               var err error
+                               elem.packet, err = elem.keyPair.receive.aead.Open(
+                                       elem.buffer[:0],
+                                       nonce[:],
+                                       content,
+                                       nil,
+                               )
+                               if err != nil {
+                                       elem.Drop()
+                               }
                        }
+
+                       elem.keyPair.receive.mutex.RUnlock()
+                       elem.mutex.Unlock()
                }
-               elem.keyPair.receive.mutex.RUnlock()
-               elem.mutex.Unlock()
        }
 }
 
 /* Handles incomming packets related to handshake
- *
- *
  */
 func (device *Device) RoutineHandshake() {
 
@@ -473,7 +474,6 @@ func (device *Device) RoutineHandshake() {
 }
 
 func (peer *Peer) RoutineSequentialReceiver() {
-       var elem *QueueInboundElement
 
        device := peer.device
 
@@ -483,118 +483,119 @@ func (peer *Peer) RoutineSequentialReceiver() {
        logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
 
        for {
-               // wait for decryption
 
                select {
                case <-peer.signal.stop:
+                       logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
                        return
-               case elem = <-peer.queue.inbound:
-               }
-               elem.mutex.Lock()
 
-               // process packet
+               case elem := <-peer.queue.inbound:
 
-               if elem.IsDropped() {
-                       continue
-               }
+                       // wait for decryption
+
+                       elem.mutex.Lock()
+                       if elem.IsDropped() {
+                               continue
+                       }
 
-               // check for replay
+                       // check for replay
 
-               if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
-                       continue
-               }
+                       if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+                               continue
+                       }
 
-               peer.TimerAnyAuthenticatedPacketTraversal()
-               peer.TimerAnyAuthenticatedPacketReceived()
-               peer.KeepKeyFreshReceiving()
+                       peer.TimerAnyAuthenticatedPacketTraversal()
+                       peer.TimerAnyAuthenticatedPacketReceived()
+                       peer.KeepKeyFreshReceiving()
 
-               // check if using new key-pair
+                       // check if using new key-pair
 
-               kp := &peer.keyPairs
-               kp.mutex.Lock()
-               if kp.next == elem.keyPair {
-                       peer.TimerHandshakeComplete()
-                       if kp.previous != nil {
-                               device.DeleteKeyPair(kp.previous)
+                       kp := &peer.keyPairs
+                       kp.mutex.Lock()
+                       if kp.next == elem.keyPair {
+                               peer.TimerHandshakeComplete()
+                               if kp.previous != nil {
+                                       device.DeleteKeyPair(kp.previous)
+                               }
+                               kp.previous = kp.current
+                               kp.current = kp.next
+                               kp.next = nil
                        }
-                       kp.previous = kp.current
-                       kp.current = kp.next
-                       kp.next = nil
-               }
-               kp.mutex.Unlock()
+                       kp.mutex.Unlock()
 
-               // check for keep-alive
+                       // check for keep-alive
 
-               if len(elem.packet) == 0 {
-                       logDebug.Println("Received keep-alive from", peer.String())
-                       continue
-               }
-               peer.TimerDataReceived()
+                       if len(elem.packet) == 0 {
+                               logDebug.Println("Received keep-alive from", peer.String())
+                               continue
+                       }
+                       peer.TimerDataReceived()
 
-               // verify source and strip padding
+                       // verify source and strip padding
 
-               switch elem.packet[0] >> 4 {
-               case ipv4.Version:
+                       switch elem.packet[0] >> 4 {
+                       case ipv4.Version:
 
-                       // strip padding
+                               // strip padding
 
-                       if len(elem.packet) < ipv4.HeaderLen {
-                               continue
-                       }
+                               if len(elem.packet) < ipv4.HeaderLen {
+                                       continue
+                               }
 
-                       field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
-                       length := binary.BigEndian.Uint16(field)
-                       if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
-                               continue
-                       }
+                               field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+                               length := binary.BigEndian.Uint16(field)
+                               if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+                                       continue
+                               }
 
-                       elem.packet = elem.packet[:length]
+                               elem.packet = elem.packet[:length]
 
-                       // verify IPv4 source
+                               // verify IPv4 source
 
-                       src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                       if device.routingTable.LookupIPv4(src) != peer {
-                               logInfo.Println("Packet with unallowed source IP from", peer.String())
-                               continue
-                       }
+                               src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+                               if device.routingTable.LookupIPv4(src) != peer {
+                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
+                                       continue
+                               }
 
-               case ipv6.Version:
+                       case ipv6.Version:
 
-                       // strip padding
+                               // strip padding
 
-                       if len(elem.packet) < ipv6.HeaderLen {
-                               continue
-                       }
+                               if len(elem.packet) < ipv6.HeaderLen {
+                                       continue
+                               }
 
-                       field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
-                       length := binary.BigEndian.Uint16(field)
-                       length += ipv6.HeaderLen
-                       if int(length) > len(elem.packet) {
-                               continue
-                       }
+                               field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+                               length := binary.BigEndian.Uint16(field)
+                               length += ipv6.HeaderLen
+                               if int(length) > len(elem.packet) {
+                                       continue
+                               }
 
-                       elem.packet = elem.packet[:length]
+                               elem.packet = elem.packet[:length]
 
-                       // verify IPv6 source
+                               // verify IPv6 source
+
+                               src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+                               if device.routingTable.LookupIPv6(src) != peer {
+                                       logInfo.Println("Packet with unallowed source IP from", peer.String())
+                                       continue
+                               }
 
-                       src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                       if device.routingTable.LookupIPv6(src) != peer {
-                               logInfo.Println("Packet with unallowed source IP from", peer.String())
+                       default:
+                               logInfo.Println("Packet with invalid IP version from", peer.String())
                                continue
                        }
 
-               default:
-                       logInfo.Println("Packet with invalid IP version from", peer.String())
-                       continue
-               }
-
-               // write to tun
+                       // write to tun
 
-               atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
-               _, err := device.tun.device.Write(elem.packet)
-               device.PutMessageBuffer(elem.buffer)
-               if err != nil {
-                       logError.Println("Failed to write packet to TUN device:", err)
+                       atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+                       _, err := device.tun.device.Write(elem.packet)
+                       device.PutMessageBuffer(elem.buffer)
+                       if err != nil {
+                               logError.Println("Failed to write packet to TUN device:", err)
+                       }
                }
        }
 }
index c598ad417a8d119c87961a172112391fea75e034..e9dfb54f9efcf515669aebc18105a5cff72c8632 100644 (file)
@@ -35,7 +35,7 @@ type QueueOutboundElement struct {
        dropped int32
        mutex   sync.Mutex
        buffer  *[MaxMessageSize]byte // slice holding the packet data
-       packet  []byte                // slice of "data" (always!)
+       packet  []byte                // slice of "buffer" (always!)
        nonce   uint64                // nonce for encryption
        keyPair *KeyPair              // key-pair for encryption
        peer    *Peer                 // related peer
@@ -52,11 +52,6 @@ func (peer *Peer) FlushNonceQueue() {
        }
 }
 
-var (
-       ErrorNoEndpoint   = errors.New("No known endpoint for peer")
-       ErrorNoConnection = errors.New("No UDP socket for device")
-)
-
 func (device *Device) NewOutboundElement() *QueueOutboundElement {
        return &QueueOutboundElement{
                dropped: AtomicFalse,
@@ -118,14 +113,13 @@ func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
        defer peer.mutex.RUnlock()
 
        endpoint := peer.endpoint
-       conn := peer.device.net.conn
-
        if endpoint == nil {
-               return 0, ErrorNoEndpoint
+               return 0, errors.New("No known endpoint for peer")
        }
 
+       conn := peer.device.net.conn
        if conn == nil {
-               return 0, ErrorNoConnection
+               return 0, errors.New("No UDP socket for device")
        }
 
        return conn.WriteToUDP(buffer, endpoint)
@@ -189,16 +183,6 @@ func (device *Device) RoutineReadFromTUN() {
                        continue
                }
 
-               // check if known endpoint (drop early)
-
-               peer.mutex.RLock()
-               if peer.endpoint == nil {
-                       peer.mutex.RUnlock()
-                       logDebug.Println("No known endpoint for peer", peer.String())
-                       continue
-               }
-               peer.mutex.RUnlock()
-
                // insert into nonce/pre-handshake queue
 
                signalSend(peer.signal.handshakeReset)
@@ -211,86 +195,61 @@ func (device *Device) RoutineReadFromTUN() {
  * Then assigns nonces to packets sequentially
  * and creates "work" structs for workers
  *
- * TODO: Avoid dynamic allocation of work queue elements
- *
  * Obs. A single instance per peer
  */
 func (peer *Peer) RoutineNonce() {
        var keyPair *KeyPair
-       var elem *QueueOutboundElement
 
        device := peer.device
        logDebug := device.log.Debug
        logDebug.Println("Routine, nonce worker, started for peer", peer.String())
 
-       func() {
-
-               for {
-               NextPacket:
-
-                       // wait for packet
+       for {
+       NextPacket:
+               select {
+               case <-peer.signal.stop:
+                       return
 
-                       if elem == nil {
-                               select {
-                               case elem = <-peer.queue.nonce:
-                               case <-peer.signal.stop:
-                                       return
-                               }
-                       }
+               case elem := <-peer.queue.nonce:
 
                        // wait for key pair
 
                        for {
-                               select {
-                               case <-peer.signal.newKeyPair:
-                               default:
-                               }
-
                                keyPair = peer.keyPairs.Current()
                                if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
                                        if time.Now().Sub(keyPair.created) < RejectAfterTime {
                                                break
                                        }
                                }
+
                                signalSend(peer.signal.handshakeBegin)
                                logDebug.Println("Awaiting key-pair for", peer.String())
 
                                select {
                                case <-peer.signal.newKeyPair:
-                                       logDebug.Println("Key-pair negotiated for", peer.String())
-                                       goto NextPacket
-
                                case <-peer.signal.flushNonceQueue:
                                        logDebug.Println("Clearing queue for", peer.String())
                                        peer.FlushNonceQueue()
-                                       elem = nil
                                        goto NextPacket
-
                                case <-peer.signal.stop:
                                        return
                                }
                        }
 
-                       // process current packet
+                       // populate work element
 
-                       if elem != nil {
-
-                               // create work element
-
-                               elem.keyPair = keyPair
-                               elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
-                               elem.dropped = AtomicFalse
-                               elem.peer = peer
-                               elem.mutex.Lock()
+                       elem.peer = peer
+                       elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+                       elem.keyPair = keyPair
+                       elem.dropped = AtomicFalse
+                       elem.mutex.Lock()
 
-                               // add to parallel and sequential queue
+                       // add to parallel and sequential queue
 
-                               addToEncryptionQueue(device.queue.encryption, elem)
-                               addToOutboundQueue(peer.queue.outbound, elem)
-                               elem = nil
-                       }
+                       addToEncryptionQueue(device.queue.encryption, elem)
+                       addToOutboundQueue(peer.queue.outbound, elem)
                }
-       }()
+       }
 }
 
 /* Encrypts the elements in the queue
@@ -300,7 +259,6 @@ func (peer *Peer) RoutineNonce() {
  */
 func (device *Device) RoutineEncryption() {
 
-       var elem *QueueOutboundElement
        var nonce [chacha20poly1305.NonceSize]byte
 
        logDebug := device.log.Debug
@@ -311,62 +269,62 @@ func (device *Device) RoutineEncryption() {
                // fetch next element
 
                select {
-               case elem = <-device.queue.encryption:
                case <-device.signal.stop:
                        logDebug.Println("Routine, encryption worker, stopped")
                        return
-               }
 
-               // check if dropped
+               case elem := <-device.queue.encryption:
 
-               if elem.IsDropped() {
-                       continue
-               }
+                       // check if dropped
+
+                       if elem.IsDropped() {
+                               continue
+                       }
 
-               // populate header fields
+                       // populate header fields
 
-               header := elem.buffer[:MessageTransportHeaderSize]
+                       header := elem.buffer[:MessageTransportHeaderSize]
 
-               fieldType := header[0:4]
-               fieldReceiver := header[4:8]
-               fieldNonce := header[8:16]
+                       fieldType := header[0:4]
+                       fieldReceiver := header[4:8]
+                       fieldNonce := header[8:16]
 
-               binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
-               binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
-               binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+                       binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+                       binary.LittleEndian.PutUint32(fieldReceiver, elem.keyPair.remoteIndex)
+                       binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
 
-               // pad content to MTU size
+                       // pad content to multiple of 16
 
-               mtu := int(atomic.LoadInt32(&device.tun.mtu))
-               pad := len(elem.packet) % PaddingMultiple
-               if pad > 0 {
-                       for i := 0; i < PaddingMultiple-pad && len(elem.packet) < mtu; i++ {
-                               elem.packet = append(elem.packet, 0)
+                       mtu := int(atomic.LoadInt32(&device.tun.mtu))
+                       rem := len(elem.packet) % PaddingMultiple
+                       if rem > 0 {
+                               for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ {
+                                       elem.packet = append(elem.packet, 0)
+                               }
                        }
-                       // TODO: How good is this code
-               }
 
-               // encrypt content (append to header)
-
-               binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
-               elem.keyPair.send.mutex.RLock()
-               if elem.keyPair.send.aead == nil {
-                       // very unlikely (the key was deleted during queuing)
-                       elem.Drop()
-               } else {
-                       elem.packet = elem.keyPair.send.aead.Seal(
-                               header,
-                               nonce[:],
-                               elem.packet,
-                               nil,
-                       )
-               }
-               elem.keyPair.send.mutex.RUnlock()
-               elem.mutex.Unlock()
+                       // encrypt content (append to header)
+
+                       binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+                       elem.keyPair.send.mutex.RLock()
+                       if elem.keyPair.send.aead == nil {
+                               // very unlikely (the key was deleted during queuing)
+                               elem.Drop()
+                       } else {
+                               elem.packet = elem.keyPair.send.aead.Seal(
+                                       header,
+                                       nonce[:],
+                                       elem.packet,
+                                       nil,
+                               )
+                       }
+                       elem.mutex.Unlock()
+                       elem.keyPair.send.mutex.RUnlock()
 
-               // refresh key if necessary
+                       // refresh key if necessary
 
-               elem.peer.KeepKeyFreshSending()
+                       elem.peer.KeepKeyFreshSending()
+               }
        }
 }
 
@@ -399,6 +357,7 @@ func (peer *Peer) RoutineSequentialSender() {
                        _, err := peer.SendBuffer(elem.packet)
                        device.PutMessageBuffer(elem.buffer)
                        if err != nil {
+                               logDebug.Println("Failed to send authenticated packet to peer", peer.String())
                                continue
                        }
                        atomic.AddUint64(&peer.stats.txBytes, length)