]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: linux: protect read fds
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 20 May 2021 16:09:55 +0000 (18:09 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 20 May 2021 16:09:55 +0000 (18:09 +0200)
The -1 protection was removed and the wrong error was returned, causing
us to read from a bogus fd. As well, remove the useless closures that
aren't doing anything, since this is all synchronized anyway.

Fixes: 10533c3 ("all: make conn.Bind.Open return a slice of receive functions")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/bind_linux.go

index c10d8b68e15fb04895f7ebf3cdeac4aa6f519523..7b970e65a0a3f0a609f6a4b74681bb35797aaa0f 100644 (file)
@@ -148,12 +148,12 @@ again:
 
        var fns []ReceiveFunc
        if sock4 != -1 {
-               fns = append(fns, bind.makeReceiveIPv4(sock4))
                bind.sock4 = sock4
+               fns = append(fns, bind.receiveIPv4)
        }
        if sock6 != -1 {
-               fns = append(fns, bind.makeReceiveIPv6(sock6))
                bind.sock6 = sock6
+               fns = append(fns, bind.receiveIPv6)
        }
        if len(fns) == 0 {
                return nil, 0, syscall.EAFNOSUPPORT
@@ -224,20 +224,26 @@ func (bind *LinuxSocketBind) Close() error {
        return err2
 }
 
-func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
-       return func(buff []byte) (int, Endpoint, error) {
-               var end LinuxSocketEndpoint
-               n, err := receive6(sock, buff, &end)
-               return n, &end, err
+func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       if bind.sock4 == -1 {
+               return 0, nil, net.ErrClosed
        }
+       var end LinuxSocketEndpoint
+       n, err := receive4(bind.sock4, buf, &end)
+       return n, &end, err
 }
 
-func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
-       return func(buff []byte) (int, Endpoint, error) {
-               var end LinuxSocketEndpoint
-               n, err := receive4(sock, buff, &end)
-               return n, &end, err
+func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+       bind.mu.RLock()
+       defer bind.mu.RUnlock()
+       if bind.sock6 == -1 {
+               return 0, nil, net.ErrClosed
        }
+       var end LinuxSocketEndpoint
+       n, err := receive6(bind.sock6, buf, &end)
+       return n, &end, err
 }
 
 func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {