]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
all: make conn.Bind.Open return a slice of receive functions
authorJosh Bleecher Snyder <josharian@gmail.com>
Wed, 31 Mar 2021 20:55:18 +0000 (13:55 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 2 Apr 2021 17:07:08 +0000 (11:07 -0600)
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.

Beneficial consequences:

* If there's no IPv6 support on a system,
  conn.Bind.Open can choose not to return a receive function for it,
  which is simpler than tracking that state in the bind.
  This simplification removes existing data races from both
  conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
  the conn.Bind no longer needs to add a separate muxing layer.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
conn/bind_linux.go
conn/bind_std.go
conn/bind_windows.go
conn/bindtest/bindtest.go
conn/conn.go
device/device.go
device/receive.go

index 70ea609d63cb519f7776757f2ea1021af2166122..9eec384aa73ac6301892567067ac954ae225aee6 100644 (file)
@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
 
 // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
 type LinuxSocketBind struct {
-       sock4    int
-       sock6    int
-       lastMark uint32
-       closing  sync.RWMutex
+       // mu guards sock4 and sock6 and the associated fds.
+       // As long as someone holds mu (read or write), the associated fds are valid.
+       mu    sync.RWMutex
+       sock4 int
+       sock6 int
 }
 
 func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
        return nil, errors.New("invalid IP address")
 }
 
-func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
+func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+
        var err error
        var newPort uint16
        var tries int
 
        if bind.sock4 != -1 || bind.sock6 != -1 {
-               return 0, ErrBindAlreadyOpen
+               return nil, 0, ErrBindAlreadyOpen
        }
 
        originalPort := port
 
 again:
        port = originalPort
+       var sock4, sock6 int
        // Attempt ipv6 bind, update port if successful.
