]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Initial implementation of source caching
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 16 Oct 2017 19:33:47 +0000 (21:33 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Mon, 16 Oct 2017 19:33:47 +0000 (21:33 +0200)
Yet untested.

src/conn.go
src/conn_linux.go
src/device.go
src/main.go
src/peer.go
src/receive.go
src/send.go
src/timers.go
src/tun.go
src/uapi.go

index db4020d61ba9c01920f9461b3163b232a6cb7a7d..012e24e0a8a087bbbc061ce588e1c15b891b928e 100644 (file)
@@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
        return addr, err
 }
 
-func ListeningUpdate(device *Device) error {
+func UpdateUDPListener(device *Device) error {
+       device.mutex.Lock()
+       defer device.mutex.Unlock()
+
        netc := &device.net
        netc.mutex.Lock()
        defer netc.mutex.Unlock()
 
        // close existing sockets
 
-       if err := device.net.bind.Close(); err != nil {
-               return err
+       if netc.bind != nil {
+               if err := netc.bind.Close(); err != nil {
+                       return err
+               }
        }
 
        // open new sockets
@@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error {
                        return err
                }
 
-               // TODO: clear endpoint (src) caches
+               // clear cached source addresses
+
+               for _, peer := range device.peers {
+                       peer.mutex.Lock()
+                       peer.endpoint.value.ClearSrc()
+                       peer.mutex.Unlock()
+               }
        }
 
        return nil
 }
 
-func ListeningClose(device *Device) error {
+func CloseUDPListener(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
        defer netc.mutex.Unlock()
index 8942b03a7c85e105f45c6cbec255f7d82a0f6aec..4a5a3f082673e6b9f4b6a37519cf43dbb26eec28 100644 (file)
@@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
        }
 }
 
-func (end *Endpoint) DestinationIP() net.IP {
+func (end *Endpoint) DstIP() net.IP {
        switch end.dst.Family {
        case unix.AF_INET6:
                return end.dst.Addr[:]
@@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP {
        }
 }
 
-func (end *Endpoint) SourceToBytes() []byte {
+func (end *Endpoint) SrcToBytes() []byte {
        ptr := unsafe.Pointer(&end.src)
        arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
        return arr[:]
 }
 
-func (end *Endpoint) SourceToString() string {
+func (end *Endpoint) SrcToString() string {
        return sockaddrToString(end.src)
 }
 
-func (end *Endpoint) DestinationToString() string {
+func (end *Endpoint) DstToString() string {
        return sockaddrToString(end.dst)
 }
 
+func (end *Endpoint) ClearDst() {
+       end.dst = unix.RawSockaddrInet6{}
+}
+
 func (end *Endpoint) ClearSrc() {
        end.src = unix.RawSockaddrInet6{}
 }
index d1e06859093b41ae3529fe0a572cb8372b26055c..1aae4488351c9ef641a81c09d7c0064e50c5a44a 100644 (file)
@@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() {
 func (device *Device) Close() {
        device.RemoveAllPeers()
        close(device.signal.stop)
-       ListeningClose(device)
+       CloseUDPListener(device)
 }
 
 func (device *Device) WaitChannel() chan struct{} {
index a05dbba238fef3384cfc111faf09c06370e19336..5aaed9bbe00c96d6ed52baa8b55a6883b0c6f188 100644 (file)
@@ -14,8 +14,6 @@ func printUsage() {
 }
 
 func main() {
-       test()
-
        // parse arguments
 
        var foreground bool
index 791c091319a7760d84b24ddde53f3ae20248a8ce..f24dcd8207f52befd1f3f272a3b4cc5832fbaf79 100644 (file)
@@ -14,9 +14,12 @@ type Peer struct {
        persistentKeepaliveInterval uint64
        keyPairs                    KeyPairs
        handshake                   Handshake
-       endpoint                    Endpoint
        device                      *Device
-       stats                       struct {
+       endpoint                    struct {
+               set   bool     // has a known endpoint been discovered
+               value Endpoint // source / destination cache
+       }
+       stats struct {
                txBytes           uint64 // bytes send to peer (endpoint)
                rxBytes           uint64 // bytes received from peer
                lastHandshakeNano int64  // nano seconds since epoch
@@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
        handshake.mutex.Unlock()
 
+       // reset endpoint
+
+       peer.endpoint.set = false
+       peer.endpoint.value.ClearDst()
+       peer.endpoint.value.ClearSrc()
+
        // prepare queuing
 
        peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
        return peer, nil
 }
 
+/* Returns a short string identification for logging
+ */
 func (peer *Peer) String() string {
+       if !peer.endpoint.set {
+               return fmt.Sprintf(
+                       "peer(%d unknown %s)",
+                       peer.id,
+                       base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+               )
+       }
        return fmt.Sprintf(
                "peer(%d %s %s)",
                peer.id,
-               peer.endpoint.DestinationToString(),
+               peer.endpoint.value.DstToString(),
                base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
        )
 }
index 664f1ba674fdee8a0578245019ed64f56a154e37..1f05b2fa064de5d35665d131359a7cd77df64e5d 100644 (file)
@@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() {
                                return
                        }
 
-                       srcBytes := elem.endpoint.SourceToBytes()
+                       srcBytes := elem.endpoint.SrcToBytes()
                        if device.IsUnderLoad() {
 
                                // verify MAC2 field
@@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() {
 
                                        // construct cookie reply
 
-                                       logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
-
+                                       logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString())
                                        sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
                                        reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
                                        if err != nil {
@@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() {
 
                                // check ratelimiter
 
-                               if !device.ratelimiter.Allow(
-                                       elem.endpoint.DestinationIP(),
-                               ) {
+                               if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
                                        continue
                                }
                        }
@@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() {
                        if peer == nil {
                                logInfo.Println(
                                        "Recieved invalid initiation message from",
-                                       elem.endpoint.DestinationToString(),
+                                       elem.endpoint.DstToString(),
                                )
                                continue
                        }
@@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() {
                        // TODO: Discover destination address also, only update on change
 
                        peer.mutex.Lock()
-                       peer.endpoint = elem.endpoint
+                       peer.endpoint.set = true
+                       peer.endpoint.value = elem.endpoint
                        peer.mutex.Unlock()
 
                        // create response
@@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() {
 
                        // send response
 
-                       _, err = peer.SendBuffer(packet)
+                       err = peer.SendBuffer(packet)
                        if err == nil {
                                peer.TimerAnyAuthenticatedPacketTraversal()
                        }
@@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() {
                        if peer == nil {
                                logInfo.Println(
                                        "Recieved invalid response message from",
-                                       elem.endpoint.DestinationToString(),
+                                       elem.endpoint.DstToString(),
                                )
                                continue
                        }
index 5c88ead2ae664ac5637566c45ed11167ebe00edd..e37a736d9ae289260a19f6b44f55b2a4433aa608 100644 (file)
@@ -105,24 +105,15 @@ func addToEncryptionQueue(
        }
 }
 
-func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
+func (peer *Peer) SendBuffer(buffer []byte) error {
        peer.device.net.mutex.RLock()
        defer peer.device.net.mutex.RUnlock()
-
        peer.mutex.RLock()
        defer peer.mutex.RUnlock()
-
-       endpoint := peer.endpoint
-       if endpoint == nil {
-               return 0, errors.New("No known endpoint for peer")
+       if !peer.endpoint.set {
+               return errors.New("No known endpoint for peer")
        }
-
-       conn := peer.device.net.conn
-       if conn == nil {
-               return 0, errors.New("No UDP socket for device")
-       }
-
-       return conn.WriteToUDP(buffer, endpoint)
+       return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
 }
 
 /* Reads packets from the TUN and inserts
@@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() {
                        // send message and return buffer to pool
 
                        length := uint64(len(elem.packet))
-                       _, err := peer.SendBuffer(elem.packet)
+                       err := peer.SendBuffer(elem.packet)
                        device.PutMessageBuffer(elem.buffer)
                        if err != nil {
                                logDebug.Println("Failed to send authenticated packet to peer", peer.String())
index 99695ba59cb61fdd6e07a33b3a6aaf82caa82201..2a940056dc464e78b672fae9702c7b9d73d403d0 100644 (file)
@@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() {
                        packet := writer.Bytes()\r
                        peer.mac.AddMacs(packet)\r
 \r
-                       _, err = peer.SendBuffer(packet)\r
+                       err = peer.SendBuffer(packet)\r
                        if err != nil {\r
                                logError.Println(\r
                                        "Failed to send handshake initiation message to",\r
index 8e8c759ff1ec0050744d7d99b1655a3b96f80d7e..9eed98747829e7fe87924afdc6fd0285821052c2 100644 (file)
@@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() {
                        if !device.tun.isUp.Get() {
                                logInfo.Println("Interface set up")
                                device.tun.isUp.Set(true)
-                               updateUDPConn(device)
+                               UpdateUDPListener(device)
                        }
                }
 
@@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() {
                        if device.tun.isUp.Get() {
                                logInfo.Println("Interface set down")
                                device.tun.isUp.Set(false)
-                               closeUDPConn(device)
+                               CloseUDPListener(device)
                        }
                }
        }
index 7d08e561eef9ab89fbfda2d8f8200dbf877a4abc..2de26ee8216740aaf0331bdd6f3d1c04731079b8 100644 (file)
@@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                send("private_key=" + device.privateKey.ToHex())
        }
 
-       if device.net.addr != nil {
-               send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
+       if device.net.port != 0 {
+               send(fmt.Sprintf("listen_port=%d", device.net.port))
        }
+
        if device.net.fwmark != 0 {
                send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
        }
@@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        defer peer.mutex.RUnlock()
                        send("public_key=" + peer.handshake.remoteStatic.ToHex())
                        send("preshared_key=" + peer.handshake.presharedKey.ToHex())
-                       if peer.endpoint != nil {
-                               send("endpoint=" + peer.endpoint.String())
+                       if peer.endpoint.set {
+                               send("endpoint=" + peer.endpoint.value.DstToString())
                        }
 
                        nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        logError.Println("Failed to set listen_port:", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-
-                               addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
-                               if err != nil {
-                                       logError.Println("Failed to set listen_port:", err)
-                                       return &IPCError{Code: ipcErrorInvalid}
-                               }
-
-                               device.net.mutex.Lock()
-                               device.net.addr = addr
-                               device.net.mutex.Unlock()
-
-                               err = updateUDPConn(device)
-                               if err != nil {
+                               device.net.port = uint16(port)
+                               if err := UpdateUDPListener(device); err != nil {
                                        logError.Println("Failed to set listen_port:", err)
                                        return &IPCError{Code: ipcErrorPortInUse}
                                }
 
-                               // TODO: Clear source address of all peers
-
                        case "fwmark":
                                fwmark, err := strconv.ParseUint(value, 10, 32)
                                if err != nil {
                                        logError.Println("Invalid fwmark", err)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-
                                device.net.mutex.Lock()
-                               if fwmark > 0 || device.net.fwmark > 0 {
-                                       device.net.fwmark = uint32(fwmark)
-                                       err := SetMark(
-                                               device.net.conn,
-                                               device.net.fwmark,
-                                       )
-                                       if err != nil {
-                                               logError.Println("Failed to set fwmark:", err)
-                                               device.net.mutex.Unlock()
-                                               return &IPCError{Code: ipcErrorIO}
-                                       }
-
-                                       // TODO: Clear source address of all peers
-                               }
+                               device.net.fwmark = uint32(fwmark)
                                device.net.mutex.Unlock()
 
                        case "public_key":
-
                                // switch to peer configuration
-
                                deviceConfig = false
 
                        case "replace_peers":
@@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                device.mutex.RLock()
                                if device.publicKey.Equals(pubKey) {
 
-                                       // create dummy instance
+                                       // create dummy instance (not added to device)
 
                                        peer = &Peer{}
                                        dummy = true
@@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                        case "remove":
+
+                               // remove currently selected peer from device
+
                                if value != "true" {
                                        logError.Println("Failed to set remove, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
@@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                dummy = true
 
                        case "preshared_key":
+
+                               // update PSK
+
                                peer.mutex.Lock()
                                err := peer.handshake.presharedKey.FromHex(value)
                                peer.mutex.Unlock()
@@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                        case "endpoint":
-                               addr, err := parseEndpoint(value)
+
+                               // set endpoint destination and reset handshake timer
+
+                               peer.mutex.Lock()
+                               err := peer.endpoint.value.Set(value)
+                               peer.endpoint.set = (err == nil)
+                               peer.mutex.Unlock()
                                if err != nil {
                                        logError.Println("Failed to set endpoint:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               peer.mutex.Lock()
-                               peer.endpoint = addr
-                               peer.mutex.Unlock()
                                signalSend(peer.signal.handshakeReset)
 
                        case "persistent_keepalive_interval":