]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: fix StdNetBind fallback on Windows
authorJordan Whited <jordan@tailscale.com>
Mon, 6 Mar 2023 23:58:32 +0000 (15:58 -0800)
committerJason A. Donenfeld <Jason@zx2c4.com>
Fri, 10 Mar 2023 13:52:24 +0000 (14:52 +0100)
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind.
StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving
datagrams, specifically via the {Read,Write}Batch methods.
These methods are unimplemented on Windows and will return runtime
errors as a result. Additionally, only Linux benefits from these
x/net types for reading and writing, so we update StdNetBind to fall
back to the standard library net package for all platforms other than
Linux.

Reviewed-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bind_std.go
conn/bind_std_test.go [new file with mode: 0644]

index b9da4c3afef702bc32964cf9a4e6a35be8e8e25f..a842b12e4df47f4bbdfc7cbc7f8825bf5c7d7dec 100644 (file)
@@ -10,6 +10,7 @@ import (
        "errors"
        "net"
        "net/netip"
+       "runtime"
        "strconv"
        "sync"
        "syscall"
@@ -22,16 +23,21 @@ var (
        _ Bind = (*StdNetBind)(nil)
 )
 
-// StdNetBind implements Bind for all platforms except Windows.
+// StdNetBind implements Bind for all platforms. While Windows has its own Bind
+// (see bind_windows.go), it may fall back to StdNetBind.
+// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
+// methods for sending and receiving multiple datagrams per-syscall. See the
+// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
 type StdNetBind struct {
-       mu           sync.Mutex // protects following fields
-       ipv4         *net.UDPConn
-       ipv6         *net.UDPConn
-       blackhole4   bool
-       blackhole6   bool
-       ipv4PC       *ipv4.PacketConn
-       ipv6PC       *ipv6.PacketConn
-       udpAddrPool  sync.Pool
+       mu         sync.Mutex // protects following fields
+       ipv4       *net.UDPConn
+       ipv6       *net.UDPConn
+       blackhole4 bool
+       blackhole6 bool
+       ipv4PC     *ipv4.PacketConn // will be nil on non-Linux
+       ipv6PC     *ipv6.PacketConn // will be nil on non-Linux
+
+       udpAddrPool  sync.Pool // following fields are not guarded by mu
        ipv4MsgsPool sync.Pool
        ipv6MsgsPool sync.Pool
 }
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
 again:
        port := int(uport)
        var v4conn, v6conn *net.UDPConn
+       var v4pc *ipv4.PacketConn
+       var v6pc *ipv6.PacketConn
 
        v4conn, port, err = listenNet("udp4", port)
        if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
@@ -173,63 +181,92 @@ again:
        }
        var fns []ReceiveFunc
        if v4conn != nil {
-               fns = append(fns, s.receiveIPv4)
+               if runtime.GOOS == "linux" {
+                       v4pc = ipv4.NewPacketConn(v4conn)
+                       s.ipv4PC = v4pc
+               }
+               fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
                s.ipv4 = v4conn
        }
        if v6conn != nil {
-               fns = append(fns, s.receiveIPv6)
+               if runtime.GOOS == "linux" {
+                       v6pc = ipv6.NewPacketConn(v6conn)
+                       s.ipv6PC = v6pc
+               }
+               fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
                s.ipv6 = v6conn
        }
        if len(fns) == 0 {
                return nil, 0, syscall.EAFNOSUPPORT
        }
 
-       s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
-       s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
-
        return fns, uint16(port), nil
 }
 
-func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-       msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
-       defer s.ipv4MsgsPool.Put(msgs)
-       for i := range buffs {
-               (*msgs)[i].Buffers[0] = buffs[i]
-       }
-       numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
-       if err != nil {
-               return 0, err
-       }
-       for i := 0; i < numMsgs; i++ {
-               msg := &(*msgs)[i]
-               sizes[i] = msg.N
-               addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-               ep := asEndpoint(addrPort)
-               getSrcFromControl(msg.OOB, ep)
-               eps[i] = ep
+func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
+       return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+               msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
+               defer s.ipv4MsgsPool.Put(msgs)
+               for i := range buffs {
+                       (*msgs)[i].Buffers[0] = buffs[i]
+               }
+               var numMsgs int
+               if runtime.GOOS == "linux" {
+                       numMsgs, err = pc.ReadBatch(*msgs, 0)
+                       if err != nil {
+                               return 0, err
+                       }
+               } else {
+                       msg := &(*msgs)[0]
+                       msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+                       if err != nil {
+                               return 0, err
+                       }
+                       numMsgs = 1
+               }
+               for i := 0; i < numMsgs; i++ {
+                       msg := &(*msgs)[i]
+                       sizes[i] = msg.N
+                       addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+                       ep := asEndpoint(addrPort)
+                       getSrcFromControl(msg.OOB, ep)
+                       eps[i] = ep
+               }
+               return numMsgs, nil
        }
-       return numMsgs, nil
 }
 