-       bind.sock6, newPort, err = create6(port)
+       sock6, newPort, err = create6(port)
        if err != nil {
-               if err != syscall.EAFNOSUPPORT {
-                       return 0, err
+               if !errors.Is(err, syscall.EAFNOSUPPORT) {
+                       return nil, 0, err
                }
        } else {
                port = newPort
        }
 
        // Attempt ipv4 bind, update port if successful.
-       bind.sock4, newPort, err = create4(port)
+       sock4, newPort, err = create4(port)
        if err != nil {
-               if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
-                       unix.Close(bind.sock6)
+               if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+                       unix.Close(sock6)
                        tries++
                        goto again
                }
-               if err != syscall.EAFNOSUPPORT {
-                       unix.Close(bind.sock6)
-                       return 0, err
+               if !errors.Is(err, syscall.EAFNOSUPPORT) {
+                       unix.Close(sock6)
+                       return nil, 0, err
                }
        } else {
                port = newPort
        }
 
-       if bind.sock4 == -1 && bind.sock6 == -1 {
-               return 0, syscall.EAFNOSUPPORT
+       var fns []ReceiveFunc
+       if sock4 != -1 {
+               fns = append(fns, makeReceiveIPv4(sock4))
+               bind.sock4 = sock4
+       }
+       if sock6 != -1 {
+               fns = append(fns, makeReceiveIPv6(sock6))
+               bind.sock6 = sock6
+       }
+       if len(fns) == 0 {
+               return nil, 0, syscall.EAFNOSUPPORT
        }
-       return port, nil
+       return fns, port, nil
 }
 
 func (bind *LinuxSocketBind) SetMark(value uint32) error {
-       bind.closing.RLock()
-       defer bind.closing.RUnlock()
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
 
        if bind.sock6 != -1 {
                err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
                }
        }
 
-       bind.lastMark = value
        return nil
 }
 
 func (bind *LinuxSocketBind) Close() error {
-       var err1, err2 error
-       bind.closing.RLock()
+       // Take a readlock to shut down the sockets...
+       bind.mu.RLock()
        if bind.sock6 != -1 {
                unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
        }
        if bind.sock4 != -1 {
                unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
        }
-       bind.closing.RUnlock()
-       bind.closing.Lock()
+       bind.mu.RUnlock()
+       // ...and a write lock to close the fd.
+       // This ensures that no one else is using the fd.
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+       var err1, err2 error
        if bind.sock6 != -1 {
                err1 = unix.Close(bind.sock6)
                bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
                err2 = unix.Close(bind.sock4)
                bind.sock4 = -1
        }
-       bind.closing.Unlock()
 
        if err1 != nil {
                return err1
@@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
        return err2
 }
 
-func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
-       bind.closing.RLock()
-       defer bind.closing.RUnlock()
-
-       var end LinuxSocketEndpoint
-       if bind.sock6 == -1 {
-               return 0, nil, net.ErrClosed
+func makeReceiveIPv6(sock int) ReceiveFunc {
+       return func(buff []byte) (int, Endpoint, error) {
+               var end LinuxSocketEndpoint
+               n, err := receive6(sock, buff, &end)
+               return n, &end, err
        }
-       n, err := receive6(
-               bind.sock6,
-               buff,
-               &end,
-       )
-       return n, &end, err
 }
 
-func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
-       bind.closing.RLock()
-       defer bind.closing.RUnlock()
-
-       var end LinuxSocketEndpoint
-       if bind.sock4 == -1 {
-               return 0, nil, net.ErrClosed
+func makeReceiveIPv4(sock int) ReceiveFunc {
+       return func(buff []byte) (int, Endpoint, error) {
+               var end LinuxSocketEndpoint
+               n, err := receive4(sock, buff, &end)
+               return n, &end, err
        }
-       n, err := receive4(
-               bind.sock4,
-               buff,
-               &end,
-       )
-       return n, &end, err
 }
 
 func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
-       bind.closing.RLock()
-       defer bind.closing.RUnlock()
-
        nend, ok := end.(*LinuxSocketEndpoint)
        if !ok {
                return ErrWrongEndpointType
        }
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
        if !nend.isV6 {
                if bind.sock4 == -1 {
                        return net.ErrClosed
index f8b8a1b055ed1ea94864615fd65b54fc77818d78..52617793bd911b2c8c761890349993ed21963374 100644 (file)
@@ -8,6 +8,7 @@ package conn
 import (
        "errors"
        "net"
+       "sync"
        "syscall"
 )
 
@@ -16,6 +17,7 @@ import (
 // It uses the Go's net package to implement networking.
 // See LinuxSocketBind for a proper implementation on the Linux platform.
 type StdNetBind struct {
+       mu         sync.Mutex // protects following fields
        ipv4       *net.UDPConn
        ipv6       *net.UDPConn
        blackhole4 bool
@@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
        return conn, uaddr.Port, nil
 }
 
-func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
+func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+
        var err error
        var tries int
 
        if bind.ipv4 != nil || bind.ipv6 != nil {
-               return 0, ErrBindAlreadyOpen
+               return nil, 0, ErrBindAlreadyOpen
        }
 
        // Attempt to open ipv4 and ipv6 listeners on the same port.
@@ -97,7 +102,7 @@ again:
 
        ipv4, port, err = listenNet("udp4", port)
        if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
-               return 0, err
+               return nil, 0, err
        }
 
        // Listen on the same port as we're using for ipv4.
@@ -109,17 +114,27 @@ again:
        }
        if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
                ipv4.Close()
-               return 0, err
+               return nil, 0, err
        }
-       if ipv4 == nil && ipv6 == nil {
-               return 0, syscall.EAFNOSUPPORT
+       var fns []ReceiveFunc
+       if ipv4 != nil {
+               fns = append(fns, makeReceiveFunc(ipv4, true))
+               bind.ipv4 = ipv4
        }
-       bind.ipv4 = ipv4
-       bind.ipv6 = ipv6
-       return uint16(port), nil
+       if ipv6 != nil {
+               fns = append(fns, makeReceiveFunc(ipv6, false))
+               bind.ipv6 = ipv6
+       }
+       if len(fns) == 0 {
+               return nil, 0, syscall.EAFNOSUPPORT
+       }
+       return fns, uint16(port), nil
 }
 
 func (bind *StdNetBind) Close() error {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
+
        var err1, err2 error
        if bind.ipv4 != nil {
                err1 = bind.ipv4.Close()
@@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
        return err2
 }
 
-func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
-       if bind.ipv4 == nil {
-               return 0, nil, syscall.EAFNOSUPPORT
+func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
+       return func(buff []byte) (int, Endpoint, error) {
+               n, endpoint, err := conn.ReadFromUDP(buff)
+               if isIPv4 && endpoint != nil {
+                       endpoint.IP = endpoint.IP.To4()
+               }
+               return n, (*StdNetEndpoint)(endpoint), err
        }
-       n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
-       if endpoint != nil {
-               endpoint.IP = endpoint.IP.To4()
-       }
-       return n, (*StdNetEndpoint)(endpoint), err
-}
-
-func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
-       if bind.ipv6 == nil {
-               return 0, nil, syscall.EAFNOSUPPORT
-       }
-       n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
-       return n, (*StdNetEndpoint)(endpoint), err
 }
 
 func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
        if !ok {
                return ErrWrongEndpointType
        }
-       var conn *net.UDPConn
-       var blackhole bool
-       if nend.IP.To4() != nil {
-               blackhole = bind.blackhole4
-               conn = bind.ipv4
-       } else {
+
+       bind.mu.Lock()
+       blackhole := bind.blackhole4
+       conn := bind.ipv4
+       if nend.IP.To4() == nil {
                blackhole = bind.blackhole6
                conn = bind.ipv6
        }
+       bind.mu.Unlock()
+
        if blackhole {
                return nil
        }
index 1e2712eaa22cf32a5bd461bd5e7854c62c92fa4d..6cabee19c7d12485f2261bed710df26c0c13cde2 100644 (file)
@@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
        return sa, nil
 }
 
-func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
+func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
        bind.mu.Lock()
        defer bind.mu.Unlock()
        defer func() {
@@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
                }
        }()
        if atomic.LoadUint32(&bind.isOpen) != 0 {
-               return 0, ErrBindAlreadyOpen
+               return nil, 0, ErrBindAlreadyOpen
        }
        var sa windows.Sockaddr
        sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
        if err != nil {
-               return 0, err
+               return nil, 0, err
        }
        sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
        if err != nil {
-               return 0, err
+               return nil, 0, err
        }
        selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
        for i := 0; i < packetsPerRing; i++ {
                err = bind.v4.InsertReceiveRequest()
                if err != nil {
-                       return 0, err
+                       return nil, 0, err
                }
                err = bind.v6.InsertReceiveRequest()
                if err != nil {
-                       return 0, err
+                       return nil, 0, err
                }
        }
        atomic.StoreUint32(&bind.isOpen, 1)
-       return
+       return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
 }
 
 func (bind *WinRingBind) Close() error {
@@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
        return n, &ep, nil
 }
 
-func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
        bind.mu.RLock()
        defer bind.mu.RUnlock()
        return bind.v4.Receive(buf, &bind.isOpen)
 }
 
