]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wgcfg: clean up IP type/method signatures
authorBrad Fitzpatrick <bradfitz@tailscale.com>
Tue, 17 Mar 2020 03:28:29 +0000 (20:28 -0700)
committerDavid Crawshaw <david@zentus.com>
Mon, 30 Mar 2020 22:33:03 +0000 (09:33 +1100)
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
wgcfg/ip.go
wgcfg/ip_test.go
wgcfg/parser.go

index 7541d185df3b45a331f138d59304c13e8199cd13..47fa91c27656b8b3c6d4b3b02ea9fb6b46dd1c1d 100644 (file)
@@ -16,9 +16,14 @@ type IP struct {
 
 func (ip IP) String() string { return net.IP(ip.Addr[:]).String() }
 
-func (ip *IP) IP() net.IP { return net.IP(ip.Addr[:]) }
-func (ip *IP) Is6() bool  { return !ip.Is4() }
-func (ip *IP) Is4() bool {
+// IP converts ip into a standard library net.IP.
+func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) }
+
+// Is6 reports whether ip is an IPv6 address.
+func (ip IP) Is6() bool { return !ip.Is4() }
+
+// Is4 reports whether ip is an IPv4 address.
+func (ip IP) Is4() bool {
        return ip.Addr[0] == 0 && ip.Addr[1] == 0 &&
                ip.Addr[2] == 0 && ip.Addr[3] == 0 &&
                ip.Addr[4] == 0 && ip.Addr[5] == 0 &&
@@ -26,19 +31,20 @@ func (ip *IP) Is4() bool {
                ip.Addr[8] == 0 && ip.Addr[9] == 0 &&
                ip.Addr[10] == 0xff && ip.Addr[11] == 0xff
 }
-func (ip *IP) To4() []byte {
+
+// To4 returns either a 4 byte slice for an IPv4 address, or nil if
+// it's not IPv4.
+func (ip IP) To4() []byte {
        if ip.Is4() {
                return ip.Addr[12:16]
        } else {
                return nil
        }
 }
-func (ip *IP) Equal(x *IP) bool {
-       if ip == nil || x == nil {
-               return false
-       }
-       // TODO: this isn't hard, write a more efficient implementation.
-       return ip.IP().Equal(x.IP())
+
+// Equal reports whether ip == x.
+func (ip IP) Equal(x IP) bool {
+       return ip == x
 }
 
 func (ip IP) MarshalText() ([]byte, error) {
@@ -46,11 +52,11 @@ func (ip IP) MarshalText() ([]byte, error) {
 }
 
 func (ip *IP) UnmarshalText(text []byte) error {
-       parsedIP := ParseIP(string(text))
-       if parsedIP == nil {
-               return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", string(text))
+       parsedIP, ok := ParseIP(string(text))
+       if !ok {
+               return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text)
        }
-       *ip = *parsedIP
+       *ip = parsedIP
        return nil
 }
 
@@ -66,15 +72,14 @@ func IPv4(b0, b1, b2, b3 byte) (ip IP) {
 // ParseIP parses the string representation of an address into an IP.
 //
 // It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0".
-// If the string is not a valid IP address, ParseIP returns nil.
-func ParseIP(s string) *IP {
+// The ok result reports whether s was a valid IP and ip is valid.
+func ParseIP(s string) (ip IP, ok bool) {
        netIP := net.ParseIP(s)
        if netIP == nil {
-               return nil
+               return IP{}, false
        }
-       ip := new(IP)
        copy(ip.Addr[:], netIP.To16())
-       return ip
+       return ip, true
 }
 
 // CIDR is a compact IP address and subnet mask.
@@ -85,12 +90,12 @@ type CIDR struct {
 
 // ParseCIDR parses CIDR notation into a CIDR type.
 // Typical CIDR strings look like "192.168.1.0/24".
-func ParseCIDR(s string) (cidr *CIDR, err error) {
+func ParseCIDR(s string) (CIDR, error) {
        netIP, netAddr, err := net.ParseCIDR(s)
        if err != nil {
-               return nil, err
+               return CIDR{}, err
        }
-       cidr = new(CIDR)
+       var cidr CIDR
        copy(cidr.IP.Addr[:], netIP.To16())
        ones, _ := netAddr.Mask.Size()
        cidr.Mask = uint8(ones)
@@ -100,7 +105,7 @@ func ParseCIDR(s string) (cidr *CIDR, err error) {
 
 func (r CIDR) String() string { return r.IPNet().String() }
 
-func (r *CIDR) IPNet() *net.IPNet {
+func (r CIDR) IPNet() *net.IPNet {
        bits := 128
        if r.IP.Is4() {
                bits = 32
@@ -108,10 +113,7 @@ func (r *CIDR) IPNet() *net.IPNet {
        return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)}
 }
 
-func (r *CIDR) Contains(ip *IP) bool {
-       if r == nil || ip == nil {
-               return false
-       }
+func (r CIDR) Contains(ip IP) bool {
        c := int8(r.Mask)
        i := 0
        if r.IP.Is4() {
@@ -145,6 +147,6 @@ func (r *CIDR) UnmarshalText(text []byte) error {
        if err != nil {
                return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err)
        }
-       *r = *cidr
+       *r = cidr
        return nil
 }
index d3682bbdc25ed31b5f79092499501d7adbb8765f..6cd41d3193ecf2bf08b151e519d4f945f8919c88 100644 (file)
@@ -11,18 +11,24 @@ import (
        "golang.zx2c4.com/wireguard/wgcfg"
 )
 
+func parseIP(t testing.TB, ipStr string) wgcfg.IP {
+       t.Helper()
+       ip, ok := wgcfg.ParseIP(ipStr)
+       if !ok {
+               t.Fatalf("failed to parse IP: %q", ipStr)
+       }
+       return ip
+}
+
 func TestCIDRContains(t *testing.T) {
        t.Run("home router test", func(t *testing.T) {
                r, err := wgcfg.ParseCIDR("192.168.0.0/24")
                if err != nil {
                        t.Fatal(err)
                }
-               ip := wgcfg.ParseIP("192.168.0.1")
-               if ip == nil {
-                       t.Fatalf("address failed to parse")
-               }
+               ip := parseIP(t, "192.168.0.1")
                if !r.Contains(ip) {
-                       t.Fatalf("'%s' should contain '%s'", r, ip)
+                       t.Fatalf("%q should contain %q", r, ip)
                }
        })
 
@@ -31,12 +37,9 @@ func TestCIDRContains(t *testing.T) {
                if err != nil {
                        t.Fatal(err)
                }
-               ip := wgcfg.ParseIP("192.168.0.4")
-               if ip == nil {
-                       t.Fatalf("address failed to parse")
-               }
+               ip := parseIP(t, "192.168.0.4")
                if r.Contains(ip) {
-                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+                       t.Fatalf("%q should not contain %q", r, ip)
                }
        })
 
@@ -45,12 +48,9 @@ func TestCIDRContains(t *testing.T) {
                if err != nil {
                        t.Fatal(err)
                }
-               ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334")
-               if ip == nil {
-                       t.Fatalf("address failed to parse")
-               }
+               ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334")
                if r.Contains(ip) {
-                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+                       t.Fatalf("%q should not contain %q", r, ip)
                }
        })
 
@@ -59,12 +59,9 @@ func TestCIDRContains(t *testing.T) {
                if err != nil {
                        t.Fatal(err)
                }
-               ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
-               if ip == nil {
-                       t.Fatalf("ParseIP returned nil pointer")
-               }
+               ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001")
                if !r.Contains(ip) {
-                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+                       t.Fatalf("%q should not contain %q", r, ip)
                }
        })
 
@@ -73,12 +70,9 @@ func TestCIDRContains(t *testing.T) {
                if err != nil {
                        t.Fatal(err)
                }
-               ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4")
-               if ip == nil {
-                       t.Fatalf("ParseIP returned nil pointer")
-               }
+               ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4")
                if r.Contains(ip) {
-                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+                       t.Fatalf("%q should not contain %q", r, ip)
                }
        })
 }
@@ -89,12 +83,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) {
                if err != nil {
                        b.Fatal(err)
                }
-               ip := wgcfg.ParseIP("1.2.3.4")
-               if ip == nil {
-                       b.Fatalf("ParseIP returned nil pointer")
-               }
-
+               ip := parseIP(b, "1.2.3.4")
                b.ResetTimer()
+
                for i := 0; i < b.N; i++ {
                        r.Contains(ip)
                }
@@ -105,12 +96,9 @@ func BenchmarkCIDRContainsIPv4(b *testing.B) {
                if err != nil {
                        b.Fatal(err)
                }
-               ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
-               if ip == nil {
-                       b.Fatalf("ParseIP returned nil pointer")
-               }
-
+               ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001")
                b.ResetTimer()
+
                for i := 0; i < b.N; i++ {
                        r.Contains(ip)
                }
index 45a60577a69ff2caf66c865ff3013cd302fc6792..e71d32b1f722dbfc7ed8652cdb0bf8e8d0aaaaa0 100644 (file)
@@ -219,7 +219,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
                                        if err != nil {
                                                return nil, err
                                        }
-                                       conf.Addresses = append(conf.Addresses, *a)
+                                       conf.Addresses = append(conf.Addresses, a)
                                }
                        case "dns":
                                addresses, err := splitList(val)
@@ -227,11 +227,11 @@ func FromWgQuick(s string, name string) (*Config, error) {
                                        return nil, err
                                }
                                for _, address := range addresses {
-                                       a := ParseIP(address)
-                                       if a == nil {
+                                       a, ok := ParseIP(address)
+                                       if !ok {
                                                return nil, &ParseError{"Invalid IP address", address}
                                        }
-                                       conf.DNS = append(conf.DNS, *a)
+                                       conf.DNS = append(conf.DNS, a)
                                }
                        default:
                                return nil, &ParseError{"Invalid key for [Interface] section", key}
@@ -260,7 +260,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
                                        if err != nil {
                                                return nil, err
                                        }
-                                       peer.AllowedIPs = append(peer.AllowedIPs, *a)
+                                       peer.AllowedIPs = append(peer.AllowedIPs, a)
                                }
                        case "persistentkeepalive":
                                p, err := parsePersistentKeepalive(val)
@@ -373,7 +373,7 @@ func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) {
                                if err != nil {
                                        return nil, err
                                }
-                               peer.AllowedIPs = append(peer.AllowedIPs, *a)
+                               peer.AllowedIPs = append(peer.AllowedIPs, a)
                        case "persistent_keepalive_interval":
                                p, err := parsePersistentKeepalive(val)
                                if err != nil {