]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
WIP
authorVladimír Čunát <vladimir.cunat@nic.cz>
Sat, 20 Sep 2025 09:19:44 +0000 (11:19 +0200)
committerVladimír Čunát <vladimir.cunat@nic.cz>
Thu, 9 Oct 2025 09:04:34 +0000 (11:04 +0200)
struct dns_tunnel_filter:
 - TorchModule can't be in the mmapped structure, as it's a pointer
 - drop unneeded parts

daemon/lua/kres-gen-33.lua
daemon/lua/kres-gen.sh
daemon/main.c
modules/dns_tunnel_filter/dns_tunnel_filter.c
modules/dns_tunnel_filter/dns_tunnel_filter.h [deleted file]
modules/dns_tunnel_filter/libblcnn.cpp
modules/dns_tunnel_filter/libblcnn.h
modules/policy/policy.lua

index 696c201e45b74b2f947e229353d7d10756942f2b..5aef601dd98c7da9f5b84af594224d2beb7610c9 100644 (file)
@@ -644,8 +644,6 @@ _Bool ratelimiting_request_begin(struct kr_request *);
 int ratelimiting_init(const char *, size_t, uint32_t, uint32_t, uint16_t, uint32_t, _Bool);
 int defer_init(const char *, uint32_t, int);
 void defer_set_price_factor16(struct kr_request *, uint32_t);
-int dns_tunnel_filter_init(const char *, size_t, uint32_t, uint32_t, uint16_t, uint32_t, _Bool);
-_Bool dns_tunnel_filter_request_begin(struct kr_request *);
 struct engine {
        char _stub[];
 };
index 41b0b7bf721edd9514780162ccb592dc82e0e611..ec9f47694a32e5a7e2911e3c64f4d5f24fc3f668 100755 (executable)
@@ -361,8 +361,6 @@ ${CDEFS} ${KRESD} functions <<-EOF
        ratelimiting_init
        defer_init
        defer_set_price_factor16
-       dns_tunnel_filter_request_begin
-       dns_tunnel_filter_init
 EOF
 
 echo "struct engine" | ${CDEFS} ${KRESD} types | sed '/module_array_t/,$ d'
index b1ce8a2d9b479caf919ecc905c15394fc2a36754..3925a15c18c80b82b0c319d7951a418d6038f40b 100644 (file)
@@ -630,17 +630,6 @@ int main(int argc, char **argv)
                goto cleanup;
        }
 
