]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Definition of platform specific socket bind
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 6 Oct 2017 20:56:01 +0000 (22:56 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Fri, 6 Oct 2017 20:56:01 +0000 (22:56 +0200)
src/conn.go
src/conn_default.go
src/conn_linux.go
src/uapi.go

index 2cf588d8425eac98e054f0274e9b3760a1cc43f2..60cd789c866b4c4ad145dda1a33c9c6ad1d79523 100644 (file)
@@ -56,7 +56,7 @@ func updateUDPConn(device *Device) error {
 
                // set fwmark
 
-               err = setMark(netc.conn, netc.fwmark)
+               err = SetMark(netc.conn, netc.fwmark)
                if err != nil {
                        return err
                }
index e7c60a8cc3ee906949bdd849b9cc9fae163f57be..279643e3105719f9c1225e580a2f8ca37891e2f4 100644 (file)
@@ -6,6 +6,6 @@ import (
        "net"
 )
 
-func setMark(conn *net.UDPConn, value uint32) error {
+func SetMark(conn *net.UDPConn, value uint32) error {
        return nil
 }
index a349a9e59d8862cd8c888cbc3ee0e83fa46d5484..64447a547800503bac49a9927dbd0cdca2dc081d 100644 (file)
@@ -14,23 +14,30 @@ import (
        "unsafe"
 )
 
+import "fmt"
+
 /* Supports source address caching
- *
- * It is important that the endpoint is only updated after the packet content has been authenticated.
  *
  * Currently there is no way to achieve this within the net package:
  * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is platform dependent.
+ *
+ * It is important that the endpoint is only updated after the packet content has been authenticated!
  */
+
 type Endpoint struct {
        // source (selected based on dst type)
        // (could use RawSockaddrAny and unsafe)
-       srcIPv6 unix.RawSockaddrInet6
-       srcIPv4 unix.RawSockaddrInet4
-       srcIf4  int32
+       src6   unix.RawSockaddrInet6
+       src4   unix.RawSockaddrInet4
+       src4if int32
 
        dst unix.RawSockaddrAny
 }
 
+type IPv4Socket int
+type IPv6Socket int
+
 func zoneToUint32(zone string) (uint32, error) {
        if zone == "" {
                return 0, nil
@@ -42,10 +49,115 @@ func zoneToUint32(zone string) (uint32, error) {
        return uint32(n), err
 }
 
+func CreateIPv4Socket(port int) (IPv4Socket, error) {
+
+       // create socket
+
+       fd, err := unix.Socket(
+               unix.AF_INET,
+               unix.SOCK_DGRAM,
+               0,
+       )
+
+       if err != nil {
+               return -1, err
+       }
+
+       // set sockopts and bind
+
+       if err := func() error {
+
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.SOL_SOCKET,
+                       unix.SO_REUSEADDR,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.IPPROTO_IP,
+                       unix.IP_PKTINFO,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
+               addr := unix.SockaddrInet4{
+                       Port: port,
+               }
+               return unix.Bind(fd, &addr)
+
+       }(); err != nil {
+               unix.Close(fd)
+       }
+
+       return IPv4Socket(fd), err
+}
+
+func CreateIPv6Socket(port int) (IPv6Socket, error) {
+
+       // create socket
+
+       fd, err := unix.Socket(
+               unix.AF_INET,
+               unix.SOCK_DGRAM,
+               0,
+       )
+
+       if err != nil {
+               return -1, err
+       }
+
+       // set sockopts and bind
+
+       if err := func() error {
+
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.SOL_SOCKET,
+                       unix.SO_REUSEADDR,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.IPPROTO_IPV6,
+                       unix.IPV6_RECVPKTINFO,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
+               if err := unix.SetsockoptInt(
+                       fd,
+                       unix.IPPROTO_IPV6,
+                       unix.IPV6_V6ONLY,
+                       1,
+               ); err != nil {
+                       return err
+               }
+
+               addr := unix.SockaddrInet6{
+                       Port: port,
+               }
+               return unix.Bind(fd, &addr)
+
+       }(); err != nil {
+               unix.Close(fd)
+       }
+
+       return IPv6Socket(fd), err
+}
+
 func (end *Endpoint) ClearSrc() {
-       end.srcIf4 = 0
-       end.srcIPv4 = unix.RawSockaddrInet4{}
-       end.srcIPv6 = unix.RawSockaddrInet6{}
+       end.src4if = 0
+       end.src4 = unix.RawSockaddrInet4{}
+       end.src6 = unix.RawSockaddrInet6{}
 }
 
 func (end *Endpoint) Set(s string) error {
@@ -85,8 +197,10 @@ func (end *Endpoint) Set(s string) error {
 }
 
 func send6(sock uintptr, end *Endpoint, buff []byte) error {
-       var iovec unix.Iovec
 
+       // construct message header
+
+       var iovec unix.Iovec
        iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
        iovec.SetLen(len(buff))
 
@@ -100,8 +214,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
                        Len:   unix.SizeofInet6Pktinfo,
                },
                unix.Inet6Pktinfo{
-                       Addr:    end.srcIPv6.Addr,
-                       Ifindex: end.srcIPv6.Scope_id,
+                       Addr:    end.src6.Addr,
+                       Ifindex: end.src6.Scope_id,
                },
        }
 
@@ -130,8 +244,10 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 }
 
 func send4(sock uintptr, end *Endpoint, buff []byte) error {
-       var iovec unix.Iovec
 
+       // construct message header
+
+       var iovec unix.Iovec
        iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
        iovec.SetLen(len(buff))
 
@@ -142,11 +258,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
                unix.Cmsghdr{
                        Level: unix.IPPROTO_IP,
                        Type:  unix.IP_PKTINFO,
-                       Len:   unix.SizeofInet6Pktinfo,
+                       Len:   unix.SizeofInet4Pktinfo,
                },
                unix.Inet4Pktinfo{
-                       Spec_dst: end.srcIPv4.Addr,
-                       Ifindex:  end.srcIf4,
+                       Spec_dst: end.src4.Addr,
+                       Ifindex:  end.src4if,
                },
        }
 
@@ -174,7 +290,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
        return errno
 }
 
-func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
+func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
 
        // extract underlying file descriptor
 
@@ -195,60 +311,102 @@ func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
        return errors.New("Unknown address family of source")
 }
 
-func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
+func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
 
-       file, err := c.File()
-       if err != nil {
-               return err, nil, nil
+       // contruct message header
+
+       var iovec unix.Iovec
+       iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+       iovec.SetLen(len(buff))
+
+       var cmsg struct {
+               cmsghdr unix.Cmsghdr
+               pktinfo unix.Inet4Pktinfo
        }
 
+       var msghdr unix.Msghdr
+       msghdr.Iov = &iovec
+       msghdr.Iovlen = 1
+       msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
+       msghdr.Namelen = unix.SizeofSockaddrInet4
+       msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
+       msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+       // recvmsg(sock, &mskhdr, 0)
+
+       size, _, errno := unix.Syscall(
+               unix.SYS_RECVMSG,
+               uintptr(sock),
+               uintptr(unsafe.Pointer(&msghdr)),
+               0,
+       )
+
+       if errno != 0 {
+               return 0, errno
+       }
+
+       fmt.Println(msghdr)
+       fmt.Println(cmsg)
+
+       // update source cache
+
+       if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+               cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+               cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+               end.src4.Addr = cmsg.pktinfo.Spec_dst
+               end.src4if = cmsg.pktinfo.Ifindex
+       }
+
+       return int(size), nil
+}
+
+func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
+
+       // contruct message header
+
        var iovec unix.Iovec
        iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
        iovec.SetLen(len(buff))
 
        var cmsg struct {
                cmsghdr unix.Cmsghdr
-               pktinfo unix.Inet6Pktinfo // big enough
+               pktinfo unix.Inet6Pktinfo
        }
 
        var msg unix.Msghdr
        msg.Iov = &iovec
        msg.Iovlen = 1
        msg.Name = (*byte)(unsafe.Pointer(&end.dst))
-       msg.Namelen = uint32(unix.SizeofSockaddrAny)
+       msg.Namelen = uint32(unix.SizeofSockaddrInet6)
        msg.Control = (*byte)(unsafe.Pointer(&cmsg))
        msg.SetControllen(int(unsafe.Sizeof(cmsg)))
 
+       // recvmsg(sock, &mskhdr, 0)
+
        _, _, errno := unix.Syscall(
                unix.SYS_RECVMSG,
-               file.Fd(),
+               uintptr(sock),
                uintptr(unsafe.Pointer(&msg)),
                0,
        )
 
        if errno != 0 {
-               return errno, nil, nil
+               return errno
        }
 
+       // update source cache
+
        if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
                cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
                cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-
-       }
-
-       if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
-               cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
-               cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-
-               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
-               println(info)
-
+               end.src6.Addr = cmsg.pktinfo.Addr
+               end.src6.Scope_id = cmsg.pktinfo.Ifindex
        }
 
-       return nil, nil, nil
+       return nil
 }
 
-func setMark(conn *net.UDPConn, value uint32) error {
+func SetMark(conn *net.UDPConn, value uint32) error {
        if conn == nil {
                return nil
        }
index 326216bb5848a2766ee0c63629a8c053fc67adce..7d08e561eef9ab89fbfda2d8f8200dbf877a4abc 100644 (file)
@@ -166,7 +166,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                device.net.mutex.Lock()
                                if fwmark > 0 || device.net.fwmark > 0 {
                                        device.net.fwmark = uint32(fwmark)
-                                       err := setMark(
+                                       err := SetMark(
                                                device.net.conn,
                                                device.net.fwmark,
                                        )