#include <stdatomic.h>
#include "daemon/dnamelimiting.h"
+#include "daemon/libblcnn.h"
#include "lib/mmapped.h"
#include "lib/utils.h"
#include "lib/resolve.h"
#define V6_PREFIXES_CNT (sizeof(V6_PREFIXES) / sizeof(*V6_PREFIXES))
#define MAX_PREFIXES_CNT ((V4_PREFIXES_CNT > V6_PREFIXES_CNT) ? V4_PREFIXES_CNT : V6_PREFIXES_CNT)
+#define DNAME_SCALE_FACTOR 25
+
struct dnamelimiting {
size_t capacity;
uint32_t instant_limit;
uint16_t slip;
bool dry_run;
bool using_avx2;
+ TorchModule net;
_Atomic uint32_t log_time;
kru_price_t v4_prices[V4_PREFIXES_CNT];
kru_price_t v6_prices[V6_PREFIXES_CNT];
dnamelimiting->v6_prices[i] = base_price / V6_RATE_MULT[i];
}
+ dnamelimiting->net = load_model();
+ if (!dnamelimiting->net) goto fail;
+
ret = mmapped_init_continue(&dnamelimiting_mmapped);
if (ret != 0) goto fail;
void dnamelimiting_deinit(void)
{
+ free_model(dnamelimiting->net);
mmapped_deinit(&dnamelimiting_mmapped);
dnamelimiting = NULL;
}
return false; // don't consider internal requests
if (req->qsource.price_factor16 == 0)
return false; // whitelisted
+ if (!req->current_query)
+ return false;
+ if (!req->current_query->sname)
+ return false;
// We only do this on pure UDP. (also TODO if cookies get implemented)
const bool ip_validated = req->qsource.flags.tcp || req->qsource.flags.tls;
if (ip_validated) return false;
const uint32_t time_now = kr_now();
+ uint32_t price_scale_factor = (strlen((char *)req->current_query->sname) << 16)/ DNAME_SCALE_FACTOR;
// classify
_Alignas(16) uint8_t key[16] = {0, };
// compute adjusted prices, using standard rounding
kru_price_t prices[V6_PREFIXES_CNT];
for (int i = 0; i < V6_PREFIXES_CNT; ++i) {
- prices[i] = (req->qsource.price_factor16
- * (uint64_t)dnamelimiting->v6_prices[i] + (1<<15)) >> 16;
+ prices[i] = (req->qsource.price_factor16 * (uint64_t)price_scale_factor
+ * (uint64_t)dnamelimiting->v6_prices[i] + (1<<15)) >> 32;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dnamelimiting->kru, time_now,
1, key, V6_PREFIXES, prices, V6_PREFIXES_CNT, NULL);
// compute adjusted prices, using standard rounding
kru_price_t prices[V4_PREFIXES_CNT];
for (int i = 0; i < V4_PREFIXES_CNT; ++i) {
- prices[i] = (req->qsource.price_factor16
- * (uint64_t)dnamelimiting->v4_prices[i] + (1<<15)) >> 16;
+ prices[i] = (req->qsource.price_factor16 * (uint64_t)price_scale_factor
+ * (uint64_t)dnamelimiting->v4_prices[i] + (1<<15)) >> 32;
}
limited_prefix = KRU.limited_multi_prefix_or((struct kru *)dnamelimiting->kru, time_now,
0, key, V4_PREFIXES, prices, V4_PREFIXES_CNT, NULL);
}
if (!limited_prefix) return false; // not limited
+ uint8_t *packet = req->qsource.packet->wire;
+ size_t packet_size = req->qsource.size;
+
+ float ret = predict_packet(dnamelimiting->net, packet, packet_size);
+ if (ret > 0.95)
+ printf("Potentially malicious packet (%f %%)\n", (ret - 0.95) * 100 * 20);
return true;
}
{
uint32_t max_queries = expected_passing > 0 ? 2 * expected_passing : -expected_passing;
struct sockaddr_storage addr;
- uint8_t wire[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
- knot_pkt_t answer = { .wire = wire };
- unsigned char sname[dname_length + 1];
- memset(sname, 'a', dname_length);
- sname[dname_length] = '\0';
- struct kr_query query = { .sname = sname};
- uint32_t price_factor16 = (strlen((char *)query.sname) << 16) / 25;
+
+ uint8_t wire_answer[KNOT_WIRE_MIN_PKTSIZE] = { 0 };
+ knot_pkt_t answer = { .wire = wire_answer };
+
+ size_t query_packet_size = get_packet_size(dname_length);
+ uint8_t wire_query[query_packet_size];
+ unsigned char dname[256];
+ create_dns_query(dname_length, wire_query, dname);
+
+ knot_pkt_t query_packet = { .wire = wire_query};
+ struct kr_query query = { .sname = dname};
struct kr_request req = {
.qsource.addr = (struct sockaddr *) &addr,
- .qsource.price_factor16 = price_factor16,
+ .qsource.price_factor16 = (1 << 16),
+ .qsource.packet = &query_packet,
+ .qsource.size = query_packet_size,
+
.answer = &answer,
.current_query = &query
};
(void)snprintf(addr_str, sizeof(addr_str), addr_format,
i % (ip_max_value - ip_min_value + 1) + ip_min_value,
i / (ip_max_value - ip_min_value + 1) % 256);
- //printf("string: %s\n", addr_str);
kr_straddr_socket_set((struct sockaddr *) &addr, addr_str, 0);
if (dnamelimiting_request_begin(&req)) {
cnt = i;
#include <sched.h>
#include <stdio.h>
#include <stdatomic.h>
+#include <string.h>
+#include <time.h>
#include "tests/unit/test.h"
#include "libdnssec/crypto.h"
uint64_t fakeclock_tick = 0;
uint64_t fakeclock_start = 0;
+#define DNS_QUERY_TYPE_A 0x0001
+#define DNS_QUERY_CLASS_IN 0x0001
+
+typedef struct {
+ uint16_t id;
+ uint16_t flags;
+ uint16_t qdcount;
+ uint16_t ancount;
+ uint16_t nscount;
+ uint16_t arcount;
+} __attribute__((packed)) dns_header_t;
+
+size_t get_packet_size(uint32_t dname_length) {
+ uint8_t header_size = sizeof(dns_header_t);
+ uint8_t encoded_dname_size = dname_length + 1;
+ uint8_t query_flags_size = 4;
+ return header_size + encoded_dname_size + query_flags_size;
+}
+
+void create_domain_name(uint8_t *dname, size_t domain_length, unsigned char *sname) {
+ uint8_t encoded_size = 0;
+ uint8_t sname_size = 0;
+ int remaining_length = domain_length - 4;
+
+ while (remaining_length > 0) {
+ uint8_t max_label_length = (remaining_length > 64) ? 64 : remaining_length;
+ uint8_t label_length = 2 + (rand() % (max_label_length - 1));
+ while (remaining_length - label_length == 1)
+ label_length = 2 + (rand() % (max_label_length - 1));
+
+ dname[encoded_size++] = label_length;
+
+ for (uint8_t i = 0; i < label_length - 1; i++) {
+ char rl = 'a' + (rand() % 26);
+ sname[sname_size++] = rl;
+ dname[encoded_size++] = rl;
+ }
+ sname[sname_size++] = '.';
+ remaining_length -= label_length;
+ }
+
+ dname[encoded_size++] = 2;
+ dname[encoded_size++] = 'c';
+ dname[encoded_size++] = 'z';
+ dname[encoded_size++] = 0;
+
+ sname[sname_size++] = 'c';
+ sname[sname_size++] = 'z';
+ sname[sname_size++] = '.';
+ sname[sname_size++] = '\0';
+}
+
+void create_dns_query(uint32_t domain_length, uint8_t *dest, unsigned char* dname) {
+ size_t domain_len = domain_length + 1;
+
+ dns_header_t *header = (dns_header_t *)dest;
+ header->id = htons(0x1234);
+ header->flags = htons(0x0100);
+ header->qdcount = htons(1);
+ header->ancount = 0;
+ header->nscount = 0;
+ header->arcount = 0;
+
+ uint8_t *qname = dest + sizeof(dns_header_t);
+ create_domain_name(qname, domain_len, dname);
+
+ uint16_t *qtype = (uint16_t *)(qname + domain_len);
+ uint16_t *qclass = (uint16_t *)(qtype + 1);
+ *qtype = htons(DNS_QUERY_TYPE_A);
+ *qclass = htons(DNS_QUERY_CLASS_IN);
+}
+
void fakeclock_init(void)
{
fakeclock_start = kr_now();
int main(int argc, char *argv[])
{
+ srand(time(NULL));
+
assert(KRU_GENERIC.initialize != KRU_AVX2.initialize);
if (KRU.initialize == KRU_AVX2.initialize) {
const UnitTest tests[] = {
extra_link_args = []
extra_link_libs = []
+ if unit_test[0] == 'dnamelimiting'
+ extra_link_args += ['-lstdc++', '-lc10']
+ extra_link_libs += [blcnn_lib]
+ endif
+
+
exec_test = executable(
unit_test[0],
unit_test[1],