]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Fix compilation errors and simplify HTML shingles
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 4 Oct 2025 10:24:46 +0000 (11:24 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 4 Oct 2025 10:24:46 +0000 (11:24 +0100)
- Export rspamd_shingles_get_keys_cached() for use in HTML shingles
- Simplify extract_etld1_from_url(): use existing url->tld field
  (in Rspamd, tld already contains eTLD+1/eSLD, no need to parse)
- Add proper reinterpret_cast for const char* to unsigned char*
- Fix variable name conflict (html_content parameter vs local var)
- Use rspamd_url_tld_unsafe() and rspamd_url_host_unsafe() macros

src/libutil/shingles.c
src/libutil/shingles.h
src/libutil/shingles_html.cxx

index e96f32199f1b7dbbeddb517d256bf5b651e27076..56d144549a80b10b8fa7537d0c876e042584610c 100644 (file)
@@ -64,7 +64,7 @@ rspamd_shingles_keys_new(void)
        return k;
 }
 
-static unsigned char **
+unsigned char **
 rspamd_shingles_get_keys_cached(const unsigned char key[SHINGLES_KEY_SIZE])
 {
        static GHashTable *ht = NULL;
index 3d66e97ebbc547200fc9c7b6ec854420669c5701..206e88bbcc3b886675eff707749663622cf98ff6 100644 (file)
@@ -149,6 +149,11 @@ double rspamd_html_shingles_compare(const struct rspamd_html_shingle *a,
 uint64_t rspamd_shingles_default_filter(uint64_t *input, gsize count,
                                                                                int shno, const unsigned char *key, gpointer ud);
 
+/**
+ * Get cached shingle keys (internal helper, exposed for HTML shingles)
+ */
+unsigned char **rspamd_shingles_get_keys_cached(const unsigned char key[16]);
+
 #ifdef __cplusplus
 }
 #endif
index d0e1adf7f33844e26a49eae0f5c423bc0de9f9ef..4b8fbf63d0df34ca5e6f10b73218c265a7311e96 100644 (file)
 #include "libserver/html/html.hxx"
 #include "libserver/url.h"
 #include "cryptobox.h"
-#include "fstring.h"
+#include "contrib/ankerl/unordered_dense.h"
+#include <vector>
+#include <string>
+#include <string_view>
+#include <algorithm>
 
 using rspamd::html::html_content;
 using rspamd::html::html_tag;
 
 /* Forward declarations for C linkage */
 extern "C" {
-static unsigned char **rspamd_shingles_get_keys_cached(const unsigned char key[16]);
 uint64_t rspamd_shingles_default_filter(uint64_t *input, gsize count,
                                                                                int shno, const unsigned char *key, gpointer ud);
+unsigned char **rspamd_shingles_get_keys_cached(const unsigned char key[16]);
 }
 
 #define SHINGLES_WINDOW 3
@@ -46,39 +50,16 @@ bucket_value(unsigned int val, const int *buckets, int nbuckets)
        return (uint8_t) nbuckets;
 }
 
-/* Helper: extract eTLD+1 from URL */
-static const char *
-extract_etld1_from_url(struct rspamd_url *url, rspamd_mempool_t *pool)
+/* Helper: extract eTLD+1 from URL (in Rspamd, tld is already eTLD+1/eSLD) */
+static std::string_view
+extract_etld1_from_url(struct rspamd_url *url)
 {
-       if (!url || !url->host || url->hostlen == 0) {
-               return NULL;
+       if (!url || url->tldlen == 0) {
+               return {};
        }
 
-       rspamd_ftok_t tld;
-       if (rspamd_url_find_tld(url->host, url->hostlen, &tld)) {
-               const char *host_start = url->host;
-               const char *tld_start = tld.begin;
-
-               /* Find start of registrable domain (before TLD) */
-               const char *p = tld_start;
-               while (p > host_start && *(p - 1) != '.') {
-                       p--;
-               }
-
-               gsize etld1_len = (url->host + url->hostlen) - p;
-               char *etld1 = rspamd_mempool_alloc(pool, etld1_len + 1);
-               memcpy(etld1, p, etld1_len);
-               etld1[etld1_len] = '\0';
-
-               /* Lowercase */
-               for (gsize i = 0; i < etld1_len; i++) {
-                       etld1[i] = g_ascii_tolower(etld1[i]);
-               }
-
-               return etld1;
-       }
-
-       return NULL;
+       /* In Rspamd, tld field already contains eTLD+1 (registrable domain) */
+       return {rspamd_url_tld_unsafe(url), url->tldlen};
 }
 
 /* Helper: check if class name contains tracking-like tokens */
@@ -105,7 +86,7 @@ static char *
 normalize_class(const char *cls, gsize len, rspamd_mempool_t *pool)
 {
        if (len == 0 || len > 64) {
-               return NULL;
+               return nullptr;
        }
 
        /* Skip if mostly digits */
@@ -116,10 +97,10 @@ normalize_class(const char *cls, gsize len, rspamd_mempool_t *pool)
                }
        }
        if (digit_count > len / 2) {
-               return NULL;
+               return nullptr;
        }
 
