]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Beginning work on TUN interface
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 4 Jun 2017 19:48:15 +0000 (21:48 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 4 Jun 2017 19:48:15 +0000 (21:48 +0200)
And outbound routing

I am not entirely convinced the use of net.IP is a good idea,
since the internal representation of net.IP is a byte slice
and all constructor functions in "net" return 16 byte slices
(padded for IPv4), while the use in this project uses 4 byte slices.
Which may be confusing.

src/config.go
src/ip.go [new file with mode: 0644]
src/main.go
src/peer.go
src/routing.go
src/trie.go
src/trie_test.go
src/tun.go [new file with mode: 0644]
src/tun_linux.go [new file with mode: 0644]

index 62af67a1c7ee2725c4407e68a063593cd1e9ba1a..a61b94055c0bc2ee32c6f58e37b72be643895921 100644 (file)
@@ -7,6 +7,8 @@ import (
        "io"
        "log"
        "net"
+       "strconv"
+       "time"
 )
 
 /* todo : use real error code
@@ -16,6 +18,7 @@ const (
        ipcErrorNoPeer            = 0
        ipcErrorNoKeyValue        = 1
        ipcErrorInvalidKey        = 2
+       ipcErrorInvalidValue      = 2
        ipcErrorInvalidPrivateKey = 3
        ipcErrorInvalidPublicKey  = 4
        ipcErrorInvalidPort       = 5
@@ -34,18 +37,16 @@ func (s *IPCError) ErrorCode() int {
        return s.Code
 }
 
-// Writes the configuration to the socket
 func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
 
 }
 
-// Creates new config, from old and socket message
-func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
+func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
        scanner := bufio.NewScanner(socket)
 
-       dev.mutex.Lock()
-       defer dev.mutex.Unlock()
+       device.mutex.Lock()
+       defer device.mutex.Unlock()
 
        for scanner.Scan() {
                var key string
@@ -71,16 +72,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
 
                case "private_key":
                        if value == "" {
-                               dev.privateKey = NoisePrivateKey{}
+                               device.privateKey = NoisePrivateKey{}
                        } else {
-                               err := dev.privateKey.FromHex(value)
+                               err := device.privateKey.FromHex(value)
                                if err != nil {
                                        return &IPCError{Code: ipcErrorInvalidPrivateKey}
                                }
                        }
 
                case "listen_port":
-                       _, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
+                       _, err := fmt.Sscanf(value, "%ud", &device.listenPort)
                        if err != nil {
                                return &IPCError{Code: ipcErrorInvalidPort}
                        }
@@ -94,7 +95,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
                        if err != nil {
                                return &IPCError{Code: ipcErrorInvalidPublicKey}
                        }
-                       found, ok := dev.peers[pubKey]
+                       found, ok := device.peers[pubKey]
                        if ok {
                                peer = found
                        } else {
@@ -102,14 +103,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
                                        publicKey: pubKey,
                                }
                                peer = newPeer
-                               dev.peers[pubKey] = newPeer
+                               device.peers[pubKey] = newPeer
                        }
 
                case "replace_peers":
                        if key == "true" {
-                               dev.RemoveAllPeers()
+                               device.RemoveAllPeers()
+                       } else if key == "false" {
+                       } else {
+                               return &IPCError{Code: ipcErrorInvalidValue}
                        }
-                       // todo: else fail
 
                default:
                        /* Peer configuration */
@@ -122,7 +125,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "remove":
                                peer.mutex.Lock()
-                               dev.RemovePeer(peer.publicKey)
+                               device.RemovePeer(peer.publicKey)
                                peer = nil
 
                        case "preshared_key":
@@ -145,15 +148,29 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
                                peer.mutex.Unlock()
 
                        case "persistent_keepalive_interval":
-                               func() {
-                                       peer.mutex.Lock()
-                                       defer peer.mutex.Unlock()
-                               }()
+                               secs, err := strconv.ParseInt(value, 10, 64)
+                               if secs < 0 || err != nil {
+                                       return &IPCError{Code: ipcErrorInvalidValue}
+                               }
+                               peer.mutex.Lock()
+                               peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second
+                               peer.mutex.Unlock()
 
                        case "replace_allowed_ips":
-                               // remove peer from trie
+                               if key == "true" {
+                                       device.routingTable.RemovePeer(peer)
+                               } else if key == "false" {
+                               } else {
+                                       return &IPCError{Code: ipcErrorInvalidValue}
+                               }
 
                        case "allowed_ip":
+                               _, network, err := net.ParseCIDR(value)
+                               if err != nil {
+                                       return &IPCError{Code: ipcErrorInvalidValue}
+                               }
+                               ones, _ := network.Mask.Size()
+                               device.routingTable.Insert(network.IP, uint(ones), peer)
 
                        /* Invalid key */
 
diff --git a/src/ip.go b/src/ip.go
new file mode 100644 (file)
index 0000000..3137891
--- /dev/null
+++ b/src/ip.go
@@ -0,0 +1,17 @@
+package main
+
+import (
+       "net"
+)
+
+const (
+       IPv4version   = 4
+       IPv4offsetSrc = 12
+       IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+)
+
+const (
+       IPv6version   = 6
+       IPv6offsetSrc = 8
+       IPv6offsetDst = IPv6offsetSrc + net.IPv6len
+)
index 0f5016df9e7110dfbeba75b85cebf06cc24e60d9..af336f03d63ffb6efc4e36fc305cc16a596bc614 100644 (file)
@@ -1,11 +1,33 @@
 package main
 
+import "fmt"
+
+func main() {
+       fd, err := CreateTUN("test0")
+       fmt.Println(fd, err)
+
+       queue := make(chan []byte, 1000)
+
+       var device Device
+
+       go OutgoingRoutingWorker(&device, queue)
+
+       for {
+               tmp := make([]byte, 1<<16)
+               n, err := fd.Read(tmp)
+               if err != nil {
+                       break
+               }
+               queue <- tmp[:n]
+       }
+}
+
+/*
 import (
        "fmt"
        "log"
        "net"
 )
-
 func main() {
        l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
        if err != nil {
@@ -24,5 +46,5 @@ func main() {
                        fmt.Println(err)
                }(fd)
        }
-
 }
+*/
index 7b2b2a6ce71ab7a2578620563fec86e17b5301cc..db5e99f52ba27afb7481d8ca175472696b4a9aa1 100644 (file)
@@ -3,6 +3,7 @@ package main
 import (
        "net"
        "sync"
+       "time"
 )
 
 type KeyPair struct {
@@ -13,8 +14,9 @@ type KeyPair struct {
 }
 
 type Peer struct {
-       mutex        sync.RWMutex
-       publicKey    NoisePublicKey
-       presharedKey NoiseSymmetricKey
-       endpoint     net.IP
+       mutex                       sync.RWMutex
+       publicKey                   NoisePublicKey
+       presharedKey                NoiseSymmetricKey
+       endpoint                    net.IP
+       persistentKeepaliveInterval time.Duration
 }
index 99b180c42399555b64587262f0f14ffda4d6f26a..0aa111ce604f4509a59945365830d301fbd9f0f5 100644 (file)
@@ -1,13 +1,12 @@
 package main
 
 import (
+       "errors"
+       "fmt"
+       "net"
        "sync"
 )
 
-/* Thread-safe high level functions for cryptkey routing.
- *
- */
-
 type RoutingTable struct {
        IPv4  *Trie
        IPv6  *Trie
@@ -20,3 +19,51 @@ func (table *RoutingTable) RemovePeer(peer *Peer) {
        table.IPv4 = table.IPv4.RemovePeer(peer)
        table.IPv6 = table.IPv6.RemovePeer(peer)
 }
+
+func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) {
+       table.mutex.Lock()
+       defer table.mutex.Unlock()
+
+       switch len(ip) {
+       case net.IPv6len:
+               table.IPv6 = table.IPv6.Insert(ip, cidr, peer)
+       case net.IPv4len:
+               table.IPv4 = table.IPv4.Insert(ip, cidr, peer)
+       default:
+               panic(errors.New("Inserting unknown address type"))
+       }
+}
+
+func (table *RoutingTable) LookupIPv4(address []byte) *Peer {
+       table.mutex.RLock()
+       defer table.mutex.RUnlock()
+       return table.IPv4.Lookup(address)
+}
+
+func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
+       table.mutex.RLock()
+       defer table.mutex.RUnlock()
+       return table.IPv6.Lookup(address)
+}
+
+func OutgoingRoutingWorker(device *Device, queue chan []byte) {
+       for {
+               packet := <-queue
+               switch packet[0] >> 4 {
+
+               case IPv4version:
+                       dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+                       peer := device.routingTable.LookupIPv4(dst)
+                       fmt.Println("IPv4", peer)
+
+               case IPv6version:
+                       dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+                       peer := device.routingTable.LookupIPv6(dst)
+                       fmt.Println("IPv6", peer)
+
+               default:
+                       // todo: log
+                       fmt.Println("Unknown IP version")
+               }
+       }
+}
index 31a4d9230e50f3f963ba6bb8ed9916d112c620e9..746c1b4fbda73c10526453661def7f045acdaa8e 100644 (file)
@@ -1,5 +1,9 @@
 package main
 
+import (
+       "net"
+)
+
 /* Binary trie
  *
  * Syncronization done seperatly
@@ -22,13 +26,13 @@ type Trie struct {
 /* Finds length of matching prefix
  * Maybe there is a faster way
  *
- * Assumption: len(s1) == len(s2)
+ * Assumption: len(ip1) == len(ip2)
  */
-func commonBits(s1 []byte, s2 []byte) uint {
+func commonBits(ip1 net.IP, ip2 net.IP) uint {
        var i uint
-       size := uint(len(s1))
+       size := uint(len(ip1))
        for i = 0; i < size; i += 1 {
-               v := s1[i] ^ s2[i]
+               v := ip1[i] ^ ip2[i]
                if v != 0 {
                        v >>= 1
                        if v == 0 {
@@ -93,17 +97,17 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
        return node.child[0]
 }
 
-func (node *Trie) choose(key []byte) byte {
-       return (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+func (node *Trie) choose(ip net.IP) byte {
+       return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
 }
 
-func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
+func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 
        // At leaf
 
        if node == nil {
                return &Trie{
-                       bits:         key,
+                       bits:         ip,
                        peer:         peer,
                        cidr:         cidr,
                        bit_at_byte:  cidr / 8,
@@ -113,21 +117,21 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
 
        // Traverse deeper
 
-       common := commonBits(node.bits, key)
+       common := commonBits(node.bits, ip)
        if node.cidr <= cidr && common >= node.cidr {
                if node.cidr == cidr {
                        node.peer = peer
                        return node
                }
-               bit := node.choose(key)
-               node.child[bit] = node.child[bit].Insert(key, cidr, peer)
+               bit := node.choose(ip)
+               node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
                return node
        }
 
        // Split node
 
        newNode := &Trie{
-               bits:         key,
+               bits:         ip,
                peer:         peer,
                cidr:         cidr,
                bit_at_byte:  cidr / 8,
@@ -147,31 +151,31 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
        // Create new parent for node & newNode
 
        parent := &Trie{
-               bits:         key,
+               bits:         ip,
                peer:         nil,
                cidr:         cidr,
                bit_at_byte:  cidr / 8,
                bit_at_shift: 7 - (cidr % 8),
        }
 
-       bit := parent.choose(key)
+       bit := parent.choose(ip)
        parent.child[bit] = newNode
        parent.child[bit^1] = node
 
        return parent
 }
 
-func (node *Trie) Lookup(key []byte) *Peer {
+func (node *Trie) Lookup(ip net.IP) *Peer {
        var found *Peer
-       size := uint(len(key))
-       for node != nil && commonBits(node.bits, key) >= node.cidr {
+       size := uint(len(ip))
+       for node != nil && commonBits(node.bits, ip) >= node.cidr {
                if node.peer != nil {
                        found = node.peer
                }
                if node.bit_at_byte == size {
                        break
                }
-               bit := node.choose(key)
+               bit := node.choose(ip)
                node = node.child[bit]
        }
        return found
index 35af0aaa82080ac89eb732b7c68643460deb1b10..9d53df3408c4bf799be49c6e9c87d7056234305c 100644 (file)
@@ -1,6 +1,8 @@
 package main
 
 import (
+       "math/rand"
+       "net"
        "testing"
 )
 
@@ -55,6 +57,49 @@ func TestCommonBits(t *testing.T) {
        }
 }
 
+func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+       var trie *Trie
+       var peers []*Peer
+
+       rand.Seed(1)
+
+       const AddressLength = 4
+
+       for n := 0; n < peerNumber; n += 1 {
+               peers = append(peers, &Peer{})
+       }
+
+       for n := 0; n < addressNumber; n += 1 {
+               var addr [AddressLength]byte
+               rand.Read(addr[:])
+               cidr := uint(rand.Uint32() % (AddressLength * 8))
+               index := rand.Int() % peerNumber
+               trie = trie.Insert(addr[:], cidr, peers[index])
+       }
+
+       for n := 0; n < b.N; n += 1 {
+               var addr [AddressLength]byte
+               rand.Read(addr[:])
+               trie.Lookup(addr[:])
+       }
+}
+
+func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
+       benchmarkTrie(100, 1000, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
+       benchmarkTrie(10, 10, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
+       benchmarkTrie(100, 1000, net.IPv6len, b)
+}
+
+func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
+       benchmarkTrie(10, 10, net.IPv6len, b)
+}
+
 /* Test ported from kernel implementation:
  * selftest/routingtable.h
  */
@@ -91,10 +136,10 @@ func TestTrieIPv4(t *testing.T) {
        insert(b, 192, 168, 4, 4, 32)
        insert(c, 192, 168, 0, 0, 16)
        insert(d, 192, 95, 5, 64, 27)
-       insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */
+       insert(c, 192, 95, 5, 65, 27)
        insert(e, 0, 0, 0, 0, 0)
        insert(g, 64, 15, 112, 0, 20)
-       insert(h, 64, 15, 123, 211, 25) /* maskself is required */
+       insert(h, 64, 15, 123, 211, 25)
        insert(a, 10, 0, 0, 0, 25)
        insert(b, 10, 0, 0, 128, 25)
        insert(a, 10, 1, 0, 0, 30)
@@ -186,20 +231,6 @@ func TestTrieIPv6(t *testing.T) {
                }
        }
 
-       /*
-               assertNEQ := func(peer *Peer, a, b, c, d uint32) {
-                       var addr []byte
-                       addr = append(addr, expand(a)...)
-                       addr = append(addr, expand(b)...)
-                       addr = append(addr, expand(c)...)
-                       addr = append(addr, expand(d)...)
-                       p := trie.Lookup(addr)
-                       if p == peer {
-                               t.Error("Assert NEQ failed")
-                       }
-               }
-       */
-
        insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
        insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
        insert(e, 0, 0, 0, 0, 0)
diff --git a/src/tun.go b/src/tun.go
new file mode 100644 (file)
index 0000000..1a8bb82
--- /dev/null
@@ -0,0 +1,8 @@
+package main
+
+type TUN interface {
+       Read([]byte) (int, error)
+       Write([]byte) (int, error)
+       Name() string
+       MTU() uint
+}
diff --git a/src/tun_linux.go b/src/tun_linux.go
new file mode 100644 (file)
index 0000000..d545dfa
--- /dev/null
@@ -0,0 +1,80 @@
+package main
+
+import (
+       "encoding/binary"
+       "errors"
+       "os"
+       "strings"
+       "syscall"
+       "unsafe"
+)
+
+/* Platform dependent functions for interacting with
+ * TUN devices on linux systems
+ *
+ */
+
+const CloneDevicePath = "/dev/net/tun"
+
+const (
+       IFF_NO_PI = 0x1000
+       IFF_TUN   = 0x1
+       IFNAMSIZ  = 0x10
+       TUNSETIFF = 0x400454CA
+)
+
+type NativeTun struct {
+       fd   *os.File
+       name string
+       mtu  uint
+}
+
+func (tun *NativeTun) Name() string {
+       return tun.name
+}
+
+func (tun *NativeTun) MTU() uint {
+       return tun.mtu
+}
+
+func (tun *NativeTun) Write(d []byte) (int, error) {
+       return tun.fd.Write(d)
+}
+
+func (tun *NativeTun) Read(d []byte) (int, error) {
+       return tun.fd.Read(d)
+}
+
+func CreateTUN(name string) (TUN, error) {
+       // Open clone device
+       fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
+       if err != nil {
+               return nil, err
+       }
+
+       // Prepare ifreq struct
+       var ifr [18]byte
+       var flags uint16 = IFF_TUN | IFF_NO_PI
+       nameBytes := []byte(name)
+       if len(nameBytes) >= IFNAMSIZ {
+               return nil, errors.New("Name size too long")
+       }
+       copy(ifr[:], nameBytes)
+       binary.LittleEndian.PutUint16(ifr[16:], flags)
+
+       // Create new device
+       _, _, errno := syscall.Syscall(syscall.SYS_IOCTL,
+               uintptr(fd.Fd()), uintptr(TUNSETIFF),
+               uintptr(unsafe.Pointer(&ifr[0])))
+       if errno != 0 {
+               return nil, errors.New("Failed to create tun, ioctl call failed")
+       }
+
+       // Read name of interface
+       newName := string(ifr[:])
+       newName = newName[:strings.Index(newName, "\000")]
+       return &NativeTun{
+               fd:   fd,
+               name: newName,
+       }, nil
+}