]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
receive: implement flush semantics
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 21 Mar 2019 20:43:04 +0000 (14:43 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 21 Mar 2019 20:45:41 +0000 (14:45 -0600)
12 files changed:
device/boundif_darwin.go
device/boundif_windows.go
device/conn.go
device/queueconstants_android.go
device/receive.go
tun/operateonfd.go [moved from tun/tun_default.go with 100% similarity]
tun/tun.go
tun/tun_darwin.go
tun/tun_freebsd.go
tun/tun_linux.go
tun/tun_openbsd.go
tun/tun_windows.go

index b3d10bacee7a18b518af58f369c8ac6196b1a169..a93441c85d64673e9a7f237810842729d518b5d9 100644 (file)
@@ -41,4 +41,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
                return err
        }
        return nil
-}
\ No newline at end of file
+}
index 00631cb9eaea2cbcf81afea242a4efea34972c8e..97381adf930d4e8a4f98d4c4fd6b61d2b1677cf3 100644 (file)
@@ -53,4 +53,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
                return err
        }
        return nil
-}
\ No newline at end of file
+}
index 25946800e86dfeb14dd2bf91556c9e48704e78e9..3c2aa04d8bdf12f9514c85a016d5b6e357dddb65 100644 (file)
@@ -177,4 +177,4 @@ func (device *Device) BindClose() error {
        err := unsafeCloseBind(device)
        device.net.Unlock()
        return err
-}
\ No newline at end of file
+}
index 8d051ad4180f8b0dbf93d2dbdd21eb911104724d..f5c042d2498fcb9172131a49aea1d371797809fa 100644 (file)
@@ -13,4 +13,4 @@ const (
        QueueHandshakeSize         = 1024
        MaxSegmentSize             = 2200
        PreallocatedBuffersPerPool = 4096
-)
\ No newline at end of file
+)
index 09fae5958e1f8a2e9dbc67fa996ce1cf6e154ca4..747a188cc5872b713f23bdf97bfd9b4b7f1ffa7d 100644 (file)
@@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() {
        }
 }
 
+func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
+       if !*shouldFlush {
+               select {
+               case <-peer.routines.stop:
+                       stop = true
+                       return
+               case elem, elemOk = <-peer.queue.inbound:
+                       return
+               }
+       } else {
+               select {
+               case <-peer.routines.stop:
+                       stop = true
+                       return
+               case elem, elemOk = <-peer.queue.inbound:
+                       return
+               default:
+                       *shouldFlush = false
+                       err := peer.device.tun.device.Flush()
+                       if err != nil {
+                               peer.device.log.Error.Printf("Unable to flush packets: %v", err)
+                       }
+                       return peer.elementStopOrFlush(shouldFlush)
+               }
+       }
+}
+
 func (peer *Peer) RoutineSequentialReceiver() {
 
        device := peer.device
@@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
 
        var elem *QueueInboundElement
        var ok bool
+       var stop bool
+
+       shouldFlush := false
 
        defer func() {
                logDebug.Println(peer, "- Routine: sequential receiver - stopped")
@@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        elem = nil
                }
 
-               select {
-
-               case <-peer.routines.stop:
+               stop, ok, elem = peer.elementStopOrFlush(&shouldFlush)
+               if stop || !ok {
                        return
+               }
 
-               case elem, ok = <-peer.queue.inbound:
-
-                       if !ok {
-                               return
-                       }
-
-                       // wait for decryption
+               // wait for decryption
 
-                       elem.Lock()
+               elem.Lock()
 
-                       if elem.IsDropped() {
-                               continue
-                       }
+               if elem.IsDropped() {
+                       continue
+               }
 
-                       // check for replay
+               // check for replay
 
-                       if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
-                               continue
-                       }
+               if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+                       continue
+               }
 
-                       // update endpoint
-                       peer.SetEndpointFromPacket(elem.endpoint)
+               // update endpoint
+               peer.SetEndpointFromPacket(elem.endpoint)
 
