]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
boundif: introduce API for socket binding
authorJason A. Donenfeld <Jason@zx2c4.com>
Sun, 3 Mar 2019 04:01:06 +0000 (05:01 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Mon, 4 Mar 2019 15:37:11 +0000 (16:37 +0100)
device/boundif_android.go [new file with mode: 0644]
device/boundif_darwin.go [new file with mode: 0644]
device/boundif_windows.go [new file with mode: 0644]
device/conn_default.go
device/conn_linux.go
device/mark_default.go
device/mark_unix.go
device/peer.go

diff --git a/device/boundif_android.go b/device/boundif_android.go
new file mode 100644 (file)
index 0000000..ecc9331
--- /dev/null
@@ -0,0 +1,34 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
+       sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+       if err != nil {
+               return
+       }
+       err = sysconn.Control(func(f uintptr) {
+               fd = int(f)
+       })
+       if err != nil {
+               return
+       }
+       return
+}
+
+func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
+       sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
+       if err != nil {
+               return
+       }
+       err = sysconn.Control(func(f uintptr) {
+               fd = int(f)
+       })
+       if err != nil {
+               return
+       }
+       return
+}
diff --git a/device/boundif_darwin.go b/device/boundif_darwin.go
new file mode 100644 (file)
index 0000000..b3d10ba
--- /dev/null
@@ -0,0 +1,44 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+       "golang.org/x/sys/unix"
+)
+
+func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
+       sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+       if err != nil {
+               return nil
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(interfaceIndex))
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
+       sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+       if err != nil {
+               return nil
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(interfaceIndex))
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       return nil
+}
\ No newline at end of file
diff --git a/device/boundif_windows.go b/device/boundif_windows.go
new file mode 100644 (file)
index 0000000..00631cb
--- /dev/null
@@ -0,0 +1,56 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+       "encoding/binary"
+       "golang.org/x/sys/windows"
+       "unsafe"
+)
+
+const (
+       sockoptIP_UNICAST_IF   = 31
+       sockoptIPV6_UNICAST_IF = 31
+)
+
+func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
+       /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+       bytes := make([]byte, 4)
+       binary.BigEndian.PutUint32(bytes, interfaceIndex)
+       interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+
+       sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
+       if err != nil {
+               return err
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       return nil
+}
+
+func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
+       sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
+       if err != nil {
+               return err
+       }
+       err2 := sysconn.Control(func(fd uintptr) {
+               err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
+       })
+       if err2 != nil {
+               return err2
+       }
+       if err != nil {
+               return err
+       }
+       return nil
+}
\ No newline at end of file
index 8a86719a3ad1dd734079932aa42a153b31d1f27a..820bb96b693fea8a73484e996d5e16fda0a72e2f 100644 (file)
@@ -20,14 +20,14 @@ import (
  * See conn_linux.go for an implementation on the linux platform.
  */
 
-type NativeBind struct {
+type nativeBind struct {
        ipv4 *net.UDPConn
        ipv6 *net.UDPConn
 }
 
 type NativeEndpoint net.UDPAddr
 
-var _ Bind = (*NativeBind)(nil)
+var _ Bind = (*nativeBind)(nil)
 var _ Endpoint = (*NativeEndpoint)(nil)
 
 func CreateEndpoint(s string) (Endpoint, error) {
@@ -100,7 +100,7 @@ func extractErrno(err error) error {
 
 func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
        var err error
-       var bind NativeBind
+       var bind nativeBind
 
        port := int(uport)
 
@@ -119,7 +119,7 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
        return &bind, uint16(port), nil
 }
 
-func (bind *NativeBind) Close() error {
+func (bind *nativeBind) Close() error {
        var err1, err2 error
        if bind.ipv4 != nil {
                err1 = bind.ipv4.Close()
@@ -133,7 +133,7 @@ func (bind *NativeBind) Close() error {
        return err2
 }
 
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        if bind.ipv4 == nil {
                return 0, nil, syscall.EAFNOSUPPORT
        }
@@ -144,7 +144,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        return n, (*NativeEndpoint)(endpoint), err
 }
 
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        if bind.ipv6 == nil {
                return 0, nil, syscall.EAFNOSUPPORT
        }
@@ -152,7 +152,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        return n, (*NativeEndpoint)(endpoint), err
 }
 
-func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
+func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
        var err error
        nend := endpoint.(*NativeEndpoint)
        if nend.IP.To4() != nil {
index 49949d575e59dcdb2a51a3c3f2bcb4b7c1bd9b65..6a8520eccd7700cf360e17d83413ae72717f76cf 100644 (file)
@@ -63,7 +63,7 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
        return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
 }
 
-type NativeBind struct {
+type nativeBind struct {
        sock4         int
        sock6         int
        netlinkSock   int
@@ -72,7 +72,7 @@ type NativeBind struct {
 }
 
 var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = (*NativeBind)(nil)
+var _ Bind = (*nativeBind)(nil)
 
 func CreateEndpoint(s string) (Endpoint, error) {
        var end NativeEndpoint
@@ -127,9 +127,9 @@ func createNetlinkRouteSocket() (int, error) {
 
 }
 
-func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
+func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
        var err error
-       var bind NativeBind
+       var bind nativeBind
        var newPort uint16
 
        bind.netlinkSock, err = createNetlinkRouteSocket()
@@ -176,7 +176,7 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
        return &bind, port, nil
 }
 
-func (bind *NativeBind) SetMark(value uint32) error {
+func (bind *nativeBind) SetMark(value uint32) error {
        if bind.sock6 != -1 {
                err := unix.SetsockoptInt(
                        bind.sock6,
@@ -213,7 +213,7 @@ func closeUnblock(fd int) error {
        return unix.Close(fd)
 }
 
-func (bind *NativeBind) Close() error {
+func (bind *nativeBind) Close() error {
        var err1, err2, err3 error
        if bind.sock6 != -1 {
                err1 = closeUnblock(bind.sock6)
@@ -232,7 +232,7 @@ func (bind *NativeBind) Close() error {
        return err3
 }
 
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
        if bind.sock6 == -1 {
                return 0, nil, syscall.EAFNOSUPPORT
@@ -245,7 +245,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
        return n, &end, err
 }
 
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        var end NativeEndpoint
        if bind.sock4 == -1 {
                return 0, nil, syscall.EAFNOSUPPORT
@@ -258,7 +258,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
        return n, &end, err
 }
 
-func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
+func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
        nend := end.(*NativeEndpoint)
        if !nend.isV6 {
                if bind.sock4 == -1 {
@@ -592,7 +592,7 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
        return size, nil
 }
 
-func (bind *NativeBind) routineRouteListener(device *Device) {
+func (bind *nativeBind) routineRouteListener(device *Device) {
        type peerEndpointPtr struct {
                peer     *Peer
                endpoint *Endpoint
index 76b101554e3eb0a8979db7534123e4602ee6ce86..7de2524c01b5124e90b776a953b7a48f139ffd5a 100644 (file)
@@ -7,6 +7,6 @@
 
 package device
 
-func (bind *NativeBind) SetMark(mark uint32) error {
+func (bind *nativeBind) SetMark(mark uint32) error {
        return nil
 }
index ee64cc9facb351e8b94af6d45206bfec9111a33f..a791c71248a66bc036b29c075d206e67e5277e42 100644 (file)
@@ -25,7 +25,7 @@ func init() {
        }
 }
 
-func (bind *NativeBind) SetMark(mark uint32) error {
+func (bind *nativeBind) SetMark(mark uint32) error {
        var operr error
        if fwmarkIoctl == 0 {
                return nil
index af3ef9d726520a3e68ed63a296b5d15b24736701..815dff4c7d7702d8fc304bfe2fa63049b3dd04f1 100644 (file)
@@ -258,10 +258,10 @@ func (peer *Peer) Stop() {
        peer.ZeroAndFlushAll()
 }
 
-var roamingDisabled bool
+var RoamingDisabled bool
 
 func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
-       if roamingDisabled {
+       if RoamingDisabled {
                return
        }
        peer.Lock()