]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Beginning work on UAPI and routing table
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 30 May 2017 20:36:49 +0000 (22:36 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Tue, 30 May 2017 20:36:49 +0000 (22:36 +0200)
src/config.go [new file with mode: 0644]
src/device.go [new file with mode: 0644]
src/main.go [new file with mode: 0644]
src/misc.go [new file with mode: 0644]
src/noise.go [new file with mode: 0644]
src/peer.go [new file with mode: 0644]
src/trie.go [new file with mode: 0644]
src/trie_test.go [new file with mode: 0644]

diff --git a/src/config.go b/src/config.go
new file mode 100644 (file)
index 0000000..f6f1378
--- /dev/null
@@ -0,0 +1,190 @@
+package main
+
+import (
+       "bufio"
+       "errors"
+       "fmt"
+       "io"
+       "log"
+)
+
+/* todo : use real error code
+ * Many of which will be the same
+ */
+const (
+       ipcErrorNoPeer            = 0
+       ipcErrorNoKeyValue        = 1
+       ipcErrorInvalidKey        = 2
+       ipcErrorInvalidPrivateKey = 3
+       ipcErrorInvalidPublicKey  = 4
+       ipcErrorInvalidPort       = 5
+)
+
+type IPCError struct {
+       Code int
+}
+
+func (s *IPCError) Error() string {
+       return fmt.Sprintf("IPC error: %d", s.Code)
+}
+
+func (s *IPCError) ErrorCode() int {
+       return s.Code
+}
+
+// Writes the configuration to the socket
+func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
+
+}
+
+// Creates new config, from old and socket message
+func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
+
+       scanner := bufio.NewScanner(socket)
+
+       dev.mutex.Lock()
+       defer dev.mutex.Unlock()
+
+       for scanner.Scan() {
+               var key string
+               var value string
+               var peer *Peer
+
+               // Parse line
+
+               line := scanner.Text()
+               if line == "\n" {
+                       break
+               }
+               fmt.Println(line)
+               n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
+               if n != 2 || err != nil {
+                       fmt.Println(err, n)
+                       return &IPCError{Code: ipcErrorNoKeyValue}
+               }
+
+               switch key {
+
+               /* Interface configuration */
+
+               case "private_key":
+                       if value == "" {
+                               dev.privateKey = NoisePrivateKey{}
+                       } else {
+                               err := dev.privateKey.FromHex(value)
+                               if err != nil {
+                                       return &IPCError{Code: ipcErrorInvalidPrivateKey}
+                               }
+                       }
+
+               case "listen_port":
+                       _, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
+                       if err != nil {
+                               return &IPCError{Code: ipcErrorInvalidPort}
+                       }
+
+               case "fwmark":
+                       panic(nil) // not handled yet
+
+               case "public_key":
+                       var pubKey NoisePublicKey
+                       err := pubKey.FromHex(value)
+                       if err != nil {
+                               return &IPCError{Code: ipcErrorInvalidPublicKey}
+                       }
+                       found, ok := dev.peers[pubKey]
+                       if ok {
+                               peer = found
+                       } else {
+                               newPeer := &Peer{
+                                       publicKey: pubKey,
+                               }
+                               peer = newPeer
+                               dev.peers[pubKey] = newPeer
+                       }
+
+               case "replace_peers":
+
+               default:
+                       /* Peer configuration */
+
+                       if peer == nil {
+                               return &IPCError{Code: ipcErrorNoPeer}
+                       }
+
+                       switch key {
+
+                       case "remove":
+                               peer.mutex.Lock()
+
+                               peer = nil
+
+                       case "preshared_key":
+                               func() {
+                                       peer.mutex.Lock()
+                                       defer peer.mutex.Unlock()
+                               }()
+
+                       case "endpoint":
+                               func() {
+                                       peer.mutex.Lock()
+                                       defer peer.mutex.Unlock()
+                               }()
+
+                       case "persistent_keepalive_interval":
+                               func() {
+                                       peer.mutex.Lock()
+                                       defer peer.mutex.Unlock()
+                               }()
+
+                       case "replace_allowed_ips":
+                               // remove peer from trie
+
+                       case "allowed_ip":
+
+                       /* Invalid key */
+
+                       default:
+                               return &IPCError{Code: ipcErrorInvalidKey}
+                       }
+               }
+       }
+
+       return nil
+}
+
+func ipcListen(dev *Device, socket io.ReadWriter) error {
+
+       buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+               reader := bufio.NewReader(s)
+               writer := bufio.NewWriter(s)
+               return bufio.NewReadWriter(reader, writer)
+       }(socket)
+
+       for {
+               op, err := buffered.ReadString('\n')
+               if err != nil {
+                       return err
+               }
+               log.Println(op)
+
+               switch op {
+
+               case "set=1\n":
+                       err := ipcSetOperation(dev, buffered)
+                       if err != nil {
+                               fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
+                               return err
+                       } else {
+                               fmt.Fprintf(buffered, "errno=0\n")
+                       }
+                       buffered.Flush()
+
+               case "get=1\n":
+
+               default:
+                       return errors.New("handle this please")
+               }
+       }
+
+}
diff --git a/src/device.go b/src/device.go
new file mode 100644 (file)
index 0000000..cd0835c
--- /dev/null
@@ -0,0 +1,14 @@
+package main
+
+import (
+       "sync"
+)
+
+type Device struct {
+       mutex      sync.RWMutex
+       peers      map[NoisePublicKey]*Peer
+       privateKey NoisePrivateKey
+       publicKey  NoisePublicKey
+       fwMark     uint32
+       listenPort uint16
+}
diff --git a/src/main.go b/src/main.go
new file mode 100644 (file)
index 0000000..0f5016d
--- /dev/null
@@ -0,0 +1,28 @@
+package main
+
+import (
+       "fmt"
+       "log"
+       "net"
+)
+
+func main() {
+       l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
+       if err != nil {
+               log.Fatal("listen error:", err)
+       }
+
+       for {
+               fd, err := l.Accept()
+               if err != nil {
+                       log.Fatal("accept error:", err)
+               }
+
+               var dev Device
+               go func(conn net.Conn) {
+                       err := ipcListen(&dev, conn)
+                       fmt.Println(err)
+               }(fd)
+       }
+
+}
diff --git a/src/misc.go b/src/misc.go
new file mode 100644 (file)
index 0000000..e1244d6
--- /dev/null
@@ -0,0 +1,8 @@
+package main
+
+func min(a uint, b uint) uint {
+       if a > b {
+               return b
+       }
+       return a
+}
diff --git a/src/noise.go b/src/noise.go
new file mode 100644 (file)
index 0000000..d13bdd6
--- /dev/null
@@ -0,0 +1,51 @@
+package main
+
+import (
+       "encoding/hex"
+       "errors"
+)
+
+const (
+       NoisePublicKeySize    = 32
+       NoisePrivateKeySize   = 32
+       NoiseSymmetricKeySize = 32
+)
+
+type (
+       NoisePublicKey    [NoisePublicKeySize]byte
+       NoisePrivateKey   [NoisePrivateKeySize]byte
+       NoiseSymmetricKey [NoiseSymmetricKeySize]byte
+       NoiseNonce        uint64 // padded to 12-bytes
+)
+
+func (key *NoisePrivateKey) FromHex(s string) error {
+       slice, err := hex.DecodeString(s)
+       if err != nil {
+               return err
+       }
+       if len(slice) != NoisePrivateKeySize {
+               return errors.New("Invalid length of hex string for curve25519 point")
+       }
+       copy(key[:], slice)
+       return nil
+}
+
+func (key *NoisePrivateKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
+
+func (key *NoisePublicKey) FromHex(s string) error {
+       slice, err := hex.DecodeString(s)
+       if err != nil {
+               return err
+       }
+       if len(slice) != NoisePublicKeySize {
+               return errors.New("Invalid length of hex string for curve25519 scalar")
+       }
+       copy(key[:], slice)
+       return nil
+}
+
+func (key *NoisePublicKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
diff --git a/src/peer.go b/src/peer.go
new file mode 100644 (file)
index 0000000..7c000da
--- /dev/null
@@ -0,0 +1,18 @@
+package main
+
+import (
+       "sync"
+)
+
+type KeyPair struct {
+       recieveKey   NoiseSymmetricKey
+       recieveNonce NoiseNonce
+       sendKey      NoiseSymmetricKey
+       sendNonce    NoiseNonce
+}
+
+type Peer struct {
+       mutex        sync.RWMutex
+       publicKey    NoisePublicKey
+       presharedKey NoiseSymmetricKey
+}
diff --git a/src/trie.go b/src/trie.go
new file mode 100644 (file)
index 0000000..7fd7c5f
--- /dev/null
@@ -0,0 +1,154 @@
+package main
+
+import "fmt"
+
+/* Syncronization must be done seperatly
+ *
+ */
+
+type Trie struct {
+       cidr  uint
+       child [2]*Trie
+       bits  []byte
+       peer  *Peer
+
+       // Index of "branching" bit
+       // bit_at_shift
+       bit_at_byte  uint
+       bit_at_shift uint
+}
+
+/* Finds length of matching prefix
+ * Maybe there is a faster way
+ *
+ * Assumption: len(s1) == len(s2)
+ */
+func commonBits(s1 []byte, s2 []byte) uint {
+       var i uint
+       size := uint(len(s1))
+       for i = 0; i < size; i += 1 {
+               v := s1[i] ^ s2[i]
+               if v != 0 {
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 7
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 6
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 5
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 4
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 3
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 2
+                       }
+
+                       v >>= 1
+                       if v == 0 {
+                               return i*8 + 1
+                       }
+                       return i * 8
+               }
+       }
+       return i * 8
+}
+
+func (node *Trie) RemovePeer(p *Peer) *Trie {
+       if node == nil {
+               return node
+       }
+
+       // Walk recursivly
+
+       node.child[0] = node.child[0].RemovePeer(p)
+       node.child[1] = node.child[1].RemovePeer(p)
+
+       if node.peer != p {
+               return node
+       }
+
+       // Remove peer & merge
+
+       node.peer = nil
+       if node.child[0] == nil {
+               return node.child[1]
+       }
+       return node.child[0]
+}
+
+func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
+       if node == nil {
+               return &Trie{
+                       bits:         key,
+                       peer:         peer,
+                       cidr:         cidr,
+                       bit_at_byte:  cidr / 8,
+                       bit_at_shift: 7 - (cidr % 8),
+               }
+       }
+
+       // Traverse deeper
+
+       common := commonBits(node.bits, key)
+       if node.cidr <= cidr && common >= node.cidr {
+               // Check if match the t.bits[:t.cidr] exactly
+               if node.cidr == cidr {
+                       node.peer = peer
+                       return node
+               }
+
+               // Go to child
+               bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+               node.child[bit] = node.child[bit].Insert(key, cidr, peer)
+               return node
+       }
+
+       // Split node
+
+       fmt.Println("new", common)
+
+       newNode := &Trie{
+               bits:         key,
+               peer:         peer,
+               cidr:         cidr,
+               bit_at_byte:  cidr / 8,
+               bit_at_shift: 7 - (cidr % 8),
+       }
+
+       cidr = min(cidr, common)
+       node.cidr = cidr
+       node.bit_at_byte = cidr / 8
+       node.bit_at_shift = 7 - (cidr % 8)
+
+       // bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
+       // Work in progress
+       node.child[0] = newNode
+       node.child[1] = newNode
+
+       return node
+}
+
+func (t *Trie) Lookup(key []byte) *Peer {
+       if t == nil {
+               return nil
+       }
+
+       return nil
+
+}
diff --git a/src/trie_test.go b/src/trie_test.go
new file mode 100644 (file)
index 0000000..ec4cde3
--- /dev/null
@@ -0,0 +1,66 @@
+package main
+
+import (
+       "testing"
+)
+
+type testPairCommonBits struct {
+       s1    []byte
+       s2    []byte
+       match uint
+}
+
+type testPairTrieInsert struct {
+       key  []byte
+       cidr uint
+       peer *Peer
+}
+
+func printTrie(t *testing.T, p *Trie) {
+       if p == nil {
+               return
+       }
+       t.Log(p)
+       printTrie(t, p.child[0])
+       printTrie(t, p.child[1])
+}
+
+func TestCommonBits(t *testing.T) {
+
+       tests := []testPairCommonBits{
+               {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
+               {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
+               {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
+               {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
+               {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+       }
+
+       for _, p := range tests {
+               v := commonBits(p.s1, p.s2)
+               if v != p.match {
+                       t.Error(
+                               "For slice", p.s1, p.s2,
+                               "expected match", p.match,
+                               "got", v,
+                       )
+               }
+       }
+}
+
+func TestTrieInsertV4(t *testing.T) {
+       var trie *Trie
+
+       peer1 := Peer{}
+       peer2 := Peer{}
+
+       tests := []testPairTrieInsert{
+               {key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
+               {key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
+       }
+
+       for _, p := range tests {
+               trie = trie.Insert(p.key, p.cidr, p.peer)
+               printTrie(t, trie)
+       }
+
+}