-       #if 0
-       if (!dns_tunnel_filter_initialized) {
-               kr_log_warning(TUNNEL, "Tunneling filter not initialized from Lua, using hardcoded default.\n");
-               ret = dns_tunnel_filter_init("dns_tunnel_filter", (1 << 20), (1 << 8), (1 << 17), 0, 0, false);
-               if (ret) {
-                       ret = EXIT_FAILURE;
-                       goto cleanup;
-               }
-       }
-       #endif
-
        ret = kr_rules_init_ensure();
        if (ret) {
                kr_log_error(RULES, "failed to initialize policy rule engine: %s\n",
@@ -670,8 +659,7 @@ int main(int argc, char **argv)
 cleanup:/* Cleanup. */
        network_unregister();
 
-       //dns_tunnel_filter_deinit();
-       ratelimiting_deinit();  
+       ratelimiting_deinit();
        kr_resolver_deinit();
        worker_deinit();
        engine_deinit();
index 3debd17d2e118cd97709e4f0bd0e057037bbc662..a52825513b01b86c69bf54e74ad0599602b5cfbc 100644 (file)
@@ -3,8 +3,8 @@
 */
 
 #include <stdatomic.h>
-#include "dns_tunnel_filter.h"
 #include "libblcnn.h"
+#include "lib/kru.h"
 #include "lib/kru-utils.h"
 #include "lib/mmapped.h"
 #include "lib/utils.h"
@@ -18,24 +18,29 @@ struct dns_tunnel_filter {
        size_t capacity;
        uint32_t instant_limit;
        uint32_t rate_limit;
-       uint32_t log_period;
-       uint16_t slip;
-       bool dry_run;
        bool using_avx2;
-       TorchModule net;
-       _Atomic uint32_t log_time;
+
        kru_price_t v4_prices[V4_PREFIXES_CNT];
        kru_price_t v6_prices[V6_PREFIXES_CNT];
        _Alignas(64) uint8_t kru[];
 };
 struct dns_tunnel_filter *dns_tunnel_filter = NULL;
 struct mmapped dns_tunnel_filter_mmapped = {0};
-bool dns_tunnel_filter_initialized = false;
 
-int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
-               uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run)
+bool load_attempted = false;
+TorchModule net = NULL;
+
+
+int dns_tunnel_filter_setup(const char *nn_file, const char *mmap_file,
+               size_t capacity, uint32_t instant_limit, uint32_t rate_limit)
 {
-       dns_tunnel_filter_initialized = true;
+       int ret;
+       net = load_model(nn_file);
+       if (!net) {
+               ret = kr_error(EINVAL); // we don't know what's wrong
+               goto fail;
+       }
+
        size_t capacity_log = 0;
        for (size_t c = capacity - 1; c > 0; c >>= 1) capacity_log++;
 
@@ -45,24 +50,18 @@ int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t inst
                .capacity = capacity,
                .instant_limit = instant_limit,
                .rate_limit = rate_limit,
-               .log_period = log_period,
-               .slip = slip,
-               .dry_run = dry_run,
                .using_avx2 = using_avx2()
        };
 
        size_t header_size = offsetof(struct dns_tunnel_filter, using_avx2) + sizeof(header.using_avx2);
-       static_assert(  // no padding up to .using_avx2
+       static_assert(  // no implicit padding up to .using_avx2
                offsetof(struct dns_tunnel_filter, using_avx2) ==
                        sizeof(header.capacity) +
                        sizeof(header.instant_limit) +
-                       sizeof(header.rate_limit) +
-                       sizeof(header.log_period) +
-                       sizeof(header.slip) +
-                       sizeof(header.dry_run),
+                       sizeof(header.rate_limit),
                "detected padding with undefined data inside mmapped header");
 
-       int ret = mmapped_init(&dns_tunnel_filter_mmapped, mmap_file, size, &header, header_size);
+       ret = mmapped_init(&dns_tunnel_filter_mmapped, mmap_file, size, &header, header_size);
        if (ret == MMAPPED_WAS_FIRST) {
                kr_log_info(TUNNEL, "Initializing DNS tunnel filter...\n");
 
@@ -79,8 +78,6 @@ int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t inst
                        goto fail;
                }
 
-               dns_tunnel_filter->log_time = kr_now() - log_period;
-
                for (size_t i = 0; i < V4_PREFIXES_CNT; i++) {
                        dns_tunnel_filter->v4_prices[i] = base_price / V4_RATE_MULT[i];
                }
@@ -89,9 +86,6 @@ int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t inst
                        dns_tunnel_filter->v6_prices[i] = base_price / V6_RATE_MULT[i];
                }
 
-               dns_tunnel_filter->net = load_model();
-               if (!dns_tunnel_filter->net) goto fail;
-
                ret = mmapped_init_continue(&dns_tunnel_filter_mmapped);
                if (ret != 0) goto fail;
 
@@ -106,33 +100,44 @@ int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t inst
 fail:
 
        kr_log_crit(TUNNEL, "Initialization of shared DNS tunnel filter data failed.\n");
+       load_attempted = true;
        return ret;
 }
 
-void dns_tunnel_filter_deinit(void)
+/// Ensure that the filter is loaded; return false if failed.
+static bool ensure_loaded(void)
 {
-       free_model(dns_tunnel_filter->net);
-       mmapped_deinit(&dns_tunnel_filter_mmapped);
-       dns_tunnel_filter = NULL;
+       if (dns_tunnel_filter)
+               return true;
+       if (load_attempted)
+               return false;
+
+       kr_log_warning(TUNNEL, "Tunneling filter not initialized from Lua, using hardcoded default.\n");
+       int ret = dns_tunnel_filter_setup("/home/vcunat/dev/nic-notes/vysocina/blcnn.pt", // FIXME TMP
+                                               "dns_tunnel_filter",
+                                               (1 << 20), (1 << 8), (1 << 17));
+       return ret == kr_ok();
 }
 
