]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
policy: refactor policy and view modules
authorPetr Špaček <petr.spacek@nic.cz>
Thu, 25 Jan 2018 09:14:28 +0000 (10:14 +0100)
committerPetr Špaček <petr.spacek@nic.cz>
Thu, 25 Jan 2018 16:35:39 +0000 (17:35 +0100)
I've removed couple layers of indirection to make it easier to follow.
This should make it easier to extend the policy module.

modules/policy/policy.lua
modules/view/view.lua

index 95010d4f07ed35576279ae607537ece685c5ff7a..cd32181d236828044d800ce9257ce4c241463b2b 100644 (file)
@@ -57,8 +57,15 @@ local function addr2sock(target, default_port)
        return sock
 end
 
+-- policy functions are defined below
+local policy = {}
+
+function policy.PASS(state, _)
+       return state
+end
+
 -- Mirror request elsewhere, and continue solving
-local function mirror(target)
+function policy.MIRROR(target)
        local addr, port = addr_split_port(target, 53)
        local sink, err = socket_client(addr, port)
        if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
@@ -84,7 +91,7 @@ local function set_nslist(qry, list)
 end
 
 -- Forward request, and solve as stub query
-local function stub(target)
+function policy.STUB(target)
        local list = {}
        if type(target) == 'table' then
                for _, v in pairs(target) do
@@ -105,7 +112,7 @@ local function stub(target)
 end
 
 -- Forward request and all subrequests to upstream; validate answers
-local function forward(target)
+function policy.FORWARD(target)
        local list = {}
        if type(target) == 'table' then
                for _, v in pairs(target) do
@@ -188,7 +195,7 @@ local function tls_forward_target_check_syntax(idx, list_entry)
 end
 
 -- Forward request and all subrequests to upstream over TLS; validate answers
-local function tls_forward(target)
+function policy.TLS_FORWARD(target)
        local sockaddr_c_list = {}
        local sockaddr_config = {}  -- items: { string_addr=<addr string>, auth_type=<auth type> }
        local ca_files = {}
@@ -255,7 +262,7 @@ local function tls_forward(target)
 end
 
 -- Rewrite records in packet
-local function reroute(tbl, names)
+function policy.REROUTE(tbl, names)
        -- Import renumbering rules
        local ren = require('renumber')
        local prefixes = {}
@@ -267,7 +274,7 @@ local function reroute(tbl, names)
 end
 
 -- Set and clear some query flags
-local function flags(opts_set, opts_clear)
+function policy.FLAGS(opts_set, opts_clear)
        return function(_, req)
                local qry = req:current()
                ffi.C.kr_qflags_set  (qry.flags, kres.mk_qflags(opts_set   or {}))
@@ -367,15 +374,6 @@ local function localhost_reversed(_, req)
        return kres.DONE
 end
 
-local policy = {
-       -- Policies
-       PASS = 1, DENY = 2, DROP = 3, TC = 4, QTRACE = 5,
-       FORWARD = forward, TLS_FORWARD = tls_forward,
-       STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags,
-       -- Special values
-       ANY = 0,
-}
-
 -- All requests
 function policy.all(action)
        return function(_, _) return action end
@@ -450,8 +448,9 @@ local function rpz_parse(action, path)
        return rules
 end
 
+-- RPZ policy set
 -- Create RPZ from zone file
-local function rpz_zonefile(action, path)
+function policy.rpz(action, path)
        local rules = rpz_parse(action, path)
        collectgarbage()
        return function(_, query)
@@ -465,9 +464,35 @@ local function rpz_zonefile(action, path)
        end
 end
 
--- RPZ policy set
-function policy.rpz(action, path)
-       return rpz_zonefile(action, path)
+function policy.DENY(_, req)
+       -- Write authority information
+       local answer = req.answer
+       ffi.C.kr_pkt_make_auth_header(answer)
+       answer:rcode(kres.rcode.NXDOMAIN)
+       answer:begin(kres.section.AUTHORITY)
+       mkauth_soa(answer, '\7blocked\0')
+       return kres.DONE
+end
+
+function policy.DROP(_, _)
+       return kres.FAIL
+end
+
+function policy.TC(state, req)
+       local answer = req.answer
+       if answer.max_size ~= 65535 then
+               answer:tc(1) -- ^ Only UDP queries
+               return kres.DONE
+       else
+               return state
+       end
+end
+
+function policy.QTRACE(_, req)
+       local qry = req:current()
+       req.options.TRACE = true
+       qry.flags.TRACE = true
+       return -- this allows to continue iterating over policy list
 end
 
 -- Evaluate packet in given rules to determine policy action
@@ -478,7 +503,7 @@ function policy.evaluate(rules, req, query, state)
                        local action = rule.cb(req, query)
                        if action ~= nil then
                                rule.count = rule.count + 1
-                               local next_state = policy.enforce(state, req, action)
+                               local next_state = action(state, req)
                                if next_state then    -- Not a chain rule,
                                        return next_state -- stop on first match
                                end
@@ -488,35 +513,6 @@ function policy.evaluate(rules, req, query, state)
        return
 end
 
--- Enforce policy action
-function policy.enforce(state, req, action)
-       if action == policy.DENY then
-               -- Write authority information
-               local answer = req.answer
-               ffi.C.kr_pkt_make_auth_header(answer)
-               answer:rcode(kres.rcode.NXDOMAIN)
-               answer:begin(kres.section.AUTHORITY)
-               mkauth_soa(answer, '\7blocked\0')
-               return kres.DONE
-       elseif action == policy.DROP then
-               return kres.FAIL
-       elseif action == policy.TC then
-               local answer = req.answer
-               if answer.max_size ~= 65535 then
-                       answer:tc(1) -- ^ Only UDP queries
-                       return kres.DONE
-               end
-       elseif action == policy.QTRACE then
-               local qry = req:current()
-               req.options.TRACE = true
-               qry.flags.TRACE = true
-               return -- this allows to continue iterating over policy list
-       elseif type(action) == 'function' then
-               return action(state, req)
-       end
-       return state
-end
-
 -- Top-down policy list walk until we hit a match
 -- the caller is responsible for reordering policy list
 -- from most specific to least specific.
index f2837fc595661755c40c76319c8b903fabbd7b58..dad097aff25962d0d3f4333fb660fb4da26de19b 100644 (file)
@@ -1,5 +1,4 @@
 local kres = require('kres')
-local policy = require('policy')
 local ffi = require('ffi')
 local C = ffi.C
 
@@ -91,7 +90,12 @@ view.layer = {
                local match_cb = evaluate(view, req)
                if match_cb ~= nil then
                        local action = match_cb(req, req:current())
-                       return policy.enforce(state, req, action) or state
+                       if action then
+                               local next_state = action(state, req)
+                               if next_state then    -- Not a chain rule,
+                                       return next_state -- stop on first match
+                               end
+                       end
                end
                return state
        end