-                       // check if using new keypair
-                       if peer.ReceivedWithKeypair(elem.keypair) {
-                               peer.timersHandshakeComplete()
-                               select {
-                               case peer.signals.newKeypairArrived <- struct{}{}:
-                               default:
-                               }
+               // check if using new keypair
+               if peer.ReceivedWithKeypair(elem.keypair) {
+                       peer.timersHandshakeComplete()
+                       select {
+                       case peer.signals.newKeypairArrived <- struct{}{}:
+                       default:
                        }
+               }
 
-                       peer.keepKeyFreshReceiving()
-                       peer.timersAnyAuthenticatedPacketTraversal()
-                       peer.timersAnyAuthenticatedPacketReceived()
-
-                       // check for keepalive
+               peer.keepKeyFreshReceiving()
+               peer.timersAnyAuthenticatedPacketTraversal()
+               peer.timersAnyAuthenticatedPacketReceived()
 
-                       if len(elem.packet) == 0 {
-                               logDebug.Println(peer, "- Receiving keepalive packet")
-                               continue
-                       }
-                       peer.timersDataReceived()
+               // check for keepalive
 
-                       // verify source and strip padding
+               if len(elem.packet) == 0 {
+                       logDebug.Println(peer, "- Receiving keepalive packet")
+                       continue
+               }
+               peer.timersDataReceived()
 
-                       switch elem.packet[0] >> 4 {
-                       case ipv4.Version:
+               // verify source and strip padding
 
-                               // strip padding
+               switch elem.packet[0] >> 4 {
+               case ipv4.Version:
 
-                               if len(elem.packet) < ipv4.HeaderLen {
-                                       continue
-                               }
+                       // strip padding
 
-                               field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
-                               length := binary.BigEndian.Uint16(field)
-                               if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
-                                       continue
-                               }
+                       if len(elem.packet) < ipv4.HeaderLen {
+                               continue
+                       }
 
-                               elem.packet = elem.packet[:length]
+                       field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+                       length := binary.BigEndian.Uint16(field)
+                       if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+                               continue
+                       }
 
-                               // verify IPv4 source
+                       elem.packet = elem.packet[:length]
 
-                               src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                               if device.allowedips.LookupIPv4(src) != peer {
-                                       logInfo.Println(
-                                               "IPv4 packet with disallowed source address from",
-                                               peer,
-                                       )
-                                       continue
-                               }
+                       // verify IPv4 source
 
-                       case ipv6.Version:
+                       src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+                       if device.allowedips.LookupIPv4(src) != peer {
+                               logInfo.Println(
+                                       "IPv4 packet with disallowed source address from",
+                                       peer,
+                               )
+                               continue
+                       }
 
-                               // strip padding
+               case ipv6.Version:
 
-                               if len(elem.packet) < ipv6.HeaderLen {
-                                       continue
-                               }
+                       // strip padding
 
-                               field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
-                               length := binary.BigEndian.Uint16(field)
-                               length += ipv6.HeaderLen
-                               if int(length) > len(elem.packet) {
-                                       continue
-                               }
+                       if len(elem.packet) < ipv6.HeaderLen {
+                               continue
+                       }
 
-                               elem.packet = elem.packet[:length]
+                       field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+                       length := binary.BigEndian.Uint16(field)
+                       length += ipv6.HeaderLen
+                       if int(length) > len(elem.packet) {
+                               continue
+                       }
 
-                               // verify IPv6 source
+                       elem.packet = elem.packet[:length]
 
-                               src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                               if device.allowedips.LookupIPv6(src) != peer {
-                                       logInfo.Println(
-                                               peer,
-                                               "sent packet with disallowed IPv6 source",
-                                       )
-                                       continue
-                               }
+                       // verify IPv6 source
 
-                       default:
-                               logInfo.Println("Packet with invalid IP version from", peer)
+                       src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+                       if device.allowedips.LookupIPv6(src) != peer {
+                               logInfo.Println(
+                                       peer,
+                                       "sent packet with disallowed IPv6 source",
+                               )
                                continue
                        }
 
-                       // write to tun device
+               default:
+                       logInfo.Println("Packet with invalid IP version from", peer)
+                       continue
+               }
+
+               // write to tun device
 
