]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Support create bpf map with BPF_MAP_TYPE_LPM_TRIE type
authorY7n05h <Y7n05h@protonmail.com>
Wed, 13 Apr 2022 19:31:37 +0000 (03:31 +0800)
committerY7n05h <Y7n05h@protonmail.com>
Thu, 21 Apr 2022 08:55:21 +0000 (16:55 +0800)
Signed-off-by: Y7n05h <Y7n05h@protonmail.com>
pdns/bpf-filter.cc
pdns/bpf-filter.hh

index 2323a05f1ee839a65a5cfa7d853ac22562543c44..c24155c8edadb1dfc0c66716cdbc8711f7ef4028 100644 (file)
@@ -20,6 +20,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "bpf-filter.hh"
+#include "iputils.hh"
 
 #ifdef HAVE_EBPF
 
@@ -81,7 +82,7 @@ static void bpf_check_map_sizes(int fd, uint32_t expectedKeySize, uint32_t expec
 }
 
 int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
-                   int max_entries)
+                   int max_entries, int map_flags)
 {
   union bpf_attr attr;
   memset(&attr, 0, sizeof(attr));
@@ -89,6 +90,7 @@ int bpf_create_map(enum bpf_map_type map_type, int key_size, int value_size,
   attr.key_size = key_size;
   attr.value_size = value_size;
   attr.max_entries = max_entries;
+  attr.map_flags = map_flags;
   return syscall(SYS_bpf, BPF_MAP_CREATE, &attr, sizeof(attr));
 }
 
@@ -200,7 +202,7 @@ BPFFilter::Map::Map(const BPFFilter::MapConfiguration& config, BPFFilter::MapFor
 {
   if (d_config.d_type == BPFFilter::MapType::Filters) {
     /* special case, this is a map of eBPF programs */
-    d_fd = FDWrapper(bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), d_config.d_maxItems));
+    d_fd = FDWrapper(bpf_create_map(BPF_MAP_TYPE_PROG_ARRAY, sizeof(uint32_t), sizeof(uint32_t), d_config.d_maxItems, 0));
     if (d_fd.getHandle() == -1) {
       throw std::runtime_error("Error creating a BPF program map of size " + std::to_string(d_config.d_maxItems) + ": " + stringerror());
     }
@@ -208,6 +210,8 @@ BPFFilter::Map::Map(const BPFFilter::MapConfiguration& config, BPFFilter::MapFor
   else {
     int keySize = 0;
     int valueSize = 0;
+    int flags = 0;
+    bpf_map_type type = BPF_MAP_TYPE_HASH;
     if (format == MapFormat::Legacy) {
       switch (d_config.d_type) {
       case MapType::IPv4:
@@ -236,6 +240,18 @@ BPFFilter::Map::Map(const BPFFilter::MapConfiguration& config, BPFFilter::MapFor
         keySize = sizeof(KeyV6);
         valueSize = sizeof(CounterAndActionValue);
         break;
+      case MapType::CIDR4:
+        keySize = sizeof(CIDR4);
+        valueSize = sizeof(CounterAndActionValue);
+        flags = BPF_F_NO_PREALLOC;
+        type = BPF_MAP_TYPE_LPM_TRIE;
+        break;
+      case MapType::CIDR6:
+        keySize = sizeof(CIDR6);
+        valueSize = sizeof(CounterAndActionValue);
+        flags = BPF_F_NO_PREALLOC;
+        type = BPF_MAP_TYPE_LPM_TRIE;
+        break;
       case MapType::QNames:
         keySize = sizeof(QNameAndQTypeKey);
         valueSize = sizeof(CounterAndActionValue);
@@ -251,21 +267,35 @@ BPFFilter::Map::Map(const BPFFilter::MapConfiguration& config, BPFFilter::MapFor
       if (d_fd.getHandle() != -1) {
         /* sanity checks: key and value size */
         bpf_check_map_sizes(d_fd.getHandle(), keySize, valueSize);
-
-        if (d_config.d_type == MapType::IPv4) {
+        switch (d_config.d_type) {
+        case MapType::IPv4: {
           uint32_t key = 0;
           while (bpf_get_next_key(d_fd.getHandle(), &key, &key) == 0) {
             ++d_count;
           }
-        }
-        else if (d_config.d_type == MapType::IPv6) {
+        } break;
+        case MapType::IPv6: {
           KeyV6 key;
           memset(&key, 0, sizeof(key));
           while (bpf_get_next_key(d_fd.getHandle(), &key, &key) == 0) {
             ++d_count;
           }
-        }
-        else if (d_config.d_type == MapType::QNames) {
+        } break;
+        case MapType::CIDR4: {
+          CIDR4 key;
+          memset(&key, 0, sizeof(key));
+          while (bpf_get_next_key(d_fd.getHandle(), &key, &key) == 0) {
+            ++d_count;
+          }
+        } break;
+        case MapType::CIDR6: {
+          CIDR6 key;
+          memset(&key, 0, sizeof(key));
+          while (bpf_get_next_key(d_fd.getHandle(), &key, &key) == 0) {
+            ++d_count;
+          }
+        } break;
+        case MapType::QNames: {
           if (format == MapFormat::Legacy) {
             QNameKey key;
             memset(&key, 0, sizeof(key));
@@ -280,12 +310,16 @@ BPFFilter::Map::Map(const BPFFilter::MapConfiguration& config, BPFFilter::MapFor
               ++d_count;
             }
           }
+        } break;
+
+        default:
+          throw std::runtime_error("Unsupported eBPF map type: " + std::to_string(static_cast<uint8_t>(d_config.d_type)));
         }
       }
     }
 
     if (d_fd.getHandle() == -1) {
-      d_fd = FDWrapper(bpf_create_map(BPF_MAP_TYPE_HASH, keySize, valueSize, static_cast<int>(d_config.d_maxItems)));
+      d_fd = FDWrapper(bpf_create_map(type, keySize, valueSize, static_cast<int>(d_config.d_maxItems), flags));
       if (d_fd.getHandle() == -1) {
         throw std::runtime_error("Error creating a BPF map of size " + std::to_string(d_config.d_maxItems) + ": " + stringerror());
       }
@@ -470,6 +504,101 @@ void BPFFilter::unblock(const ComboAddress& addr)
   }
 }
 
+void BPFFilter::block(const Netmask& addr, BPFFilter::MatchAction action)
+{
+  CounterAndActionValue value;
+
+  int res = 0;
+  if (addr.isIPv4()) {
+    CIDR4 key(addr);
+    auto maps = d_maps.lock();
+    auto& map = maps->d_cidr4;
+    if (map.d_count >= map.d_config.d_maxItems) {
+      throw std::runtime_error("Table full when trying to block " + addr.toString());
+    }
+
+    res = bpf_lookup_elem(map.d_fd.getHandle(), &key, &value);
+    if (res != -1 && value.action == action) {
+      throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
+    }
+
+    value.counter = 0;
+    value.action = action;
+
+    res = bpf_update_elem(map.d_fd.getHandle(), &key, &value, BPF_NOEXIST);
+    if (res == 0) {
+      ++map.d_count;
+    }
+  }
+  else if (addr.isIPv6()) {
+    CIDR6 key(addr);
+
+    auto maps = d_maps.lock();
+    auto& map = maps->d_cidr6;
+    if (map.d_count >= map.d_config.d_maxItems) {
+      throw std::runtime_error("Table full when trying to block " + addr.toString());
+    }
+
+    res = bpf_lookup_elem(map.d_fd.getHandle(), &key, &value);
+    if (res != -1 && value.action == action) {
+      throw std::runtime_error("Trying to block an already blocked address: " + addr.toString());
+    }
+
+    value.counter = 0;
+    value.action = action;
+
+    res = bpf_update_elem(map.d_fd.getHandle(), &key, &value, BPF_NOEXIST);
+    if (res == 0) {
+      map.d_count++;
+    }
+  }
+
+  if (res != 0) {
+    throw std::runtime_error("Error adding blocked address " + addr.toString() + ": " + stringerror());
+  }
+}
+
+void BPFFilter::unblock(const Netmask& addr)
+{
+  int res = 0;
+  CounterAndActionValue value;
+  value.counter = 0;
+  value.action = MatchAction::Pass;
+  if (addr.isIPv4()) {
+    CIDR4 key(addr);
+    auto maps = d_maps.lock();
+    auto& map = maps->d_cidr4;
+    res = bpf_delete_elem(map.d_fd.getHandle(), &key);
+    if (res == 0) {
+      --map.d_count;
+    }
+    else {
+      res = bpf_update_elem(map.d_fd.getHandle(), &key, &value, BPF_NOEXIST);
+      if (res == 0)
+        ++map.d_count;
+    }
+  }
+  else if (addr.isIPv6()) {
+    CIDR6 key(addr);
+
+    auto maps = d_maps.lock();
+    auto& map = maps->d_cidr6;
+    res = bpf_delete_elem(map.d_fd.getHandle(), &key);
+    if (res == 0) {
+      --map.d_count;
+    }
+    else {
+      res = bpf_update_elem(map.d_fd.getHandle(), &key, &value, BPF_NOEXIST);
+      if (res == 0)
+        ++map.d_count;
+    }
+  }
+
+  if (res != 0) {
+    throw std::runtime_error("Error removing blocked address " + addr.toString() + ": " + stringerror());
+  }
+}
+
 void BPFFilter::block(const DNSName& qname, BPFFilter::MatchAction action, uint16_t qtype)
 {
   CounterAndActionValue cadvalue;
@@ -720,6 +849,15 @@ void BPFFilter::unblock(const DNSName&, uint16_t)
   throw std::runtime_error("eBPF support not enabled");
 }
 
+void BPFFilter::block(const Netmask&, BPFFilter::MatchAction)
+{
+  throw std::runtime_error("eBPF support not enabled");
+}
+void BPFFilter::unblock(const Netmask&)
+{
+  throw std::runtime_error("eBPF support not enabled");
+}
+
 std::vector<std::pair<ComboAddress, uint64_t> > BPFFilter::getAddrStats()
 {
   std::vector<std::pair<ComboAddress, uint64_t> > result;
index ed1a4024005d1261e8e75fff92dc58ca550c9b60..5fc991e47a596ce2ace549b0cc613b6ed675bb85 100644 (file)
@@ -26,6 +26,8 @@
 
 #include "iputils.hh"
 #include "lock.hh"
+#include <netinet/in.h>
+#include <stdexcept>
 
 class BPFFilter
 {
@@ -34,7 +36,9 @@ public:
     IPv4,
     IPv6,
     QNames,
-    Filters
+    Filters,
+    CIDR4,
+    CIDR6
   };
 
   enum class MapFormat : uint8_t {
@@ -65,8 +69,10 @@ public:
   void addSocket(int sock);
   void removeSocket(int sock);
   void block(const ComboAddress& addr, MatchAction action);
+  void block(const Netmask& address, BPFFilter::MatchAction action);
   void block(const DNSName& qname, MatchAction action, uint16_t qtype=255);
   void unblock(const ComboAddress& addr);
+  void unblock(const Netmask& address);
   void unblock(const DNSName& qname, uint16_t qtype=255);
 
   std::vector<std::pair<ComboAddress, uint64_t> > getAddrStats();
@@ -94,6 +100,8 @@ private:
   {
     Map d_v4;
     Map d_v6;
+    Map d_cidr4;
+    Map d_cidr6;
     Map d_qnames;
     /* The qname filter program held in d_qnamefilter is
        stored in an eBPF map, so we can call it from the
@@ -107,7 +115,34 @@ private:
   FDWrapper d_mainfilter;
   /* qname filtering program */
   FDWrapper d_qnamefilter;
-
+  struct CIDR4
+  {
+    uint32_t cidr;
+    struct in_addr addr;
+    explicit CIDR4(Netmask address)
+    {
+      if (!address.isIPv4()) {
+        throw std::runtime_error("ComboAddress is invalid");
+      }
+      addr = address.getNetwork().sin4.sin_addr;
+      cidr = address.getBits();
+    }
+    CIDR4() = default;
+  };
+  struct CIDR6
+  {
+    uint32_t cidr;
+    struct in6_addr addr;
+    CIDR6(Netmask address)
+    {
+      if (address.isIPv6()) {
+        throw std::runtime_error("ComboAddress is invalid");
+      }
+      addr = address.getNetwork().sin6.sin6_addr;
+      cidr = address.getBits();
+    }
+    CIDR6() = default;
+  };
   /* whether the maps are in the 'old' format, which we need
      to keep to prevent going over the 4k instructions per eBPF
      program limit in kernels < 5.2, as well as the complexity limit: