]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: return error from Up() and Down()
authorJason A. Donenfeld <Jason@zx2c4.com>
Tue, 9 Feb 2021 23:12:23 +0000 (00:12 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 9 Feb 2021 23:12:23 +0000 (00:12 +0100)
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/device.go
device/device_test.go
main_windows.go
tun/netstack/examples/http_client.go

index 7f96a1e9edd9279b87aab4ec4a100bcaba56ff97..432549d6715852f61426b24cd799149a0a7212de 100644 (file)
@@ -139,37 +139,42 @@ func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
 }
 
 // changeState attempts to change the device state to match want.
-func (device *Device) changeState(want deviceState) {
+func (device *Device) changeState(want deviceState) (err error) {
        device.state.Lock()
        defer device.state.Unlock()
        old := device.deviceState()
        if old == deviceStateClosed {
                // once closed, always closed
                device.log.Verbosef("Interface closed, ignored requested state %s", want)
-               return
+               return nil
        }
        switch want {
        case old:
-               return
+               return nil
        case deviceStateUp:
                atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
-               if ok := device.upLocked(); ok {
+               err = device.upLocked()
+               if err == nil {
                        break
                }
                fallthrough // up failed; bring the device all the way back down
        case deviceStateDown:
                atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
-               device.downLocked()
+               errDown := device.downLocked()
+               if err == nil {
+                       err = errDown
+               }
        }
        device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+       return
 }
 
 // upLocked attempts to bring the device up and reports whether it succeeded.
 // The caller must hold device.state.mu and is responsible for updating device.state.state.
-func (device *Device) upLocked() bool {
+func (device *Device) upLocked() error {
        if err := device.BindUpdate(); err != nil {
                device.log.Errorf("Unable to update bind: %v", err)
-               return false
+               return err
        }
 
        device.peers.RLock()
@@ -180,12 +185,12 @@ func (device *Device) upLocked() bool {
                }
        }
        device.peers.RUnlock()
-       return true
+       return nil
 }
 
 // downLocked attempts to bring the device down.
 // The caller must hold device.state.mu and is responsible for updating device.state.state.
-func (device *Device) downLocked() {
+func (device *Device) downLocked() error {
        err := device.BindClose()
        if err != nil {
                device.log.Errorf("Bind close failed: %v", err)
@@ -196,14 +201,15 @@ func (device *Device) downLocked() {
                peer.Stop()
        }
        device.peers.RUnlock()
+       return err
 }
 
-func (device *Device) Up() {
-       device.changeState(deviceStateUp)
+func (device *Device) Up() error {
+       return device.changeState(deviceStateUp)
 }
 
-func (device *Device) Down() {
-       device.changeState(deviceStateDown)
+func (device *Device) Down() error {
+       return device.changeState(deviceStateDown)
 }
 
 func (device *Device) IsUnderLoad() bool {
index ce1ba9ba35831248551ee440eb2e7e8eaa085707..c17b35093e63b8163e886a2c8a034b86cc9807f3 100644 (file)
@@ -157,14 +157,13 @@ func genTestPair(tb testing.TB) (pair testPair) {
                        level = LogLevelError
                }
                p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
-               p.dev.Up()
                if err := p.dev.IpcSet(cfg[i]); err != nil {
                        tb.Errorf("failed to configure device %d: %v", i, err)
                        p.dev.Close()
                        continue
                }
-               if !p.dev.isUp() {
-                       tb.Errorf("device %d did not come up", i)
+               if err := p.dev.Up(); err != nil {
+                       tb.Errorf("failed to bring up device %d: %v", i, err)
                        p.dev.Close()
                        continue
                }
@@ -212,9 +211,13 @@ func TestUpDown(t *testing.T) {
                        go func(d *Device) {
                                defer wg.Done()
                                for i := 0; i < itrials; i++ {
-                                       d.Up()
+                                       if err := d.Up(); err != nil {
+                                               t.Errorf("failed up bring up device: %v", err)
+                                       }
                                        time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
-                                       d.Down()
+                                       if err := d.Down(); err != nil {
+                                               t.Errorf("failed to bring down device: %v", err)
+                                       }
                                        time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
                                }
                        }(pair[i].dev)
index 10b0c7d151118298f0a7822aa1fde1fba2eeb502..128a0cd8efe8c0f9cea496d838841060a0234e4b 100644 (file)
@@ -48,7 +48,11 @@ func main() {
        }
 
        device := device.NewDevice(tun, logger)
-       device.Up()
+       err = device.Up()
+       if err != nil {
+               logger.Errorf("Failed to bring up device: %v", err)
+               os.Exit(ExitSetupFailed)
+       }
        logger.Verbosef("Device started")
 
        uapi, err := ipc.UAPIListen(interfaceName)
index 2c1f8f4a59ddae795a2951b4ebb40f6327d33f29..25a8d12759f5a1605c628fb66e9ca096a6ccb8d6 100644 (file)
@@ -31,7 +31,10 @@ public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
 endpoint=163.172.161.0:12912
 allowed_ip=0.0.0.0/0
 `)
-       dev.Up()
+       err = dev.Up()
+       if err != nil {
+               log.Panic(err)
+       }
 
        client := http.Client{
                Transport: &http.Transport{