-func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
        bind.mu.RLock()
        defer bind.mu.RUnlock()
        return bind.v6.Receive(buf, &bind.isOpen)
@@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
 }
 
 func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
        sysconn, err := bind.ipv4.SyscallConn()
        if err != nil {
                return err
@@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
 }
 
 func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+       bind.mu.Lock()
+       defer bind.mu.Unlock()
        sysconn, err := bind.ipv6.SyscallConn()
        if err != nil {
                return err
index ad8fa05efe7b355c6e250f34e8081b47b355458c..7d43fb30b5485d8e513fe471564d4704c250bb2d 100644 (file)
@@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
 
 func (c ChannelEndpoint) SrcIP() net.IP { return nil }
 
-func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
+func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
        c.closeSignal = make(chan bool)
+       fns = append(fns, c.makeReceiveFunc(*c.rx4))
+       fns = append(fns, c.makeReceiveFunc(*c.rx6))
        if rand.Uint32()&1 == 0 {
-               return uint16(c.source4), nil
+               return fns, uint16(c.source4), nil
        } else {
-               return uint16(c.source6), nil
+               return fns, uint16(c.source6), nil
        }
 }
 
@@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error {
 
 func (c *ChannelBind) SetMark(mark uint32) error { return nil }
 
-func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
-       select {
-       case <-c.closeSignal:
-               return 0, nil, net.ErrClosed
-       case rx := <-*c.rx6:
-               return copy(b, rx), c.target6, nil
-       }
-}
-
-func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
-       select {
-       case <-c.closeSignal:
-               return 0, nil, net.ErrClosed
-       case rx := <-*c.rx4:
-               return copy(b, rx), c.target4, nil
+func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
+       return func(b []byte) (n int, ep conn.Endpoint, err error) {
+               select {
+               case <-c.closeSignal:
+                       return 0, nil, net.ErrClosed
+               case rx := <-ch:
+                       return copy(b, rx), c.target6, nil
+               }
        }
 }
 
