]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Begin work on source address caching (linux)
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 24 Sep 2017 19:35:25 +0000 (21:35 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 24 Sep 2017 19:35:25 +0000 (21:35 +0200)
src/conn.go
src/conn_linux.go
src/misc.go
src/tun_linux.go
src/uapi.go

index 7b35829f83a0889b6689c9abef4fb892295980cd..41a5b855e36eaec5300524f51f51e233ce0f7c28 100644 (file)
@@ -1,9 +1,31 @@
 package main
 
 import (
+       "errors"
        "net"
 )
 
+func parseEndpoint(s string) (*net.UDPAddr, error) {
+
+       // ensure that the host is an IP address
+
+       host, _, err := net.SplitHostPort(s)
+       if err != nil {
+               return nil, err
+       }
+       if ip := net.ParseIP(host); ip == nil {
+               return nil, errors.New("Failed to parse IP address: " + host)
+       }
+
+       // parse address and port
+
+       addr, err := net.ResolveUDPAddr("udp", s)
+       if err != nil {
+               return nil, err
+       }
+       return addr, err
+}
+
 func updateUDPConn(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
index e973b25eee91dea8f7dcb5c837c81ac547a5d455..a349a9e59d8862cd8c888cbc3ee0e83fa46d5484 100644 (file)
+/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ *
+ * This implements userspace semantics of "sticky sockets", modeled after
+ * WireGuard's kernelspace implementation.
+ */
+
 package main
 
 import (
+       "errors"
        "golang.org/x/sys/unix"
        "net"
+       "strconv"
+       "unsafe"
 )
 
+/* Supports source address caching
+ *
+ * It is important that the endpoint is only updated after the packet content has been authenticated.
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ */
+type Endpoint struct {
+       // source (selected based on dst type)
+       // (could use RawSockaddrAny and unsafe)
+       srcIPv6 unix.RawSockaddrInet6
+       srcIPv4 unix.RawSockaddrInet4
+       srcIf4  int32
+
+       dst unix.RawSockaddrAny
+}
+
+func zoneToUint32(zone string) (uint32, error) {
+       if zone == "" {
+               return 0, nil
+       }
+       if intr, err := net.InterfaceByName(zone); err == nil {
+               return uint32(intr.Index), nil
+       }
+       n, err := strconv.ParseUint(zone, 10, 32)
+       return uint32(n), err
+}
+
+func (end *Endpoint) ClearSrc() {
+       end.srcIf4 = 0
+       end.srcIPv4 = unix.RawSockaddrInet4{}
+       end.srcIPv6 = unix.RawSockaddrInet6{}
+}
+
+func (end *Endpoint) Set(s string) error {
+       addr, err := parseEndpoint(s)
+       if err != nil {
+               return err
+       }
+
+       ipv6 := addr.IP.To16()
+       if ipv6 != nil {
+               zone, err := zoneToUint32(addr.Zone)
+               if err != nil {
+                       return err
+               }
+               ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
+               ptr.Family = unix.AF_INET6
+               ptr.Port = uint16(addr.Port)
+               ptr.Flowinfo = 0
+               ptr.Scope_id = zone
+               copy(ptr.Addr[:], ipv6[:])
+               end.ClearSrc()
+               return nil
+       }
+
+       ipv4 := addr.IP.To4()
+       if ipv4 != nil {
+               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+               ptr.Family = unix.AF_INET
+               ptr.Port = uint16(addr.Port)
+               ptr.Zero = [8]byte{}
+               copy(ptr.Addr[:], ipv4)
+               end.ClearSrc()
+               return nil
+       }
+
+       return errors.New("Failed to recognize IP address format")
+}
+
+func send6(sock uintptr, end *Endpoint, buff []byte) error {
+       var iovec unix.Iovec
+
+       iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+       iovec.SetLen(len(buff))
+
+       cmsg := struct {
+               cmsghdr unix.Cmsghdr
+               pktinfo unix.Inet6Pktinfo
+       }{
+               unix.Cmsghdr{
+                       Level: unix.IPPROTO_IPV6,
+                       Type:  unix.IPV6_PKTINFO,
+                       Len:   unix.SizeofInet6Pktinfo,
+               },
+               unix.Inet6Pktinfo{
+                       Addr:    end.srcIPv6.Addr,
+                       Ifindex: end.srcIPv6.Scope_id,
+               },
+       }
+
+       msghdr := unix.Msghdr{
+               Iov:     &iovec,
+               Iovlen:  1,
+               Name:    (*byte)(unsafe.Pointer(&end.dst)),
+               Namelen: unix.SizeofSockaddrInet6,
+               Control: (*byte)(unsafe.Pointer(&cmsg)),
+       }
+
+       msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+       // sendmsg(sock, &msghdr, 0)
+
+       _, _, errno := unix.Syscall(
+               unix.SYS_SENDMSG,
+               sock,
+               uintptr(unsafe.Pointer(&msghdr)),
+               0,
+       )
+       if errno == unix.EINVAL {
+               end.ClearSrc()
+       }
+       return errno
+}
+
+func send4(sock uintptr, end *Endpoint, buff []byte) error {
+       var iovec unix.Iovec
+
+       iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+       iovec.SetLen(len(buff))
+
+       cmsg := struct {
+               cmsghdr unix.Cmsghdr
+               pktinfo unix.Inet4Pktinfo
+       }{
+               unix.Cmsghdr{
+                       Level: unix.IPPROTO_IP,
+                       Type:  unix.IP_PKTINFO,
+                       Len:   unix.SizeofInet6Pktinfo,
+               },
+               unix.Inet4Pktinfo{
+                       Spec_dst: end.srcIPv4.Addr,
+                       Ifindex:  end.srcIf4,
+               },
+       }
+
+       msghdr := unix.Msghdr{
+               Iov:     &iovec,
+               Iovlen:  1,
+               Name:    (*byte)(unsafe.Pointer(&end.dst)),
+               Namelen: unix.SizeofSockaddrInet4,
+               Control: (*byte)(unsafe.Pointer(&cmsg)),
+       }
+
+       msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+       // sendmsg(sock, &msghdr, 0)
+
+       _, _, errno := unix.Syscall(
+               unix.SYS_SENDMSG,
+               sock,
+               uintptr(unsafe.Pointer(&msghdr)),
+               0,
+       )
+       if errno == unix.EINVAL {
+               end.ClearSrc()
+       }
+       return errno
+}
+
+func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
+
+       // extract underlying file descriptor
+
+       file, err := c.File()
+       if err != nil {
+               return err
+       }
+       sock := file.Fd()
+
+       // send depending on address family of dst
+
+       family := *((*uint16)(unsafe.Pointer(&end.dst)))
+       if family == unix.AF_INET {
+               return send4(sock, end, buff)
+       } else if family == unix.AF_INET6 {
+               return send6(sock, end, buff)
+       }
+       return errors.New("Unknown address family of source")
+}
+
+func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
+
+       file, err := c.File()
+       if err != nil {
+               return err, nil, nil
+       }
+
+       var iovec unix.Iovec
+       iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+       iovec.SetLen(len(buff))
+
+       var cmsg struct {
+               cmsghdr unix.Cmsghdr
+               pktinfo unix.Inet6Pktinfo // big enough
+       }
+
+       var msg unix.Msghdr
+       msg.Iov = &iovec
+       msg.Iovlen = 1
+       msg.Name = (*byte)(unsafe.Pointer(&end.dst))
+       msg.Namelen = uint32(unix.SizeofSockaddrAny)
+       msg.Control = (*byte)(unsafe.Pointer(&cmsg))
+       msg.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+       _, _, errno := unix.Syscall(
+               unix.SYS_RECVMSG,
+               file.Fd(),
+               uintptr(unsafe.Pointer(&msg)),
+               0,
+       )
+
+       if errno != 0 {
+               return errno, nil, nil
+       }
+
+       if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
+               cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
+               cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
+
+       }
+
+       if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+               cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+               cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+
+               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
+               println(info)
+
+       }
+
+       return nil, nil, nil
+}
+
 func setMark(conn *net.UDPConn, value uint32) error {
        if conn == nil {
                return nil
index d93849e1b383e77f7e2a19f7efea1a8541313978..bbe0d6841748f6194c0ee9f1dcf8dbc46e743901 100644 (file)
@@ -29,6 +29,11 @@ func (a *AtomicBool) Set(val bool) {
        atomic.StoreInt32(&a.flag, flag)
 }
 
+func toInt32(n uint32) int32 {
+       mask := uint32(1 << 31)
+       return int32(-(n & mask) + (n & ^mask))
+}
+
 func min(a uint, b uint) uint {
        if a > b {
                return b
index 58a762ad56e636d3938c723c9701ed3054c68be1..accc6c6064447910e291d20e281606569c87d9e5 100644 (file)
@@ -120,14 +120,6 @@ func (tun *NativeTun) Name() string {
        return tun.name
 }
 
-func toInt32(val []byte) int32 {
-       n := binary.LittleEndian.Uint32(val[:4])
-       if n >= (1 << 31) {
-               return -int32(^n) - 1
-       }
-       return int32(n)
-}
-
 func getDummySock() (int, error) {
        return unix.Socket(
                unix.AF_INET,
@@ -157,7 +149,8 @@ func getIFIndex(name string) (int32, error) {
                return 0, errno
        }
 
-       return toInt32(ifr[unix.IFNAMSIZ:]), nil
+       index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:])
+       return toInt32(index), nil
 }
 
 func (tun *NativeTun) setMTU(n int) error {
index 428b17399cd84978b1630e2cc83c1f260be79b4f..3a2f3f984d764d777b19b97f292a40a0abac8b57 100644 (file)
@@ -273,8 +273,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                                }
 
                        case "endpoint":
-                               // TODO: Only IP and port
-                               addr, err := net.ResolveUDPAddr("udp", value)
+                               addr, err := parseEndpoint(value)
                                if err != nil {
                                        logError.Println("Failed to set endpoint:", value)
                                        return &IPCError{Code: ipcErrorInvalid}