]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun/netstack: implement ICMP ping
authorThomas H. Ptacek <thomas@sockpuppet.org>
Mon, 31 Jan 2022 22:55:36 +0000 (16:55 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Wed, 2 Feb 2022 22:09:37 +0000 (23:09 +0100)
Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/netstack/examples/ping_client.go [new file with mode: 0644]
tun/netstack/tun.go

diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
new file mode 100644 (file)
index 0000000..cbd54b8
--- /dev/null
@@ -0,0 +1,76 @@
+//go:build ignore
+// +build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+       "bytes"
+       "log"
+       "math/rand"
+       "time"
+
+       "golang.org/x/net/icmp"
+       "golang.org/x/net/ipv4"
+
+       "golang.zx2c4.com/go118/netip"
+       "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/device"
+       "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+       tun, tnet, err := netstack.CreateNetTUN(
+               []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+               []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+               1420)
+       if err != nil {
+               log.Panic(err)
+       }
+       dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+       dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
+public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
+endpoint=163.172.161.0:12912
+allowed_ip=0.0.0.0/0
+`)
+       err = dev.Up()
+       if err != nil {
+               log.Panic(err)
+       }
+
+       socket, err := tnet.Dial("ping4", "zx2c4.com")
+       if err != nil {
+               log.Panic(err)
+       }
+       requestPing := icmp.Echo{
+               Seq:  rand.Intn(1 << 16),
+               Data: []byte("gopher burrow"),
+       }
+       icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
+       socket.SetReadDeadline(time.Now().Add(time.Second * 10))
+       start := time.Now()
+       _, err = socket.Write(icmpBytes)
+       if err != nil {
+               log.Panic(err)
+       }
+       n, err := socket.Read(icmpBytes[:])
+       if err != nil {
+               log.Panic(err)
+       }
+       replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
+       if err != nil {
+               log.Panic(err)
+       }
+       replyPing, ok := replyPacket.Body.(*icmp.Echo)
+       if !ok {
+               log.Panicf("invalid reply type: %v", replyPacket)
+       }
+       if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
+               log.Panicf("invalid ping reply: %v", replyPing)
+       }
+       log.Printf("Ping latency: %v", time.Since(start))
+}
index fb7f07d5484195f8af5abd4da81be216b8925e55..97983b415f2d63ed824cc04a519fece84df2ebd6 100644 (file)
@@ -14,8 +14,10 @@ import (
        "io"
        "net"
        "os"
+       "regexp"
        "strconv"
        "strings"
+       "sync"
        "time"
 
        "golang.zx2c4.com/go118/netip"
@@ -29,8 +31,10 @@ import (
        "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
        "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
        "gvisor.dev/gvisor/pkg/tcpip/stack"
+       "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
        "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
        "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+       "gvisor.dev/gvisor/pkg/waiter"
 )
 
 type netTun struct {
@@ -101,7 +105,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
        opts := stack.Options{
                NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
-               TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+               TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
                HandleLocal:        true,
        }
        dev := &netTun{
@@ -270,6 +274,10 @@ func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, er
        return gonet.DialUDP(net.stack, lfa, rfa, pn)
 }
 
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+       return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
        var la, ra netip.AddrPort
        if laddr != nil {
@@ -281,6 +289,233 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
        return net.DialUDPAddrPort(la, ra)
 }
 
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+       return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+       laddr           PingAddr
+       raddr           PingAddr
+       wq              waiter.Queue
+       ep              tcpip.Endpoint
+       mu              sync.RWMutex
+       deadline        time.Time
+       deadlineBreaker chan struct{}
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+       return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+       if ia.addr.Is4() {
+               return "ping4"
+       } else if ia.addr.Is6() {
+               return "ping6"
+       }
+       return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+       return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+       return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+       if !laddr.IsValid() && !raddr.IsValid() {
+               return nil, errors.New("ping dial: invalid address")
+       }
+       v6 := laddr.Is6() || raddr.Is6()
+       bind := laddr.IsValid()
+       if !bind {
+               if v6 {
+                       laddr = netip.IPv6Unspecified()
+               } else {
+                       laddr = netip.IPv4Unspecified()
+               }
+       }
+
+       tn := icmp.ProtocolNumber4
+       pn := ipv4.ProtocolNumber
+       if v6 {
+               tn = icmp.ProtocolNumber6
+               pn = ipv6.ProtocolNumber
+       }
+
+       pc := &PingConn{
+               laddr:           PingAddr{laddr},
+               deadlineBreaker: make(chan struct{}, 1),
+       }
+
+       ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+       if tcpipErr != nil {
+               return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+       }
+       pc.ep = ep
+
+       if bind {
+               fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+               if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+                       return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+               }
+       }
+
+       if raddr.IsValid() {
+               pc.raddr = PingAddr{raddr}
+               fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+               if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+                       return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+               }
+       }
+
+       return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+       return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+       var la, ra netip.Addr
+       if laddr != nil {
+               la = laddr.addr
+       }
+       if raddr != nil {
+               ra = raddr.addr
+       }
+       return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+       var la netip.Addr
+       if laddr != nil {
+               la = laddr.addr
+       }
+       return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+       return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+       return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+       close(pc.deadlineBreaker)
+       pc.ep.Close()
+       return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+       return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+       var na netip.Addr
+       switch v := addr.(type) {
+       case *PingAddr:
+               na = v.addr
+       case *net.IPAddr:
+               na = netip.AddrFromSlice(v.IP)
+       default:
+               return 0, fmt.Errorf("ping write: wrong net.Addr type")
+       }
+       if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+               return 0, fmt.Errorf("ping write: mismatched protocols")
+       }
+
+       buf := buffer.NewViewFromBytes(p)
+       rdr := buf.Reader()
+       rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+       // won't block, no deadlines
+       n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
+               To: &rfa,
+       })
+       if tcpipErr != nil {
+               return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+       }
+
+       return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+       return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+       e, notifyCh := waiter.NewChannelEntry(nil)
+       pc.wq.EventRegister(&e, waiter.EventIn)
+       defer pc.wq.EventUnregister(&e)
+
+       ready := false
+
+       for !ready {
+               pc.mu.RLock()
+               deadlineBreaker := pc.deadlineBreaker
+               deadline := pc.deadline
+               pc.mu.RUnlock()
+
+               if deadline.IsZero() {
+                       select {
+                       case <-deadlineBreaker:
+                       case <-notifyCh:
+                               ready = true
+                       }
+               } else {
+                       t := time.NewTimer(deadline.Sub(time.Now()))
+                       defer t.Stop()
+
+                       select {
+                       case <-t.C:
+                               return 0, nil, os.ErrDeadlineExceeded
+
+                       case <-deadlineBreaker:
+                       case <-notifyCh:
+                               ready = true
+                       }
+               }
+       }
+
+       w := tcpip.SliceWriter(p)
+
+       res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+               NeedRemoteAddr: true,
+       })
+       if tcpipErr != nil {
+               return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+       }
+
+       addr = &PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
+       return res.Count, addr, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+       n, _, err = pc.ReadFrom(p)
+       return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+       // pc.SetWriteDeadline is unimplemented
+
+       return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+       pc.mu.Lock()
+       defer pc.mu.Unlock()
+       close(pc.deadlineBreaker)
+       pc.deadlineBreaker = make(chan struct{}, 1)
+       pc.deadline = t
+       return nil
+}
+
 var (
        errNoSuchHost                   = errors.New("no such host")
        errLameReferral                 = errors.New("lame referral")
@@ -755,33 +990,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
        return now.Add(timeout), nil
 }
 
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
 func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
        if ctx == nil {
                panic("nil context")
        }
-       var acceptV4, acceptV6, useUDP bool
-       if len(network) == 3 {
+       var acceptV4, acceptV6 bool
+       matches := protoSplitter.FindStringSubmatch(network)
+       if matches == nil {
+               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+       } else if len(matches[2]) == 0 {
                acceptV4 = true
                acceptV6 = true
-       } else if len(network) == 4 {
-               acceptV4 = network[3] == '4'
-               acceptV6 = network[3] == '6'
-       }
-       if !acceptV4 && !acceptV6 {
-               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
-       }
-       if network[:3] == "udp" {
-               useUDP = true
-       } else if network[:3] != "tcp" {
-               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
-       }
-       host, sport, err := net.SplitHostPort(address)
-       if err != nil {
-               return nil, &net.OpError{Op: "dial", Err: err}
+       } else {
+               acceptV4 = matches[2][0] == '4'
+               acceptV6 = !acceptV4
        }
-       port, err := strconv.Atoi(sport)
-       if err != nil || port < 0 || port > 65535 {
-               return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+       var host string
+       var port int
+       if matches[1] == "ping" {
+               host = address
+       } else {
+               var sport string
+               var err error
+               host, sport, err = net.SplitHostPort(address)
+               if err != nil {
+                       return nil, &net.OpError{Op: "dial", Err: err}
+               }
+               port, err = strconv.Atoi(sport)
+               if err != nil || port < 0 || port > 65535 {
+                       return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+               }
        }
        allAddr, err := tnet.LookupContextHost(ctx, host)
        if err != nil {
@@ -829,10 +1069,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
                }
 
                var c net.Conn
-               if useUDP {
-                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
-               } else {
+               switch matches[1] {
+               case "tcp":
                        c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+               case "udp":
+                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+               case "ping":
+                       c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
                }
                if err == nil {
                        return c, nil