]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Added new UDPBind interface
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 8 Oct 2017 20:03:32 +0000 (22:03 +0200)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sun, 8 Oct 2017 20:03:32 +0000 (22:03 +0200)
src/conn.go
src/conn_linux.go
src/cookie.go
src/device.go
src/peer.go
src/receive.go

index 61be3bfc7b25166e7ef7b686b600a5fb16a3a52b..db4020d61ba9c01920f9461b3163b232a6cb7a7d 100644 (file)
@@ -5,6 +5,14 @@ import (
        "net"
 )
 
+type UDPBind interface {
+       SetMark(value uint32) error
+       ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
+       ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
+       Send(buff []byte, end *Endpoint) error
+       Close() error
+}
+
 func parseEndpoint(s string) (*net.UDPAddr, error) {
 
        // ensure that the host is an IP address
@@ -26,19 +34,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
        return addr, err
 }
 
-func ListenerClose(l *Listener) (err error) {
-       if l.active {
-               err = CloseIPv4Socket(l.sock)
-               l.active = false
-       }
-       return
-}
-
-func (l *Listener) Init() {
-       l.update = make(chan struct{}, 1)
-       ListenerClose(l)
-}
-
 func ListeningUpdate(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
@@ -46,11 +41,7 @@ func ListeningUpdate(device *Device) error {
 
        // close existing sockets
 
-       if err := ListenerClose(&netc.ipv4); err != nil {
-               return err
-       }
-
-       if err := ListenerClose(&netc.ipv6); err != nil {
+       if err := device.net.bind.Close(); err != nil {
                return err
        }
 
@@ -58,45 +49,22 @@ func ListeningUpdate(device *Device) error {
 
        if device.tun.isUp.Get() {
 
-               // listen on IPv4
-
-               {
-                       list := &netc.ipv6
-                       sock, port, err := CreateIPv4Socket(netc.port)
-                       if err != nil {
-                               return err
-                       }
-                       netc.port = port
-                       list.sock = sock
-                       list.active = true
-
-                       if err := SetMark(list.sock, netc.fwmark); err != nil {
-                               ListenerClose(list)
-                               return err
-                       }
-                       signalSend(list.update)
+               // bind to new port
+
+               var err error
+               netc.bind, netc.port, err = CreateUDPBind(netc.port)
+               if err != nil {
+                       return err
                }
 
-               // listen on IPv6
-
-               {
-                       list := &netc.ipv6
-                       sock, port, err := CreateIPv6Socket(netc.port)
-                       if err != nil {
-                               return err
-                       }
-                       netc.port = port
-                       list.sock = sock
-                       list.active = true
-
-                       if err := SetMark(list.sock, netc.fwmark); err != nil {
-                               ListenerClose(list)
-                               return err
-                       }
-                       signalSend(list.update)
+               // set mark
+
+               err = netc.bind.SetMark(netc.fwmark)
+               if err != nil {
+                       return err
                }
 
-               // TODO: clear endpoint caches
+               // TODO: clear endpoint (src) caches
        }
 
        return nil
@@ -106,16 +74,5 @@ func ListeningClose(device *Device) error {
        netc := &device.net
        netc.mutex.Lock()
        defer netc.mutex.Unlock()
-
-       if err := ListenerClose(&netc.ipv4); err != nil {
-               return err
-       }
-       signalSend(netc.ipv4.update)
-
-       if err := ListenerClose(&netc.ipv6); err != nil {
-               return err
-       }
-       signalSend(netc.ipv6.update)
-
-       return nil
+       return netc.bind.Close()
 }
index 034fb8bfe616f5174be754d8198e86c966732472..8942b03a7c85e105f45c6cbec255f7d82a0f6aec 100644 (file)
@@ -14,35 +14,158 @@ import (
        "unsafe"
 )
 
-import "fmt"
-
 /* Supports source address caching
  *
  * Currently there is no way to achieve this within the net package:
  * See e.g. https://github.com/golang/go/issues/17930
- * So this code is platform dependent.
- *
- * It is important that the endpoint is only updated after the packet content has been authenticated!
+ * So this code is remains platform dependent.
  */
 
 type Endpoint struct {
-       // source (selected based on dst type)
-       // (could use RawSockaddrAny and unsafe)
-       // TODO: Merge
-       src6   unix.RawSockaddrInet6
-       src4   unix.RawSockaddrInet4
-       src4if int32
-
-       dst unix.RawSockaddrAny
+       src unix.RawSockaddrInet6
+       dst unix.RawSockaddrInet6
+}
+
+type IPv4Source struct {
+       src     unix.RawSockaddrInet4
+       Ifindex int32
 }
 
-type Socket int
+type Bind struct {
+       sock4 int
+       sock6 int
+}
 
-/* Returns a byte representation of the source field(s)
- * for use in "under load" cookie computations.
- */
-func (endpoint *Endpoint) Source() []byte {
-       return nil
+func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
+       var err error
+       var bind Bind
+
+       bind.sock6, port, err = create6(port)
+       if err != nil {
+               return nil, port, err
+       }
+
+       bind.sock4, port, err = create4(port)
+       if err != nil {
+               unix.Close(bind.sock6)
+       }
+       return &bind, port, err
+}
+
+func (bind *Bind) SetMark(value uint32) error {
+       err := unix.SetsockoptInt(
+               bind.sock6,
+               unix.SOL_SOCKET,
+               unix.SO_MARK,
+               int(value),
+       )
+
+       if err != nil {
+               return err
+       }
+
+       return unix.SetsockoptInt(
+               bind.sock4,
+               unix.SOL_SOCKET,
+               unix.SO_MARK,
+               int(value),
+       )
+}
+
+func (bind *Bind) Close() error {
+       err1 := unix.Close(bind.sock6)
+       err2 := unix.Close(bind.sock4)
+       if err1 != nil {
+               return err1
+       }
+       return err2
+}
+
+func (bind *Bind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
+       return receive6(
+               bind.sock6,
+               buff,
+               end,
+       )
+}
+
+func (bind *Bind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
+       return receive4(
+               bind.sock4,
+               buff,
+               end,
+       )
+}
+
+func (bind *Bind) Send(buff []byte, end *Endpoint) error {
+       switch end.src.Family {
+       case unix.AF_INET6:
+               return send6(bind.sock6, end, buff)
+       case unix.AF_INET:
+               return send4(bind.sock4, end, buff)
+       default:
+               return errors.New("Unknown address family of source")
+       }
+}
+
+func sockaddrToString(addr unix.RawSockaddrInet6) string {
+       var udpAddr net.UDPAddr
+
+       switch addr.Family {
+       case unix.AF_INET6:
+               udpAddr.Port = int(addr.Port)
+               udpAddr.IP = addr.Addr[:]
+               return udpAddr.String()
+
+       case unix.AF_INET:
+               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+               udpAddr.Port = int(ptr.Port)
+               udpAddr.IP = net.IPv4(
+                       ptr.Addr[0],
+                       ptr.Addr[1],
+                       ptr.Addr[2],
+                       ptr.Addr[3],
+               )
+               return udpAddr.String()
+
+       default:
+               return "<unknown address family>"
+       }
+}
+
+func (end *Endpoint) DestinationIP() net.IP {
+       switch end.dst.Family {
+       case unix.AF_INET6:
+               return end.dst.Addr[:]
+       case unix.AF_INET:
+               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+               return net.IPv4(
+                       ptr.Addr[0],
+                       ptr.Addr[1],
+                       ptr.Addr[2],
+                       ptr.Addr[3],
+               )
+       default:
+               return nil
+       }
+}
+
+func (end *Endpoint) SourceToBytes() []byte {
+       ptr := unsafe.Pointer(&end.src)
+       arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
+       return arr[:]
+}
+
+func (end *Endpoint) SourceToString() string {
+       return sockaddrToString(end.src)
+}
+
+func (end *Endpoint) DestinationToString() string {
+       return sockaddrToString(end.dst)
+}
+
+func (end *Endpoint) ClearSrc() {
+       end.src = unix.RawSockaddrInet6{}
 }
 
 func zoneToUint32(zone string) (uint32, error) {
@@ -56,7 +179,7 @@ func zoneToUint32(zone string) (uint32, error) {
        return uint32(n), err
 }
 
-func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
+func create4(port uint16) (int, uint16, error) {
 
        // create socket
 
@@ -100,18 +223,10 @@ func CreateIPv4Socket(port uint16) (Socket, uint16, error) {
                unix.Close(fd)
        }
 
-       return Socket(fd), uint16(addr.Port), err
+       return fd, uint16(addr.Port), err
 }
 
-func CloseIPv4Socket(sock Socket) error {
-       return unix.Close(int(sock))
-}
-
-func CloseIPv6Socket(sock Socket) error {
-       return unix.Close(int(sock))
-}
-
-func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
+func create6(port uint16) (int, uint16, error) {
 
        // create socket
 
@@ -166,13 +281,7 @@ func CreateIPv6Socket(port uint16) (Socket, uint16, error) {
                unix.Close(fd)
        }
 
-       return Socket(fd), uint16(addr.Port), err
-}
-
-func (end *Endpoint) ClearSrc() {
-       end.src4if = 0
-       end.src4 = unix.RawSockaddrInet4{}
-       end.src6 = unix.RawSockaddrInet6{}
+       return fd, uint16(addr.Port), err
 }
 
 func (end *Endpoint) Set(s string) error {
@@ -187,23 +296,23 @@ func (end *Endpoint) Set(s string) error {
                if err != nil {
                        return err
                }
-               ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
-               ptr.Family = unix.AF_INET6
-               ptr.Port = uint16(addr.Port)
-               ptr.Flowinfo = 0
-               ptr.Scope_id = zone
-               copy(ptr.Addr[:], ipv6[:])
+               dst := &end.dst
+               dst.Family = unix.AF_INET6
+               dst.Port = uint16(addr.Port)
+               dst.Flowinfo = 0
+               dst.Scope_id = zone
+               copy(dst.Addr[:], ipv6[:])
                end.ClearSrc()
                return nil
        }
 
        ipv4 := addr.IP.To4()
        if ipv4 != nil {
-               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
-               ptr.Family = unix.AF_INET
-               ptr.Port = uint16(addr.Port)
-               ptr.Zero = [8]byte{}
-               copy(ptr.Addr[:], ipv4)
+               dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+               dst.Family = unix.AF_INET
+               dst.Port = uint16(addr.Port)
+               dst.Zero = [8]byte{}
+               copy(dst.Addr[:], ipv4)
                end.ClearSrc()
                return nil
        }
@@ -211,7 +320,7 @@ func (end *Endpoint) Set(s string) error {
        return errors.New("Failed to recognize IP address format")
 }
 
-func send6(sock uintptr, end *Endpoint, buff []byte) error {
+func send6(sock int, end *Endpoint, buff []byte) error {
 
        // construct message header
 
@@ -229,8 +338,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
                        Len:   unix.SizeofInet6Pktinfo,
                },
                unix.Inet6Pktinfo{
-                       Addr:    end.src6.Addr,
-                       Ifindex: end.src6.Scope_id,
+                       Addr:    end.src.Addr,
+                       Ifindex: end.src.Scope_id,
                },
        }
 
@@ -248,7 +357,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
 
        _, _, errno := unix.Syscall(
                unix.SYS_SENDMSG,
-               sock,
+               uintptr(sock),
                uintptr(unsafe.Pointer(&msghdr)),
                0,
        )
@@ -258,7 +367,7 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
        return errno
 }
 
-func send4(sock uintptr, end *Endpoint, buff []byte) error {
+func send4(sock int, end *Endpoint, buff []byte) error {
 
        // construct message header
 
@@ -266,6 +375,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
        iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
        iovec.SetLen(len(buff))
 
+       src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+
        cmsg := struct {
                cmsghdr unix.Cmsghdr
                pktinfo unix.Inet4Pktinfo
@@ -276,8 +387,8 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
                        Len:   unix.SizeofInet4Pktinfo,
                },
                unix.Inet4Pktinfo{
-                       Spec_dst: end.src4.Addr,
-                       Ifindex:  end.src4if,
+                       Spec_dst: src4.src.Addr,
+                       Ifindex:  src4.Ifindex,
                },
        }
 
@@ -295,7 +406,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
 
        _, _, errno := unix.Syscall(
                unix.SYS_SENDMSG,
-               sock,
+               uintptr(sock),
                uintptr(unsafe.Pointer(&msghdr)),
                0,
        )
@@ -305,28 +416,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
        return errno
 }
 
-func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
-
-       // extract underlying file descriptor
-
-       file, err := c.File()
-       if err != nil {
-               return err
-       }
-       sock := file.Fd()
-
-       // send depending on address family of dst
-
-       family := *((*uint16)(unsafe.Pointer(&end.dst)))
-       if family == unix.AF_INET {
-               return send4(sock, end, buff)
-       } else if family == unix.AF_INET6 {
-               return send6(sock, end, buff)
-       }
-       return errors.New("Unknown address family of source")
-}
-
-func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
+func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
 
        // contruct message header
 
@@ -360,22 +450,21 @@ func (end *Endpoint) ReceiveIPv4(sock Socket, buff []byte) (int, error) {
                return 0, errno
        }
 
-       fmt.Println(msghdr)
-       fmt.Println(cmsg)
-
        // update source cache
 
        if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
                cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
                cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-               end.src4.Addr = cmsg.pktinfo.Spec_dst
-               end.src4if = cmsg.pktinfo.Ifindex
+               src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+               src4.src.Family = unix.AF_INET
+               src4.src.Addr = cmsg.pktinfo.Spec_dst
+               src4.Ifindex = cmsg.pktinfo.Ifindex
        }
 
        return int(size), nil
 }
 
-func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
+func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
 
        // contruct message header
 
@@ -414,18 +503,10 @@ func (end *Endpoint) ReceiveIPv6(sock Socket, buff []byte) (int, error) {
        if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
                cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
                cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-               end.src6.Addr = cmsg.pktinfo.Addr
-               end.src6.Scope_id = cmsg.pktinfo.Ifindex
+               end.src.Family = unix.AF_INET6
+               end.src.Addr = cmsg.pktinfo.Addr
+               end.src.Scope_id = cmsg.pktinfo.Ifindex
        }
 
        return int(size), nil
 }
-
-func SetMark(sock Socket, value uint32) error {
-       return unix.SetsockoptInt(
-               int(sock),
-               unix.SOL_SOCKET,
-               unix.SO_MARK,
-               int(value),
-       )
-}
index a81819b83f09b9203fb00c9d43825211cc03203c..a13ad49e2dc4b4b441b18210c1eea79d4652427d 100644 (file)
@@ -5,10 +5,8 @@ import (
        "crypto/rand"
        "golang.org/x/crypto/blake2s"
        "golang.org/x/crypto/chacha20poly1305"
-       "net"
        "sync"
        "time"
-       "unsafe"
 )
 
 type CookieChecker struct {
@@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
        return hmac.Equal(mac1[:], msg[smac1:smac2])
 }
 
-func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
+func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
        st.mutex.RLock()
        defer st.mutex.RUnlock()
 
@@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
        var cookie [blake2s.Size128]byte
        func() {
                mac, _ := blake2s.New128(st.mac2.secret[:])
-               mac.Write(src.IP)
-               mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
+               mac.Write(src)
                mac.Sum(cookie[:0])
        }()
 
@@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
 func (st *CookieChecker) CreateReply(
        msg []byte,
        recv uint32,
-       src *net.UDPAddr,
+       src []byte,
 ) (*MessageCookieReply, error) {
 
        st.mutex.RLock()
@@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
        var cookie [blake2s.Size128]byte
        func() {
                mac, _ := blake2s.New128(st.mac2.secret[:])
-               mac.Write(src.IP)
-               mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
+               mac.Write(src)
                mac.Sum(cookie[:0])
        }()
 
index 509e6a741bc632010fe9eddc72a57d3484001339..d1e06859093b41ae3529fe0a572cb8372b26055c 100644 (file)
@@ -1,18 +1,14 @@
 package main
 
 import (
+       "golang.org/x/net/ipv4"
+       "golang.org/x/net/ipv6"
        "runtime"
        "sync"
        "sync/atomic"
        "time"
 )
 
-type Listener struct {
-       sock   Socket
-       active bool
-       update chan struct{}
-}
-
 type Device struct {
        log       *Logger // collection of loggers for levels
        idCounter uint    // for assigning debug ids to peers
@@ -27,8 +23,7 @@ type Device struct {
        }
        net struct {
                mutex  sync.RWMutex
-               ipv4   Listener
-               ipv6   Listener
+               bind   UDPBind
                port   uint16
                fwmark uint32
        }
@@ -43,9 +38,8 @@ type Device struct {
                handshake  chan QueueHandshakeElement
        }
        signal struct {
-               stop             chan struct{} // halts all go routines
-               updateIPv4Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
-               updateIPv6Socket chan struct{} // a net.conn was set (consumed by the receiver routine)
+               stop       chan struct{}
+               updateBind chan struct{}
        }
        underLoadUntil atomic.Value
        ratelimiter    Ratelimiter
@@ -146,8 +140,6 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        device.tun.device = tun
 
        device.indices.Init()
-       device.net.ipv4.Init()
-       device.net.ipv6.Init()
        device.ratelimiter.Init()
 
        device.routingTable.Reset()
@@ -181,8 +173,8 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
        go device.RoutineReadFromTUN()
        go device.RoutineTUNEventReader()
        go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
-       go device.RoutineReceiveIncomming(&device.net.ipv4)
-       go device.RoutineReceiveIncomming(&device.net.ipv6)
+       go device.RoutineReceiveIncomming(ipv4.Version)
+       go device.RoutineReceiveIncomming(ipv6.Version)
        return device
 }
 
index 6fea82912faf7c3594b8f63b7ce1cd4e49ecd5d5..791c091319a7760d84b24ddde53f3ae20248a8ce 100644 (file)
@@ -4,7 +4,6 @@ import (
        "encoding/base64"
        "errors"
        "fmt"
-       "net"
        "sync"
        "time"
 )
@@ -15,8 +14,8 @@ type Peer struct {
        persistentKeepaliveInterval uint64
        keyPairs                    KeyPairs
        handshake                   Handshake
+       endpoint                    Endpoint
        device                      *Device
-       endpoint                    *net.UDPAddr
        stats                       struct {
                txBytes           uint64 // bytes send to peer (endpoint)
                rxBytes           uint64 // bytes received from peer
@@ -134,7 +133,7 @@ func (peer *Peer) String() string {
        return fmt.Sprintf(
                "peer(%d %s %s)",
                peer.id,
-               peer.endpoint.String(),
+               peer.endpoint.DestinationToString(),
                base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
        )
 }
index 60c0f2c7137373111ea37817261194a1816dfcb8..664f1ba674fdee8a0578245019ed64f56a154e37 100644 (file)
@@ -97,17 +97,6 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
        logDebug := device.log.Debug
        logDebug.Println("Routine, receive incomming, started")
 
-       var listener *Listener
-
-       switch IPVersion {
-       case ipv4.Version:
-               listener = &device.net.ipv4
-       case ipv6.Version:
-               listener = &device.net.ipv6
-       default:
-               return
-       }
-
        for {
 
                // wait for new conn
@@ -118,15 +107,14 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
                case <-device.signal.stop:
                        return
 
-               case <-listener.update:
+               case <-device.signal.updateBind:
 
                        // fetch new socket
 
                        device.net.mutex.RLock()
-                       sock := listener.sock
-                       okay := listener.active
+                       bind := device.net.bind
                        device.net.mutex.RUnlock()
-                       if !okay {
+                       if bind == nil {
                                continue
                        }
 
@@ -145,10 +133,13 @@ func (device *Device) RoutineReceiveIncomming(IPVersion int) {
 
                                var endpoint Endpoint
 
-                               if IPVersion == ipv6.Version {
-                                       size, err = endpoint.ReceiveIPv4(sock, buffer[:])
-                               } else {
-                                       size, err = endpoint.ReceiveIPv6(sock, buffer[:])
+                               switch IPVersion {
+                               case ipv4.Version:
+                                       size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
+                               case ipv6.Version:
+                                       size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
+                               default:
+                                       return
                                }
 
                                if err != nil {
@@ -340,15 +331,19 @@ func (device *Device) RoutineHandshake() {
                                return
                        }
 
+                       srcBytes := elem.endpoint.SourceToBytes()
                        if device.IsUnderLoad() {
-                               if !device.mac.CheckMAC2(elem.packet, elem.source) {
+
+                               // verify MAC2 field
+
+                               if !device.mac.CheckMAC2(elem.packet, srcBytes) {
 
                                        // construct cookie reply
 
-                                       logDebug.Println("Sending cookie reply to:", elem.source.String())
+                                       logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString())
 
                                        sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
-                                       reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
+                                       reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
                                        if err != nil {
                                                logError.Println("Failed to create cookie reply:", err)
                                                return
@@ -358,9 +353,9 @@ func (device *Device) RoutineHandshake() {
 
                                        writer := bytes.NewBuffer(temp[:0])
                                        binary.Write(writer, binary.LittleEndian, reply)
-                                       _, err = device.net.conn.WriteToUDP(
+                                       device.net.bind.Send(
                                                writer.Bytes(),
-                                               elem.source,
+                                               &elem.endpoint,
                                        )
                                        if err != nil {
                                                logDebug.Println("Failed to send cookie reply:", err)
@@ -368,7 +363,11 @@ func (device *Device) RoutineHandshake() {
                                        continue
                                }
 
-                               if !device.ratelimiter.Allow(elem.source.IP) {
+                               // check ratelimiter
+
+                               if !device.ratelimiter.Allow(
+                                       elem.endpoint.DestinationIP(),
+                               ) {
                                        continue
                                }
                        }
@@ -399,8 +398,7 @@ func (device *Device) RoutineHandshake() {
                        if peer == nil {
                                logInfo.Println(
                                        "Recieved invalid initiation message from",
-                                       elem.source.IP.String(),
-                                       elem.source.Port,
+                                       elem.endpoint.DestinationToString(),
                                )
                                continue
                        }
@@ -414,7 +412,7 @@ func (device *Device) RoutineHandshake() {
                        // TODO: Discover destination address also, only update on change
 
                        peer.mutex.Lock()
-                       peer.endpoint = elem.source
+                       peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
                        // create response
@@ -460,8 +458,7 @@ func (device *Device) RoutineHandshake() {
                        if peer == nil {
                                logInfo.Println(
                                        "Recieved invalid response message from",
-                                       elem.source.IP.String(),
-                                       elem.source.Port,
+                                       elem.endpoint.DestinationToString(),
                                )
                                continue
                        }