]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
tun/netstack: implement ICMP ping
authorThomas H. Ptacek <thomas@sockpuppet.org>
Mon, 31 Jan 2022 22:55:36 +0000 (16:55 -0600)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 1 Feb 2022 19:16:42 +0000 (20:16 +0100)
Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Currently is missing:
- Write deadlines
- Context support

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
tun/netstack/examples/ping_client.go [new file with mode: 0644]
tun/netstack/tun.go

diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
new file mode 100644 (file)
index 0000000..843a3ee
--- /dev/null
@@ -0,0 +1,57 @@
+//go:build ignore
+// +build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+       "log"
+       "time"
+
+       "golang.zx2c4.com/go118/netip"
+       "golang.zx2c4.com/wireguard/conn"
+       "golang.zx2c4.com/wireguard/device"
+       "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+       tun, tnet, err := netstack.CreateNetTUN(
+               []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+               []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+               1420)
+       if err != nil {
+               log.Panic(err)
+       }
+       dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+       dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
+public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
+endpoint=163.172.161.0:12912
+allowed_ip=0.0.0.0/0
+`)
+       err = dev.Up()
+       if err != nil {
+               log.Panic(err)
+       }
+
+       socket, err := tnet.Dial("ping4", "zx2c4.com")
+       if err != nil {
+               log.Panic(err)
+       }
+       const payload = "gopher burrow"
+       socket.SetReadDeadline(time.Now().Add(time.Second * 10))
+       start := time.Now()
+       _, err = socket.Write([]byte(payload))
+       if err != nil {
+               log.Panic(err)
+       }
+       var reply [len(payload)]byte
+       n, err := socket.Read(reply[:])
+       if err != nil || string(reply[:n]) != payload {
+               log.Panic(err)
+       }
+       log.Printf("Ping latency: %v", time.Since(start))
+}
index fb7f07d5484195f8af5abd4da81be216b8925e55..f0e954b310db69bdc9a2d687fdd5082702338357 100644 (file)
@@ -14,6 +14,7 @@ import (
        "io"
        "net"
        "os"
+       "regexp"
        "strconv"
        "strings"
        "time"
@@ -29,8 +30,10 @@ import (
        "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
        "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
        "gvisor.dev/gvisor/pkg/tcpip/stack"
+       "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
        "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
        "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+       "gvisor.dev/gvisor/pkg/waiter"
 )
 
 type netTun struct {
@@ -101,7 +104,7 @@ func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.Network
 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
        opts := stack.Options{
                NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
-               TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
+               TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
                HandleLocal:        true,
        }
        dev := &netTun{
@@ -281,6 +284,178 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
        return net.DialUDPAddrPort(la, ra)
 }
 
+type PingConn struct {
+       laddr    PingAddr
+       raddr    PingAddr
+       wq       waiter.Queue
+       ep       tcpip.Endpoint
+       deadline time.Time
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+       return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+       if ia.addr.Is4() {
+               return "ping4"
+       } else if ia.addr.Is6() {
+               return "ping6"
+       }
+       return "ping"
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+       v6 := laddr.Is6() || raddr.Is6()
+       bind := laddr.IsValid()
+       if !bind {
+               if v6 {
+                       laddr = netip.IPv6Unspecified()
+               } else {
+                       laddr = netip.IPv4Unspecified()
+               }
+       }
+
+       tn := icmp.ProtocolNumber4
+       pn := ipv4.ProtocolNumber
+       if v6 {
+               tn = icmp.ProtocolNumber6
+               pn = ipv6.ProtocolNumber
+       }
+
+       pc := &PingConn{laddr: PingAddr{laddr}}
+
+       ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+       if tcpipErr != nil {
+               return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+       }
+       pc.ep = ep
+
+       if bind {
+               fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+               if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+                       return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+               }
+       }
+
+       if raddr.IsValid() {
+               pc.raddr = PingAddr{raddr}
+               fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+               if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+                       return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+               }
+       }
+
+       return pc, nil
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+       return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+       return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+       pc.ep.Close()
+       return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+       return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+       ia, ok := addr.(PingAddr)
+       if !ok || !((ia.addr.Is4() && pc.laddr.addr.Is4()) || (ia.addr.Is6() && pc.laddr.addr.Is6())) {
+               return 0, fmt.Errorf("ping write: mismatched protocols")
+       }
+
+       var buf buffer.View
+       if ia.addr.Is4() {
+               buf = buffer.NewView(header.ICMPv4MinimumSize + len(p))
+               copy(buf[header.ICMPv4MinimumSize:], p)
+               icmp := header.ICMPv4(buf)
+               icmp.SetType(header.ICMPv4Echo)
+       } else if ia.addr.Is6() {
+               buf = buffer.NewView(header.ICMPv6MinimumSize + len(p))
+               copy(buf[header.ICMPv6MinimumSize:], p)
+               icmp := header.ICMPv6(buf)
+               icmp.SetType(header.ICMPv6EchoRequest)
+       }
+
+       rdr := buf.Reader()
+       rfa, _ := convertToFullAddr(netip.AddrPortFrom(ia.addr, 0))
+       // won't block, no deadlines
+       n64, tcpipErr := pc.ep.Write(&rdr, tcpip.WriteOptions{
+               To: &rfa,
+       })
+       if tcpipErr != nil {
+               return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+       }
+
+       return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+       return pc.WriteTo(p, pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+       e, notifyCh := waiter.NewChannelEntry(nil)
+       pc.wq.EventRegister(&e, waiter.EventIn)
+       defer pc.wq.EventUnregister(&e)
+
+       deadline := pc.deadline
+
+       if deadline.IsZero() {
+               <-notifyCh
+       } else {
+               select {
+               case <-time.NewTimer(deadline.Sub(time.Now())).C:
+                       return 0, nil, os.ErrDeadlineExceeded
+               case <-notifyCh:
+               }
+       }
+
+       min := header.ICMPv6MinimumSize
+       if pc.laddr.addr.Is4() {
+               min = header.ICMPv4MinimumSize
+       }
+       reply := make([]byte, min+len(p))
+       w := tcpip.SliceWriter(reply)
+
+       res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+               NeedRemoteAddr: true,
+       })
+       if tcpipErr != nil {
+               return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+       }
+
+       addr = PingAddr{netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))}
+       copy(p, reply[min:res.Count])
+       return res.Count - min, addr, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+       n, _, err = pc.ReadFrom(p)
+       return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+       // pc.SetWriteDeadline is unimplemented
+
+       return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+       pc.deadline = t
+       return nil
+}
+
 var (
        errNoSuchHost                   = errors.New("no such host")
        errLameReferral                 = errors.New("lame referral")
@@ -755,33 +930,38 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
        return now.Add(timeout), nil
 }
 
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
 func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
        if ctx == nil {
                panic("nil context")
        }
-       var acceptV4, acceptV6, useUDP bool
-       if len(network) == 3 {
+       var acceptV4, acceptV6 bool
+       matches := protoSplitter.FindStringSubmatch(network)
+       if matches == nil {
+               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+       } else if len(matches[2]) == 0 {
                acceptV4 = true
                acceptV6 = true
-       } else if len(network) == 4 {
-               acceptV4 = network[3] == '4'
-               acceptV6 = network[3] == '6'
-       }
-       if !acceptV4 && !acceptV6 {
-               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
-       }
-       if network[:3] == "udp" {
-               useUDP = true
-       } else if network[:3] != "tcp" {
-               return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
-       }
-       host, sport, err := net.SplitHostPort(address)
-       if err != nil {
-               return nil, &net.OpError{Op: "dial", Err: err}
+       } else {
+               acceptV4 = matches[2][0] == '4'
+               acceptV6 = !acceptV4
        }
-       port, err := strconv.Atoi(sport)
-       if err != nil || port < 0 || port > 65535 {
-               return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+       var host string
+       var port int
+       if matches[1] == "ping" {
+               host = address
+       } else {
+               var sport string
+               var err error
+               host, sport, err = net.SplitHostPort(address)
+               if err != nil {
+                       return nil, &net.OpError{Op: "dial", Err: err}
+               }
+               port, err = strconv.Atoi(sport)
+               if err != nil || port < 0 || port > 65535 {
+                       return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+               }
        }
        allAddr, err := tnet.LookupContextHost(ctx, host)
        if err != nil {
@@ -829,10 +1009,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
                }
 
                var c net.Conn
-               if useUDP {
-                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
-               } else {
+               switch matches[1] {
+               case "tcp":
                        c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+               case "udp":
+                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+               case "ping":
+                       c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
                }
                if err == nil {
                        return c, nil