-                       offset := MessageTransportOffsetContent
-                       atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
-                       _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
-                       if err != nil && !device.isClosed.Get() {
-                               logError.Println("Failed to write packet to TUN device:", err)
-                       }
+               offset := MessageTransportOffsetContent
+               atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+               _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
+               if err == nil {
+                       shouldFlush = true
+               }
+               if err != nil && !device.isClosed.Get() {
+                       logError.Println("Failed to write packet to TUN device:", err)
                }
        }
 }
similarity index 100%
rename from tun/tun_default.go
rename to tun/operateonfd.go
index c4b6cacec97e88cb19be46635ed4202ea12fcbd0..12febb8890184dc052789ec4bdc6616fde6eb7c8 100644 (file)
@@ -21,6 +21,7 @@ type TUNDevice interface {
        File() *os.File                 // returns the file descriptor of the device
        Read([]byte, int) (int, error)  // read a packet from the device (without any additional headers)
        Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
+       Flush() error                   // flush all previous writes to the device
        MTU() (int, error)              // returns the MTU of the device
        Name() (string, error)          // fetches and returns the current name
        Events() chan TUNEvent          // returns a constant channel of events related to the device
index 3b3998288097701ab12e90cdefdc1e8d32d62676..2077de361d98ac1db0e8fa7413587d82cf109f41 100644 (file)
@@ -281,6 +281,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.tunFile.Write(buff)
 }
 
+func (tun *NativeTun) Flush() error {
+       //TODO: can flushing be implemented by buffering and using sendmmsg?
+       return nil
+}
+
 func (tun *NativeTun) Close() error {
        var err2 error
        err1 := tun.tunFile.Close()
index 3a607255a1d3845cd0174f6bdc4a8743221dc07f..01a43486fc28be189959f20b2779247691851375 100644 (file)
@@ -406,6 +406,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.tunFile.Write(buff)
 }
 
+func (tun *NativeTun) Flush() error {
+       //TODO: can flushing be implemented by buffering and using sendmmsg?
+       return nil
+}
+
 func (tun *NativeTun) Close() error {
        var err3 error
        err1 := tun.tunFile.Close()
index b7c429c5152fd3a8e40801c55cd863f3cd3a4ef1..784cb9f4f0a2a380e4bc25aef15e189b6b6bbbc0 100644 (file)
@@ -318,6 +318,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.tunFile.Write(buff)
 }
 
+func (tun *NativeTun) Flush() error {
+       //TODO: can flushing be implemented by buffering and using sendmmsg?
+       return nil
+}
+
 func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
        select {
        case err := <-tun.errors:
index 57edcb45f36f3e20a3718e978f0b006b97d73e37..645bccad602bd25e8a1db25a8f7542313c666b3f 100644 (file)
@@ -237,6 +237,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        return tun.tunFile.Write(buff)
 }
 
+func (tun *NativeTun) Flush() error {
+       //TODO: can flushing be implemented by buffering and using sendmmsg?
+       return nil
+}
+
 func (tun *NativeTun) Close() error {
        var err2 error
        err1 := tun.tunFile.Close()
index dcb414a02a69dac7dea50408ae25509eb7380e53..fffd802f57983b0269a85625436124906300767d 100644 (file)
@@ -281,7 +281,11 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
 
 // Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking.
 
-func (tun *NativeTun) flush() error {
+func (tun *NativeTun) Flush() error {
+       if tun.wrBuff.offset == 0 {
+               return nil
+       }
+
        // Get TUN data pipe.
        file, err := tun.getTUN()
        if err != nil {
@@ -322,7 +326,7 @@ func (tun *NativeTun) putTunPacket(buff []byte) error {
 
        if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
                // Exchange buffer is full -> flush first.
-               err := tun.flush()
+               err := tun.Flush()
                if err != nil {
                        return err
                }
@@ -345,9 +349,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
        if err != nil {
                return 0, err
        }
-
-       // Flush write buffer.
-       return len(buff) - offset, tun.flush()
+       return len(buff) - offset, nil
 }
 
 //