]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: test up/down using virtual conn
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 22 Feb 2021 03:30:31 +0000 (04:30 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 23 Feb 2021 19:00:57 +0000 (20:00 +0100)
This prevents port clashing bugs.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bindtest/bindtest.go [new file with mode: 0644]
device/device_test.go
tun/tuntest/tuntest.go

diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
new file mode 100644 (file)
index 0000000..ad8fa05
--- /dev/null
@@ -0,0 +1,136 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package bindtest
+
+import (
+       "fmt"
+       "math/rand"
+       "net"
+       "os"
+       "strconv"
+
+       "golang.zx2c4.com/wireguard/conn"
+)
+
+type ChannelBind struct {
+       rx4, tx4         *chan []byte
+       rx6, tx6         *chan []byte
+       closeSignal      chan bool
+       source4, source6 ChannelEndpoint
+       target4, target6 ChannelEndpoint
+}
+
+type ChannelEndpoint uint16
+
+var _ conn.Bind = (*ChannelBind)(nil)
+var _ conn.Endpoint = (*ChannelEndpoint)(nil)
+
+func NewChannelBinds() [2]conn.Bind {
+       arx4 := make(chan []byte, 8192)
+       brx4 := make(chan []byte, 8192)
+       arx6 := make(chan []byte, 8192)
+       brx6 := make(chan []byte, 8192)
+       var binds [2]ChannelBind
+       binds[0].rx4 = &arx4
+       binds[0].tx4 = &brx4
+       binds[1].rx4 = &brx4
+       binds[1].tx4 = &arx4
+       binds[0].rx6 = &arx6
+       binds[0].tx6 = &brx6
+       binds[1].rx6 = &brx6
+       binds[1].tx6 = &arx6
+       binds[0].target4 = ChannelEndpoint(1)
+       binds[1].target4 = ChannelEndpoint(2)
+       binds[0].target6 = ChannelEndpoint(3)
+       binds[1].target6 = ChannelEndpoint(4)
+       binds[0].source4 = binds[1].target4
+       binds[0].source6 = binds[1].target6
+       binds[1].source4 = binds[0].target4
+       binds[1].source6 = binds[0].target6
+       return [2]conn.Bind{&binds[0], &binds[1]}
+}
+
+func (c ChannelEndpoint) ClearSrc() {}
+
+func (c ChannelEndpoint) SrcToString() string { return "" }
+
+func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
+
+func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
+
+func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+
+func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+
+func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
+       c.closeSignal = make(chan bool)
+       if rand.Uint32()&1 == 0 {
+               return uint16(c.source4), nil
+       } else {
+               return uint16(c.source6), nil
+       }
+}
+
+func (c *ChannelBind) Close() error {
+       if c.closeSignal != nil {
+               select {
+               case <-c.closeSignal:
+               default:
+                       close(c.closeSignal)
+               }
+       }
+       return nil
+}
+
+func (c *ChannelBind) SetMark(mark uint32) error { return nil }
+
+func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
+       select {
+       case <-c.closeSignal:
+               return 0, nil, net.ErrClosed
+       case rx := <-*c.rx6:
+               return copy(b, rx), c.target6, nil
+       }
+}
+
+func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
+       select {
+       case <-c.closeSignal:
+               return 0, nil, net.ErrClosed
+       case rx := <-*c.rx4:
+               return copy(b, rx), c.target4, nil
+       }
+}
+
+func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
+       select {
+       case <-c.closeSignal:
+               return net.ErrClosed
+       default:
+               bc := make([]byte, len(b))
+               copy(bc, b)
+               if ep.(ChannelEndpoint) == c.target4 {
+                       *c.tx4 <- bc
+               } else if ep.(ChannelEndpoint) == c.target6 {
+                       *c.tx6 <- bc
+               } else {
+                       return os.ErrInvalid
+               }
+       }
+       return nil
+}
+
+func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
+       _, port, err := net.SplitHostPort(s)
+       if err != nil {
+               return nil, err
+       }
+       i, err := strconv.ParseUint(port, 10, 16)
+       if err != nil {
+               return nil, err
+       }
+       return ChannelEndpoint(i), nil
+}
index 1716f9276f398caf6e509d9880e8d5270e320f04..29daeb9c399e089e18ca804bbb97c22721a5ff19 100644 (file)
@@ -8,7 +8,6 @@ package device
 import (
        "bytes"
        "encoding/hex"
-       "errors"
        "fmt"
        "io"
        "math/rand"
@@ -17,11 +16,11 @@ import (
        "runtime/pprof"
        "sync"
        "sync/atomic"
-       "syscall"
        "testing"
        "time"
 
        "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/conn/bindtest"
        "golang.zx2c4.com/wireguard/tun/tuntest"
 )
 
@@ -148,8 +147,14 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
 }
 
 // genTestPair creates a testPair.
