]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
wgcfg: add fast CIDR.Contains implementation
authorTyler Kropp <kropptyler@gmail.com>
Tue, 3 Mar 2020 00:41:28 +0000 (19:41 -0500)
committerDavid Crawshaw <david@zentus.com>
Mon, 30 Mar 2020 22:32:57 +0000 (09:32 +1100)
Signed-off-by: Tyler Kropp <kropptyler@gmail.com>
wgcfg/ip.go
wgcfg/ip_test.go [new file with mode: 0644]

index ecf5faff71c776f76d9e84d01b1335cd4eeb36f5..7541d185df3b45a331f138d59304c13e8199cd13 100644 (file)
@@ -2,6 +2,7 @@ package wgcfg
 
 import (
        "fmt"
+       "math"
        "net"
 )
 
@@ -106,12 +107,33 @@ func (r *CIDR) IPNet() *net.IPNet {
        }
        return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)}
 }
+
 func (r *CIDR) Contains(ip *IP) bool {
        if r == nil || ip == nil {
                return false
        }
-       // TODO: this isn't hard, write a more efficient implementation.
-       return r.IPNet().Contains(ip.IP())
+       c := int8(r.Mask)
+       i := 0
+       if r.IP.Is4() {
+               i = 12
+               if ip.Is6() {
+                       return false
+               }
+       }
+       for ; i < 16 && c > 0; i++ {
+               var x uint8
+               if c < 8 {
+                       x = 8 - uint8(c)
+               }
+               m := uint8(math.MaxUint8) >> x << x
+               a := r.IP.Addr[i] & m
+               b := ip.Addr[i] & m
+               if a != b {
+                       return false
+               }
+               c -= 8
+       }
+       return true
 }
 
 func (r CIDR) MarshalText() ([]byte, error) {
diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go
new file mode 100644 (file)
index 0000000..d3682bb
--- /dev/null
@@ -0,0 +1,118 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wgcfg_test
+
+import (
+       "testing"
+
+       "golang.zx2c4.com/wireguard/wgcfg"
+)
+
+func TestCIDRContains(t *testing.T) {
+       t.Run("home router test", func(t *testing.T) {
+               r, err := wgcfg.ParseCIDR("192.168.0.0/24")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("192.168.0.1")
+               if ip == nil {
+                       t.Fatalf("address failed to parse")
+               }
+               if !r.Contains(ip) {
+                       t.Fatalf("'%s' should contain '%s'", r, ip)
+               }
+       })
+
+       t.Run("IPv4 outside network", func(t *testing.T) {
+               r, err := wgcfg.ParseCIDR("192.168.0.0/30")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("192.168.0.4")
+               if ip == nil {
+                       t.Fatalf("address failed to parse")
+               }
+               if r.Contains(ip) {
+                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+               }
+       })
+
+       t.Run("IPv4 does not contain IPv6", func(t *testing.T) {
+               r, err := wgcfg.ParseCIDR("192.168.0.0/24")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334")
+               if ip == nil {
+                       t.Fatalf("address failed to parse")
+               }
+               if r.Contains(ip) {
+                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+               }
+       })
+
+       t.Run("IPv6 inside network", func(t *testing.T) {
+               r, err := wgcfg.ParseCIDR("2001:db8:1234::/48")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
+               if ip == nil {
+                       t.Fatalf("ParseIP returned nil pointer")
+               }
+               if !r.Contains(ip) {
+                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+               }
+       })
+
+       t.Run("IPv6 outside network", func(t *testing.T) {
+               r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4")
+               if ip == nil {
+                       t.Fatalf("ParseIP returned nil pointer")
+               }
+               if r.Contains(ip) {
+                       t.Fatalf("'%s' should not contain '%s'", r, ip)
+               }
+       })
+}
+
+func BenchmarkCIDRContainsIPv4(b *testing.B) {
+       b.Run("IPv4", func(b *testing.B) {
+               r, err := wgcfg.ParseCIDR("192.168.1.0/24")
+               if err != nil {
+                       b.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("1.2.3.4")
+               if ip == nil {
+                       b.Fatalf("ParseIP returned nil pointer")
+               }
+
+               b.ResetTimer()
+               for i := 0; i < b.N; i++ {
+                       r.Contains(ip)
+               }
+       })
+
+       b.Run("IPv6", func(b *testing.B) {
+               r, err := wgcfg.ParseCIDR("2001:db8:1234::/48")
+               if err != nil {
+                       b.Fatal(err)
+               }
+               ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001")
+               if ip == nil {
+                       b.Fatalf("ParseIP returned nil pointer")
+               }
+
+               b.ResetTimer()
+               for i := 0; i < b.N; i++ {
+                       r.Contains(ip)
+               }
+       })
+}