index 6fd232f1b96daaf09fcddb58d3a715099da41211..3c7fcd0098c10c12349f1a6d50706313e039544e 100644 (file)
@@ -12,6 +12,11 @@ import (
        "strings"
 )
 
+// A ReceiveFunc receives a single inbound packet from the network.
+// It writes the data into b. n is the length of the packet.
+// ep is the remote endpoint.
+type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+
 // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
 //
 // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
@@ -19,23 +24,17 @@ import (
 type Bind interface {
        // Open puts the Bind into a listening state on a given port and reports the actual
        // port that it bound to. Passing zero results in a random selection.
-       Open(port uint16) (actualPort uint16, err error)
+       // fns is the set of functions that will be called to receive packets.
+       Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
 
        // Close closes the Bind listener.
+       // All fns returned by Open must return net.ErrClosed after a call to Close.
        Close() error
 
        // SetMark sets the mark for each packet sent through this Bind.
        // This mark is passed to the kernel as the socket option SO_MARK.
        SetMark(mark uint32) error
 
-       // ReceiveIPv6 reads an IPv6 UDP packet into b.  It reports the number of bytes read,
-       // n, the packet source address ep, and any error.
-       ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
-
-       // ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
-       // n, the packet source address ep, and any error.
-       ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
-
        // Send writes a packet b to address ep.
        Send(b []byte, ep Endpoint) error
 
index 1e32db67d13afa89f01d4bde3a7f335e3d38f309..a635e687cdfe2b549cc2bb0e80d30288f6c92f16 100644 (file)
@@ -11,9 +11,6 @@ import (
        "sync/atomic"
        "time"
 
-       "golang.org/x/net/ipv4"
-       "golang.org/x/net/ipv6"
-
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/ratelimiter"
        "golang.zx2c4.com/wireguard/rwcancel"
@@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error {
 
        // bind to new port
        var err error
+       var recvFns []conn.ReceiveFunc
        netc := &device.net
-       netc.port, err = netc.bind.Open(netc.port)
+       recvFns, netc.port, err = netc.bind.Open(netc.port)
        if err != nil {
                netc.port = 0
                return err
@@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error {
        device.peers.RUnlock()
 
        // start receiving routines
-       device.net.stopping.Add(2)
-       device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
-       device.queue.handshake.wg.Add(2)  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
-       go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
-       go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+       device.net.stopping.Add(len(recvFns))
+       device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+       device.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+       for _, fn := range recvFns {
+               go device.RoutineReceiveIncoming(fn)
+       }
 
        device.log.Verbosef("UDP bind has been updated")
        return nil
index 5ddb66c015f1ae2aabb174ea786d11d80a1cd1f0..fa5c0a603f4a57be7ec355270ae86ee5b347a2a0 100644 (file)
@@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
  * Every time the bind is updated a new routine is started for
  * IPv4 and IPv6 (separately)
  */
-func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
+func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
        defer func() {
-               device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
+               device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
                device.queue.decryption.wg.Done()
                device.queue.handshake.wg.Done()
                device.net.stopping.Done()
        }()
 
-       device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
+       device.log.Verbosef("Routine: receive incoming %p - started", recv)
 
        // receive datagrams until conn is closed
 
@@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
        )
 
        for {
-               switch IP {
-               case ipv4.Version:
-                       size, endpoint, err = bind.ReceiveIPv4(buffer[:])
-               case ipv6.Version:
-                       size, endpoint, err = bind.ReceiveIPv6(buffer[:])
-               default:
-                       panic("invalid IP version")
-               }
+               size, endpoint, err = recv(buffer[:])
 
                if err != nil {
                        device.PutMessageBuffer(buffer)