]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: expand IPCError
authorJosh Bleecher Snyder <josh@tailscale.com>
Fri, 15 Jan 2021 21:24:38 +0000 (13:24 -0800)
committerJosh Bleecher Snyder <josh@tailscale.com>
Mon, 25 Jan 2021 16:47:48 +0000 (08:47 -0800)
Expand IPCError to contain a wrapped error,
and add a helper to make constructing such errors easier.

Add a defer-based "log on returned error" to IpcSetOperation.
This lets us simplify all of the error return paths.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
device/uapi.go

index 4436e72effc6c50ce8bde9d7d655b729e9b98cca..7f508692969d484d79929750d7d17eeb171c9ed7 100644 (file)
@@ -21,15 +21,24 @@ import (
 )
 
 type IPCError struct {
-       int64
+       code int64 // error code
+       err  error // underlying/wrapped error
 }
 
 func (s IPCError) Error() string {
-       return fmt.Sprintf("IPC error: %d", s.int64)
+       return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
+}
+
+func (s IPCError) Unwrap() error {
+       return s.err
 }
 
 func (s IPCError) ErrorCode() int64 {
-       return s.int64
+       return s.code
+}
+
+func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
+       return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
 }
 
 func (device *Device) IpcGetOperation(w io.Writer) error {
@@ -100,24 +109,28 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
        for _, line := range lines {
                _, err := io.WriteString(w, line+"\n")
                if err != nil {
-                       return &IPCError{ipc.IpcErrorIO}
+                       return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
                }
        }
 
        return nil
 }
 
