]> git.ipfire.org Git - thirdparty/wireguard-go.git/commitdiff
device: simplify allowedips lookup signature
authorJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 14:12:29 +0000 (16:12 +0200)
committerJason A. Donenfeld <Jason@zx2c4.com>
Thu, 3 Jun 2021 14:29:43 +0000 (16:29 +0200)
The inliner should handle this for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
device/allowedips.go
device/allowedips_rand_test.go
device/allowedips_test.go
device/receive.go
device/send.go

index 7af9fc7b8035c4326b4890b32f3abca73cdaa078..95615ab3592149d89eb4fcb92792d2fc252b59e3 100644 (file)
@@ -285,14 +285,15 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
        }
 }
 
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(address []byte) *Peer {
        table.mutex.RLock()
        defer table.mutex.RUnlock()
-       return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
-       table.mutex.RLock()
-       defer table.mutex.RUnlock()
-       return table.IPv6.lookup(address)
+       switch len(address) {
+       case net.IPv6len:
+               return table.IPv6.lookup(address)
+       case net.IPv4len:
+               return table.IPv4.lookup(address)
+       default:
+               panic(errors.New("looking up unknown address type"))
+       }
 }
index c5f80fe937859ea72a60c57287b0de181a377ce3..8d1e6333849e039ab9c6faf111de3a0120af17cf 100644 (file)
@@ -108,7 +108,7 @@ func TestTrieRandom(t *testing.T) {
                        var addr4 [4]byte
                        rand.Read(addr4[:])
                        peer1 := slow4.Lookup(addr4[:])
-                       peer2 := allowedIPs.LookupIPv4(addr4[:])
+                       peer2 := allowedIPs.Lookup(addr4[:])
                        if peer1 != peer2 {
                                t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
                        }
@@ -116,7 +116,7 @@ func TestTrieRandom(t *testing.T) {
                        var addr6 [16]byte
                        rand.Read(addr6[:])
                        peer1 = slow6.Lookup(addr6[:])
-                       peer2 = allowedIPs.LookupIPv6(addr6[:])
+                       peer2 = allowedIPs.Lookup(addr6[:])
                        if peer1 != peer2 {
                                t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
                        }
index cbd32cc5fea7ad847cd5880702dc84ec16a7a30d..7701cde450d864c6161d6a0c031ea2ef4a3b3284 100644 (file)
@@ -102,14 +102,14 @@ func TestTrieIPv4(t *testing.T) {
        }
 
        assertEQ := func(peer *Peer, a, b, c, d byte) {
-               p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
+               p := allowedIPs.Lookup([]byte{a, b, c, d})
                if p != peer {
                        t.Error("Assert EQ failed")
                }
        }
 
        assertNEQ := func(peer *Peer, a, b, c, d byte) {
-               p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
+               p := allowedIPs.Lookup([]byte{a, b, c, d})
                if p == peer {
                        t.Error("Assert NEQ failed")
                }
@@ -208,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
                addr = append(addr, expand(b)...)
                addr = append(addr, expand(c)...)
                addr = append(addr, expand(d)...)
-               p := allowedIPs.LookupIPv6(addr)
+               p := allowedIPs.Lookup(addr)
                if p != peer {
                        t.Error("Assert EQ failed")
                }
index 11822464a94592fd247ba27ea781f4c74dc6741c..58574810f812eeea7b03646a22ad2201152ecc75 100644 (file)
@@ -447,7 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        }
                        elem.packet = elem.packet[:length]
                        src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
-                       if device.allowedips.LookupIPv4(src) != peer {
+                       if device.allowedips.Lookup(src) != peer {
                                device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
                                goto skip
                        }
@@ -464,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
                        }
                        elem.packet = elem.packet[:length]
                        src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
-                       if device.allowedips.LookupIPv6(src) != peer {
+                       if device.allowedips.Lookup(src) != peer {
                                device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
                                goto skip
                        }
index a4f07e4470f50f8ad718ca244b90f7feb715b1c2..b05c69e94cc6a8d8dbe68b9f3a239b7d05178d89 100644 (file)
@@ -254,14 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
                                continue
                        }
                        dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
-                       peer = device.allowedips.LookupIPv4(dst)
+                       peer = device.allowedips.Lookup(dst)
 
                case ipv6.Version:
                        if len(elem.packet) < ipv6.HeaderLen {
                                continue
                        }
                        dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
-                       peer = device.allowedips.LookupIPv6(dst)
+                       peer = device.allowedips.Lookup(dst)
 
                default:
                        device.log.Verbosef("Received packet with unknown IP version")