]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
fix addressing and deadlines
authortqbf <thomas@fly.io>
Wed, 2 Feb 2022 18:14:39 +0000 (12:14 -0600)
committertqbf <thomas@fly.io>
Wed, 2 Feb 2022 18:14:39 +0000 (12:14 -0600)
- setting a now deadline unblocks a blocked read
- setting a specific deadline over a previous deadline honors
  the new one
- WriteTo will accept a net.Addr, not just a PingAddr

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
tun/netstack/tun.go

index f0e954b310db69bdc9a2d687fdd5082702338357..058aca50f464870c0aab66989f74b34efc59d14f 100644 (file)
@@ -17,6 +17,7 @@ import (
        "regexp"
        "strconv"
        "strings"
+       "sync"
        "time"
 
        "golang.zx2c4.com/go118/netip"
@@ -285,11 +286,13 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
 }
 
 type PingConn struct {
-       laddr    PingAddr
-       raddr    PingAddr
-       wq       waiter.Queue
-       ep       tcpip.Endpoint
-       deadline time.Time
+       laddr           PingAddr
+       raddr           PingAddr
+       wq              waiter.Queue
+       ep              tcpip.Endpoint
+       mu              sync.RWMutex
+       deadline        time.Time
+       deadlineBreaker chan struct{}
 }
 
 type PingAddr struct{ addr netip.Addr }
@@ -307,6 +310,20 @@ func (ia PingAddr) Network() string {
        return "ping"
 }
 
+func PingAddrFromAddr(addr net.Addr) (PingAddr, error) {
+       switch v := addr.(type) {
+       case PingAddr:
+               return v, nil
+
+       case *net.IPAddr:
+               nip := netip.AddrFromSlice(v.IP)
+               return PingAddr{nip}, nil
+
+       default:
+               return PingAddr{}, errors.New("wrong address format")
+       }
+}
+
 func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
        v6 := laddr.Is6() || raddr.Is6()
        bind := laddr.IsValid()
@@ -325,7 +342,10 @@ func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
                pn = ipv6.ProtocolNumber
        }
 
-       pc := &PingConn{laddr: PingAddr{laddr}}
+       pc := &PingConn{
+               laddr:           PingAddr{laddr},
+               deadlineBreaker: make(chan struct{}, 1),
+       }
 
        ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
        if tcpipErr != nil {
@@ -360,6 +380,7 @@ func (pc *PingConn) RemoteAddr() net.Addr {
 }
 
 func (pc *PingConn) Close() error {
+       close(pc.deadlineBreaker)
        pc.ep.Close()
        return nil
 }
@@ -369,8 +390,11 @@ func (pc *PingConn) SetWriteDeadline(t time.Time) error {
 }
 
 func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
-       ia, ok := addr.(PingAddr)
-       if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
+       ia, err := PingAddrFromAddr(addr)
+       if err != nil {
+               return 0, fmt.Errorf("ping write: %w", err)
+       }
+       if !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
                return 0, fmt.Errorf("ping write: mismatched protocols")
        }
 
@@ -409,15 +433,32 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
        pc.wq.EventRegister(&e, waiter.EventIn)
        defer pc.wq.EventUnregister(&e)
 
-       deadline := pc.deadline
+       ready := false
 
-       if deadline.IsZero() {
-               <-notifyCh
-       } else {
-               select {
-               case <-time.NewTimer(deadline.Sub(time.Now())).C:
-                       return 0, nil, os.ErrDeadlineExceeded
-               case <-notifyCh:
+       for !ready {
+               pc.mu.RLock()
+               deadlineBreaker := pc.deadlineBreaker
+               deadline := pc.deadline
+               pc.mu.RUnlock()
+
+               if deadline.IsZero() {
+                       select {
+                       case <-deadlineBreaker:
+                       case <-notifyCh:
+                               ready = true
+                       }
+               } else {
+                       t := time.NewTimer(deadline.Sub(time.Now()))
+                       defer t.Stop()
+
+                       select {
+                       case <-t.C:
+                               return 0, nil, os.ErrDeadlineExceeded
+
+                       case <-deadlineBreaker:
+                       case <-notifyCh:
+                               ready = true
+                       }
                }
        }
 
@@ -452,6 +493,10 @@ func (pc *PingConn) SetDeadline(t time.Time) error {
 }
 
 func (pc *PingConn) SetReadDeadline(t time.Time) error {
+       pc.mu.Lock()
+       defer pc.mu.Unlock()
+       close(pc.deadlineBreaker)
+       pc.deadlineBreaker = make(chan struct{}, 1)
        pc.deadline = t
        return nil
 }