]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
view: allow multiple :tsig rules with the same key
authorVladimír Čunát <vladimir.cunat@nic.cz>
Tue, 11 Dec 2018 17:13:32 +0000 (18:13 +0100)
committerPetr Špaček <petr.spacek@nic.cz>
Thu, 13 Dec 2018 16:28:07 +0000 (17:28 +0100)
It's perhaps still confusing that there are three distinct rule chains:
policy, view:tsig and view:addr.

modules/view/view.lua

index 9dfa91902edea9f24bc0d60c370ced3d11701287..7e84ff29d158c0e41e3a542cbb91cccd7b3698bf 100644 (file)
@@ -4,14 +4,18 @@ local C = ffi.C
 
 -- Module declaration
 local view = {
-       key = {},
+       key = {}, -- map from :owner() to list of policy rules
        src = {},
        dst = {},
 }
 
 -- @function View based on TSIG key name.
-function view.tsig(_, tsig, rules)
-       view.key[tsig] = rules
+function view.tsig(_, tsig, rule)
+       if view.key[tsig] == nil then
+               view.key[tsig] = { rule }
+       else
+               table.insert(view.key[tsig], rule)
+       end
 end
 
 -- @function View based on source IP subnet.
@@ -46,16 +50,18 @@ end
 
 -- @function Try all the rules in order, until a non-chain rule gets executed.
 local function evaluate(state, req)
-       -- Try :tsig first.
+       -- Try :tsig rules first.
        local client_key = req.qsource.packet.tsig_rr
-       local match_cb = (client_key ~= nil) and view.key[client_key:owner()] or nil
-       if execute(state, req, match_cb) then return end
+       local match_cbs = (client_key ~= nil) and view.key[client_key:owner()] or {}
+       for _, match_cb in ipairs(match_cbs) do
+               if execute(state, req, match_cb) then return end
+       end
        -- Then try :addr by the source.
        if req.qsource.addr ~= nil then
                for i = 1, #view.src do
                        local pair = view.src[i]
                        if match_subnet(pair[1], pair[2], pair[3], req.qsource.addr) then
-                               match_cb = pair[4]
+                               local match_cb = pair[4]
                                if execute(state, req, match_cb) then return end
                        end
                end
@@ -64,7 +70,7 @@ local function evaluate(state, req)
                for i = 1, #view.dst do
                        local pair = view.dst[i]
                        if match_subnet(pair[1], pair[2], pair[3], req.qsource.dst_addr) then
-                               match_cb = pair[4]
+                               local match_cb = pair[4]
                                if execute(state, req, match_cb) then return end
                        end
                end