]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
dnamelimiting: Add NN packet classification for limited packets
authorHynek Šabacký <hynek.sabacky@nic.cz>
Thu, 27 Mar 2025 15:16:01 +0000 (16:16 +0100)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Fri, 19 Sep 2025 08:21:27 +0000 (10:21 +0200)
daemon/dnamelimiting.c
daemon/dnamelimiting.test/tests.c
daemon/dnamelimiting.test/tests.inc.c
tests/unit/meson.build

index 3dc2dff55b38da8b00faba2981dec7cd1413e68d..3d955bca15e11774e51408cadcde537b1054d3aa 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <stdatomic.h>
 #include "daemon/dnamelimiting.h"
+#include "daemon/libblcnn.h"
 #include "lib/mmapped.h"
 #include "lib/utils.h"
 #include "lib/resolve.h"
@@ -18,6 +19,8 @@
 #define V6_PREFIXES_CNT (sizeof(V6_PREFIXES) / sizeof(*V6_PREFIXES))
 #define MAX_PREFIXES_CNT ((V4_PREFIXES_CNT > V6_PREFIXES_CNT) ? V4_PREFIXES_CNT : V6_PREFIXES_CNT)
 
+#define DNAME_SCALE_FACTOR 25
+
 struct dnamelimiting {
        size_t capacity;
        uint32_t instant_limit;
@@ -26,6 +29,7 @@ struct dnamelimiting {
        uint16_t slip;
        bool dry_run;
        bool using_avx2;
+       TorchModule net;
        _Atomic uint32_t log_time;
        kru_price_t v4_prices[V4_PREFIXES_CNT];
        kru_price_t v6_prices[V6_PREFIXES_CNT];
@@ -99,6 +103,9 @@ int dnamelimiting_init(const char *mmap_file, size_t capacity, uint32_t instant_
                        dnamelimiting->v6_prices[i] = base_price / V6_RATE_MULT[i];
                }
 
+               dnamelimiting->net = load_model();
+               if (!dnamelimiting->net) goto fail;
+
                ret = mmapped_init_continue(&dnamelimiting_mmapped);
                if (ret != 0) goto fail;
 
@@ -118,6 +125,7 @@ fail:
 
 void dnamelimiting_deinit(void)
 {
+       free_model(dnamelimiting->net);
        mmapped_deinit(&dnamelimiting_mmapped);
        dnamelimiting = NULL;
 }
@@ -130,12 +138,17 @@ bool dnamelimiting_request_begin(struct kr_request *req)
                return false;  // don't consider internal requests
        if (req->qsource.price_factor16 == 0)
                return false;  // whitelisted
+       if (!req->current_query)
+               return false;
+       if (!req->current_query->sname)
+               return false;
 
        // We only do this on pure UDP.  (also TODO if cookies get implemented)
        const bool ip_validated = req->qsource.flags.tcp || req->qsource.flags.tls;
        if (ip_validated) return false;
 
        const uint32_t time_now = kr_now();
+       uint32_t price_scale_factor = (strlen((char *)req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
 
        // classify
        _Alignas(16) uint8_t key[16] = {0, };
@@ -147,8 +160,8 @@ bool dnamelimiting_request_begin(struct kr_request *req)
                // compute adjusted prices, using standard rounding
                kru_price_t prices[V6_PREFIXES_CNT];
                for (int i = 0; i < V6_PREFIXES_CNT; ++i) {
-                       prices[i] = (req->qsource.price_factor16
-                                       * (uint64_t)dnamelimiting->v6_prices[i] + (1<<15)) >> 16;
+                       prices[i] = (req->qsource.price_factor16 * (uint64_t)price_scale_factor
+                                       * (uint64_t)dnamelimiting->v6_prices[i] + (1<<15)) >> 32;
                }
                limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dnamelimiting->kru, time_now,
                                1, key, V6_PREFIXES, prices, V6_PREFIXES_CNT, NULL);
@@ -159,13 +172,19 @@ bool dnamelimiting_request_begin(struct kr_request *req)
                // compute adjusted prices, using standard rounding
                kru_price_t prices[V4_PREFIXES_CNT];
                for (int i = 0; i < V4_PREFIXES_CNT; ++i) {
-                       prices[i] = (req->qsource.price_factor16
-                                       * (uint64_t)dnamelimiting->v4_prices[i] + (1<<15)) >> 16;
+                       prices[i] = (req->qsource.price_factor16 * (uint64_t)price_scale_factor
+                                       * (uint64_t)dnamelimiting->v4_prices[i] + (1<<15)) >> 32;
                }
                limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dnamelimiting->kru, time_now,
                                0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
        }
        if (!limited_prefix) return false;  // not limited
 
+       uint8_t *packet = req->qsource.packet->wire;
+       size_t packet_size = req->qsource.size;
+
+       float ret = predict_packet(dnamelimiting->net, packet, packet_size);
+       if (ret > 0.95)
+               printf("Potentially malicious packet (%f %%)\n", (ret - 0.95) * 100 * 20);
        return true;
 }
index 05b14096fdd109d9f7cc7ce38fb9f5cad6cebdd8..1dae27691125d96584e068e5b5bd62c69df5d47e 100644 (file)
@@ -28,16 +28,23 @@ uint32_t _count_test(int expected_passing, uint32_t dname_length, int addr_famil
 {
        uint32_t max_queries = expected_passing > 0 ? 2 * expected_passing : -expected_passing;
        struct sockaddr_storage addr;
-       uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
-       knot_pkt_t answer = { .wire = wire };
-       unsigned char sname[dname_length + 1];
-       memset(sname, 'a', dname_length);
-       sname[dname_length] = '\0';
-       struct kr_query query = { .sname = sname}; 
-       uint32_t price_factor16 = (strlen((char *)query.sname) << 16) / 25;
+
+       uint8_t wire_answer[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
+       knot_pkt_t answer = { .wire = wire_answer };
+
+       size_t query_packet_size = get_packet_size(dname_length);
+       uint8_t wire_query[query_packet_size];
+       unsigned char dname[256];
+       create_dns_query(dname_length, wire_query, dname);
+       
+       knot_pkt_t query_packet = { .wire = wire_query};
+       struct kr_query query = { .sname = dname}; 
        struct kr_request req = {
                .qsource.addr = (struct sockaddr *) &addr,
-               .qsource.price_factor16 = price_factor16,
+               .qsource.price_factor16 = (1 << 16),
+               .qsource.packet = &query_packet,
+               .qsource.size = query_packet_size,
+               
                .answer = &answer,
                .current_query = &query
        };
@@ -48,7 +55,6 @@ uint32_t _count_test(int expected_passing, uint32_t dname_length, int addr_famil
                (void)snprintf(addr_str, sizeof(addr_str), addr_format,
                                i % (ip_max_value - ip_min_value + 1) + ip_min_value,
                                i / (ip_max_value - ip_min_value + 1) % 256);
-               //printf("string: %s\n", addr_str);
                kr_straddr_socket_set((struct sockaddr *) &addr, addr_str, 0);
                if (dnamelimiting_request_begin(&req)) {
                        cnt = i;
index ca2ffb776223128c93bd56f9797305666eec4554..8c19af01dab135154f619f1f69ec2773d8912194 100644 (file)
@@ -18,6 +18,8 @@
 #include <sched.h>
 #include <stdio.h>
 #include <stdatomic.h>
+#include <string.h>
+#include <time.h>
 
 #include "tests/unit/test.h"
 #include "libdnssec/crypto.h"
@@ -76,6 +78,78 @@ struct kru_avx2 {
 uint64_t fakeclock_tick = 0;
 uint64_t fakeclock_start = 0;
 
+#define DNS_QUERY_TYPE_A 0x0001
+#define DNS_QUERY_CLASS_IN 0x0001
+
+typedef struct {
+       uint16_t id;
+       uint16_t flags;
+       uint16_t qdcount;
+       uint16_t ancount;
+       uint16_t nscount;
+       uint16_t arcount;
+} __attribute__((packed)) dns_header_t;
+
+size_t get_packet_size(uint32_t dname_length) {
+       uint8_t header_size = sizeof(dns_header_t);
+       uint8_t encoded_dname_size = dname_length + 1;
+       uint8_t query_flags_size = 4;
+       return header_size + encoded_dname_size + query_flags_size;
+}
+
+void create_domain_name(uint8_t *dname, size_t domain_length, unsigned char *sname) {
+       uint8_t encoded_size = 0;
+       uint8_t sname_size = 0;
+       int remaining_length = domain_length - 4;
+
+       while (remaining_length > 0) {
+               uint8_t max_label_length = (remaining_length > 64) ? 64 : remaining_length;
+               uint8_t label_length = 2 + (rand() % (max_label_length - 1));
+               while (remaining_length - label_length == 1)
+                       label_length = 2 + (rand() % (max_label_length - 1));
+
+               dname[encoded_size++] = label_length;
+
+               for (uint8_t i = 0; i < label_length - 1; i++) {
+                       char rl = 'a' + (rand() % 26);
+                       sname[sname_size++] = rl;
+                       dname[encoded_size++] = rl;
+               }
+               sname[sname_size++] = '.';
+               remaining_length -= label_length;
+       }
+
+       dname[encoded_size++] = 2;
+       dname[encoded_size++] = 'c';
+       dname[encoded_size++] = 'z';
+       dname[encoded_size++] = 0;
+
+       sname[sname_size++] = 'c';
+       sname[sname_size++] = 'z';
+       sname[sname_size++] = '.';
+       sname[sname_size++] = '\0';
+}
+
+void create_dns_query(uint32_t domain_length, uint8_t *dest, unsigned char* dname) {
+       size_t domain_len = domain_length + 1;
+
+       dns_header_t *header = (dns_header_t *)dest;
+       header->id = htons(0x1234);
+       header->flags = htons(0x0100);
+       header->qdcount = htons(1);
+       header->ancount = 0;
+       header->nscount = 0;
+       header->arcount = 0;
+
+       uint8_t *qname = dest + sizeof(dns_header_t);
+       create_domain_name(qname, domain_len, dname);
+
+       uint16_t *qtype = (uint16_t *)(qname + domain_len);
+       uint16_t *qclass = (uint16_t *)(qtype + 1);
+       *qtype = htons(DNS_QUERY_TYPE_A);
+       *qclass = htons(DNS_QUERY_CLASS_IN);
+}
+
 void fakeclock_init(void)
 {
        fakeclock_start = kr_now();
@@ -126,6 +200,8 @@ static void test_rrl_avx2(void **state) {
 
 int main(int argc, char *argv[])
 {
+       srand(time(NULL));
+       
        assert(KRU_GENERIC.initialize != KRU_AVX2.initialize);
        if (KRU.initialize == KRU_AVX2.initialize) {
                const UnitTest tests[] = {
index 253bf8d26805480f8ab895976924954151db02a8..074dfbf066e0c412b1a6408dc3b166931aba4317 100644 (file)
@@ -19,6 +19,12 @@ foreach unit_test : unit_tests
   extra_link_args = []
   extra_link_libs = []
 
+  if unit_test[0] == 'dnamelimiting'
+    extra_link_args += ['-lstdc++', '-lc10']
+    extra_link_libs += [blcnn_lib]
+  endif
+
+
   exec_test = executable(
     unit_test[0],
     unit_test[1],