]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add SPF flattening tool with macro preservation
authorVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 20 Nov 2025 11:43:40 +0000 (11:43 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Thu, 20 Nov 2025 11:52:39 +0000 (11:52 +0000)
- Add new 'spf-flatten' command to dns_tool for optimizing SPF records
- Introduce RSPAMD_SPF_FLAG_MACRO_UNRESOLVED flag to preserve SPF macros
- Prevent macro expansion when sender IP is unavailable (flatten mode)
- SPF elements with macros (exists:, a:, mx:, ptr:) now preserved correctly
- Add multiple output formats: default, json, compact (BIND-style)
- Optimize IP addresses by removing default /32 and /128 masks
- Automatically split large SPF records into multiple includes
- Preserve qualifiers and 'all' mechanism in flattened records

lualib/rspamadm/dns_tool.lua
src/libserver/spf.c
src/libserver/spf.h

index b9324fc2464cacbf81bc0e7e4eda9bc9fa18af96..cabc6aa75a9719390237e936f72a1a0f85b9b7e4 100644 (file)
@@ -49,6 +49,16 @@ spf:option "-i --ip"
 spf:flag "-a --all"
    :description "Print all records"
 
+local spf_flatten = parser:command "spf-flatten"
+                          :description "Flattens SPF records by resolving all includes and optimizing"
+spf_flatten:argument "domain"
+           :description "Domain to flatten SPF for"
+           :argname("<domain>")
+spf_flatten:option "-f --format"
+           :description "Output format: default, json, compact"
+           :argname("<format>")
+           :default("default")
+
 local function printf(fmt, ...)
   if fmt then
     io.write(string.format(fmt, ...))
@@ -211,6 +221,314 @@ local function spf_handler(opts)
   rspamd_spf.resolve(task, cb)
 end
 
+local function spf_flatten_handler(opts)
+  local rspamd_spf = require "rspamd_spf"
+  local rspamd_task = require "rspamd_task"
+
+  if not opts.domain then
+    io.stderr:write('Domain is required\n')
+    os.exit(1)
+  end
+
+  local task = rspamd_task.create(rspamd_config, rspamadm_ev_base)
+  task:set_session(rspamadm_session)
+  task:set_resolver(rspamadm_dns_resolver)
+  task:set_from('smtp', { user = 'user', domain = opts.domain })
+
+  local function has_macro(str)
+    return str and str:find('%%')
+  end
+
+  local function result_to_qualifier(result)
+    if result == rspamd_spf.policy.pass then
+      return '+'
+    elseif result == rspamd_spf.policy.fail then
+      return '-'
+    elseif result == rspamd_spf.policy.soft_fail then
+      return '~'
+    elseif result == rspamd_spf.policy.neutral then
+      return '?'
+    end
+    return '+'
+  end
+
+  local function is_all_mechanism(str)
+    return str and str:match('^[+-~?]?all$')
+  end
+
+  local function is_valid_ip_net(addr)
+    if not addr or addr == '' or addr == 'any' then
+      return false
+    end
+    return addr:match('^[0-9]') or addr:match('^[a-fA-F0-9]*:')
+  end
+
+  local function has_macro_unresolved_flag(elt)
+    if not elt.flags then
+      return false
+    end
+    local RSPAMD_SPF_FLAG_MACRO_UNRESOLVED = bit.lshift(1, 14)
+    return bit.band(elt.flags, RSPAMD_SPF_FLAG_MACRO_UNRESOLVED) ~= 0
+  end
+
+  local function collect_mechanisms(elts)
+    local ipv4_nets = {}
+    local ipv6_nets = {}
+    local dynamic_mechanisms = {}
+    local other_mechanisms = {}
+    local seen_all = false
+
+    for _, elt in ipairs(elts) do
+      local processed = false
+
+      if elt.str and is_all_mechanism(elt.str) then
+        local qualifier = result_to_qualifier(elt.result)
+        local all_mech = (qualifier == '+' and 'all') or (qualifier .. 'all')
+        table.insert(other_mechanisms, all_mech)
+        seen_all = true
+        processed = true
+      elseif elt.str and (has_macro(elt.str) or elt.str:match('^[+-~?]?redirect')) then
+        table.insert(other_mechanisms, elt.str)
+        processed = true
+      elseif has_macro_unresolved_flag(elt) then
+        table.insert(other_mechanisms, elt.str)
+        processed = true
+      elseif elt.addr and not seen_all and is_valid_ip_net(elt.addr) then
+        local qualifier = result_to_qualifier(elt.result)
+        local net = elt.addr
+
+        if net:find(':') then
+          table.insert(ipv6_nets, { net = net, qual = qualifier })
+        else
+          table.insert(ipv4_nets, { net = net, qual = qualifier })
+        end
+        processed = true
+      end
+
+      if not processed and elt.str and not elt.str:match('^[+-~?]?include:') then
+        table.insert(other_mechanisms, elt.str)
+      end
+    end
+
+    return ipv4_nets, ipv6_nets, dynamic_mechanisms, other_mechanisms
+  end
+
+  local function optimize_ip_net(net, is_ipv6)
+    local default_mask = is_ipv6 and '/128' or '/32'
+    if net:sub(-#default_mask) == default_mask then
+      return net:sub(1, -#default_mask - 1)
+    end
+    return net
+  end
+
+  local function build_spf_record(ipv4_nets, ipv6_nets, dynamic_mechanisms, other_mechanisms, includes)
+    local parts = { 'v=spf1' }
+
+    if includes then
+      for _, inc in ipairs(includes) do
+        table.insert(parts, 'include:' .. inc)
+      end
+    end
+
+    for _, mech in ipairs(dynamic_mechanisms) do
+      table.insert(parts, mech)
+    end
+
+    for _, item in ipairs(ipv4_nets) do
+      local prefix = item.qual == '+' and '' or item.qual
+      local optimized_net = optimize_ip_net(item.net, false)
+      table.insert(parts, prefix .. 'ip4:' .. optimized_net)
+    end
+
+    for _, item in ipairs(ipv6_nets) do
+      local prefix = item.qual == '+' and '' or item.qual
+      local optimized_net = optimize_ip_net(item.net, true)
+      table.insert(parts, prefix .. 'ip6:' .. optimized_net)
+    end
+
+    for _, mech in ipairs(other_mechanisms) do
+      table.insert(parts, mech)
+    end
+
+    return table.concat(parts, ' ')
+  end
+
+  local function split_networks_into_chunks(ipv4_nets, ipv6_nets, base_domain, all_mechanism)
+    local max_record_length = 450
+    local chunks = {}
+    local current_chunk_v4 = {}
+    local current_chunk_v6 = {}
+    local all_v4 = {}
+    local all_v6 = {}
+
+    local all_mechs = all_mechanism and {all_mechanism} or {}
+
+    for _, item in ipairs(ipv4_nets) do
+      table.insert(all_v4, item)
+    end
+    for _, item in ipairs(ipv6_nets) do
+      table.insert(all_v6, item)
+    end
+
+    local chunk_idx = 1
+    local function finalize_chunk()
+      if #current_chunk_v4 > 0 or #current_chunk_v6 > 0 then
+        local record = build_spf_record(current_chunk_v4, current_chunk_v6, {}, all_mechs, nil)
+        table.insert(chunks, {
+          name = string.format('%d._spf.%s', chunk_idx, base_domain),
+          record = record
+        })
+        chunk_idx = chunk_idx + 1
+        current_chunk_v4 = {}
+        current_chunk_v6 = {}
+      end
+    end
+
+    for _, item in ipairs(all_v4) do
+      table.insert(current_chunk_v4, item)
+      local test_record = build_spf_record(current_chunk_v4, current_chunk_v6, {}, all_mechs, nil)
+      if #test_record > max_record_length then
+        table.remove(current_chunk_v4)
+        finalize_chunk()
+        table.insert(current_chunk_v4, item)
+      end
+    end
+
+    for _, item in ipairs(all_v6) do
+      table.insert(current_chunk_v6, item)
+      local test_record = build_spf_record(current_chunk_v4, current_chunk_v6, {}, all_mechs, nil)
+      if #test_record > max_record_length then
+        table.remove(current_chunk_v6)
+        finalize_chunk()
+        table.insert(current_chunk_v6, item)
+      end
+    end
+
+    finalize_chunk()
+    return chunks
+  end
+
+  local function cb(record, flags, err)
+    if not record then
+      printf('Cannot get SPF record: %s', err)
+      os.exit(1)
+    end
+
+    local elts = record:get_elts()
+    local ipv4_nets, ipv6_nets, dynamic_mechanisms, other_mechanisms = collect_mechanisms(elts)
+
+    local all_mechanism = nil
+    local other_without_all = {}
+    for _, mech in ipairs(other_mechanisms) do
+      if is_all_mechanism(mech) then
+        all_mechanism = mech
+      else
+        table.insert(other_without_all, mech)
+      end
+    end
+
+    local test_record = build_spf_record(ipv4_nets, ipv6_nets, dynamic_mechanisms, other_mechanisms, nil)
+    local needs_split = #test_record > 450
+
+    if opts.format == 'json' then
+      local ucl = require "ucl"
+      local result = {
+        domain = opts.domain,
+        ipv4_count = #ipv4_nets,
+        ipv6_count = #ipv6_nets,
+        dynamic_mechanisms = dynamic_mechanisms,
+        other_mechanisms = other_mechanisms,
+        needs_split = needs_split
+      }
+
+      if needs_split then
+        local chunks = split_networks_into_chunks(ipv4_nets, ipv6_nets, opts.domain, all_mechanism)
+        local include_names = {}
+        for _, chunk in ipairs(chunks) do
+          table.insert(include_names, chunk.name)
+        end
+        local main_record = build_spf_record({}, {}, dynamic_mechanisms, other_without_all, include_names)
+        if all_mechanism then
+          main_record = main_record .. ' ' .. all_mechanism
+        end
+
+        result.main_record = main_record
+        result.additional_records = {}
+        for _, chunk in ipairs(chunks) do
+          table.insert(result.additional_records, {
+            name = chunk.name,
+            value = chunk.record
+          })
+        end
+      else
+        result.record = test_record
+      end
+
+      printf('%s', ucl.to_format(result, 'json'))
+    elseif opts.format == 'compact' then
+      if needs_split then
+        local chunks = split_networks_into_chunks(ipv4_nets, ipv6_nets, opts.domain, all_mechanism)
+        local include_names = {}
+        for _, chunk in ipairs(chunks) do
+          table.insert(include_names, chunk.name)
+        end
+        local main_record = build_spf_record({}, {}, dynamic_mechanisms, other_without_all, include_names)
+        if all_mechanism then
+          main_record = main_record .. ' ' .. all_mechanism
+        end
+
+        printf('%s. IN TXT "%s"', opts.domain, main_record)
+        for _, chunk in ipairs(chunks) do
+          printf('%s. IN TXT "%s"', chunk.name, chunk.record)
+        end
+      else
+        printf('%s. IN TXT "%s"', opts.domain, test_record)
+      end
+    else
+      printf('Flattened SPF record for %s:', highlight(opts.domain))
+      printf('')
+      printf('Found %s IPv4 networks, %s IPv6 networks, %s dynamic mechanisms, %s other mechanisms',
+             highlight(tostring(#ipv4_nets)),
+             highlight(tostring(#ipv6_nets)),
+             highlight(tostring(#dynamic_mechanisms)),
+             highlight(tostring(#other_mechanisms)))
+      printf('')
+
+      if needs_split then
+        printf('%s: Needs splitting (full length: %d)', red('Result'), #test_record)
+        printf('')
+
+        local chunks = split_networks_into_chunks(ipv4_nets, ipv6_nets, opts.domain, all_mechanism)
+        local include_names = {}
+        for _, chunk in ipairs(chunks) do
+          table.insert(include_names, chunk.name)
+        end
+
+        local main_record = build_spf_record({}, {}, dynamic_mechanisms, other_without_all, include_names)
+        if all_mechanism then
+          main_record = main_record .. ' ' .. all_mechanism
+        end
+
+        printf('%s:', highlight('Main record'))
+        printf('%s', main_record)
+        printf('')
+
+        for _, chunk in ipairs(chunks) do
+          printf('%s:', highlight('TXT record for ' .. chunk.name))
+          printf('%s', chunk.record)
+          printf('')
+        end
+      else
+        printf('%s: Single record (length: %d)', green('Result'), #test_record)
+        printf('')
+        printf('%s', test_record)
+      end
+    end
+  end
+
+  rspamd_spf.resolve(task, cb)
+end
+
 local function handler(args)
   local opts = parser:parse(args)
   load_config(opts)
@@ -219,6 +537,8 @@ local function handler(args)
 
   if command == 'spf' then
     spf_handler(opts)
+  elseif command == 'spf-flatten' then
+    spf_flatten_handler(opts)
   else
     parser:error('command %s is not implemented', command)
   end
index c91cc5245559f3f0fbd247add33eab00ae5581c5..5707b006ae93499a60c928860f66e799589b07cd 100644 (file)
@@ -1329,6 +1329,14 @@ parse_spf_a(struct spf_record *rec,
 
        CHECK_REC(rec);
 
+       /* Check if element has unresolved macros */
+       if (addr->flags & RSPAMD_SPF_FLAG_MACRO_UNRESOLVED) {
+               msg_debug_spf("a element has unresolved macros: %s", addr->spf_string);
+               addr->flags |= RSPAMD_SPF_FLAG_RESOLVED;
+               spf_record_addr_set(addr, FALSE);
+               return TRUE;
+       }
+
        host = parse_spf_domain_mask(rec, addr, resolved, TRUE);
 
        if (host == NULL) {
@@ -1386,8 +1394,24 @@ parse_spf_ptr(struct spf_record *rec,
 
        CHECK_REC(rec);
 
+       /* Check if element has unresolved macros */
+       if (addr->flags & RSPAMD_SPF_FLAG_MACRO_UNRESOLVED) {
+               msg_debug_spf("ptr element has unresolved macros: %s", addr->spf_string);
+               addr->flags |= RSPAMD_SPF_FLAG_RESOLVED;
+               spf_record_addr_set(addr, FALSE);
+               return TRUE;
+       }
+
        host = parse_spf_domain_mask(rec, addr, resolved, FALSE);
 
+       if (!task->from_addr) {
+               /* PTR requires from_addr to generate reverse DNS query */
+               msg_debug_spf("ptr element requires sender IP: %s", addr->spf_string);
+               addr->flags |= RSPAMD_SPF_FLAG_RESOLVED;
+               spf_record_addr_set(addr, FALSE);
+               return TRUE;
+       }
+
        rec->dns_requests++;
        cb = rspamd_mempool_alloc0(task->task_pool, sizeof(struct spf_dns_cb));
        cb->rec = rec;
@@ -1395,6 +1419,7 @@ parse_spf_ptr(struct spf_record *rec,
        cb->initiated_by = SPF_RESOLVE_PTR;
        cb->resolved = resolved;
        cb->initiated_dns_name = rspamd_mempool_strdup(task->task_pool, host);
+
        ptr =
                rdns_generate_ptr_from_str(rspamd_inet_address_to_string(
                        task->from_addr));
@@ -1432,6 +1457,14 @@ parse_spf_mx(struct spf_record *rec,
 
        CHECK_REC(rec);
 
+       /* Check if element has unresolved macros */
+       if (addr->flags & RSPAMD_SPF_FLAG_MACRO_UNRESOLVED) {
+               msg_debug_spf("mx element has unresolved macros: %s", addr->spf_string);
+               addr->flags |= RSPAMD_SPF_FLAG_RESOLVED;
+               spf_record_addr_set(addr, FALSE);
+               return TRUE;
+       }
+
        host = parse_spf_domain_mask(rec, addr, resolved, TRUE);
 
        if (host == NULL) {
@@ -1753,6 +1786,14 @@ parse_spf_exists(struct spf_record *rec, struct spf_addr *addr)
        resolved = g_ptr_array_index(rec->resolved, rec->resolved->len - 1);
        CHECK_REC(rec);
 
+       /* Check if element has unresolved macros */
+       if (addr->flags & RSPAMD_SPF_FLAG_MACRO_UNRESOLVED) {
+               msg_debug_spf("exists element has unresolved macros: %s", addr->spf_string);
+               addr->flags |= RSPAMD_SPF_FLAG_RESOLVED;
+               spf_record_addr_set(addr, FALSE);
+               return TRUE;
+       }
+
        host = strchr(addr->spf_string, ':');
        if (host == NULL) {
                host = strchr(addr->spf_string, '=');
@@ -1765,6 +1806,7 @@ parse_spf_exists(struct spf_record *rec, struct spf_addr *addr)
        }
 
        host++;
+
        rec->dns_requests++;
 
        cb = rspamd_mempool_alloc0(task->task_pool, sizeof(struct spf_dns_cb));
@@ -1897,7 +1939,7 @@ rspamd_spf_process_substitution(const char *macro_value,
 
 static const char *
 expand_spf_macro(struct spf_record *rec, struct spf_resolved_element *resolved,
-                                const char *begin)
+                                const char *begin, gboolean *macro_unresolved)
 {
        const char *p, *macro_value = NULL;
        char *c, *new, *tmp, delim = '.';
@@ -1910,6 +1952,10 @@ expand_spf_macro(struct spf_record *rec, struct spf_resolved_element *resolved,
        g_assert(rec != NULL);
        g_assert(begin != NULL);
 
+       if (macro_unresolved) {
+               *macro_unresolved = FALSE;
+       }
+
        task = rec->task;
        p = begin;
        /* Calculate length */
@@ -2031,14 +2077,25 @@ expand_spf_macro(struct spf_record *rec, struct spf_resolved_element *resolved,
                return begin;
        }
 
-       new = rspamd_mempool_alloc(task->task_pool, len + 1);
-
        /* Reduce TTL to avoid caching of records with macros */
        if (rec->ttl != 0) {
                rec->ttl = 0;
                msg_debug_spf("disable SPF caching as there is macro expansion");
        }
 
+       /* Check if we have necessary data for macro expansion */
+       if (!task->from_addr || !rec->sender) {
+               /* Cannot expand macros without sender IP and sender, return original */
+               msg_debug_spf("SPF macro expansion skipped: missing required data (from_addr=%p, sender=%s) for %s",
+                                         task->from_addr, rec->sender ? rec->sender : "null", begin);
+               if (macro_unresolved) {
+                       *macro_unresolved = TRUE;
+               }
+               return begin;
+       }
+
+       new = rspamd_mempool_alloc(task->task_pool, len + 1);
+
        c = new;
        p = begin;
        state = 0;
@@ -2289,9 +2346,13 @@ spf_process_element(struct spf_record *rec,
                return TRUE;
        }
 
-       begin = expand_spf_macro(rec, resolved, elt);
+       gboolean macro_unresolved = FALSE;
+       begin = expand_spf_macro(rec, resolved, elt, &macro_unresolved);
        addr = rspamd_spf_new_addr(rec, resolved, begin);
        g_assert(addr != NULL);
+       if (macro_unresolved) {
+               addr->flags |= RSPAMD_SPF_FLAG_MACRO_UNRESOLVED;
+       }
        t = g_ascii_tolower(addr->spf_string[0]);
        begin = addr->spf_string;
 
index 9c133e266cb26f89d3ebabd02e9fe84b8b8ede16..327f7ebd27c68810c6c1f07ee6e4a346966956ed 100644 (file)
@@ -78,6 +78,7 @@ typedef enum spf_action_e {
 #define RSPAMD_SPF_FLAG_RESOLVED (1u << 11u)
 #define RSPAMD_SPF_FLAG_CACHED (1u << 12u)
 #define RSPAMD_SPF_FLAG_PLUSALL (1u << 13u)
+#define RSPAMD_SPF_FLAG_MACRO_UNRESOLVED (1u << 14u)
 
 /** Default SPF limits for avoiding abuse **/
 #define SPF_MAX_NESTING 10