]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Support IPv6-less kernels
authorJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Jun 2018 17:04:38 +0000 (19:04 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 11 Jun 2018 23:32:46 +0000 (01:32 +0200)
conn_default.go
conn_linux.go

index 14ed56cbda28e6b968714131a47964d195a5374b..92135cb1486eb17938bf129173381f5f443046f6 100644 (file)
@@ -11,7 +11,9 @@ package main
 import (
        "golang.org/x/sys/unix"
        "net"
+       "os"
        "runtime"
+       "syscall"
 )
 
 /* This code is meant to be a temporary solution
@@ -87,6 +89,18 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
        return conn, uaddr.Port, nil
 }
 
+func extractErrno(err error) error {
+       opErr, ok := err.(*net.OpError)
+       if !ok {
+               return nil
+       }
+       syscallErr, ok := opErr.Err.(*os.SyscallError)
+       if !ok {
+               return nil
+       }
+       return syscallErr.Err
+}
+
 func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
        var err error
        var bind NativeBind
@@ -94,13 +108,15 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
        port := int(uport)
 
        bind.ipv4, port, err = listenNet("udp4", port)
-       if err != nil {
+       if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
                return nil, 0, err
        }
 
        bind.ipv6, port, err = listenNet("udp6", port)
-       if err != nil {
+       if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
+               return nil, 0, err
                bind.ipv4.Close()
+               bind.ipv4 = nil
                return nil, 0, err
        }
 
@@ -108,8 +124,13 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
 }
 
 func (bind *NativeBind) Close() error {
-       err1 := bind.ipv4.Close()
-       err2 := bind.ipv6.Close()
+       var err1, err2 error
+       if bind.ipv4 != nil {
+               err1 = bind.ipv4.Close()
+       }
+       if bind.ipv6 != nil {
+               err2 = bind.ipv6.Close()
+       }
        if err1 != nil {
                return err1
        }
@@ -117,6 +138,9 @@ func (bind *NativeBind) Close() error {
 }
 
 func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+       if bind.ipv4 == nil {
+               return 0, nil, syscall.EAFNOSUPPORT
+       }
        n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
        if endpoint != nil {
                endpoint.IP = endpoint.IP.To4()
@@ -125,6 +149,9 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 }
 
 func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+       if bind.ipv6 == nil {
+               return 0, nil, syscall.EAFNOSUPPORT
+       }
        n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
        return n, (*NativeEndpoint)(endpoint), err
 }
@@ -133,8 +160,14 @@ func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
        var err error
        nend := endpoint.(*NativeEndpoint)
        if nend.IP.To4() != nil {
+               if bind.ipv4 == nil {
+                       return syscall.EAFNOSUPPORT
+               }
                _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
        } else {
+               if bind.ipv6 == nil {
+                       return syscall.EAFNOSUPPORT
+               }
                _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
        }
        return err
@@ -157,31 +190,29 @@ func (bind *NativeBind) SetMark(mark uint32) error {
        if fwmarkIoctl == 0 {
                return nil
        }
-       fd4, err1 := bind.ipv4.SyscallConn()
-       fd6, err2 := bind.ipv6.SyscallConn()
-       if err1 != nil {
-               return err1
-       }
-       if err2 != nil {
-               return err2
-       }
-       err3 := fd4.Control(func(fd uintptr) {
-               err1 = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
-       })
-       err4 := fd6.Control(func(fd uintptr) {
-               err2 = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
-       })
-       if err1 != nil {
-               return err1
-       }
-       if err2 != nil {
-               return err2
-       }
-       if err3 != nil {
-               return err3
-       }
-       if err4 != nil {
-               return err4
+       if bind.ipv4 != nil {
+               fd, err := bind.ipv4.SyscallConn()
+               if err != nil {
+                       return err
+               }
+               err = fd.Control(func(fd uintptr) {
+                       err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+               })
+               if err != nil {
+                       return err
+               }
+       }
+       if bind.ipv6 != nil {
+               fd, err := bind.ipv6.SyscallConn()
+               if err != nil {
+                       return err
+               }
+               err = fd.Control(func(fd uintptr) {
+                       err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+               })
+               if err != nil {
+                       return err
+               }
        }
        return nil
 }
index 0227f044b3395a50ab77ee4a7a82029a9499c410..2b15d05534e50662e6ae22a9b5fb951a158e5742 100644 (file)
@@ -24,6 +24,7 @@ import (
        "net"
        "strconv"
        "sync"
+       "syscall"
        "unsafe"
 )
 
@@ -140,40 +141,45 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
        go bind.routineRouteListener(device)
 
        bind.sock6, port, err = create6(port)
-       if err != nil {
+       if err != nil && err != syscall.EAFNOSUPPORT {
                bind.netlinkCancel.Cancel()
-               return nil, port, err
+               return nil, 0, err
        }
 
        bind.sock4, port, err = create4(port)
-       if err != nil {
+       if err != nil && err != syscall.EAFNOSUPPORT {
                bind.netlinkCancel.Cancel()
                unix.Close(bind.sock6)
+               return nil, 0, err
        }
-       return &bind, port, err
+       return &bind, port, nil
 }
 
 func (bind *NativeBind) SetMark(value uint32) error {
-       err := unix.SetsockoptInt(
-               bind.sock6,
-               unix.SOL_SOCKET,
-               unix.SO_MARK,
-               int(value),
-       )
+       if bind.sock6 != -1 {
+               err := unix.SetsockoptInt(
+                       bind.sock6,
+                       unix.SOL_SOCKET,
+                       unix.SO_MARK,
+                       int(value),
+               )
 
-       if err != nil {
-               return err
+               if err != nil {
+                       return err
+               }
        }
 
-       err = unix.SetsockoptInt(
-               bind.sock4,
-               unix.SOL_SOCKET,
-               unix.SO_MARK,
-               int(value),
-       )
+       if bind.sock4 != -1 {
+               err := unix.SetsockoptInt(
+                       bind.sock4,
+                       unix.SOL_SOCKET,
+                       unix.SO_MARK,
+                       int(value),
+               )
 
-       if err != nil {
-               return err
+               if err != nil {
+                       return err
+               }
        }
 
        bind.lastMark = value
@@ -187,9 +193,14 @@ func closeUnblock(fd int) error {
 }
 
 func (bind *NativeBind) Close() error {
-       err1 := closeUnblock(bind.sock6)
-       err2 := closeUnblock(bind.sock4)
-       err3 := bind.netlinkCancel.Cancel()
+       var err1, err2, err3 error
+       if bind.sock6 != -1 {
+               err1 = closeUnblock(bind.sock6)
+       }
+       if bind.sock4 != -1 {
+               err2 = closeUnblock(bind.sock4)
+       }
+       err3 = bind.netlinkCancel.Cancel()
 
        if err1 != nil {
                return err1
@@ -202,6 +213,9 @@ func (bind *NativeBind) Close() error {
 
 func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
+       if bind.sock6 == -1 {
+               return 0, nil, syscall.EAFNOSUPPORT
+       }
        n, err := receive6(
                bind.sock6,
                buff,
@@ -212,6 +226,9 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
 
 func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
+       if bind.sock4 == -1 {
+               return 0, nil, syscall.EAFNOSUPPORT
+       }
        n, err := receive4(
                bind.sock4,
                buff,
@@ -223,8 +240,14 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
 func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
        nend := end.(*NativeEndpoint)
        if !nend.isV6 {
+               if bind.sock4 == -1 {
+                       return syscall.EAFNOSUPPORT
+               }
                return send4(bind.sock4, nend, buff)
        } else {
+               if bind.sock6 == -1 {
+                       return syscall.EAFNOSUPPORT
+               }
                return send6(bind.sock6, nend, buff)
        }
 }