]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
global: use netip where possible now
authorJason A. Donenfeld <Jason@zx2c4.com>
Fri, 5 Nov 2021 00:52:54 +0000 (01:52 +0100)
committerJason A. Donenfeld <Jason@zx2c4.com>
Tue, 23 Nov 2021 21:03:15 +0000 (22:03 +0100)
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
22 files changed:
conn/bind_linux.go
conn/bind_std.go
conn/bind_windows.go
conn/bindtest/bindtest.go
conn/conn.go
device/allowedips.go
device/allowedips_rand_test.go
device/allowedips_test.go
device/device_test.go
device/endpoint_test.go
device/receive.go
device/uapi.go
go.mod
go.sum
ratelimiter/ratelimiter.go
ratelimiter/ratelimiter_test.go
tun/netstack/examples/http_client.go
tun/netstack/examples/http_server.go
tun/netstack/go.mod
tun/netstack/go.sum
tun/netstack/tun.go
tun/tuntest/tuntest.go

index 7b970e65a0a3f0a609f6a4b74681bb35797aaa0f..da0670ad7702b3f2a1c8542acfb150a2c290c4f6 100644 (file)
@@ -14,6 +14,7 @@ import (
        "unsafe"
 
        "golang.org/x/sys/unix"
+       "golang.zx2c4.com/go118/netip"
 )
 
 type ipv4Source struct {
@@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
 
 func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
        var end LinuxSocketEndpoint
-       addr, err := parseEndpoint(s)
+       e, err := netip.ParseAddrPort(s)
        if err != nil {
                return nil, err
        }
 
-       ipv4 := addr.IP.To4()
-       if ipv4 != nil {
+       if e.Addr().Is4() {
                dst := end.dst4()
                end.isV6 = false
-               dst.Port = addr.Port
-               copy(dst.Addr[:], ipv4)
+               dst.Port = int(e.Port())
+               dst.Addr = e.Addr().As4()
                end.ClearSrc()
                return &end, nil
        }
 
-       ipv6 := addr.IP.To16()
-       if ipv6 != nil {
-               zone, err := zoneToUint32(addr.Zone)
+       if e.Addr().Is6() {
+               zone, err := zoneToUint32(e.Addr().Zone())
                if err != nil {
                        return nil, err
                }
                dst := end.dst6()
                end.isV6 = true
-               dst.Port = addr.Port
+               dst.Port = int(e.Port())
                dst.ZoneId = zone
-               copy(dst.Addr[:], ipv6[:])
+               dst.Addr = e.Addr().As16()
                end.ClearSrc()
                return &end, nil
        }
@@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
        }
 }
 
-func (end *LinuxSocketEndpoint) SrcIP() net.IP {
+func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
        if !end.isV6 {
-               return net.IPv4(
-                       end.src4().Src[0],
-                       end.src4().Src[1],
-                       end.src4().Src[2],
-                       end.src4().Src[3],
-               )
+               return netip.AddrFrom4(end.src4().Src)
        } else {
-               return end.src6().src[:]
+               return netip.AddrFrom16(end.src6().src)
        }
 }
 
-func (end *LinuxSocketEndpoint) DstIP() net.IP {
+func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
        if !end.isV6 {
-               return net.IPv4(
-                       end.dst4().Addr[0],
-                       end.dst4().Addr[1],
-                       end.dst4().Addr[2],
-                       end.dst4().Addr[3],
-               )
+               return netip.AddrFrom4(end.dst4().Addr)
        } else {
-               return end.dst6().Addr[:]
+               return netip.AddrFrom16(end.dst6().Addr)
        }
 }
 
@@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
 }
 
 func (end *LinuxSocketEndpoint) DstToString() string {
-       var udpAddr net.UDPAddr
-       udpAddr.IP = end.DstIP()
+       var port int
        if !end.isV6 {
-               udpAddr.Port = end.dst4().Port
+               port = end.dst4().Port
        } else {
-               udpAddr.Port = end.dst6().Port
+               port = end.dst6().Port
        }
-       return udpAddr.String()
+       return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
 }
 
 func (end *LinuxSocketEndpoint) ClearDst() {
index cb85cfdaff5a9edd6f1d7617abf3cb708ac1de72..a3cbb158447119939c7fe1cbc244bc659e3de5a7 100644 (file)
@@ -10,6 +10,8 @@ import (
        "net"
        "sync"
        "syscall"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 // StdNetBind is meant to be a temporary solution on platforms for which
@@ -32,18 +34,23 @@ var _ Bind = (*StdNetBind)(nil)
 var _ Endpoint = (*StdNetEndpoint)(nil)
 
 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
-       addr, err := parseEndpoint(s)
-       return (*StdNetEndpoint)(addr), err
+       e, err := netip.ParseAddrPort(s)
+       return (*StdNetEndpoint)(&net.UDPAddr{
+               IP:   e.Addr().AsSlice(),
+               Port: int(e.Port()),
+               Zone: e.Addr().Zone(),
+       }), err
 }
 
 func (*StdNetEndpoint) ClearSrc() {}
 
-func (e *StdNetEndpoint) DstIP() net.IP {
-       return (*net.UDPAddr)(e).IP
+func (e *StdNetEndpoint) DstIP() netip.Addr {
+       a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
+       return a
 }
 
-func (e *StdNetEndpoint) SrcIP() net.IP {
-       return nil // not supported
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+       return netip.Addr{} // not supported
 }
 
 func (e *StdNetEndpoint) DstToBytes() []byte {
index 42e06adbc9113a611cd152d1e10a4371e0dbf661..26a3af85120842c9cff607345f48b334d0860427 100644 (file)
@@ -15,6 +15,7 @@ import (
        "unsafe"
 
        "golang.org/x/sys/windows"
+       "golang.zx2c4.com/go118/netip"
 
        "golang.zx2c4.com/wireguard/conn/winrio"
 )
@@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
 
 func (*WinRingEndpoint) ClearSrc() {}
 
-func (e *WinRingEndpoint) DstIP() net.IP {
+func (e *WinRingEndpoint) DstIP() netip.Addr {
        switch e.family {
        case windows.AF_INET:
-               return append([]byte{}, e.data[2:6]...)
+               return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
        case windows.AF_INET6:
-               return append([]byte{}, e.data[6:22]...)
+               return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
        }
-       return nil
+       return netip.Addr{}
 }
 
-func (e *WinRingEndpoint) SrcIP() net.IP {
-       return nil // not supported
+func (e *WinRingEndpoint) SrcIP() netip.Addr {
+       return netip.Addr{} // not supported
 }
 
 func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
 func (e *WinRingEndpoint) DstToString() string {
        switch e.family {
        case windows.AF_INET:
-               addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
-               return addr.String()
+               netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
        case windows.AF_INET6:
                var zone string
                if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
                        zone = strconv.FormatUint(uint64(scope), 10)
                }
-               addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
-               return addr.String()
+               return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
        }
        return ""
 }
index 7d43fb30b5485d8e513fe471564d4704c250bb2d..6a4589613f8f50fa34f748ed67b74bf6922c19aa 100644 (file)
@@ -10,8 +10,8 @@ import (
        "math/rand"
        "net"
        "os"
-       "strconv"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/conn"
 )
 
@@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
 
 func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
 
-func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
+func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
 
-func (c ChannelEndpoint) SrcIP() net.IP { return nil }
+func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
 
 func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
        c.closeSignal = make(chan bool)
@@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
 }
 
 func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
-       _, port, err := net.SplitHostPort(s)
+       addr, err := netip.ParseAddrPort(s)
        if err != nil {
                return nil, err
        }
-       i, err := strconv.ParseUint(port, 10, 16)
-       if err != nil {
-               return nil, err
-       }
-       return ChannelEndpoint(i), nil
+       return ChannelEndpoint(addr.Port()), nil
 }
index 9cce9adea39e508597fc349f9b71a3b42b0fdd27..35fb6b1b0020d650bbc4144a9565a98c3bbbe2ef 100644 (file)
@@ -9,10 +9,11 @@ package conn
 import (
        "errors"
        "fmt"
-       "net"
        "reflect"
        "runtime"
        "strings"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 // A ReceiveFunc receives a single inbound packet from the network.
@@ -68,8 +69,8 @@ type Endpoint interface {
        SrcToString() string // returns the local source address (ip:port)
        DstToString() string // returns the destination address (ip:port)
        DstToBytes() []byte  // used for mac2 cookie calculations
-       DstIP() net.IP
-       SrcIP() net.IP
+       DstIP() netip.Addr
+       SrcIP() netip.Addr
 }
 
 var (
@@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
        }
        return name
 }
-
-func parseEndpoint(s string) (*net.UDPAddr, error) {
-       // ensure that the host is an IP address
-
-       host, _, err := net.SplitHostPort(s)
-       if err != nil {
-               return nil, err
-       }
-       if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
-               // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
-               // trying to make sure with a small sanity test that this is a real IP address and
-               // not something that's likely to incur DNS lookups.
-               host = host[:i]
-       }
-       if ip := net.ParseIP(host); ip == nil {
-               return nil, errors.New("Failed to parse IP address: " + host)
-       }
-
-       // parse address and port
-
-       addr, err := net.ResolveUDPAddr("udp", s)
-       if err != nil {
-               return nil, err
-       }
-       ip4 := addr.IP.To4()
-       if ip4 != nil {
-               addr.IP = ip4
-       }
-       return addr, err
-}
index c08399bbf69efb0f841db0a121338fb14753064d..7a0b2756037ad2b40cc9257577dc8f0e59bcddfb 100644 (file)
@@ -12,6 +12,8 @@ import (
        "net"
        "sync"
        "unsafe"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 type parentIndirection struct {
@@ -26,7 +28,7 @@ type trieEntry struct {
        cidr        uint8
        bitAtByte   uint8
        bitAtShift  uint8
-       bits        net.IP
+       bits        []byte
        perPeerElem *list.Element
 }
 
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
        return bits.ReverseBytes64(i)
 }
 
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
        size := len(ip1)
        if size == net.IPv4len {
                a := (*uint32)(unsafe.Pointer(&ip1[0]))
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
        }
 }
 
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
        return (ip[node.bitAtByte] >> node.bitAtShift) & 1
 }
 
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
        node.parent.parentBit = nil
 }
 
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
        for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
                parent = node
                if parent.cidr == cidr {
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
        return
 }
 
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
        if *trie.parentBit == nil {
                node := &trieEntry{
                        peer:       peer,
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
        }
 }
 
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
        var found *Peer
        size := uint8(len(ip))
        for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +231,14 @@ type AllowedIPs struct {
        mutex sync.RWMutex
 }
 
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
 
        for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
                node := elem.Value.(*trieEntry)
