]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
Moved endpoint into interface and simplified peer
authorMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 18 Nov 2017 22:34:02 +0000 (23:34 +0100)
committerMathias Hall-Andersen <mathias@hall-andersen.dk>
Sat, 18 Nov 2017 22:34:02 +0000 (23:34 +0100)
src/conn.go
src/conn_linux.go
src/device.go
src/peer.go
src/receive.go
src/uapi.go

index 3cf00ab5d5f52da0be66c63690f295ee672066bb..74bb075699cd71fa448f9bf6800db9d3edbc30be 100644 (file)
@@ -7,26 +7,28 @@ import (
        "net"
 )
 
-type UDPBind interface {
+/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
+ */
+type Bind interface {
        SetMark(value uint32) error
-       ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
-       ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
-       Send(buff []byte, end *Endpoint) error
+       ReceiveIPv6(buff []byte) (int, Endpoint, error)
+       ReceiveIPv4(buff []byte) (int, Endpoint, error)
+       Send(buff []byte, end Endpoint) error
        Close() error
 }
 
 /* An Endpoint maintains the source/destination caching for a peer
  *
- * dst : the remote address of a peer
+ * dst : the remote address of a peer ("endpoint" in uapi terminology)
  * src : the local address from which datagrams originate going to the peer
- *
  */
-type UDPEndpoint interface {
+type Endpoint interface {
        ClearSrc()           // clears the source address
        ClearDst()           // clears the destination address
        SrcToString() string // returns the local source address (ip:port)
        DstToString() string // returns the destination address (ip:port)
        DstToBytes() []byte  // used for mac2 cookie calculations
+       SetDst(string) error // used for manually setting the endpoint (uapi)
        DstIP() net.IP
        SrcIP() net.IP
 }
@@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error {
 
                for _, peer := range device.peers {
                        peer.mutex.Lock()
-                       peer.endpoint.value.ClearSrc()
+                       if peer.endpoint != nil {
+                               peer.endpoint.ClearSrc()
+                       }
                        peer.mutex.Unlock()
                }
 
index fb576b1f9aed352b774ff4a972a9955ceeeb28c6..46f873fa5f86f7c1fe6fe53ae0ce55234a78026f 100644 (file)
@@ -21,22 +21,24 @@ import (
  * See e.g. https://github.com/golang/go/issues/17930
  * So this code is remains platform dependent.
  */
-
-type Endpoint struct {
+type NativeEndpoint struct {
        src unix.RawSockaddrInet6
        dst unix.RawSockaddrInet6
 }
 
-type IPv4Source struct {
-       src     unix.RawSockaddrInet4
-       Ifindex int32
-}
-
 type NativeBind struct {
        sock4 int
        sock6 int
 }
 
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = NativeBind{}
+
+type IPv4Source struct {
+       src     unix.RawSockaddrInet4
+       Ifindex int32
+}
+
 func htons(val uint16) uint16 {
        var out [unsafe.Sizeof(val)]byte
        binary.BigEndian.PutUint16(out[:], val)
@@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 {
        return binary.BigEndian.Uint16((*tmp)[:])
 }
 
-func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
+func NewEndpoint() Endpoint {
+       return &NativeEndpoint{}
+}
+
+func CreateUDPBind(port uint16) (Bind, uint16, error) {
        var err error
        var bind NativeBind
 
@@ -99,28 +105,33 @@ func (bind NativeBind) Close() error {
        return err2
 }
 
-func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
-       return receive6(
+func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+       var end NativeEndpoint
+       n, err := receive6(
                bind.sock6,
                buff,
-               end,
+               &end,
        )
+       return n, &end, err
 }
 
-func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
-       return receive4(
+func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+       var end NativeEndpoint
+       n, err := receive4(
                bind.sock4,
                buff,
-               end,
+               &end,
        )
+       return n, &end, err
 }
 
-func (bind NativeBind) Send(buff []byte, end *Endpoint) error {
-       switch end.dst.Family {
+func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+       nend := end.(*NativeEndpoint)
+       switch nend.dst.Family {
        case unix.AF_INET6:
-               return send6(bind.sock6, end, buff)
+               return send6(bind.sock6, nend, buff)
        case unix.AF_INET:
-               return send4(bind.sock4, end, buff)
+               return send4(bind.sock4, nend, buff)
        default:
                return errors.New("Unknown address family of destination")
        }
@@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
        }
 }
 
-func (end *Endpoint) DstIP() net.IP {
-       switch end.dst.Family {
+func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
+       switch addr.Family {
        case unix.AF_INET6:
-               return end.dst.Addr[:]
+               return addr.Addr[:]
        case unix.AF_INET:
-               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+               ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
                return net.IPv4(
                        ptr.Addr[0],
                        ptr.Addr[1],
@@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP {
        }
 }
 
-func (end *Endpoint) DstToBytes() []byte {
+func (end *NativeEndpoint) SrcIP() net.IP {
+       return rawAddrToIP(end.src)
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+       return rawAddrToIP(end.dst)
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
        ptr := unsafe.Pointer(&end.src)
        arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
        return arr[:]
 }
 
-func (end *Endpoint) SrcToString() string {
+func (end *NativeEndpoint) SrcToString() string {
        return sockaddrToString(end.src)
 }
 
-func (end *Endpoint) DstToString() string {
+func (end *NativeEndpoint) DstToString() string {
        return sockaddrToString(end.dst)
 }
 
-func (end *Endpoint) ClearDst() {
+func (end *NativeEndpoint) ClearDst() {
        end.dst = unix.RawSockaddrInet6{}
 }
 
-func (end *Endpoint) ClearSrc() {
+func (end *NativeEndpoint) ClearSrc() {
        end.src = unix.RawSockaddrInet6{}
 }
 
@@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) {
        return fd, uint16(addr.Port), err
 }
 
-func (end *Endpoint) SetDst(s string) error {
+func (end *NativeEndpoint) SetDst(s string) error {
        addr, err := parseEndpoint(s)
        if err != nil {
                return err
@@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error {
        return errors.New("Failed to recognize IP address format")
 }
 
-func send6(sock int, end *Endpoint, buff []byte) error {
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
 
        // construct message header
 
@@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error {
        return errno
 }
 
-func send4(sock int, end *Endpoint, buff []byte) error {
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
 
        // construct message header
 
@@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error {
        return errno
 }
 
-func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
        // contruct message header
 
@@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
        return int(size), nil
 }
 
-func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
 
        // contruct message header
 
index 0085ceea298266f35a25907dccf28a98ec51032b..76235bd7b0c90934f254a4021e68e9f1a35e470b 100644 (file)
@@ -22,9 +22,9 @@ type Device struct {
        }
        net struct {
                mutex  sync.RWMutex
-               bind   UDPBind // bind interface
-               port   uint16  // listening port
-               fwmark uint32  // mark value (0 = disabled)
+               bind   Bind   // bind interface
+               port   uint16 // listening port
+               fwmark uint32 // mark value (0 = disabled)
        }
        mutex        sync.RWMutex
        privateKey   NoisePrivateKey
index a98fc973fa7a5d3391fd40b38717d0abe6f81e3a..f3eb6c28518e74377715d259c04c4624103695b5 100644 (file)
@@ -15,11 +15,8 @@ type Peer struct {
        keyPairs                    KeyPairs
        handshake                   Handshake
        device                      *Device
-       endpoint                    struct {
-               set   bool     // has a known endpoint been discovered
-               value Endpoint // source / destination cache
-       }
-       stats struct {
+       endpoint                    Endpoint
+       stats                       struct {
                txBytes           uint64 // bytes send to peer (endpoint)
                rxBytes           uint64 // bytes received from peer
                lastHandshakeNano int64  // nano seconds since epoch
@@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
 
        // reset endpoint
 
-       peer.endpoint.set = false
-       peer.endpoint.value.ClearDst()
-       peer.endpoint.value.ClearSrc()
+       peer.endpoint = nil
 
        // prepare queuing
 
@@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
        defer peer.device.net.mutex.RUnlock()
        peer.mutex.RLock()
        defer peer.mutex.RUnlock()
-       if !peer.endpoint.set {
+       if peer.endpoint == nil {
                return errors.New("No known endpoint for peer")
        }
-       return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
+       return peer.device.net.bind.Send(buffer, peer.endpoint)
 }
 
 /* Returns a short string identification for logging
  */
 func (peer *Peer) String() string {
-       if !peer.endpoint.set {
+       if peer.endpoint == nil {
                return fmt.Sprintf(
                        "peer(%d unknown %s)",
                        peer.id,
@@ -162,7 +157,7 @@ func (peer *Peer) String() string {
        return fmt.Sprintf(
                "peer(%d %s %s)",
                peer.id,
-               peer.endpoint.value.DstToString(),
+               peer.endpoint.DstToString(),
                base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
        )
 }
index b8b06f722409797502f86c43ed5f239653f841e2..27fdb8ac044a1b358c55579dbd0835293c5a9b7a 100644 (file)
@@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue(
        }
 }
 
-func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) {
+func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
 
        logDebug := device.log.Debug
        logDebug.Println("Routine, receive incomming, IP version:", IP)
@@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) {
 
                buffer := device.GetMessageBuffer()
 
-               var size int
-               var err error
+               var (
+                       err      error
+                       size     int
+                       endpoint Endpoint
+               )
 
                for {
 
                        // read next datagram
 
-                       var endpoint Endpoint
-
                        switch IP {
                        case ipv4.Version:
-                               size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
+                               size, endpoint, err = bind.ReceiveIPv4(buffer[:])
                        case ipv6.Version:
-                               size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
+                               size, endpoint, err = bind.ReceiveIPv6(buffer[:])
                        default:
                                return
                        }
@@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() {
 
                                        writer := bytes.NewBuffer(temp[:0])
                                        binary.Write(writer, binary.LittleEndian, reply)
-                                       device.net.bind.Send(
-                                               writer.Bytes(),
-                                               &elem.endpoint,
-                                       )
+                                       device.net.bind.Send(writer.Bytes(), elem.endpoint)
                                        if err != nil {
                                                logDebug.Println("Failed to send cookie reply:", err)
                                        }
@@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() {
                        // update endpoint
 
                        peer.mutex.Lock()
-                       peer.endpoint.set = true
-                       peer.endpoint.value = elem.endpoint
+                       peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
                        // create response
@@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() {
                        // update endpoint
 
                        peer.mutex.Lock()
-                       peer.endpoint.set = true
-                       peer.endpoint.value = elem.endpoint
+                       peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
                        logDebug.Println("Received handshake initation from", peer)
@@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        // update endpoint
 
                        peer.mutex.Lock()
-                       peer.endpoint.set = true
-                       peer.endpoint.value = elem.endpoint
+                       peer.endpoint = elem.endpoint
                        peer.mutex.Unlock()
 
                        // check for keep-alive
index e1d092953fea020a286526811885ac40be164b69..670ecc4e8296aa71218436876a833361cb202279 100644 (file)
@@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
                        defer peer.mutex.RUnlock()
                        send("public_key=" + peer.handshake.remoteStatic.ToHex())
                        send("preshared_key=" + peer.handshake.presharedKey.ToHex())
-                       if peer.endpoint.set {
-                               send("endpoint=" + peer.endpoint.value.DstToString())
+                       if peer.endpoint != nil {
+                               send("endpoint=" + peer.endpoint.DstToString())
                        }
 
                        nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
 
                        case "endpoint":
 
-                               // set endpoint destination and reset handshake timer
+                               // set endpoint destination
+
+                               err := func() error {
+                                       peer.mutex.Lock()
+                                       defer peer.mutex.Unlock()
+
+                                       endpoint := NewEndpoint()
+                                       if err := endpoint.SetDst(value); err != nil {
+                                               return err
+                                       }
+                                       peer.endpoint = endpoint
+                                       signalSend(peer.signal.handshakeReset)
+                                       return nil
+                               }()
 
-                               peer.mutex.Lock()
-                               err := peer.endpoint.value.SetDst(value)
-                               peer.endpoint.set = (err == nil)
-                               peer.mutex.Unlock()
                                if err != nil {
                                        logError.Println("Failed to set endpoint:", value)
                                        return &IPCError{Code: ipcErrorInvalid}
                                }
-                               signalSend(peer.signal.handshakeReset)
 
                        case "persistent_keepalive_interval":