-func (device *Device) IpcSetOperation(r io.Reader) error {
-       scanner := bufio.NewScanner(r)
-       logError := device.log.Error
+func (device *Device) IpcSetOperation(r io.Reader) (err error) {
+       defer func() {
+               if err != nil {
+                       device.log.Error.Println(err)
+               }
+       }()
+
        logDebug := device.log.Debug
 
        var peer *Peer
-
        dummy := false
        createdNewPeer := false
        deviceConfig := true
 
+       scanner := bufio.NewScanner(r)
        for scanner.Scan() {
 
                // parse line
@@ -128,7 +141,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                }
                parts := strings.Split(line, "=")
                if len(parts) != 2 {
-                       return &IPCError{ipc.IpcErrorProtocol}
+                       return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
                }
                key := parts[0]
                value := parts[1]
@@ -142,8 +155,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                var sk NoisePrivateKey
                                err := sk.FromMaybeZeroHex(value)
                                if err != nil {
-                                       logError.Println("Failed to set private_key:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
                                }
                                logDebug.Println("UAPI: Updating private key")
                                device.SetPrivateKey(sk)
@@ -154,8 +166,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
 
                                port, err := strconv.ParseUint(value, 10, 16)
                                if err != nil {
-                                       logError.Println("Failed to parse listen_port:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
                                }
 
                                // update port and rebind
@@ -167,8 +178,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                device.net.Unlock()
 
                                if err := device.BindUpdate(); err != nil {
-                                       logError.Println("Failed to set listen_port:", err)
-                                       return &IPCError{ipc.IpcErrorPortInUse}
+                                       return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
                                }
 
                        case "fwmark":
@@ -184,15 +194,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                }()
 
                                if err != nil {
-                                       logError.Println("Invalid fwmark", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
                                }
 
                                logDebug.Println("UAPI: Updating fwmark")
 
                                if err := device.BindSetMark(uint32(fwmark)); err != nil {
-                                       logError.Println("Failed to update fwmark:", err)
-                                       return &IPCError{ipc.IpcErrorPortInUse}
+                                       return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
                                }
 
                        case "public_key":
@@ -202,15 +210,13 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
 
                        case "replace_peers":
                                if value != "true" {
-                                       logError.Println("Failed to set replace_peers, invalid value:", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
                                }
                                logDebug.Println("UAPI: Removing all peers")
                                device.RemoveAllPeers()
 
                        default:
-                               logError.Println("Invalid UAPI device key:", key)
-                               return &IPCError{ipc.IpcErrorInvalid}
+                               return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
                        }
                }
 
@@ -224,8 +230,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                var publicKey NoisePublicKey
                                err := publicKey.FromHex(value)
                                if err != nil {
-                                       logError.Println("Failed to get peer by public key:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
                                }
 
                                // ignore peer with public key of device
@@ -244,8 +249,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                if createdNewPeer {
                                        peer, err = device.NewPeer(publicKey)
                                        if err != nil {
-                                               logError.Println("Failed to create new peer:", err)
-                                               return &IPCError{ipc.IpcErrorInvalid}
+                                               return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
                                        }
                                        logDebug.Println(peer, "- UAPI: Created")
                                }
@@ -255,8 +259,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                // allow disabling of creation
 
                                if value != "true" {
-                                       logError.Println("Failed to set update only, invalid value:", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
                                }
                                if createdNewPeer && !dummy {
                                        device.RemovePeer(peer.handshake.remoteStatic)
@@ -269,8 +272,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                // remove currently selected peer from device
 
                                if value != "true" {
-                                       logError.Println("Failed to set remove, invalid value:", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
                                }
                                if !dummy {
                                        logDebug.Println(peer, "- UAPI: Removing")
@@ -290,8 +292,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                peer.handshake.mutex.Unlock()
 
                                if err != nil {
-                                       logError.Println("Failed to set preshared key:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
                                }
 
                        case "endpoint":
@@ -312,8 +313,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                }()
 
                                if err != nil {
-                                       logError.Println("Failed to set endpoint:", err, ":", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
                                }
 
                        case "persistent_keepalive_interval":
@@ -324,8 +324,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
 
                                secs, err := strconv.ParseUint(value, 10, 16)
                                if err != nil {
-                                       logError.Println("Failed to set persistent keepalive interval:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
                                }
 
                                old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
@@ -334,8 +333,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
 
                                if old == 0 && secs != 0 {
                                        if err != nil {
-                                               logError.Println("Failed to get tun device status:", err)
-                                               return &IPCError{ipc.IpcErrorIO}
+                                               return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
                                        }
                                        if device.isUp.Get() && !dummy {
                                                peer.SendKeepalive()
@@ -347,8 +345,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                                logDebug.Println(peer, "- UAPI: Removing all allowedips")
 
                                if value != "true" {
-                                       logError.Println("Failed to replace allowedips, invalid value:", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
                                }
 
                                if dummy {
@@ -363,8 +360,7 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
 
                                _, network, err := net.ParseCIDR(value)
                                if err != nil {
-                                       logError.Println("Failed to set allowed ip:", err)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
                                }
 
                                if dummy {
@@ -377,13 +373,11 @@ func (device *Device) IpcSetOperation(r io.Reader) error {
                        case "protocol_version":
 
                                if value != "1" {
-                                       logError.Println("Invalid protocol version:", value)
-                                       return &IPCError{ipc.IpcErrorInvalid}
+                                       return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
                                }
 
                        default:
-                               logError.Println("Invalid UAPI peer key:", key)
-                               return &IPCError{ipc.IpcErrorInvalid}
+                               return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
                        }
                }
        }
@@ -431,16 +425,14 @@ func (device *Device) IpcHandle(socket net.Conn) {
                err = device.IpcSetOperation(buffered.Reader)
                if err != nil && !errors.As(err, &status) {
                        // should never happen
-                       device.log.Error.Println("Invalid UAPI error:", err)
-                       status = &IPCError{1}
+                       status = ipcErrorf(1, "invalid UAPI error: %w", err)
                }
 
        case "get=1\n":
                err = device.IpcGetOperation(buffered.Writer)
                if err != nil && !errors.As(err, &status) {
                        // should never happen
-                       device.log.Error.Println("Invalid UAPI error:", err)
-                       status = &IPCError{1}
+                       status = ipcErrorf(1, "invalid UAPI error: %w", err)
                }
 
        default: