]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Inital implementation of trie
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 1 Jun 2017 19:31:30 +0000 (21:31 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Thu, 1 Jun 2017 19:31:30 +0000 (21:31 +0200)
src/config.go
src/device.go
src/noise.go
src/peer.go
src/ping-test.go [deleted file]
src/routing.go [new file with mode: 0644]
src/trie.go
src/trie_test.go

index f6f1378cfb8a7cfa90c9421c8ee7820db5db7398..62af67a1c7ee2725c4407e68a063593cd1e9ba1a 100644 (file)
@@ -6,6 +6,7 @@ import (
        "fmt"
        "io"
        "log"
+       "net"
 )
 
 /* todo : use real error code
@@ -18,6 +19,7 @@ const (
        ipcErrorInvalidPrivateKey = 3
        ipcErrorInvalidPublicKey  = 4
        ipcErrorInvalidPort       = 5
+       ipcErrorInvalidIPAddress  = 6
 )
 
 type IPCError struct {
@@ -104,6 +106,10 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
                        }
 
                case "replace_peers":
+                       if key == "true" {
+                               dev.RemoveAllPeers()
+                       }
+                       // todo: else fail
 
                default:
                        /* Peer configuration */
@@ -116,20 +122,27 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "remove":
                                peer.mutex.Lock()
-
+                               dev.RemovePeer(peer.publicKey)
                                peer = nil
 
                        case "preshared_key":
-                               func() {
+                               err := func() error {
                                        peer.mutex.Lock()
                                        defer peer.mutex.Unlock()
+                                       return peer.presharedKey.FromHex(value)
                                }()
+                               if err != nil {
+                                       return &IPCError{Code: ipcErrorInvalidPublicKey}
+                               }
 
                        case "endpoint":
-                               func() {
-                                       peer.mutex.Lock()
-                                       defer peer.mutex.Unlock()
-                               }()
+                               ip := net.ParseIP(value)
+                               if ip == nil {
+                                       return &IPCError{Code: ipcErrorInvalidIPAddress}
+                               }
+                               peer.mutex.Lock()
+                               peer.endpoint = ip
+                               peer.mutex.Unlock()
 
                        case "persistent_keepalive_interval":
                                func() {
index cd0835c1feb877984b5c5fa2dbe74933b1b88a86..d03057dd2a28cd4e71217714875939042967d4f6 100644 (file)
@@ -5,10 +5,39 @@ import (
 )
 
 type Device struct {
-       mutex      sync.RWMutex
-       peers      map[NoisePublicKey]*Peer
-       privateKey NoisePrivateKey
-       publicKey  NoisePublicKey
-       fwMark     uint32
-       listenPort uint16
+       mutex        sync.RWMutex
+       peers        map[NoisePublicKey]*Peer
+       privateKey   NoisePrivateKey
+       publicKey    NoisePublicKey
+       fwMark       uint32
+       listenPort   uint16
+       routingTable RoutingTable
+}
+
+func (dev *Device) RemovePeer(key NoisePublicKey) {
+       dev.mutex.Lock()
+       defer dev.mutex.Unlock()
+       peer, ok := dev.peers[key]
+       if !ok {
+               return
+       }
+       peer.mutex.Lock()
+       dev.routingTable.RemovePeer(peer)
+       delete(dev.peers, key)
+}
+
+func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
+
+}
+
+func (dev *Device) RemoveAllPeers() {
+       dev.mutex.Lock()
+       defer dev.mutex.Unlock()
+
+       for key, peer := range dev.peers {
+               peer.mutex.Lock()
+               dev.routingTable.RemovePeer(peer)
+               delete(dev.peers, key)
+               peer.mutex.Unlock()
+       }
 }
index d13bdd64ab281ddfd27c57e5b647a776b558cc4e..5508f9a526f138659eefe3f27baff9283cb81921 100644 (file)
@@ -18,34 +18,38 @@ type (
        NoiseNonce        uint64 // padded to 12-bytes
 )
 
-func (key *NoisePrivateKey) FromHex(s string) error {
-       slice, err := hex.DecodeString(s)
+func loadExactHex(dst []byte, src string) error {
+       slice, err := hex.DecodeString(src)
        if err != nil {
                return err
        }
-       if len(slice) != NoisePrivateKeySize {
-               return errors.New("Invalid length of hex string for curve25519 point")
+       if len(slice) != len(dst) {
+               return errors.New("Hex string does not fit the slice")
        }
-       copy(key[:], slice)
+       copy(dst, slice)
        return nil
 }
 
-func (key *NoisePrivateKey) ToHex() string {
+func (key *NoisePrivateKey) FromHex(src string) error {
+       return loadExactHex(key[:], src)
+}
+
+func (key NoisePrivateKey) ToHex() string {
        return hex.EncodeToString(key[:])
 }
 
-func (key *NoisePublicKey) FromHex(s string) error {
-       slice, err := hex.DecodeString(s)
-       if err != nil {
-               return err
-       }
-       if len(slice) != NoisePublicKeySize {
-               return errors.New("Invalid length of hex string for curve25519 scalar")
-       }
-       copy(key[:], slice)
-       return nil
+func (key *NoisePublicKey) FromHex(src string) error {
+       return loadExactHex(key[:], src)
+}
+
+func (key NoisePublicKey) ToHex() string {
+       return hex.EncodeToString(key[:])
+}
+
+func (key *NoiseSymmetricKey) FromHex(src string) error {
+       return loadExactHex(key[:], src)
 }
 
-func (key *NoisePublicKey) ToHex() string {
+func (key NoiseSymmetricKey) ToHex() string {
        return hex.EncodeToString(key[:])
 }
index 7c000da7ccef61be49392d7acb81cfc8f0996388..7b2b2a6ce71ab7a2578620563fec86e17b5301cc 100644 (file)
@@ -1,6 +1,7 @@
 package main
 
 import (
+       "net"
        "sync"
 )
 
@@ -15,4 +16,5 @@ type Peer struct {
        mutex        sync.RWMutex
        publicKey    NoisePublicKey
        presharedKey NoiseSymmetricKey
+       endpoint     net.IP
 }
diff --git a/src/ping-test.go b/src/ping-test.go
deleted file mode 100644 (file)
index 4b58891..0000000
+++ /dev/null
@@ -1,175 +0,0 @@
-/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-
-package main
-
-import (
-       "crypto/rand"
-       "encoding/base64"
-       "encoding/binary"
-       "log"
-       "net"
-       "time"
-
-       "github.com/dchest/blake2s"
-       "github.com/titanous/noise"
-       "golang.org/x/net/icmp"
-       "golang.org/x/net/ipv4"
-)
-
-func ipChecksum(buf []byte) uint16 {
-       sum := uint32(0)
-       for ; len(buf) >= 2; buf = buf[2:] {
-               sum += uint32(buf[0])<<8 | uint32(buf[1])
-       }
-       if len(buf) > 0 {
-               sum += uint32(buf[0]) << 8
-       }
-       for sum > 0xffff {
-               sum = (sum >> 16) + (sum & 0xffff)
-       }
-       csum := ^uint16(sum)
-       if csum == 0 {
-               csum = 0xffff
-       }
-       return csum
-}
-
-func main() {
-       ourPrivate, _ := base64.StdEncoding.DecodeString("WAmgVYXkbT2bCtdcDwolI88/iVi/aV3/PHcUBTQSYmo=")
-       ourPublic, _ := base64.StdEncoding.DecodeString("K5sF9yESrSBsOXPd6TcpKNgqoy1Ik3ZFKl4FolzrRyI=")
-       theirPublic, _ := base64.StdEncoding.DecodeString("qRCwZSKInrMAq5sepfCdaCsRJaoLe5jhtzfiw7CjbwM=")
-       preshared, _ := base64.StdEncoding.DecodeString("FpCyhws9cxwWoV4xELtfJvjJN+zQVRPISllRWgeopVE=")
-       cs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s)
-       hs := noise.NewHandshakeState(noise.Config{
-               CipherSuite:   cs,
-               Random:        rand.Reader,
-               Pattern:       noise.HandshakeIK,
-               Initiator:     true,
-               Prologue:      []byte("WireGuard v1 zx2c4 Jason@zx2c4.com"),
-               PresharedKey:  preshared,
-               PresharedKeyPlacement: 2,
-               StaticKeypair: noise.DHKey{Private: ourPrivate, Public: ourPublic},
-               PeerStatic:    theirPublic,
-       })
-       conn, err := net.Dial("udp", "demo.wireguard.io:12913")
-       if err != nil {
-               log.Fatalf("error dialing udp socket: %s", err)
-       }
-       defer conn.Close()
-
-       // write handshake initiation packet
-       now := time.Now()
-       tai64n := make([]byte, 12)
-       binary.BigEndian.PutUint64(tai64n[:], 4611686018427387914+uint64(now.Unix()))
-       binary.BigEndian.PutUint32(tai64n[8:], uint32(now.UnixNano()))
-       initiationPacket := make([]byte, 8)
-       initiationPacket[0] = 1 // Type: Initiation
-       initiationPacket[1] = 0 // Reserved
-       initiationPacket[2] = 0 // Reserved
-       initiationPacket[3] = 0 // Reserved
-       binary.LittleEndian.PutUint32(initiationPacket[4:], 28) // Sender index: 28 (arbitrary)
-       initiationPacket, _, _ = hs.WriteMessage(initiationPacket, tai64n)
-       hasher, _ := blake2s.New(&blake2s.Config{Size: 32})
-       hasher.Write([]byte("mac1----"))
-       hasher.Write(theirPublic)
-       hasher, _ = blake2s.New(&blake2s.Config{Size: 16, Key: hasher.Sum(nil)})
-       hasher.Write(initiationPacket)
-       initiationPacket = append(initiationPacket, hasher.Sum(nil)[:16]...)
-       initiationPacket = append(initiationPacket, make([]byte, 16)...)
-       if _, err := conn.Write(initiationPacket); err != nil {
-               log.Fatalf("error writing initiation packet: %s", err)
-       }
-
-       // read handshake response packet
-       responsePacket := make([]byte, 92)
-       n, err := conn.Read(responsePacket)
-       if err != nil {
-               log.Fatalf("error reading response packet: %s", err)
-       }
-       if n != len(responsePacket) {
-               log.Fatalf("response packet too short: want %d, got %d", len(responsePacket), n)
-       }
-       if responsePacket[0] != 2 { // Type: Response
-               log.Fatalf("response packet type wrong: want %d, got %d", 2, responsePacket[0])
-       }
-       if responsePacket[1] != 0 || responsePacket[2] != 0 || responsePacket[3] != 0 {
-               log.Fatalf("response packet has non-zero reserved fields")
-       }
-       theirIndex := binary.LittleEndian.Uint32(responsePacket[4:])
-       ourIndex := binary.LittleEndian.Uint32(responsePacket[8:])
-       if ourIndex != 28 {
-               log.Fatalf("response packet index wrong: want %d, got %d", 28, ourIndex)
-       }
-       payload, sendCipher, receiveCipher, err := hs.ReadMessage(nil, responsePacket[12:60])
-       if err != nil {
-               log.Fatalf("error reading handshake message: %s", err)
-       }
-       if len(payload) > 0 {
-               log.Fatalf("unexpected payload: %x", payload)
-       }
-
-       // write ICMP Echo packet
-       pingMessage, _ := (&icmp.Message{
-               Type: ipv4.ICMPTypeEcho,
-               Body: &icmp.Echo{
-                       ID:   921,
-                       Seq:  438,
-                       Data: []byte("WireGuard"),
-               },
-       }).Marshal(nil)
-       pingHeader, err := (&ipv4.Header{
-               Version:  ipv4.Version,
-               Len:      ipv4.HeaderLen,
-               TotalLen: ipv4.HeaderLen + len(pingMessage),
-               Protocol: 1, // ICMP
-               TTL:      20,
-               Src:      net.IPv4(10, 189, 129, 2),
-               Dst:      net.IPv4(10, 189, 129, 1),
-       }).Marshal()
-       binary.BigEndian.PutUint16(pingHeader[2:], uint16(ipv4.HeaderLen+len(pingMessage))) // fix the length endianness on BSDs
-       pingData := append(pingHeader, pingMessage...)
-       binary.BigEndian.PutUint16(pingData[10:], ipChecksum(pingData))
-       pingPacket := make([]byte, 16)
-       pingPacket[0] = 4 // Type: Data
-       pingPacket[1] = 0 // Reserved
-       pingPacket[2] = 0 // Reserved
-       pingPacket[3] = 0 // Reserved
-       binary.LittleEndian.PutUint32(pingPacket[4:], theirIndex)
-       binary.LittleEndian.PutUint64(pingPacket[8:], 0) // Nonce
-       pingPacket = sendCipher.Encrypt(pingPacket, nil, pingData)
-       if _, err := conn.Write(pingPacket); err != nil {
-               log.Fatalf("error writing ping message: %s", err)
-       }
-
-       // read ICMP Echo Reply packet
-       replyPacket := make([]byte, 128)
-       n, err = conn.Read(replyPacket)
-       if err != nil {
-               log.Fatalf("error reading ping reply message: %s", err)
-       }
-       replyPacket = replyPacket[:n]
-       if replyPacket[0] != 4 { // Type: Data
-               log.Fatalf("unexpected reply packet type: %d", replyPacket[0])
-       }
-       if replyPacket[1] != 0 || replyPacket[2] != 0 || replyPacket[3] != 0 {
-               log.Fatalf("reply packet has non-zero reserved fields")
-       }
-       replyPacket, err = receiveCipher.Decrypt(nil, nil, replyPacket[16:])
-       if err != nil {
-               log.Fatalf("error decrypting reply packet: %s", err)
-       }
-       replyHeaderLen := int(replyPacket[0]&0x0f) << 2
-       replyLen := binary.BigEndian.Uint16(replyPacket[2:])
-       replyMessage, err := icmp.ParseMessage(1, replyPacket[replyHeaderLen:replyLen])
-       if err != nil {
-               log.Fatalf("error parsing echo: %s", err)
-       }
-       echo, ok := replyMessage.Body.(*icmp.Echo)
-       if !ok {
-               log.Fatalf("unexpected reply body type %T", replyMessage.Body)
-       }
-
-       if echo.ID != 921 || echo.Seq != 438 || string(echo.Data) != "WireGuard" {
-               log.Fatalf("incorrect echo response: %#v", echo)
-       }
-}
diff --git a/src/routing.go b/src/routing.go
new file mode 100644 (file)
index 0000000..99b180c
--- /dev/null
@@ -0,0 +1,22 @@
+package main
+
+import (
+       "sync"
+)
+
+/* Thread-safe high level functions for cryptkey routing.
+ *
+ */
+
+type RoutingTable struct {
+       IPv4  *Trie
+       IPv6  *Trie
+       mutex sync.RWMutex
+}
+
+func (table *RoutingTable) RemovePeer(peer *Peer) {
+       table.mutex.Lock()
+       defer table.mutex.Unlock()
+       table.IPv4 = table.IPv4.RemovePeer(peer)
+       table.IPv6 = table.IPv6.RemovePeer(peer)
+}
index 7fd7c5f9d8ad478818f634155e744c2a3f269ac7..31a4d9230e50f3f963ba6bb8ed9916d112c620e9 100644 (file)
@@ -1,9 +1,11 @@
 package main
 
-import "fmt"
-
-/* Syncronization must be done seperatly
+/* Binary trie
+ *
+ * Syncronization done seperatly
+ * See: routing.go
  *
+ * Todo: Better commenting
  */
 
 type Trie struct {
@@ -13,7 +15,6 @@ type Trie struct {
        peer  *Peer
 
        // Index of "branching" bit
-       // bit_at_shift
        bit_at_byte  uint
        bit_at_shift uint
 }
@@ -92,7 +93,14 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
        return node.child[0]
 }
 
+func (node *Trie) choose(key []byte) byte {
+       return (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+}
+
 func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
+
+       // At leaf
+
        if node == nil {
                return &Trie{
                        bits:         key,
@@ -107,22 +115,17 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
 
        common := commonBits(node.bits, key)
        if node.cidr <= cidr && common >= node.cidr {
-               // Check if match the t.bits[:t.cidr] exactly
                if node.cidr == cidr {
                        node.peer = peer
                        return node
                }
-
-               // Go to child
-               bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+               bit := node.choose(key)
                node.child[bit] = node.child[bit].Insert(key, cidr, peer)
                return node
        }
 
        // Split node
 
-       fmt.Println("new", common)
-
        newNode := &Trie{
                bits:         key,
                peer:         peer,
@@ -132,23 +135,53 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
        }
 
        cidr = min(cidr, common)
-       node.cidr = cidr
-       node.bit_at_byte = cidr / 8
-       node.bit_at_shift = 7 - (cidr % 8)
 
-       // bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
-       // Work in progress
-       node.child[0] = newNode
-       node.child[1] = newNode
+       // Check for shorter prefix
 
-       return node
-}
+       if newNode.cidr == cidr {
+               bit := newNode.choose(node.bits)
+               newNode.child[bit] = node
+               return newNode
+       }
+
+       // Create new parent for node & newNode
 
-func (t *Trie) Lookup(key []byte) *Peer {
-       if t == nil {
-               return nil
+       parent := &Trie{
+               bits:         key,
+               peer:         nil,
+               cidr:         cidr,
+               bit_at_byte:  cidr / 8,
+               bit_at_shift: 7 - (cidr % 8),
        }
 
-       return nil
+       bit := parent.choose(key)
+       parent.child[bit] = newNode
+       parent.child[bit^1] = node
+
+       return parent
+}
+
+func (node *Trie) Lookup(key []byte) *Peer {
+       var found *Peer
+       size := uint(len(key))
+       for node != nil && commonBits(node.bits, key) >= node.cidr {
+               if node.peer != nil {
+                       found = node.peer
+               }
+               if node.bit_at_byte == size {
+                       break
+               }
+               bit := node.choose(key)
+               node = node.child[bit]
+       }
+       return found
+}
 
+func (node *Trie) Count() uint {
+       if node == nil {
+               return 0
+       }
+       l := node.child[0].Count()
+       r := node.child[1].Count()
+       return l + r
 }
index ec4cde349d83c9d84783216d5a5d4c145a47ea71..35af0aaa82080ac89eb732b7c68643460deb1b10 100644 (file)
@@ -4,6 +4,9 @@ import (
        "testing"
 )
 
+/* Todo: More comprehensive
+ */
+
 type testPairCommonBits struct {
        s1    []byte
        s2    []byte
@@ -16,6 +19,11 @@ type testPairTrieInsert struct {
        peer *Peer
 }
 
+type testPairTrieLookup struct {
+       key  []byte
+       peer *Peer
+}
+
 func printTrie(t *testing.T, p *Trie) {
        if p == nil {
                return
@@ -41,26 +49,176 @@ func TestCommonBits(t *testing.T) {
                        t.Error(
                                "For slice", p.s1, p.s2,
                                "expected match", p.match,
-                               "got", v,
+                               ",but got", v,
                        )
                }
        }
 }
 
-func TestTrieInsertV4(t *testing.T) {
+/* Test ported from kernel implementation:
+ * selftest/routingtable.h
+ */
+func TestTrieIPv4(t *testing.T) {
+       a := &Peer{}
+       b := &Peer{}
+       c := &Peer{}
+       d := &Peer{}
+       e := &Peer{}
+       g := &Peer{}
+       h := &Peer{}
+
        var trie *Trie
 
-       peer1 := Peer{}
-       peer2 := Peer{}
+       insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
+               trie = trie.Insert([]byte{a, b, c, d}, cidr, peer)
+       }
 
-       tests := []testPairTrieInsert{
-               {key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
-               {key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
+       assertEQ := func(peer *Peer, a, b, c, d byte) {
+               p := trie.Lookup([]byte{a, b, c, d})
+               if p != peer {
+                       t.Error("Assert EQ failed")
+               }
        }
 
-       for _, p := range tests {
-               trie = trie.Insert(p.key, p.cidr, p.peer)
-               printTrie(t, trie)
+       assertNEQ := func(peer *Peer, a, b, c, d byte) {
+               p := trie.Lookup([]byte{a, b, c, d})
+               if p == peer {
+                       t.Error("Assert NEQ failed")
+               }
+       }
+
+       insert(a, 192, 168, 4, 0, 24)
+       insert(b, 192, 168, 4, 4, 32)
+       insert(c, 192, 168, 0, 0, 16)
+       insert(d, 192, 95, 5, 64, 27)
+       insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */
+       insert(e, 0, 0, 0, 0, 0)
+       insert(g, 64, 15, 112, 0, 20)
+       insert(h, 64, 15, 123, 211, 25) /* maskself is required */
+       insert(a, 10, 0, 0, 0, 25)
+       insert(b, 10, 0, 0, 128, 25)
+       insert(a, 10, 1, 0, 0, 30)
+       insert(b, 10, 1, 0, 4, 30)
+       insert(c, 10, 1, 0, 8, 29)
+       insert(d, 10, 1, 0, 16, 29)
+
+       assertEQ(a, 192, 168, 4, 20)
+       assertEQ(a, 192, 168, 4, 0)
+       assertEQ(b, 192, 168, 4, 4)
+       assertEQ(c, 192, 168, 200, 182)
+       assertEQ(c, 192, 95, 5, 68)
+       assertEQ(e, 192, 95, 5, 96)
+       assertEQ(g, 64, 15, 116, 26)
+       assertEQ(g, 64, 15, 127, 3)
+
+       insert(a, 1, 0, 0, 0, 32)
+       insert(a, 64, 0, 0, 0, 32)
+       insert(a, 128, 0, 0, 0, 32)
+       insert(a, 192, 0, 0, 0, 32)
+       insert(a, 255, 0, 0, 0, 32)
+
+       assertEQ(a, 1, 0, 0, 0)
+       assertEQ(a, 64, 0, 0, 0)
+       assertEQ(a, 128, 0, 0, 0)
+       assertEQ(a, 192, 0, 0, 0)
+       assertEQ(a, 255, 0, 0, 0)
+
+       trie = trie.RemovePeer(a)
+
+       assertNEQ(a, 1, 0, 0, 0)
+       assertNEQ(a, 64, 0, 0, 0)
+       assertNEQ(a, 128, 0, 0, 0)
+       assertNEQ(a, 192, 0, 0, 0)
+       assertNEQ(a, 255, 0, 0, 0)
+
+       trie = nil
+
+       insert(a, 192, 168, 0, 0, 16)
+       insert(a, 192, 168, 0, 0, 24)
+
+       trie = trie.RemovePeer(a)
+
+       assertNEQ(a, 192, 168, 0, 1)
+}
+
+/* Test ported from kernel implementation:
+ * selftest/routingtable.h
+ */
+func TestTrieIPv6(t *testing.T) {
+       a := &Peer{}
+       b := &Peer{}
+       c := &Peer{}
+       d := &Peer{}
+       e := &Peer{}
+       f := &Peer{}
+       g := &Peer{}
+       h := &Peer{}
+
+       var trie *Trie
+
+       expand := func(a uint32) []byte {
+               var out [4]byte
+               out[0] = byte(a >> 24 & 0xff)
+               out[1] = byte(a >> 16 & 0xff)
+               out[2] = byte(a >> 8 & 0xff)
+               out[3] = byte(a & 0xff)
+               return out[:]
        }
 
+       insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+               var addr []byte
+               addr = append(addr, expand(a)...)
+               addr = append(addr, expand(b)...)
+               addr = append(addr, expand(c)...)
+               addr = append(addr, expand(d)...)
+               trie = trie.Insert(addr, cidr, peer)
+       }
+
+       assertEQ := func(peer *Peer, a, b, c, d uint32) {
+               var addr []byte
+               addr = append(addr, expand(a)...)
+               addr = append(addr, expand(b)...)
+               addr = append(addr, expand(c)...)
+               addr = append(addr, expand(d)...)
+               p := trie.Lookup(addr)
+               if p != peer {
+                       t.Error("Assert EQ failed")
+               }
+       }
+
+       /*
+               assertNEQ := func(peer *Peer, a, b, c, d uint32) {
+                       var addr []byte
+                       addr = append(addr, expand(a)...)
+                       addr = append(addr, expand(b)...)
+                       addr = append(addr, expand(c)...)
+                       addr = append(addr, expand(d)...)
+                       p := trie.Lookup(addr)
+                       if p == peer {
+                               t.Error("Assert NEQ failed")
+                       }
+               }
+       */
+
+       insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
+       insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
+       insert(e, 0, 0, 0, 0, 0)
+       insert(f, 0, 0, 0, 0, 0)
+       insert(g, 0x24046800, 0, 0, 0, 32)
+       insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
+       insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
+       insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+       insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+
+       assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
+       assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
+       assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
+       assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
+       assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
+       assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
+       assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
+       assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
+       assertEQ(h, 0x24046800, 0x40040800, 0, 0)
+       assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
+       assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
 }