]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/policy: add policy.slice() function
authorTomas Krizek <tomas.krizek@nic.cz>
Mon, 15 Jul 2019 15:44:28 +0000 (17:44 +0200)
committerTomas Krizek <tomas.krizek@nic.cz>
Mon, 5 Aug 2019 12:52:45 +0000 (14:52 +0200)
modules/policy/policy.lua

index 2279a73451d8b25854d89dedec26c79022e36abb..c65642429bdd3e202acbacd9ae7d0c54232f4d02 100644 (file)
@@ -444,6 +444,69 @@ function policy.rpz(action, path, watch)
        end
 end
 
+-- Apply an action when query belongs to a slice (determined by slice_func())
+function policy.slice(slice_func, ...)
+       local actions = {...}
+       if #actions <= 0 then
+               error('[poli] at least one action must be provided to policy.slice()')
+       end
+
+       return function(_, query)
+               local index = slice_func(query, #actions)
+               return actions[index]
+       end
+end
+
+-- Initializes slicing function that randomly assigns queries to a slice based on their registrable domain
+function policy.slice_randomize_psl(seed)
+       local has_psl, psl_lib = pcall(require, 'psl')
+       if not has_psl then
+               error('[poli] lua-psl is required for policy.slice_randomize_psl()')
+       end
+       -- load psl
+       local has_latest, psl = pcall(psl_lib.latest)
+       if not has_latest then -- compatiblity with lua-psl < 0.15
+               psl = psl_lib.builtin()
+       end
+
+       if seed == nil then
+               seed = os.time() / (3600 * 24 * 7)
+       end
+       seed = math.floor(seed)  -- convert to int
+
+       return function(query, length)
+               assert(length > 0)
+
+               local domain = kres.dname2str(query:name())
+               if domain == nil then -- invalid data: auto-select first action
+                       return 1
+               end
+               if domain:len() > 1 then  --remove trailing dot
+                       domain = domain:sub(0, -2)
+               end
+
+               -- do psl lookup for registrable domain
+               local reg_domain = psl:registrable_domain(domain)
+               if reg_domain == nil then  -- fallback to unreg. domain
+                       reg_domain = psl:unregistrable_domain(domain)
+                       if reg_domain == nil then  -- shouldn't happen: safe fallback
+                               return 1
+                       end
+               end
+
+               local rand_seed = seed
+               -- create deterministic seed for pseudo-random slice assignment
+               for i = 1, #reg_domain do
+                       rand_seed = rand_seed + reg_domain:byte(i)
+               end
+
+               -- use lineral congruential generator with values from ANSI C
+               rand_seed = rand_seed % 0x80000000  -- ensure seed is positive 32b int
+               local rand = (1103515245 * rand_seed + 12345) % 0x10000
+               return 1 + rand % length
+       end
+end
+
 function policy.DENY_MSG(msg)
        if msg and (type(msg) ~= 'string' or #msg >= 255) then
                error('DENY_MSG: optional msg must be string shorter than 256 characters')