]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: remove listen port race in tests
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 8 Feb 2021 23:59:39 +0000 (00:59 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 9 Feb 2021 14:37:04 +0000 (15:37 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/device_test.go

index 634c999c4faefdaf6d5be0fd84e00fff101f3ff4..8f16207c2341eff3c8baf589e9ad312d403d64da 100644 (file)
@@ -22,15 +22,6 @@ import (
        "golang.zx2c4.com/wireguard/tun/tuntest"
 )
 
-func getFreePort(tb testing.TB) string {
-       l, err := net.ListenPacket("udp", "localhost:0")
-       if err != nil {
-               tb.Fatal(err)
-       }
-       defer l.Close()
-       return fmt.Sprintf("%d", l.LocalAddr().(*net.UDPAddr).Port)
-}
-
 // uapiCfg returns a string that contains cfg formatted use with IpcSet.
 // cfg is a series of alternating key/value strings.
 // uapiCfg exists because editors and humans like to insert
@@ -55,12 +46,7 @@ func uapiCfg(cfg ...string) string {
 
 // genConfigs generates a pair of configs that connect to each other.
 // The configs use distinct, probably-usable ports.
-func genConfigs(tb testing.TB) (cfgs [2]string) {
-       var port1, port2 string
-       for port1 == port2 {
-               port1 = getFreePort(tb)
-               port2 = getFreePort(tb)
-       }
+func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
        var key1, key2 NoisePrivateKey
        _, err := rand.Read(key1[:])
        if err != nil {
@@ -74,23 +60,29 @@ func genConfigs(tb testing.TB) (cfgs [2]string) {
 
        cfgs[0] = uapiCfg(
                "private_key", hex.EncodeToString(key1[:]),
-               "listen_port", port1,
+               "listen_port", "0",
                "replace_peers", "true",
                "public_key", hex.EncodeToString(pub2[:]),
                "protocol_version", "1",
                "replace_allowed_ips", "true",
                "allowed_ip", "1.0.0.2/32",
-               "endpoint", "127.0.0.1:"+port2,
+       )
+       endpointCfgs[0] = uapiCfg(
+               "public_key", hex.EncodeToString(pub2[:]),
+               "endpoint", "127.0.0.1:%d",
        )
        cfgs[1] = uapiCfg(
                "private_key", hex.EncodeToString(key2[:]),
-               "listen_port", port2,
+               "listen_port", "0",
                "replace_peers", "true",
                "public_key", hex.EncodeToString(pub1[:]),
                "protocol_version", "1",
                "replace_allowed_ips", "true",
                "allowed_ip", "1.0.0.1/32",
-               "endpoint", "127.0.0.1:"+port1,
+       )
+       endpointCfgs[1] = uapiCfg(
+               "public_key", hex.EncodeToString(pub1[:]),
+               "endpoint", "127.0.0.1:%d",
        )
        return
 }
@@ -154,52 +146,40 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
 
 // genTestPair creates a testPair.
 func genTestPair(tb testing.TB) (pair testPair) {
-       const maxAttempts = 10
-NextAttempt:
-       for i := 0; i < maxAttempts; i++ {
-               cfg := genConfigs(tb)
-               // Bring up a ChannelTun for each config.
-               for i := range pair {
-                       p := &pair[i]
-                       p.tun = tuntest.NewChannelTUN()
-                       if i == 0 {
-                               p.ip = net.ParseIP("1.0.0.1")
-                       } else {
-                               p.ip = net.ParseIP("1.0.0.2")
-                       }
-                       level := LogLevelVerbose
-                       if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
-                               level = LogLevelError
-                       }
-                       p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
-                       p.dev.Up()
-                       if err := p.dev.IpcSet(cfg[i]); err != nil {
-                               // genConfigs attempted to pick ports that were free.
-                               // There's a tiny window between genConfigs closing the port
-                               // and us opening it, during which another process could
-                               // start using it. We probably just lost that race.
-                               // Try again from the beginning.
-                               // If there's something permanent wrong,
-                               // we'll see that when we run out of attempts.
-                               tb.Logf("failed to configure device %d: %v", i, err)
-                               p.dev.Close()
-                               continue NextAttempt
-                       }
-                       // The device might still not be up, e.g. due to an error
-                       // in RoutineTUNEventReader's call to dev.Up that got swallowed.
-                       // Assume it's due to a transient error (port in use), and retry.
-                       if !p.dev.isUp() {
-                               tb.Logf("device %d did not come up, trying again", i)
-                               p.dev.Close()
-                               continue NextAttempt
-                       }
-                       // The device is up. Close it when the test completes.
-                       tb.Cleanup(p.dev.Close)
+       cfg, endpointCfg := genConfigs(tb)
+       // Bring up a ChannelTun for each config.
+       for i := range pair {
+               p := &pair[i]
+               p.tun = tuntest.NewChannelTUN()
+               p.ip = net.IPv4(1, 0, 0, byte(i+1))
+               level := LogLevelVerbose
+               if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
+                       level = LogLevelError
+               }
+               p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+               p.dev.Up()
+               if err := p.dev.IpcSet(cfg[i]); err != nil {
+                       tb.Errorf("failed to configure device %d: %v", i, err)
+                       p.dev.Close()
+                       continue
                }
-               return // success
+               if !p.dev.isUp() {
+                       tb.Errorf("device %d did not come up", i)
+                       p.dev.Close()
+                       continue
+               }
+               endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
+       }
+       for i := range pair {
+               p := &pair[i]
+               if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
+                       tb.Errorf("failed to configure device endpoint %d: %v", i, err)
+                       p.dev.Close()
+                       continue
+               }
+               // The device is ready. Close it when the test completes.
+               tb.Cleanup(p.dev.Close)
        }
-
-       tb.Fatalf("genChannelTUNs: failed %d times", maxAttempts)
        return
 }