]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
conn: fix getSrcFromControl() iteration
authorJordan Whited <jordan@tailscale.com>
Wed, 15 Mar 2023 03:28:07 +0000 (20:28 -0700)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 16 Mar 2023 16:45:41 +0000 (17:45 +0100)
We only expect a single control message in the normal case, but this
would loop infinitely if there were more.

Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
conn/sticky_linux.go
conn/sticky_linux_test.go

index 342e7396d25d0fd4732b7b0b71aad7209b8bc377..278eb195a344b95b9a54b3246d9ce802bb7398e3 100644 (file)
@@ -25,7 +25,7 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
        )
 
        for len(rem) > unix.SizeofCmsghdr {
-               hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
+               hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
                if err != nil {
                        return
                }
index 672b67e4c80f69c00874242fb75a79f8e50827dd..503c3427f84a5a907fc83d9bdfa0df00da86ea0b 100644 (file)
@@ -150,6 +150,34 @@ func Test_getSrcFromControl(t *testing.T) {
                        t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
                }
        })
+       t.Run("Multiple", func(t *testing.T) {
+               zeroControl := make([]byte, unix.CmsgSpace(0))
+               zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+               zeroHdr.SetLen(unix.CmsgLen(0))
+
+               control := make([]byte, srcControlSize)
+               hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+               hdr.Level = unix.IPPROTO_IP
+               hdr.Type = unix.IP_PKTINFO
+               hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+               info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+               info.Spec_dst = [4]byte{127, 0, 0, 1}
+               info.Ifindex = 5
+
+               combined := make([]byte, 0)
+               combined = append(combined, zeroControl...)
+               combined = append(combined, control...)
+
+               ep := &StdNetEndpoint{}
+               getSrcFromControl(combined, ep)
+
+               if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
+                       t.Errorf("unexpected address: %v", ep.src.Addr)
+               }
+               if ep.src.ifidx != 5 {
+                       t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+               }
+       })
 }
 
 func Test_listenConfig(t *testing.T) {