-               if !cb(node.bits, node.cidr) {
+               a, _ := netip.AddrFromSlice(node.bits)
+               if !cb(netip.PrefixFrom(a, int(node.cidr))) {
                        return
                }
        }
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
        }
 }
 
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
        table.mutex.Lock()
        defer table.mutex.Unlock()
 
-       switch len(ip) {
-       case net.IPv6len:
-               parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
-       case net.IPv4len:
-               parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
-       default:
+       if prefix.Addr().Is6() {
+               ip := prefix.Addr().As16()
+               parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+       } else if prefix.Addr().Is4() {
+               ip := prefix.Addr().As4()
+               parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+       } else {
                panic(errors.New("inserting unknown address type"))
        }
 }
 
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
-       switch len(address) {
+       switch len(ip) {
        case net.IPv6len:
-               return table.IPv6.lookup(address)
+               return table.IPv6.lookup(ip)
        case net.IPv4len:
-               return table.IPv4.lookup(address)
+               return table.IPv4.lookup(ip)
        default:
                panic(errors.New("looking up unknown address type"))
        }
index 16de1704ae2439c180359ce58dbe53be2f8d1a0a..ff56fe6a9050aa43b220aba8f6a7d93ba4cbc38b 100644 (file)
@@ -10,6 +10,8 @@ import (
        "net"
        "sort"
        "testing"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 const (
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
                rand.Read(addr4[:])
                cidr := uint8(rand.Intn(32) + 1)
                index := rand.Intn(NumberOfPeers)
-               allowedIPs.Insert(addr4[:], cidr, peers[index])
+               allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
                slow4 = slow4.Insert(addr4[:], cidr, peers[index])
 
                var addr6 [16]byte
                rand.Read(addr6[:])
                cidr = uint8(rand.Intn(128) + 1)
                index = rand.Intn(NumberOfPeers)
-               allowedIPs.Insert(addr6[:], cidr, peers[index])
+               allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
                slow6 = slow6.Insert(addr6[:], cidr, peers[index])
        }
 
index 2059a8836d0f4da2c6e411d01e8c899e31acf2ba..a27499782a0012abd9c49a5fba03aa6eab5be964 100644 (file)
@@ -9,6 +9,8 @@ import (
        "math/rand"
        "net"
        "testing"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 type testPairCommonBits struct {
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
        var allowedIPs AllowedIPs
 
        insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
-               allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+               allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
        }
 
        assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
                addr = append(addr, expand(b)...)
                addr = append(addr, expand(c)...)
                addr = append(addr, expand(d)...)
-               allowedIPs.Insert(addr, cidr, peer)
+               allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
        }
 
        assertEQ := func(peer *Peer, a, b, c, d uint32) {
index 29daeb9c399e089e18ca804bbb97c22721a5ff19..84221bed15e298dc14f4d2a670d9f423dbee949a 100644 (file)
@@ -11,7 +11,6 @@ import (
        "fmt"
        "io"
        "math/rand"
-       "net"
        "runtime"
        "runtime/pprof"
        "sync"
@@ -19,6 +18,7 @@ import (
        "testing"
        "time"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/conn/bindtest"
        "golang.zx2c4.com/wireguard/tun/tuntest"
@@ -96,7 +96,7 @@ type testPair [2]testPeer
 type testPeer struct {
        tun *tuntest.ChannelTUN
        dev *Device
-       ip  net.IP
+       ip  netip.Addr
 }
 
 type SendDirection bool
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
        for i := range pair {
                p := &pair[i]
                p.tun = tuntest.NewChannelTUN()
-               p.ip = net.IPv4(1, 0, 0, byte(i+1))
+               p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
                level := LogLevelVerbose
                if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
                        level = LogLevelError
index 57c361ca6fab484989e32fdd83a551642dffd79d..f1ae47e34dd549b0b9c3a5dd85a3c2f982479e81 100644 (file)
@@ -7,47 +7,44 @@ package device
 
 import (
        "math/rand"
-       "net"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 type DummyEndpoint struct {
-       src [16]byte
-       dst [16]byte
+       src, dst netip.Addr
 }
 
 func CreateDummyEndpoint() (*DummyEndpoint, error) {
-       var end DummyEndpoint
-       if _, err := rand.Read(end.src[:]); err != nil {
+       var src, dst [16]byte
+       if _, err := rand.Read(src[:]); err != nil {
                return nil, err
        }
-       _, err := rand.Read(end.dst[:])
-       return &end, err
+       _, err := rand.Read(dst[:])
+       return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
 }
 
 func (e *DummyEndpoint) ClearSrc() {}
 
 func (e *DummyEndpoint) SrcToString() string {
-       var addr net.UDPAddr
-       addr.IP = e.SrcIP()
-       addr.Port = 1000
-       return addr.String()
+       return netip.AddrPortFrom(e.SrcIP(), 1000).String()
 }
 
 func (e *DummyEndpoint) DstToString() string {
-       var addr net.UDPAddr
-       addr.IP = e.DstIP()
-       addr.Port = 1000
-       return addr.String()
+       return netip.AddrPortFrom(e.DstIP(), 1000).String()
 }
 
-func (e *DummyEndpoint) SrcToBytes() []byte {
-       return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+       out := e.DstIP().AsSlice()
+       out = append(out, byte(1000&0xff))
+       out = append(out, byte((1000>>8)&0xff))
+       return out
 }
 
-func (e *DummyEndpoint) DstIP() net.IP {
-       return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+       return e.dst
 }
 
-func (e *DummyEndpoint) SrcIP() net.IP {
-       return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+       return e.src
 }
index 58574810f812eeea7b03646a22ad2201152ecc75..cc3449801ca1dc6478db2f7fb35829ff075fc1f5 100644 (file)
@@ -17,7 +17,6 @@ import (
        "golang.org/x/crypto/chacha20poly1305"
        "golang.org/x/net/ipv4"
        "golang.org/x/net/ipv6"
-
        "golang.zx2c4.com/wireguard/conn"
 )
 
index 2306183a648a3c7c5f9abbcb7328d51309228ed8..98e83110483162d05cb78ca2c5b0e0c819af5273 100644 (file)
@@ -18,6 +18,7 @@ import (
        "sync/atomic"
        "time"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/ipc"
 )
 
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
                        sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
                        sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
 
-                       device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
-                               sendf("allowed_ip=%s/%d", ip.String(), cidr)
+                       device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+                               sendf("allowed_ip=%s", prefix.String())
                                return true
                        })
                }
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
 
        case "allowed_ip":
                device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
-               _, network, err := net.ParseCIDR(value)
+               prefix, err := netip.ParsePrefix(value)
                if err != nil {
                        return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
                }
                if peer.dummy {
                        return nil
                }
-               ones, _ := network.Mask.Size()
-               device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+               device.allowedips.Insert(prefix, peer.Peer)
 
        case "protocol_version":
                if value != "1" {
diff --git a/go.mod b/go.mod
index 856bb6c2190ec50c265d44508cf0c7b64c4f03d5..b51096070479a3e4c8a577d016520bc9b51ee35f 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,9 @@ module golang.zx2c4.com/wireguard
 go 1.17
 
 require (
-       golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
-       golang.org/x/net v0.0.0-20211101193420-4a448f8816b3
-       golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
+       golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa
+       golang.org/x/net v0.0.0-20211111083644-e5c967477495
+       golang.org/x/sys v0.0.0-20211110154304-99a53858aa08
+       golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d
        golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
 )
diff --git a/go.sum b/go.sum
index 37fe0679648fae4f2454954d126f365d360b4abc..78f7367056d51c7da44f60f8fb67bc0d82edfa2c 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -1,16 +1,19 @@
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa h1:idItI2DDfCokpg0N51B2VtiLdJ4vAuXC9fnCb2gACo4=
+golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3 h1:VrJZAjbekhoRn7n5FBujY31gboH+iB3pdLxn3gE9FjU=
-golang.org/x/net v0.0.0-20211101193420-4a448f8816b3/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495 h1:cjxxlQm6d4kYbhpZ2ghvmI8xnq0AG+jXmzrhzfkyu5A=
+golang.org/x/net v0.0.0-20211111083644-e5c967477495/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20211103235746-7861aae1554b h1:1VkfZQv42XQlA/jchYumAnv1UPo6RgF9rJFkTgZIxO4=
 golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08 h1:WecRHqgE09JBkh/584XIE6PMz5KKE/vER4izNUi30AQ=
+golang.org/x/sys v0.0.0-20211110154304-99a53858aa08/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d h1:9+v0G0naRhLPOJEeJOL6NuXTtAHHwmkyZlgQJ0XcQ8I=
+golang.zx2c4.com/go118/netip v0.0.0-20211111135330-a4a02eeacf9d/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
 golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
index 2f7aa2ae0e2344b727aeaf8bca17ff0f08bd2d82..8e78d5e46ad5d18f3ac8d234a54449d48e9ab2dd 100644 (file)
@@ -6,9 +6,10 @@
 package ratelimiter
 
 import (
-       "net"
        "sync"
        "time"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 const (
@@ -30,8 +31,7 @@ type Ratelimiter struct {
        timeNow func() time.Time
 
        stopReset chan struct{} // send to reset, close to stop
-       tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
-       tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
+       table     map[netip.Addr]*RatelimiterEntry
 }
 
 func (rate *Ratelimiter) Close() {
@@ -57,8 +57,7 @@ func (rate *Ratelimiter) Init() {
        }
 
        rate.stopReset = make(chan struct{})
-       rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
-       rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
+       rate.table = make(map[netip.Addr]*RatelimiterEntry)
 
        stopReset := rate.stopReset // store in case Init is called again.
 
@@ -87,71 +86,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
        rate.mu.Lock()
        defer rate.mu.Unlock()
 
-       for key, entry := range rate.tableIPv4 {
+       for key, entry := range rate.table {
                entry.mu.Lock()
                if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
-                       delete(rate.tableIPv4, key)
+                       delete(rate.table, key)
                }
                entry.mu.Unlock()
        }
 
-       for key, entry := range rate.tableIPv6 {
-               entry.mu.Lock()
-               if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
-                       delete(rate.tableIPv6, key)
-               }
-               entry.mu.Unlock()
-       }
-
-       return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
+       return len(rate.table) == 0
 }
 
-func (rate *Ratelimiter) Allow(ip net.IP) bool {
+func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
        var entry *RatelimiterEntry
-       var keyIPv4 [net.IPv4len]byte
-       var keyIPv6 [net.IPv6len]byte
-
        // lookup entry
-
-       IPv4 := ip.To4()
-       IPv6 := ip.To16()
-
        rate.mu.RLock()
-
-       if IPv4 != nil {
-               copy(keyIPv4[:], IPv4)
-               entry = rate.tableIPv4[keyIPv4]
-       } else {
-               copy(keyIPv6[:], IPv6)
-               entry = rate.tableIPv6[keyIPv6]
-       }
-
+       entry = rate.table[ip]
        rate.mu.RUnlock()
 
        // make new entry if not found
-
        if entry == nil {
                entry = new(RatelimiterEntry)
                entry.tokens = maxTokens - packetCost
                entry.lastTime = rate.timeNow()
                rate.mu.Lock()
-               if IPv4 != nil {
-                       rate.tableIPv4[keyIPv4] = entry
-                       if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
-                               rate.stopReset <- struct{}{}
-                       }
-               } else {
-                       rate.tableIPv6[keyIPv6] = entry
-                       if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
-                               rate.stopReset <- struct{}{}
-                       }
+               rate.table[ip] = entry
+               if len(rate.table) == 1 {
+                       rate.stopReset <- struct{}{}
                }
                rate.mu.Unlock()
                return true
        }
 
        // add tokens to entry
-
        entry.mu.Lock()
        now := rate.timeNow()
        entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +128,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
        }
 
        // subtract cost of packet
-
        if entry.tokens > packetCost {
                entry.tokens -= packetCost
                entry.mu.Unlock()
index f231fe5b7c0e776352ff94dcf18523e7753d23dd..3e06ff77848843ddf21defd0270d766d0955c1c8 100644 (file)
@@ -6,9 +6,10 @@
 package ratelimiter
 
 import (
-       "net"
        "testing"
        "time"
+
+       "golang.zx2c4.com/go118/netip"
 )
 
 type result struct {
@@ -71,21 +72,21 @@ func TestRatelimiter(t *testing.T) {
                text:    "packet following 2 packet burst",
        })
 
-       ips := []net.IP{
-               net.ParseIP("127.0.0.1"),
-               net.ParseIP("192.168.1.1"),
-               net.ParseIP("172.167.2.3"),
-               net.ParseIP("97.231.252.215"),
-               net.ParseIP("248.97.91.167"),
-               net.ParseIP("188.208.233.47"),
-               net.ParseIP("104.2.183.179"),
-               net.ParseIP("72.129.46.120"),
-               net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
-               net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
-               net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
-               net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
-               net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
-               net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
+       ips := []netip.Addr{
+               netip.MustParseAddr("127.0.0.1"),
+               netip.MustParseAddr("192.168.1.1"),
+               netip.MustParseAddr("172.167.2.3"),
+               netip.MustParseAddr("97.231.252.215"),
+               netip.MustParseAddr("248.97.91.167"),
+               netip.MustParseAddr("188.208.233.47"),
+               netip.MustParseAddr("104.2.183.179"),
+               netip.MustParseAddr("72.129.46.120"),
+               netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
+               netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
+               netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
+               netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
+               netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
+               netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
        }
 
        now := time.Now()
index 6ac28596b67ed275f11c2228ff2e42c7f463ee6b..b39b453164a7ff6692884beef6919370e0ac2be8 100644 (file)
@@ -1,4 +1,5 @@
 //go:build ignore
+// +build ignore
 
 /* SPDX-License-Identifier: MIT
  *
@@ -10,9 +11,9 @@ package main
 import (
        "io"
        "log"
-       "net"
        "net/http"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/device"
        "golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +21,8 @@ import (
 
 func main() {
        tun, tnet, err := netstack.CreateNetTUN(
-               []net.IP{net.ParseIP("192.168.4.29")},
-               []net.IP{net.ParseIP("8.8.8.8")},
+               []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+               []netip.Addr{netip.MustParseAddr("8.8.8.8")},
                1420)
        if err != nil {
                log.Panic(err)
index 577c6ea4cd7dda332f64859ce787791ae8f70583..40f780447e85543da3745b487763c5f07937311c 100644 (file)
@@ -1,4 +1,5 @@
 //go:build ignore
+// +build ignore
 
 /* SPDX-License-Identifier: MIT
  *
@@ -13,6 +14,7 @@ import (
        "net"
        "net/http"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/conn"
        "golang.zx2c4.com/wireguard/device"
        "golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +22,8 @@ import (
 
 func main() {
        tun, tnet, err := netstack.CreateNetTUN(
-               []net.IP{net.ParseIP("192.168.4.29")},
-               []net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
+               []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+               []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
                1420,
        )
        if err != nil {
index 8db9f4bcdb51963bb730bdd5b37b2b9d23d4085e..46b57bab10f8114589fbc9897942c61c776c993b 100644 (file)
@@ -6,6 +6,7 @@ require (
        golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
        golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
        golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
+       golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
        golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
        gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
 )
index 78c025c54f37d5b2cd190b104824f3e7482948b2..01bfbc70cb0e1e94a6695497ca91664423af28d9 100644 (file)
@@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
+golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
+golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
 golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
 google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
index 24d0835dd5893d70f718b8f28b5614c790523d9c..f1c03f48cceeebb1382ea7a7de93ecff36dafa5a 100644 (file)
@@ -18,6 +18,7 @@ import (
        "strings"
        "time"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/tun"
 
        "golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
        events         chan tun.Event
        incomingPacket chan buffer.VectorisedView
        mtu            int
-       dnsServers     []net.IP
+       dnsServers     []netip.Addr
        hasV4, hasV6   bool
 }
 type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
 func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
 }
 
-func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
        opts := stack.Options{
                NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
                TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
                return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
        }
        for _, ip := range localAddresses {
-               if ip4 := ip.To4(); ip4 != nil {
-                       protoAddr := tcpip.ProtocolAddress{
-                               Protocol:          ipv4.ProtocolNumber,
-                               AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
-                       }
-                       tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
-                       if tcpipErr != nil {
-                               return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
-                       }
+               var protoNumber tcpip.NetworkProtocolNumber
+               if ip.Is4() {
+                       protoNumber = ipv4.ProtocolNumber
+               } else if ip.Is6() {
+                       protoNumber = ipv6.ProtocolNumber
+               }
+               protoAddr := tcpip.ProtocolAddress{
+                       Protocol:          protoNumber,
+                       AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
+               }
+               tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+               if tcpipErr != nil {
+                       return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+               }
+               if ip.Is4() {
                        dev.hasV4 = true
-               } else {
-                       protoAddr := tcpip.ProtocolAddress{
-                               Protocol:          ipv6.ProtocolNumber,
-                               AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
-                       }
-                       tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
-                       if tcpipErr != nil {
-                               return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
-                       }
+               } else if ip.Is6() {
                        dev.hasV6 = true
                }
        }
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
        return tun.mtu, nil
 }
 
-func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
-       if ip4 := ip.To4(); ip4 != nil {
-               return tcpip.FullAddress{
-                       NIC:  1,
-                       Addr: tcpip.Address(ip4),
-                       Port: uint16(port),
-               }, ipv4.ProtocolNumber
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+       var protoNumber tcpip.NetworkProtocolNumber
+       if endpoint.Addr().Is4() {
+               protoNumber = ipv4.ProtocolNumber
        } else {
-               return tcpip.FullAddress{
-                       NIC:  1,
-                       Addr: tcpip.Address(ip),
-                       Port: uint16(port),
-               }, ipv6.ProtocolNumber
+               protoNumber = ipv6.ProtocolNumber
        }
+       return tcpip.FullAddress{
+               NIC:  1,
+               Addr: tcpip.Address(endpoint.Addr().AsSlice()),
+               Port: endpoint.Port(),
+       }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+       fa, pn := convertToFullAddr(addr)
+       return gonet.DialContextTCP(ctx, net.stack, fa, pn)
 }
 
 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
        if addr == nil {
-               panic("todo: deal with auto addr semantics for nil addr")
+               return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
        }
-       fa, pn := convertToFullAddr(addr.IP, addr.Port)
-       return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+       return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+       fa, pn := convertToFullAddr(addr)
+       return gonet.DialTCP(net.stack, fa, pn)
 }
 
 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
        if addr == nil {
-               panic("todo: deal with auto addr semantics for nil addr")
+               return net.DialTCPAddrPort(netip.AddrPort{})
        }
-       fa, pn := convertToFullAddr(addr.IP, addr.Port)
-       return gonet.DialTCP(net.stack, fa, pn)
+       return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+       fa, pn := convertToFullAddr(addr)
+       return gonet.ListenTCP(net.stack, fa, pn)
 }
 
 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
        if addr == nil {
-               panic("todo: deal with auto addr semantics for nil addr")
+               return net.ListenTCPAddrPort(netip.AddrPort{})
        }
-       fa, pn := convertToFullAddr(addr.IP, addr.Port)
-       return gonet.ListenTCP(net.stack, fa, pn)
+       return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
 }
 
-func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
        var lfa, rfa *tcpip.FullAddress
        var pn tcpip.NetworkProtocolNumber
-       if laddr != nil {
+       if laddr.IsValid() || laddr.Port() > 0 {
                var addr tcpip.FullAddress
-               addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
+               addr, pn = convertToFullAddr(laddr)
                lfa = &addr
        }
-       if raddr != nil {
+       if raddr.IsValid() || raddr.Port() > 0 {
                var addr tcpip.FullAddress
-               addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
+               addr, pn = convertToFullAddr(raddr)
                rfa = &addr
        }
        return gonet.DialUDP(net.stack, lfa, rfa, pn)
 }
 
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+       var la, ra netip.AddrPort
+       if laddr != nil {
+               la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
+       }
+       if raddr != nil {
+               ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
+       }
+       return net.DialUDPAddrPort(la, ra)
+}
+
 var (
        errNoSuchHost                   = errors.New("no such host")
        errLameReferral                 = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
        return p, h, nil
 }
 
-func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
        q.Class = dnsmessage.ClassINET
        id, udpReq, tcpReq, err := newRequest(q)
        if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
                var c net.Conn
                var err error
                if useUDP {
-                       c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
+                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
                } else {
-                       c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
+                       c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
                }
 
                if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
                        zlen = zidx
                }
        }
-       if ip := net.ParseIP(host[:zlen]); ip != nil {
-               return []string{host[:zlen]}, nil
+       if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+               return []string{ip.String()}, nil
        }
 
        if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
                server string
                error
        }
-       var addrsV4, addrsV6 []net.IP
+       var addrsV4, addrsV6 []netip.Addr
        lanes := 0
        if tnet.hasV4 {
                lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
                                        }
                                        break loop
                                }
-                               addrsV4 = append(addrsV4, net.IP(a.A[:]))
+                               addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
 
                        case dnsmessage.TypeAAAA:
                                aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
                                        }
                                        break loop
                                }
-                               addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
+                               addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
 
                        default:
                                if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
                }
        }
        // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
-       var addrs []net.IP
+       var addrs []netip.Addr
        if tnet.hasV6 {
                addrs = append(addrsV6, addrsV4...)
        } else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
        if err != nil {
                return nil, &net.OpError{Op: "dial", Err: err}
        }
-       var addrs []net.IP
+       var addrs []netip.AddrPort
        for _, addr := range allAddr {
-               if strings.IndexByte(addr, ':') != -1 && acceptV6 {
-                       addrs = append(addrs, net.ParseIP(addr))
-               } else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
-                       addrs = append(addrs, net.ParseIP(addr))
+               ip, err := netip.ParseAddr(addr)
+               if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+                       addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
                }
        }
        if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
 
                var c net.Conn
                if useUDP {
-                       c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
+                       c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
                } else {
-                       c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
+                       c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
                }
                if err == nil {
                        return c, nil
index d89db718724ec6149521fa71cc09d6c7897ff8e4..bdf0467b4f159cff07a140d66ea159f79ab441b6 100644 (file)
@@ -8,13 +8,13 @@ package tuntest
 import (
        "encoding/binary"
        "io"
-       "net"
        "os"
 
+       "golang.zx2c4.com/go118/netip"
        "golang.zx2c4.com/wireguard/tun"
 )
 
-func Ping(dst, src net.IP) []byte {
+func Ping(dst, src netip.Addr) []byte {
        localPort := uint16(1337)
        seq := uint16(0)
 
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
        return ^uint16(v)
 }
 
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
        const (
                icmpv4ProtocolNumber = 1
                icmpv4Echo           = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
        binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
        ip[8] = ttl
        ip[9] = icmpv4ProtocolNumber
-       copy(ip[12:], src.To4())
-       copy(ip[16:], dst.To4())
+       copy(ip[12:], src.AsSlice())
+       copy(ip[16:], dst.AsSlice())
        chksum = ^checksum(ip[:], 0)
        binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)