-bool dns_tunnel_filter_request_begin(struct kr_request *req)
+static int produce(kr_layer_t *ctx, knot_pkt_t *pkt)
 {
-       if (!dns_tunnel_filter) return false;
+       struct kr_request *req = ctx->req;
+       if (!ensure_loaded())
+               return ctx->state;
        if (!req->qsource.addr)
-               return false;  // don't consider internal requests
+               return ctx->state;  // don't consider internal requests
        if (req->qsource.price_factor16 == 0)
-               return false;  // whitelisted
+               return ctx->state;  // whitelisted
        if (!req->current_query)
-               return false;
+               return ctx->state;
        if (req->current_query->flags.CACHED) {
-               return false; // don't consider cached results
+               return ctx->state; // don't consider cached results
        }
        if (!req->current_query->sname)
-               return false;
+               return ctx->state;
 
        const uint32_t time_now = kr_now();
-       uint32_t price_scale_factor = (strlen((char *)req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
+       uint32_t price_scale_factor = (knot_dname_size(req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
 
        // classify
        _Alignas(16) uint8_t key[16] = {0, };
@@ -162,19 +167,41 @@ bool dns_tunnel_filter_request_begin(struct kr_request *req)
                limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dns_tunnel_filter->kru, time_now,
                                0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
        }
-       if (!limited_prefix) return false;  // not limited
+       if (!limited_prefix) return ctx->state;  // not limited
 
        uint8_t *packet = req->qsource.packet->wire;
        size_t packet_size = req->qsource.size;
 
-       float tunnel_prob = predict_packet(dns_tunnel_filter->net, packet, packet_size);
+       float tunnel_prob = predict_packet(net, packet, packet_size);
        
        if (tunnel_prob > 0.95) {
                kr_log_info(TUNNEL, "Malicious packet detected! (%f %%)\n", (tunnel_prob - 0.95) * 100 * 20);
-               req->options.NO_ANSWER = true;
-               req->state = KR_STATE_FAIL;
-               return true;
+               req->options.NO_ANSWER = true; // FIXME: this isn't a good reaction
+               return ctx->state = req->state = KR_STATE_FAIL;
        } else {
-               return false;
+               return ctx->state;
        }
 }
+
+/// Remove mmapped file data if not used by other processes.
+KR_EXPORT
+int dns_tunnel_filter_deinit(struct kr_module *self)
+{
+       free_model(net);
+       mmapped_deinit(&dns_tunnel_filter_mmapped);
+       dns_tunnel_filter = NULL;
+       return kr_ok();
+}
+
+KR_EXPORT
+int dns_tunnel_filter_init(struct kr_module *module) {
+       static kr_layer_api_t layer = {
+               .produce = produce,
+       };
+       layer.data = module;
+       module->layer = &layer;
+
+       return kr_ok();
+}
+
+KR_MODULE_EXPORT(dns_tunnel_filter)
diff --git a/modules/dns_tunnel_filter/dns_tunnel_filter.h b/modules/dns_tunnel_filter/dns_tunnel_filter.h
deleted file mode 100644 (file)
index 7d70262..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-/*  Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
- *  SPDX-License-Identifier: GPL-3.0-or-later
- */
-
- #include <stdbool.h>
- #include "lib/defines.h"
- #include "lib/utils.h"
- #include "lib/kru.h"
- struct kr_request;
- extern bool dns_tunnel_filter_initialized;
- /** Initialize rate-limiting with shared mmapped memory.
-  * The existing data are used if another instance is already using the file
-  * and it was initialized with the same parameters; it fails on mismatch. */
- KR_EXPORT
- int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
-               uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run);
- /** Do rate-limiting, during knot_layer_api::begin. */
- KR_EXPORT
- bool dns_tunnel_filter_request_begin(struct kr_request *req);
- /** Remove mmapped file data if not used by other processes. */
- KR_EXPORT
- void dns_tunnel_filter_deinit(void);
\ No newline at end of file
index 942ce70966fe34558f3cd49d9eb7349fb4032f75..74f23f188d19b1b6ad4b82f9fb15d102fcfda438 100644 (file)
@@ -13,22 +13,17 @@ struct TorchModuleWrapper {
        torch::jit::script::Module model;
 };
 
-TorchModule load_model(void) {
+TorchModule load_model(const char *nn_file) {
+       auto *wrapper = new TorchModuleWrapper();
        try {
-               namespace fs = std::filesystem;
-               auto *wrapper = new TorchModuleWrapper();
-
-               // FIXME: the path to this file
-               fs::path file_path = fs::relative(__FILE__, "../");
-               fs::path absolute_path = fs::absolute(file_path.parent_path()) / "blcnn.pt";
-               wrapper->model = torch::jit::load(absolute_path);
+               wrapper->model = torch::jit::load(nn_file);
                wrapper->model.to(torch::kCPU);
                wrapper->model.eval();
 
                return static_cast<TorchModule>(wrapper);
        } catch (const c10::Error &e) {
                std::cerr << "Error loading model: " << e.what() << std::endl;
-
+               delete wrapper;
                return nullptr;
        }
 }
index e191eac53850875416beaa527a3308ed98ec5f98..186784a140bed081fdb90299103bbfd7b11fcb3f 100644 (file)
@@ -5,9 +5,11 @@
 extern "C" {
 #endif
 
+#include <stddef.h>
+
 typedef void* TorchModule;
 
-TorchModule load_model(void);
+TorchModule load_model(const char *nn_file);
 float predict_packet(TorchModule model, const unsigned char *data, size_t size);
 void free_model(TorchModule model);
 
index 0e7884d53fa11132c5585f54a29dacda297c4a10..5311d0ff7664c2d6c72411b3261081c8107161ea 100644 (file)
@@ -960,9 +960,6 @@ policy.layer = {
                return policy.evaluate(policy.rules, req, qry, state)
                        or state
        end,
-       produce = function(state, req)
-               ffi.C.dns_tunnel_filter_request_begin(req)
-       end,
        finish = function(state, req)
                -- Optimization for the typical case
                if #policy.postrules == 0 then return state end