*/
#include <stdatomic.h>
-#include "dns_tunnel_filter.h"
#include "libblcnn.h"
+#include "lib/kru.h"
#include "lib/kru-utils.h"
#include "lib/mmapped.h"
#include "lib/utils.h"
size_t capacity;
uint32_t instant_limit;
uint32_t rate_limit;
- uint32_t log_period;
- uint16_t slip;
- bool dry_run;
bool using_avx2;
- TorchModule net;
- _Atomic uint32_t log_time;
+
kru_price_t v4_prices[V4_PREFIXES_CNT];
kru_price_t v6_prices[V6_PREFIXES_CNT];
_Alignas(64) uint8_t kru[];
};
struct dns_tunnel_filter *dns_tunnel_filter = NULL;
struct mmapped dns_tunnel_filter_mmapped = {0};
-bool dns_tunnel_filter_initialized = false;
-int dns_tunnel_filter_init(const char *mmap_file, size_t capacity, uint32_t instant_limit,
- uint32_t rate_limit, uint16_t slip, uint32_t log_period, bool dry_run)
+bool load_attempted = false;
+TorchModule net = NULL;
+
+
+int dns_tunnel_filter_setup(const char *nn_file, const char *mmap_file,
+ size_t capacity, uint32_t instant_limit, uint32_t rate_limit)
{
- dns_tunnel_filter_initialized = true;
+ int ret;
+ net = load_model(nn_file);
+ if (!net) {
+ ret = kr_error(EINVAL); // we don't know what's wrong
+ goto fail;
+ }
+
size_t capacity_log = 0;
for (size_t c = capacity - 1; c > 0; c >>= 1) capacity_log++;
.capacity = capacity,
.instant_limit = instant_limit,
.rate_limit = rate_limit,
- .log_period = log_period,
- .slip = slip,
- .dry_run = dry_run,
.using_avx2 = using_avx2()
};
size_t header_size = offsetof(struct dns_tunnel_filter, using_avx2) + sizeof(header.using_avx2);
- static_assert( // no padding up to .using_avx2
+ static_assert( // no implicit padding up to .using_avx2
offsetof(struct dns_tunnel_filter, using_avx2) ==
sizeof(header.capacity) +
sizeof(header.instant_limit) +
- sizeof(header.rate_limit) +
- sizeof(header.log_period) +
- sizeof(header.slip) +
- sizeof(header.dry_run),
+ sizeof(header.rate_limit),
"detected padding with undefined data inside mmapped header");
- int ret = mmapped_init(&dns_tunnel_filter_mmapped, mmap_file, size, &header, header_size);
+ ret = mmapped_init(&dns_tunnel_filter_mmapped, mmap_file, size, &header, header_size);
if (ret == MMAPPED_WAS_FIRST) {
kr_log_info(TUNNEL, "Initializing DNS tunnel filter...\n");
goto fail;
}
- dns_tunnel_filter->log_time = kr_now() - log_period;
-
for (size_t i = 0; i < V4_PREFIXES_CNT; i++) {
dns_tunnel_filter->v4_prices[i] = base_price / V4_RATE_MULT[i];
}
dns_tunnel_filter->v6_prices[i] = base_price / V6_RATE_MULT[i];
}
- dns_tunnel_filter->net = load_model();
- if (!dns_tunnel_filter->net) goto fail;
-
ret = mmapped_init_continue(&dns_tunnel_filter_mmapped);
if (ret != 0) goto fail;
fail:
kr_log_crit(TUNNEL, "Initialization of shared DNS tunnel filter data failed.\n");
+ load_attempted = true;
return ret;
}
-void dns_tunnel_filter_deinit(void)
+/// Ensure that the filter is loaded; return false if failed.
+static bool ensure_loaded(void)
{
- free_model(dns_tunnel_filter->net);
- mmapped_deinit(&dns_tunnel_filter_mmapped);
- dns_tunnel_filter = NULL;
+ if (dns_tunnel_filter)
+ return true;
+ if (load_attempted)
+ return false;
+
+ kr_log_warning(TUNNEL, "Tunneling filter not initialized from Lua, using hardcoded default.\n");
+ int ret = dns_tunnel_filter_setup("/home/vcunat/dev/nic-notes/vysocina/blcnn.pt", // FIXME TMP
+ "dns_tunnel_filter",
+ (1 << 20), (1 << 8), (1 << 17));
+ return ret == kr_ok();
}
-bool dns_tunnel_filter_request_begin(struct kr_request *req)
+static int produce(kr_layer_t *ctx, knot_pkt_t *pkt)
{
- if (!dns_tunnel_filter) return false;
+ struct kr_request *req = ctx->req;
+ if (!ensure_loaded())
+ return ctx->state;
if (!req->qsource.addr)
- return false; // don't consider internal requests
+ return ctx->state; // don't consider internal requests
if (req->qsource.price_factor16 == 0)
- return false; // whitelisted
+ return ctx->state; // whitelisted
if (!req->current_query)
- return false;
+ return ctx->state;
if (req->current_query->flags.CACHED) {
- return false; // don't consider cached results
+ return ctx->state; // don't consider cached results
}
if (!req->current_query->sname)
- return false;
+ return ctx->state;
const uint32_t time_now = kr_now();
- uint32_t price_scale_factor = (strlen((char *)req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
+ uint32_t price_scale_factor = (knot_dname_size(req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
// classify
_Alignas(16) uint8_t key[16] = {0, };
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dns_tunnel_filter->kru, time_now,
0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
}
- if (!limited_prefix) return false; // not limited
+ if (!limited_prefix) return ctx->state; // not limited
uint8_t *packet = req->qsource.packet->wire;
size_t packet_size = req->qsource.size;
- float tunnel_prob = predict_packet(dns_tunnel_filter->net, packet, packet_size);
+ float tunnel_prob = predict_packet(net, packet, packet_size);
if (tunnel_prob > 0.95) {
kr_log_info(TUNNEL, "Malicious packet detected! (%f %%)\n", (tunnel_prob - 0.95) * 100 * 20);
- req->options.NO_ANSWER = true;
- req->state = KR_STATE_FAIL;
- return true;
+ req->options.NO_ANSWER = true; // FIXME: this isn't a good reaction
+ return ctx->state = req->state = KR_STATE_FAIL;
} else {
- return false;
+ return ctx->state;
}
}
+
+/// Remove mmapped file data if not used by other processes.
+KR_EXPORT
+int dns_tunnel_filter_deinit(struct kr_module *self)
+{
+ free_model(net);
+ mmapped_deinit(&dns_tunnel_filter_mmapped);
+ dns_tunnel_filter = NULL;
+ return kr_ok();
+}
+
+KR_EXPORT
+int dns_tunnel_filter_init(struct kr_module *module) {
+ static kr_layer_api_t layer = {
+ .produce = produce,
+ };
+ layer.data = module;
+ module->layer = &layer;
+
+ return kr_ok();
+}
+
+KR_MODULE_EXPORT(dns_tunnel_filter)