-func genTestPair(tb testing.TB) (pair testPair) {
+func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
        cfg, endpointCfg := genConfigs(tb)
+       var binds [2]conn.Bind
+       if realSocket {
+               binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
+       } else {
+               binds = bindtest.NewChannelBinds()
+       }
        // Bring up a ChannelTun for each config.
        for i := range pair {
                p := &pair[i]
@@ -159,7 +164,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
                if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
                        level = LogLevelError
                }
-               p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+               p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
                if err := p.dev.IpcSet(cfg[i]); err != nil {
                        tb.Errorf("failed to configure device %d: %v", i, err)
                        p.dev.Close()
@@ -187,7 +192,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
 
 func TestTwoDevicePing(t *testing.T) {
        goroutineLeakCheck(t)
-       pair := genTestPair(t)
+       pair := genTestPair(t, true)
        t.Run("ping 1.0.0.1", func(t *testing.T) {
                pair.Send(t, Ping, nil)
        })
@@ -198,11 +203,11 @@ func TestTwoDevicePing(t *testing.T) {
 
 func TestUpDown(t *testing.T) {
        goroutineLeakCheck(t)
-       const itrials = 20
-       const otrials = 1
+       const itrials = 50
+       const otrials = 10
 
        for n := 0; n < otrials; n++ {
-               pair := genTestPair(t)
+               pair := genTestPair(t, false)
                for i := range pair {
                        for k := range pair[i].dev.peers.keyMap {
                                pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@@ -214,17 +219,8 @@ func TestUpDown(t *testing.T) {
                        go func(d *Device) {
                                defer wg.Done()
                                for i := 0; i < itrials; i++ {
-                                       start := time.Now()
-                                       for {
-                                               if err := d.Up(); err != nil {
-                                                       if errors.Is(err, syscall.EADDRINUSE) && time.Now().Sub(start) < time.Second*4 {
-                                                               // Some other test process is racing with us, so try again.
-                                                               time.Sleep(time.Millisecond * 10)
-                                                               continue
-                                                       }
-                                                       t.Errorf("failed up bring up device: %v", err)
-                                               }
-                                               break
+                                       if err := d.Up(); err != nil {
+                                               t.Errorf("failed up bring up device: %v", err)
                                        }
                                        time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
                                        if err := d.Down(); err != nil {
@@ -245,7 +241,7 @@ func TestUpDown(t *testing.T) {
 // TestConcurrencySafety does other things concurrently with tunnel use.
 // It is intended to be used with the race detector to catch data races.
 func TestConcurrencySafety(t *testing.T) {
-       pair := genTestPair(t)
+       pair := genTestPair(t, true)
        done := make(chan struct{})
 
        const warmupIters = 10
@@ -315,7 +311,7 @@ func TestConcurrencySafety(t *testing.T) {
 }
 
 func BenchmarkLatency(b *testing.B) {
-       pair := genTestPair(b)
+       pair := genTestPair(b, true)
 
        // Establish a connection.
        pair.Send(b, Ping, nil)
@@ -329,7 +325,7 @@ func BenchmarkLatency(b *testing.B) {
 }
 
 func BenchmarkThroughput(b *testing.B) {
-       pair := genTestPair(b)
+       pair := genTestPair(b, true)
 
        // Establish a connection.
        pair.Send(b, Ping, nil)
@@ -373,7 +369,7 @@ func BenchmarkThroughput(b *testing.B) {
 }
 
 func BenchmarkUAPIGet(b *testing.B) {
-       pair := genTestPair(b)
+       pair := genTestPair(b, true)
        pair.Send(b, Ping, nil)
        pair.Send(b, Pong, nil)
        b.ReportAllocs()
index 80ccdf9eb75aa68ae7bd0c0564bd3ef327e76d2d..92aa9d82b959edc8fa28b657997c0515feb30c22 100644 (file)
@@ -79,7 +79,6 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
        return pkt
 }
 
-// TODO(crawshaw): find a reusable home for this. package devicetest?
 type ChannelTUN struct {
        Inbound  chan []byte // incoming packets, closed on TUN close
        Outbound chan []byte // outbound packets, blocks forever on TUN close