-       char *result = rspamd_mempool_alloc(pool, len + 1);
+       auto *result = static_cast<char *>(rspamd_mempool_alloc(pool, len + 1));
        gsize out_len = 0;
 
        for (gsize i = 0; i < len && out_len < 32; i++) {
@@ -130,19 +111,20 @@ normalize_class(const char *cls, gsize len, rspamd_mempool_t *pool)
        }
 
        result[out_len] = '\0';
-       return out_len > 0 ? result : NULL;
+       return out_len > 0 ? result : nullptr;
 }
 
 /* Helper: extract structural tokens from HTML */
-static GPtrArray *
+static void
 html_extract_structural_tokens(html_content *hc,
                                                           rspamd_mempool_t *pool,
-                                                          GPtrArray **cta_domains_out,
-                                                          GPtrArray **all_domains_out)
+                                                          std::vector<std::string> &tokens,
+                                                          std::vector<std::string_view> &cta_domains,
+                                                          std::vector<std::string_view> &all_domains)
 {
-       GPtrArray *tokens = g_ptr_array_sized_new(hc->all_tags.size());
-       GPtrArray *cta_domains = g_ptr_array_sized_new(16);
-       GPtrArray *all_domains = g_ptr_array_sized_new(32);
+       tokens.reserve(hc->all_tags.size());
+       cta_domains.reserve(16);
+       all_domains.reserve(32);
 
        for (const auto &tag_ptr: hc->all_tags) {
                struct html_tag *tag = tag_ptr.get();
@@ -152,15 +134,16 @@ html_extract_structural_tokens(html_content *hc,
                        continue;
                }
 
-               GString *token = g_string_sized_new(64);
+               std::string token;
+               token.reserve(64);
 
                /* 1. Tag name */
                const char *tag_name = rspamd_html_tag_by_id(tag->id);
                if (tag_name) {
-                       g_string_append(token, tag_name);
+                       token = tag_name;
                }
                else {
-                       g_string_append(token, "unknown");
+                       token = "unknown";
                }
 
                /* 2. Structural class (if not tracking-related) */
@@ -170,139 +153,121 @@ html_extract_structural_tokens(html_content *hc,
                        if (!is_tracking_class(class_sv.data(), class_sv.size())) {
                                char *norm_cls = normalize_class(class_sv.data(), class_sv.size(), pool);
                                if (norm_cls) {
-                                       g_string_append_printf(token, ".%s", norm_cls);
+                                       token += '.';
+                                       token += norm_cls;
                                }
                        }
                }
 
                /* 3. Domain from URL (for links/images) */
-               const char *etld1 = NULL;
-               struct rspamd_url *url = NULL;
-
                if (std::holds_alternative<rspamd_url *>(tag->extra)) {
-                       url = std::get<rspamd_url *>(tag->extra);
-                       if (url && url->host && url->hostlen > 0) {
-                               etld1 = extract_etld1_from_url(url, pool);
+                       auto *url = std::get<rspamd_url *>(tag->extra);
+                       if (url && url->tldlen > 0) {
+                               auto etld1 = extract_etld1_from_url(url);
 
-                               if (etld1) {
-                                       g_string_append_printf(token, "@%s", etld1);
+                               if (!etld1.empty()) {
+                                       token += '@';
+                                       token += etld1;
 
                                        /* Add to all_domains */
-                                       g_ptr_array_add(all_domains, (gpointer) etld1);
+                                       all_domains.push_back(etld1);
 
                                        /* Check if this is a CTA link using button weights */
                                        auto weight_it = hc->url_button_weights.find(url);
-                                       if (weight_it != hc->url_button_weights.end() && weight_it->second > 0.3) {
+                                       if (weight_it != hc->url_button_weights.end() && weight_it->second > 0.3f) {
                                                /* This URL has significant button weight -> likely CTA */
-                                               g_ptr_array_add(cta_domains, (gpointer) etld1);
+                                               cta_domains.push_back(etld1);
                                        }
                                }
                        }
                }
 
                /* Add token */
-               g_ptr_array_add(tokens, g_string_free(token, FALSE));
+               tokens.emplace_back(std::move(token));
        }
-
-       *cta_domains_out = cta_domains;
-       *all_domains_out = all_domains;
-
-       return tokens;
-}
-
-/* Helper: string comparison for qsort */
-static int
-compare_strings(gconstpointer a, gconstpointer b)
-{
-       return strcmp(*(const char **) a, *(const char **) b);
 }
 
 /* Helper: hash a sorted list of domains */
 static uint64_t
-hash_domain_list(GPtrArray *domains, const unsigned char key[16])
+hash_domain_list(std::vector<std::string_view> &domains, const unsigned char key[16])
 {
-       rspamd_cryptobox_hash_state_t st;
-       unsigned char digest[rspamd_cryptobox_HASHBYTES];
-       uint64_t result;
-
-       if (domains->len == 0) {
+       if (domains.empty()) {
                return 0;
        }
 
        /* Sort domains for consistent hashing */
-       g_ptr_array_sort(domains, compare_strings);
+       std::sort(domains.begin(), domains.end());
+
+       rspamd_cryptobox_hash_state_t st;
+       unsigned char digest[rspamd_cryptobox_HASHBYTES];
+       uint64_t result;
 
        rspamd_cryptobox_hash_init(&st, key, 16);
 
        /* Hash each unique domain */
-       const char *prev = NULL;
-       for (unsigned int i = 0; i < domains->len; i++) {
-               const char *dom = (const char *) g_ptr_array_index(domains, i);
+       std::string_view prev;
+       for (const auto &dom: domains) {
                /* Skip duplicates */
-               if (prev && strcmp(dom, prev) == 0) {
+               if (!prev.empty() && dom == prev) {
                        continue;
                }
-               rspamd_cryptobox_hash_update(&st, dom, strlen(dom));
+               rspamd_cryptobox_hash_update(&st, reinterpret_cast<const unsigned char *>(dom.data()), dom.size());
                prev = dom;
        }
 
        rspamd_cryptobox_hash_final(&st, digest);
-       memcpy(&result, digest, sizeof(result));
+       std::memcpy(&result, digest, sizeof(result));
 
        return result;
 }
 
 /* Helper: hash top-N most frequent domains */
 static uint64_t
-hash_top_domains(GPtrArray *domains, unsigned int top_n, const unsigned char key[16])
+hash_top_domains(std::vector<std::string_view> &domains, unsigned int top_n, const unsigned char key[16])
 {
-       if (domains->len == 0) {
+       if (domains.empty()) {
                return 0;
        }
 
-       /* Count domain frequencies */
-       GHashTable *freq_table = g_hash_table_new(g_str_hash, g_str_equal);
+       /* Count domain frequencies using modern C++ map */
+       ankerl::unordered_dense::map<std::string_view, unsigned int> freq_map;
+       freq_map.reserve(domains.size());
 
-       for (unsigned int i = 0; i < domains->len; i++) {
-               const char *dom = (const char *) g_ptr_array_index(domains, i);
-               gpointer count_ptr = g_hash_table_lookup(freq_table, dom);
-               unsigned int count = GPOINTER_TO_UINT(count_ptr);
-               g_hash_table_insert(freq_table, (gpointer) dom, GUINT_TO_POINTER(count + 1));
+       for (const auto &dom: domains) {
+               freq_map[dom]++;
        }
 
-       /* Extract and sort by frequency */
-       GPtrArray *sorted_domains = g_ptr_array_sized_new(g_hash_table_size(freq_table));
-
-       GHashTableIter iter;
-       gpointer key_ptr, value_ptr;
-       g_hash_table_iter_init(&iter, freq_table);
+       /* Extract domains and sort by frequency */
+       std::vector<std::pair<std::string_view, unsigned int>> sorted_domains;
+       sorted_domains.reserve(freq_map.size());
 
-       while (g_hash_table_iter_next(&iter, &key_ptr, &value_ptr)) {
-               g_ptr_array_add(sorted_domains, key_ptr);
+       for (const auto &[dom, count]: freq_map) {
+               sorted_domains.emplace_back(dom, count);
        }
 
-       /* Simple sort by frequency (using hash table lookup) */
-       g_ptr_array_sort_with_data(sorted_domains, [](gconstpointer a, gconstpointer b, gpointer user_data) -> int {
-                                                                  GHashTable *freq = (GHashTable *) user_data;
-                                                                  unsigned int freq_a = GPOINTER_TO_UINT(g_hash_table_lookup(freq, a));
-                                                                  unsigned int freq_b = GPOINTER_TO_UINT(g_hash_table_lookup(freq, b));
-                                                                  if (freq_a != freq_b) {
-                                                                          return (int) freq_b - (int) freq_a;/* Descending */
-                                                                  }
-                                                                  return strcmp((const char *) a, (const char *) b); }, freq_table);
+       /* Sort by frequency (descending), then by name (lexicographic) */
+       std::sort(sorted_domains.begin(), sorted_domains.end(),
+                         [](const auto &a, const auto &b) {
+                                 if (a.second != b.second) {
+                                         return a.second > b.second; /* Descending by frequency */
+                                 }
+                                 return a.first < b.first; /* Ascending by name */
+                         });
 
        /* Take top-N */
-       if (sorted_domains->len > top_n) {
-               g_ptr_array_set_size(sorted_domains, top_n);
+       if (sorted_domains.size() > top_n) {
+               sorted_domains.resize(top_n);
        }
 
-       /* Hash the top domains */
-       uint64_t result = hash_domain_list(sorted_domains, key);
-
-       g_ptr_array_free(sorted_domains, TRUE);
-       g_hash_table_destroy(freq_table);
+       /* Extract just the domain names for hashing */
+       std::vector<std::string_view> top_domain_names;
+       top_domain_names.reserve(sorted_domains.size());
+       for (const auto &[dom, _]: sorted_domains) {
+               top_domain_names.push_back(dom);
+       }
 
-       return result;
+       /* Hash the top domains */
+       return hash_domain_list(top_domain_names, key);
 }
 
 /* Helper: hash HTML features (bucketed) */
@@ -346,36 +311,30 @@ hash_html_features(html_content *hc, const unsigned char key[16])
 
 /* Helper: generate shingles from string tokens (like text shingles but for tokens) */
 static struct rspamd_shingle *
-generate_shingles_from_string_tokens(GPtrArray *tokens,
+generate_shingles_from_string_tokens(const std::vector<std::string> &tokens,
                                                                         const unsigned char key[16],
                                                                         rspamd_shingles_filter filter,
                                                                         gpointer filterd,
                                                                         enum rspamd_shingle_alg alg)
 {
-       struct rspamd_shingle *res;
-       uint64_t **hashes;
-       unsigned char **keys;
-       gsize hlen, ilen, beg = 0;
-       unsigned int i, j;
-       enum rspamd_cryptobox_fast_hash_type ht;
-       uint64_t val;
-
-       ilen = tokens->len;
-       if (ilen == 0) {
-               return NULL;
+       if (tokens.empty()) {
+               return nullptr;
        }
 
-       res = g_new0(struct rspamd_shingle, 1);
-       hlen = ilen > SHINGLES_WINDOW ? (ilen - SHINGLES_WINDOW + 1) : 1;
+       auto res = new rspamd_shingle;
+       gsize ilen = tokens.size();
+       gsize hlen = ilen > SHINGLES_WINDOW ? (ilen - SHINGLES_WINDOW + 1) : 1;
 
-       keys = rspamd_shingles_get_keys_cached(key);
-       hashes = g_new(uint64_t *, RSPAMD_SHINGLE_SIZE);
+       auto keys = rspamd_shingles_get_keys_cached(key);
 
-       for (i = 0; i < RSPAMD_SHINGLE_SIZE; i++) {
-               hashes[i] = g_new(uint64_t, hlen);
+       /* Allocate hash arrays using modern C++ */
+       std::vector<std::vector<uint64_t>> hashes(RSPAMD_SHINGLE_SIZE);
+       for (auto &hash_vec: hashes) {
+               hash_vec.resize(hlen);
        }
 
        /* Select hash algorithm */
+       enum rspamd_cryptobox_fast_hash_type ht;
        switch (alg) {
        case RSPAMD_SHINGLES_XXHASH:
                ht = RSPAMD_CRYPTOBOX_XXHASH64;
@@ -392,19 +351,20 @@ generate_shingles_from_string_tokens(GPtrArray *tokens,
        }
 
        /* Generate hashes using sliding window */
-       for (i = 0; i <= ilen; i++) {
+       gsize beg = 0;
+       for (gsize i = 0; i <= ilen; i++) {
                if (i - beg >= SHINGLES_WINDOW || i == ilen) {
                        /* Hash the window */
-                       for (j = 0; j < RSPAMD_SHINGLE_SIZE; j++) {
+                       for (unsigned int j = 0; j < RSPAMD_SHINGLE_SIZE; j++) {
                                uint64_t seed;
-                               memcpy(&seed, keys[j], sizeof(seed));
+                               std::memcpy(&seed, keys[j], sizeof(seed));
 
                                /* Combine hashes of tokens in window */
-                               val = 0;
-                               for (unsigned int k = beg; k < i && k < ilen; k++) {
-                                       const char *token = (const char *) g_ptr_array_index(tokens, k);
+                               uint64_t val = 0;
+                               for (gsize k = beg; k < i && k < ilen; k++) {
+                                       const auto &token = tokens[k];
                                        uint64_t token_hash = rspamd_cryptobox_fast_hash_specific(ht,
-                                                                                                                                                         token, strlen(token), seed);
+                                                                                                                                                         token.data(), token.size(), seed);
                                        val ^= token_hash >> (8 * (k - beg));
                                }
 
@@ -417,13 +377,10 @@ generate_shingles_from_string_tokens(GPtrArray *tokens,
        }
 
        /* Apply filter to get final shingles */
-       for (i = 0; i < RSPAMD_SHINGLE_SIZE; i++) {
-               res->hashes[i] = filter(hashes[i], hlen, i, key, filterd);
-               g_free(hashes[i]);
+       for (unsigned int i = 0; i < RSPAMD_SHINGLE_SIZE; i++) {
+               res->hashes[i] = filter(hashes[i].data(), hlen, i, key, filterd);
        }
 
-       g_free(hashes);
-
        return res;
 }
 
@@ -436,64 +393,56 @@ rspamd_shingles_from_html(void *html_content,
                                                  gpointer filterd,
                                                  enum rspamd_shingle_alg alg)
 {
-       struct rspamd_html_shingle *res;
-       GPtrArray *tokens = NULL, *cta_domains = NULL, *all_domains = NULL;
-       struct rspamd_shingle *struct_sgl = NULL;
-
        if (!html_content) {
-               return NULL;
+               return nullptr;
        }
 
-       html_content *hc = html_content::from_ptr(html_content);
+       auto *hc_ptr = html_content::from_ptr(html_content);
 
-       if (!hc || hc->all_tags.empty()) {
-               return NULL;
+       if (!hc_ptr || hc_ptr->all_tags.empty()) {
+               return nullptr;
        }
 
-       if (pool) {
-               res = rspamd_mempool_alloc0(pool, sizeof(*res));
-       }
-       else {
-               res = g_new0(struct rspamd_html_shingle, 1);
-       }
+       /* 1. Extract structural tokens and domain lists using modern C++ */
+       std::vector<std::string> tokens;
+       std::vector<std::string_view> cta_domains, all_domains;
+       html_extract_structural_tokens(hc_ptr, pool, tokens, cta_domains, all_domains);
 
-       /* 1. Extract structural tokens and domain lists */
-       tokens = html_extract_structural_tokens(hc, pool, &cta_domains, &all_domains);
-
-       if (tokens->len == 0) {
+       if (tokens.empty()) {
                /* Empty HTML structure */
-               g_ptr_array_free(tokens, TRUE);
-               g_ptr_array_free(cta_domains, TRUE);
-               g_ptr_array_free(all_domains, TRUE);
+               return nullptr;
+       }
 
-               if (pool == NULL) {
-                       g_free(res);
-               }
-               return NULL;
+       /* Allocate result */
+       struct rspamd_html_shingle *res;
+       if (pool) {
+               res = static_cast<rspamd_html_shingle *>(rspamd_mempool_alloc0(pool, sizeof(*res)));
+       }
+       else {
+               res = new rspamd_html_shingle{};
        }
 
        /* 2. Generate direct hash of ALL tokens (for exact matching, like text parts) */
        rspamd_cryptobox_hash_state_t st;
        rspamd_cryptobox_hash_init(&st, key, 16);
 
-       for (unsigned int i = 0; i < tokens->len; i++) {
-               const char *token = (const char *) g_ptr_array_index(tokens, i);
-               rspamd_cryptobox_hash_update(&st, token, strlen(token));
+       for (const auto &token: tokens) {
+               rspamd_cryptobox_hash_update(&st, reinterpret_cast<const unsigned char *>(token.data()), token.size());
        }
 
        rspamd_cryptobox_hash_final(&st, res->direct_hash);
 
        /* 3. Generate structure shingles from tokens (for fuzzy matching) */
-       struct_sgl = generate_shingles_from_string_tokens(tokens, key, filter, filterd, alg);
+       auto *struct_sgl = generate_shingles_from_string_tokens(tokens, key, filter, filterd, alg);
 
        if (struct_sgl) {
-               memcpy(&res->structure_shingles, struct_sgl, sizeof(struct rspamd_shingle));
-               if (pool == NULL) {
-                       g_free(struct_sgl);
+               std::memcpy(&res->structure_shingles, struct_sgl, sizeof(struct rspamd_shingle));
+               if (pool == nullptr) {
+                       delete struct_sgl;
                }
        }
        else {
-               memset(&res->structure_shingles, 0, sizeof(struct rspamd_shingle));
+               std::memset(&res->structure_shingles, 0, sizeof(struct rspamd_shingle));
        }
 
        /* 4. Generate CTA domains hash (critical for phishing detection) */
@@ -503,20 +452,12 @@ rspamd_shingles_from_html(void *html_content,
        res->all_domains_hash = hash_top_domains(all_domains, 10, key);
 
        /* 6. Generate features hash (bucketed statistics) */
-       res->features_hash = hash_html_features(hc, key);
+       res->features_hash = hash_html_features(hc_ptr, key);
 
        /* 7. Store metadata */
-       res->tags_count = hc->features.tags_count > UINT16_MAX ? UINT16_MAX : (uint16_t) hc->features.tags_count;
-       res->links_count = hc->features.links.total_links > UINT16_MAX ? UINT16_MAX : (uint16_t) hc->features.links.total_links;
-       res->dom_depth = hc->features.max_dom_depth > UINT8_MAX ? UINT8_MAX : (uint8_t) hc->features.max_dom_depth;
-
-       /* Cleanup */
-       for (unsigned int i = 0; i < tokens->len; i++) {
-               g_free(g_ptr_array_index(tokens, i));
-       }
-       g_ptr_array_free(tokens, TRUE);
-       g_ptr_array_free(cta_domains, TRUE);
-       g_ptr_array_free(all_domains, TRUE);
+       res->tags_count = std::min<unsigned int>(hc_ptr->features.tags_count, UINT16_MAX);
+       res->links_count = std::min<unsigned int>(hc_ptr->features.links.total_links, UINT16_MAX);
+       res->dom_depth = std::min<unsigned int>(hc_ptr->features.max_dom_depth, UINT8_MAX);
 
        return res;
 }