]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Fixed receive path infinite loop
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 30 Nov 2017 23:03:06 +0000 (00:03 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 30 Nov 2017 23:03:06 +0000 (00:03 +0100)
src/receive.go

index 7d493b088e6ff4c9b78fc654c4150ee1dd18ce26..fd1993eab9afd0b05206289fcebc163207fb79f4 100644 (file)
@@ -98,118 +98,115 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
        logDebug := device.log.Debug
        logDebug.Println("Routine, receive incoming, IP version:", IP)
 
-       for {
-
-               // receive datagrams until conn is closed
+       // receive datagrams until conn is closed
 
-               buffer := device.GetMessageBuffer()
+       buffer := device.GetMessageBuffer()
 
-               var (
-                       err      error
-                       size     int
-                       endpoint Endpoint
-               )
+       var (
+               err      error
+               size     int
+               endpoint Endpoint
+       )
 
-               for {
+       for {
 
-                       // read next datagram
+               // read next datagram
 
-                       switch IP {
-                       case ipv4.Version:
-                               size, endpoint, err = bind.ReceiveIPv4(buffer[:])
-                       case ipv6.Version:
-                               size, endpoint, err = bind.ReceiveIPv6(buffer[:])
-                       default:
-                               return
-                       }
+               switch IP {
+               case ipv4.Version:
+                       size, endpoint, err = bind.ReceiveIPv4(buffer[:])
+               case ipv6.Version:
+                       size, endpoint, err = bind.ReceiveIPv6(buffer[:])
+               default:
+                       return
+               }
 
-                       if err != nil {
-                               break
-                       }
+               if err != nil {
+                       return
+               }
 
-                       if size < MinMessageSize {
-                               continue
-                       }
+               if size < MinMessageSize {
+                       continue
+               }
 
-                       // check size of packet
+               // check size of packet
 
-                       packet := buffer[:size]
-                       msgType := binary.LittleEndian.Uint32(packet[:4])
+               packet := buffer[:size]
+               msgType := binary.LittleEndian.Uint32(packet[:4])
 
-                       var okay bool
+               var okay bool
 
-                       switch msgType {
+               switch msgType {
 
-                       // check if transport
+               // check if transport
 
-                       case MessageTransportType:
+               case MessageTransportType:
 
-                               // check size
+                       // check size
 
-                               if len(packet) < MessageTransportType {
-                                       continue
-                               }
+                       if len(packet) < MessageTransportType {
+                               continue
+                       }
 
-                               // lookup key pair
+                       // lookup key pair
 
-                               receiver := binary.LittleEndian.Uint32(
-                                       packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
-                               )
-                               value := device.indices.Lookup(receiver)
-                               keyPair := value.keyPair
-                               if keyPair == nil {
-                                       continue
-                               }
+                       receiver := binary.LittleEndian.Uint32(
+                               packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+                       )
+                       value := device.indices.Lookup(receiver)
+                       keyPair := value.keyPair
+                       if keyPair == nil {
+                               continue
+                       }
 
-                               // check key-pair expiry
+                       // check key-pair expiry
 
-                               if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
-                                       continue
-                               }
+                       if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+                               continue
+                       }
 
-                               // create work element
+                       // create work element
 
-                               peer := value.peer
-                               elem := &QueueInboundElement{
-                                       packet:   packet,
-                                       buffer:   buffer,
-                                       keyPair:  keyPair,
-                                       dropped:  AtomicFalse,
-                                       endpoint: endpoint,
-                               }
-                               elem.mutex.Lock()
+                       peer := value.peer
+                       elem := &QueueInboundElement{
+                               packet:   packet,
+                               buffer:   buffer,
+                               keyPair:  keyPair,
+                               dropped:  AtomicFalse,
+                               endpoint: endpoint,
+                       }
+                       elem.mutex.Lock()
 
-                               // add to decryption queues
+                       // add to decryption queues
 
-                               device.addToDecryptionQueue(device.queue.decryption, elem)
-                               device.addToInboundQueue(peer.queue.inbound, elem)
-                               buffer = device.GetMessageBuffer()
-                               continue
+                       device.addToDecryptionQueue(device.queue.decryption, elem)
+                       device.addToInboundQueue(peer.queue.inbound, elem)
+                       buffer = device.GetMessageBuffer()
+                       continue
 
-                       // otherwise it is a fixed size & handshake related packet
+               // otherwise it is a fixed size & handshake related packet
 
-                       case MessageInitiationType:
-                               okay = len(packet) == MessageInitiationSize
+               case MessageInitiationType:
+                       okay = len(packet) == MessageInitiationSize
 
-                       case MessageResponseType:
-                               okay = len(packet) == MessageResponseSize
+               case MessageResponseType:
+                       okay = len(packet) == MessageResponseSize
 
-                       case MessageCookieReplyType:
-                               okay = len(packet) == MessageCookieReplySize
-                       }
+               case MessageCookieReplyType:
+                       okay = len(packet) == MessageCookieReplySize
+               }
 
-                       if okay {
-                               device.addToHandshakeQueue(
-                                       device.queue.handshake,
-                                       QueueHandshakeElement{
-                                               msgType:  msgType,
-                                               buffer:   buffer,
-                                               packet:   packet,
-                                               endpoint: endpoint,
-                                       },
-                               )
-                               buffer = device.GetMessageBuffer()
-                       }
+               if okay {
+                       device.addToHandshakeQueue(
+                               device.queue.handshake,
+                               QueueHandshakeElement{
+                                       msgType:  msgType,
+                                       buffer:   buffer,
+                                       packet:   packet,
+                                       endpoint: endpoint,
+                               },
+                       )
+                       buffer = device.GetMessageBuffer()
                }
        }
 }