]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Completed get/set configuration
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 29 Jun 2017 12:39:21 +0000 (14:39 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 29 Jun 2017 12:39:21 +0000 (14:39 +0200)
For debugging of "outbound flow"
Mostly, a few things still missing

src/config.go
src/device.go
src/main.go
src/routing.go
src/send.go
src/trie.go

index 3b91d00e6e1cf79319913b0b7cba08dbacc6efa7..2f8dc76de543023fc0821eef5b75c951cbc84cf1 100644 (file)
@@ -5,24 +5,22 @@ import (
        "errors"
        "fmt"
        "io"
-       "log"
        "net"
        "strconv"
+       "strings"
        "time"
 )
 
-/* TODO : use real error code
- * Many of which will be the same
+// #include <errno.h>
+import "C"
+
+/* TODO: More fine grained?
  */
 const (
-       ipcErrorNoPeer            = 0
-       ipcErrorNoKeyValue        = 1
-       ipcErrorInvalidKey        = 2
-       ipcErrorInvalidValue      = 2
-       ipcErrorInvalidPrivateKey = 3
-       ipcErrorInvalidPublicKey  = 4
-       ipcErrorInvalidPort       = 5
-       ipcErrorInvalidIPAddress  = 6
+       ipcErrorNoPeer       = C.EPROTO
+       ipcErrorNoKeyValue   = C.EPROTO
+       ipcErrorInvalidKey   = C.EPROTO
+       ipcErrorInvalidValue = C.EPROTO
 )
 
 type IPCError struct {
@@ -78,7 +76,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
        // send lines
 
        for _, line := range lines {
-               device.log.Debug.Println("config:", line)
+               device.log.Debug.Println("Response:", line)
                _, err := socket.WriteString(line + "\n")
                if err != nil {
                        return err
@@ -89,29 +87,26 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
 }
 
 func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
-
+       logger := device.log.Debug
        scanner := bufio.NewScanner(socket)
 
-       device.mutex.Lock()
-       defer device.mutex.Unlock()
-
+       var peer *Peer
        for scanner.Scan() {
-               var key string
-               var value string
-               var peer *Peer
 
                // Parse line
 
                line := scanner.Text()
-               if line == "\n" {
-                       break
+               if line == "" {
+                       return nil
                }
-               fmt.Println(line)
-               n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
-               if n != 2 || err != nil {
-                       fmt.Println(err, n)
+               parts := strings.Split(line, "=")
+               if len(parts) != 2 {
+                       device.log.Debug.Println(parts)
                        return &IPCError{Code: ipcErrorNoKeyValue}
                }
+               key := parts[0]
+               value := parts[1]
+               logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log
 
                switch key {
 
@@ -119,41 +114,60 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                case "private_key":
                        if value == "" {
+                               device.mutex.Lock()
                                device.privateKey = NoisePrivateKey{}
+                               device.mutex.Unlock()
                        } else {
+                               device.mutex.Lock()
                                err := device.privateKey.FromHex(value)
+                               device.mutex.Unlock()
                                if err != nil {
-                                       return &IPCError{Code: ipcErrorInvalidPrivateKey}
+                                       logger.Println("Failed to set private_key:", err)
+                                       return &IPCError{Code: ipcErrorInvalidValue}
                                }
                        }
 
                case "listen_port":
-                       _, err := fmt.Sscanf(value, "%ud", &device.address.Port)
-                       if err != nil {
-                               return &IPCError{Code: ipcErrorInvalidPort}
+                       var port int
+                       _, err := fmt.Sscanf(value, "%d", &port)
+                       if err != nil || port > (1<<16) || port < 0 {
+                               logger.Println("Failed to set listen_port:", err)
+                               return &IPCError{Code: ipcErrorInvalidValue}
                        }
+                       device.mutex.Lock()
+                       if device.address == nil {
+                               device.address = &net.UDPAddr{}
+                       }
+                       device.address.Port = port
+                       device.mutex.Unlock()
 
                case "fwmark":
-                       panic(nil) // not handled yet
+                       logger.Println("FWMark not handled yet")
 
                case "public_key":
                        var pubKey NoisePublicKey
                        err := pubKey.FromHex(value)
                        if err != nil {
-                               return &IPCError{Code: ipcErrorInvalidPublicKey}
+                               logger.Println("Failed to get peer by public_key:", err)
+                               return &IPCError{Code: ipcErrorInvalidValue}
                        }
+                       device.mutex.RLock()
                        found, ok := device.peers[pubKey]
+                       device.mutex.RUnlock()
                        if ok {
                                peer = found
                        } else {
                                peer = device.NewPeer(pubKey)
                        }
+                       if peer == nil {
+                               panic(errors.New("bug: failed to find peer"))
+                       }
 
                case "replace_peers":
-                       if key == "true" {
+                       if value == "true" {
                                device.RemoveAllPeers()
-                       } else if key == "false" {
                        } else {
+                               logger.Println("Failed to set replace_peers, invalid value:", value)
                                return &IPCError{Code: ipcErrorInvalidValue}
                        }
 
@@ -161,6 +175,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        /* Peer configuration */
 
                        if peer == nil {
+                               logger.Println("No peer referenced, before peer operation")
                                return &IPCError{Code: ipcErrorNoPeer}
                        }
 
@@ -168,7 +183,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "remove":
                                peer.mutex.Lock()
-                               // device.RemovePeer(peer.publicKey)
+                               device.RemovePeer(peer.handshake.remoteStatic)
+                               peer.mutex.Unlock()
+                               logger.Println("Remove peer")
                                peer = nil
 
                        case "preshared_key":
@@ -178,13 +195,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                        return peer.handshake.presharedKey.FromHex(value)
                                }()
                                if err != nil {
-                                       return &IPCError{Code: ipcErrorInvalidPublicKey}
+                                       logger.Println("Failed to set preshared_key:", err)
+                                       return &IPCError{Code: ipcErrorInvalidValue}
                                }
 
                        case "endpoint":
                                ip := net.ParseIP(value)
                                if ip == nil {
-                                       return &IPCError{Code: ipcErrorInvalidIPAddress}
+                                       logger.Println("Failed to set endpoint:", value)
+                                       return &IPCError{Code: ipcErrorInvalidValue}
                                }
                                peer.mutex.Lock()
                                // peer.endpoint = ip FIX
@@ -193,6 +212,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        case "persistent_keepalive_interval":
                                secs, err := strconv.ParseInt(value, 10, 64)
                                if secs < 0 || err != nil {
+                                       logger.Println("Failed to set persistent_keepalive_interval:", err)
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
                                peer.mutex.Lock()
@@ -200,24 +220,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                peer.mutex.Unlock()
 
                        case "replace_allowed_ips":
-                               if key == "true" {
+                               if value == "true" {
                                        device.routingTable.RemovePeer(peer)
-                               } else if key == "false" {
                                } else {
+                                       logger.Println("Failed to set replace_allowed_ips, invalid value:", value)
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
 
                        case "allowed_ip":
                                _, network, err := net.ParseCIDR(value)
                                if err != nil {
+                                       logger.Println("Failed to set allowed_ip:", err)
                                        return &IPCError{Code: ipcErrorInvalidValue}
                                }
                                ones, _ := network.Mask.Size()
+                               logger.Println(network, ones, network.IP)
                                device.routingTable.Insert(network.IP, uint(ones), peer)
 
                        /* Invalid key */
 
                        default:
+                               logger.Println("Invalid key:", key)
                                return &IPCError{Code: ipcErrorInvalidKey}
                        }
                }
@@ -226,49 +249,48 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
        return nil
 }
 
-func ipcListen(device *Device, socket io.ReadWriter) error {
+func ipcHandle(device *Device, socket net.Conn) {
 
-       buffered := func(s io.ReadWriter) *bufio.ReadWriter {
-               reader := bufio.NewReader(s)
-               writer := bufio.NewWriter(s)
-               return bufio.NewReadWriter(reader, writer)
-       }(socket)
+       func() {
+               buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+                       reader := bufio.NewReader(s)
+                       writer := bufio.NewWriter(s)
+                       return bufio.NewReadWriter(reader, writer)
+               }(socket)
 
-       defer buffered.Flush()
+               defer buffered.Flush()
 
-       for {
                op, err := buffered.ReadString('\n')
                if err != nil {
-                       return err
+                       return
                }
-               log.Println(op)
 
                switch op {
 
                case "set=1\n":
+                       device.log.Debug.Println("Config, set operation")
                        err := ipcSetOperation(device, buffered)
                        if err != nil {
                                fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
-                               return err
                        } else {
                                fmt.Fprintf(buffered, "errno=0\n\n")
                        }
-                       buffered.Flush()
+                       break
 
                case "get=1\n":
+                       device.log.Debug.Println("Config, get operation")
                        err := ipcGetOperation(device, buffered)
                        if err != nil {
                                fmt.Fprintf(buffered, "errno=1\n\n") // fix
-                               return err
                        } else {
                                fmt.Fprintf(buffered, "errno=0\n\n")
                        }
-                       buffered.Flush()
+                       break
 
-               case "\n":
                default:
-                       return errors.New("handle this please")
+                       device.log.Info.Println("Invalid UAPI operation:", op)
                }
-       }
+       }()
 
+       socket.Close()
 }
index a7a5c7bc12ef7bf3db671cce95c505df90b73242..52ac6a499c5354095f8cf54d040fa1801e5d08fb 100644 (file)
@@ -81,10 +81,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
        peer.mutex.Lock()
        device.routingTable.RemovePeer(peer)
        delete(device.peers, key)
-}
-
-func (device *Device) RemoveAllAllowedIps(peer *Peer) {
-
+       peer.Close()
 }
 
 func (device *Device) RemoveAllPeers() {
@@ -93,8 +90,7 @@ func (device *Device) RemoveAllPeers() {
 
        for key, peer := range device.peers {
                peer.mutex.Lock()
-               device.routingTable.RemovePeer(peer)
                delete(device.peers, key)
-               peer.mutex.Unlock()
+               peer.Close()
        }
 }
index 7c589721ee115ec9c816bd91a8c8254504d4fe26..9c76ff4cebd11fac76e54c895538c2e89d74050f 100644 (file)
@@ -1,21 +1,28 @@
 package main
 
 import (
+       "fmt"
        "log"
        "net"
+       "os"
 )
 
-/*
- *
- * TODO: Fix logging
+/* TODO: Fix logging
+ * TODO: Fix daemon
  */
 
 func main() {
+
+       if len(os.Args) != 2 {
+               return
+       }
+       deviceName := os.Args[1]
+
        // Open TUN device
 
        // TODO: Fix capabilities
 
-       tun, err := CreateTUN("test0")
+       tun, err := CreateTUN(deviceName)
        log.Println(tun, err)
        if err != nil {
                return
@@ -25,19 +32,17 @@ func main() {
 
        // Start configuration lister
 
-       l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
+       socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
+       l, err := net.Listen("unix", socketPath)
        if err != nil {
                log.Fatal("listen error:", err)
        }
 
        for {
-               fd, err := l.Accept()
+               conn, err := l.Accept()
                if err != nil {
                        log.Fatal("accept error:", err)
                }
-               go func(conn net.Conn) {
-                       err := ipcListen(device, conn)
-                       log.Println(err)
-               }(fd)
+               go ipcHandle(device, conn)
        }
 }
index 6a5e1f36f0c7ca1ad97314687223337ff0acb66f..2a2e237085cda711e15b28191f473646cc675a20 100644 (file)
@@ -16,9 +16,9 @@ func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
 
-       allowed := make([]net.IPNet, 10)
-       table.IPv4.AllowedIPs(peer, allowed)
-       table.IPv6.AllowedIPs(peer, allowed)
+       allowed := make([]net.IPNet, 0, 10)
+       allowed = table.IPv4.AllowedIPs(peer, allowed)
+       allowed = table.IPv6.AllowedIPs(peer, allowed)
        return allowed
 }
 
index 4ff75db94655f74389df5cfc3d588bc1b8c1858b..ab75750f13669c63a35d80408ced5dd0151b8025 100644 (file)
@@ -61,9 +61,11 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
  * Obs. Single instance per TUN device
  */
 func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+       device.log.Debug.Println("Routine, TUN Reader: started")
        for {
                // read packet
 
+               device.log.Debug.Println("Read")
                packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
                size, err := tun.Read(packet)
                if err != nil {
@@ -76,8 +78,6 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
                        continue
                }
 
-               device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
-
                // lookup peer
 
                var peer *Peer
@@ -85,10 +85,12 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
                case IPv4version:
                        dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
                        peer = device.routingTable.LookupIPv4(dst)
+                       device.log.Debug.Println("New IPv4 packet:", packet, dst)
 
                case IPv6version:
                        dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
                        peer = device.routingTable.LookupIPv6(dst)
+                       device.log.Debug.Println("New IPv6 packet:", packet, dst)
 
                default:
                        device.log.Debug.Println("Receieved packet with unknown IP version")
@@ -97,7 +99,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
 
                if peer == nil {
                        device.log.Debug.Println("No peer configured for IP")
-                       return
+                       continue
                }
 
                // insert into nonce/pre-handshake queue
index 4049167692e3c12a11b82254ed8d24bd4f9b7fb2..c2304b2f3d0087cecaebed7c295dfa81fecfa1da 100644 (file)
@@ -195,7 +195,10 @@ func (node *Trie) Count() uint {
        return l + r
 }
 
-func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
+func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
+       if node == nil {
+               return results
+       }
        if node.peer == p {
                var mask net.IPNet
                mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
@@ -213,6 +216,7 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
                }
                results = append(results, mask)
        }
-       node.child[0].AllowedIPs(p, results)
-       node.child[1].AllowedIPs(p, results)
+       results = node.child[0].AllowedIPs(p, results)
+       results = node.child[1].AllowedIPs(p, results)
+       return results
 }