-func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
-       msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
-       defer s.ipv6MsgsPool.Put(msgs)
-       for i := range buffs {
-               (*msgs)[i].Buffers[0] = buffs[i]
-       }
-       numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
-       if err != nil {
-               return 0, err
-       }
-       for i := 0; i < numMsgs; i++ {
-               msg := &(*msgs)[i]
-               sizes[i] = msg.N
-               addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
-               ep := asEndpoint(addrPort)
-               getSrcFromControl(msg.OOB, ep)
-               eps[i] = ep
+func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
+       return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+               msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
+               defer s.ipv4MsgsPool.Put(msgs)
+               for i := range buffs {
+                       (*msgs)[i].Buffers[0] = buffs[i]
+               }
+               var numMsgs int
+               if runtime.GOOS == "linux" {
+                       numMsgs, err = pc.ReadBatch(*msgs, 0)
+                       if err != nil {
+                               return 0, err
+                       }
+               } else {
+                       msg := &(*msgs)[0]
+                       msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
+                       if err != nil {
+                               return 0, err
+                       }
+                       numMsgs = 1
+               }
+               for i := 0; i < numMsgs; i++ {
+                       msg := &(*msgs)[i]
+                       sizes[i] = msg.N
+                       addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
+                       ep := asEndpoint(addrPort)
+                       getSrcFromControl(msg.OOB, ep)
+                       eps[i] = ep
+               }
+               return numMsgs, nil
        }
-       return numMsgs, nil
 }
 
 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
        if s.ipv4 != nil {
                err1 = s.ipv4.Close()
                s.ipv4 = nil
+               s.ipv4PC = nil
        }
        if s.ipv6 != nil {
                err2 = s.ipv6.Close()
                s.ipv6 = nil
+               s.ipv6PC = nil
        }
        s.blackhole4 = false
        s.blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
        s.mu.Lock()
        blackhole := s.blackhole4
        conn := s.ipv4
+       var (
+               pc4 *ipv4.PacketConn
+               pc6 *ipv6.PacketConn
+       )
        is6 := false
        if endpoint.DstIP().Is6() {
                blackhole = s.blackhole6
                conn = s.ipv6
+               pc6 = s.ipv6PC
                is6 = true
+       } else {
+               pc4 = s.ipv4PC
        }
        s.mu.Unlock()
 
@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
                return syscall.EAFNOSUPPORT
        }
        if is6 {
-               return s.send6(s.ipv6PC, endpoint, buffs)
+               return s.send6(conn, pc6, endpoint, buffs)
        } else {
-               return s.send4(s.ipv4PC, endpoint, buffs)
+               return s.send4(conn, pc4, endpoint, buffs)
        }
 }
 
-func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
        ua := s.udpAddrPool.Get().(*net.UDPAddr)
        as4 := ep.DstIP().As4()
        copy(ua.IP, as4[:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
                err   error
                start int
        )
-       for {
-               n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
-               if err != nil || n == len((*msgs)[start:len(buffs)]) {
-                       break
+       if runtime.GOOS == "linux" {
+               for {
+                       n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+                       if err != nil || n == len((*msgs)[start:len(buffs)]) {
+                               break
+                       }
+                       start += n
+               }
+       } else {
+               for i, buff := range buffs {
+                       _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+                       if err != nil {
+                               break
+                       }
                }
-               start += n
        }
        s.udpAddrPool.Put(ua)
        s.ipv4MsgsPool.Put(msgs)
        return err
 }
 
-func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
+func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
        ua := s.udpAddrPool.Get().(*net.UDPAddr)
        as16 := ep.DstIP().As16()
        copy(ua.IP, as16[:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
                err   error
                start int
        )
-       for {
-               n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
-               if err != nil || n == len((*msgs)[start:len(buffs)]) {
-                       break
+       if runtime.GOOS == "linux" {
+               for {
+                       n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
+                       if err != nil || n == len((*msgs)[start:len(buffs)]) {
+                               break
+                       }
+                       start += n
+               }
+       } else {
+               for i, buff := range buffs {
+                       _, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
+                       if err != nil {
+                               break
+                       }
                }
-               start += n
        }
        s.udpAddrPool.Put(ua)
        s.ipv6MsgsPool.Put(msgs)
diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go
new file mode 100644 (file)
index 0000000..76afa30
--- /dev/null
@@ -0,0 +1,22 @@
+package conn
+
+import "testing"
+
+func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
+       bind := NewStdNetBind().(*StdNetBind)
+       fns, _, err := bind.Open(0)
+       if err != nil {
+               t.Fatal(err)
+       }
+       bind.Close()
+       buffs := make([][]byte, 1)
+       buffs[0] = make([]byte, 1)
+       sizes := make([]int, 1)
+       eps := make([]Endpoint, 1)
+       for _, fn := range fns {
+               // The ReceiveFuncs must not access conn-related fields on StdNetBind
+               // unguarded. Close() nils the conn-related fields resulting in a panic
+               // if they violate the mutex.
+               fn(buffs, sizes, eps)
+       }
+}