From: Petr Špaček Date: Thu, 25 Jan 2018 09:14:28 +0000 (+0100) Subject: policy: refactor policy and view modules X-Git-Tag: v2.0.0~11^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a141e8a84a677620809d0ab25f89c071a6d4c60c;p=thirdparty%2Fknot-resolver.git policy: refactor policy and view modules I've removed couple layers of indirection to make it easier to follow. This should make it easier to extend the policy module. --- diff --git a/modules/policy/policy.lua b/modules/policy/policy.lua index 95010d4f0..cd32181d2 100644 --- a/modules/policy/policy.lua +++ b/modules/policy/policy.lua @@ -57,8 +57,15 @@ local function addr2sock(target, default_port) return sock end +-- policy functions are defined below +local policy = {} + +function policy.PASS(state, _) + return state +end + -- Mirror request elsewhere, and continue solving -local function mirror(target) +function policy.MIRROR(target) local addr, port = addr_split_port(target, 53) local sink, err = socket_client(addr, port) if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end @@ -84,7 +91,7 @@ local function set_nslist(qry, list) end -- Forward request, and solve as stub query -local function stub(target) +function policy.STUB(target) local list = {} if type(target) == 'table' then for _, v in pairs(target) do @@ -105,7 +112,7 @@ local function stub(target) end -- Forward request and all subrequests to upstream; validate answers -local function forward(target) +function policy.FORWARD(target) local list = {} if type(target) == 'table' then for _, v in pairs(target) do @@ -188,7 +195,7 @@ local function tls_forward_target_check_syntax(idx, list_entry) end -- Forward request and all subrequests to upstream over TLS; validate answers -local function tls_forward(target) +function policy.TLS_FORWARD(target) local sockaddr_c_list = {} local sockaddr_config = {} -- items: { string_addr=, auth_type= } local ca_files = {} @@ -255,7 +262,7 @@ local function tls_forward(target) end -- Rewrite records in packet -local function reroute(tbl, names) +function policy.REROUTE(tbl, names) -- Import renumbering rules local ren = require('renumber') local prefixes = {} @@ -267,7 +274,7 @@ local function reroute(tbl, names) end -- Set and clear some query flags -local function flags(opts_set, opts_clear) +function policy.FLAGS(opts_set, opts_clear) return function(_, req) local qry = req:current() ffi.C.kr_qflags_set (qry.flags, kres.mk_qflags(opts_set or {})) @@ -367,15 +374,6 @@ local function localhost_reversed(_, req) return kres.DONE end -local policy = { - -- Policies - PASS = 1, DENY = 2, DROP = 3, TC = 4, QTRACE = 5, - FORWARD = forward, TLS_FORWARD = tls_forward, - STUB = stub, REROUTE = reroute, MIRROR = mirror, FLAGS = flags, - -- Special values - ANY = 0, -} - -- All requests function policy.all(action) return function(_, _) return action end @@ -450,8 +448,9 @@ local function rpz_parse(action, path) return rules end +-- RPZ policy set -- Create RPZ from zone file -local function rpz_zonefile(action, path) +function policy.rpz(action, path) local rules = rpz_parse(action, path) collectgarbage() return function(_, query) @@ -465,9 +464,35 @@ local function rpz_zonefile(action, path) end end --- RPZ policy set -function policy.rpz(action, path) - return rpz_zonefile(action, path) +function policy.DENY(_, req) + -- Write authority information + local answer = req.answer + ffi.C.kr_pkt_make_auth_header(answer) + answer:rcode(kres.rcode.NXDOMAIN) + answer:begin(kres.section.AUTHORITY) + mkauth_soa(answer, '\7blocked\0') + return kres.DONE +end + +function policy.DROP(_, _) + return kres.FAIL +end + +function policy.TC(state, req) + local answer = req.answer + if answer.max_size ~= 65535 then + answer:tc(1) -- ^ Only UDP queries + return kres.DONE + else + return state + end +end + +function policy.QTRACE(_, req) + local qry = req:current() + req.options.TRACE = true + qry.flags.TRACE = true + return -- this allows to continue iterating over policy list end -- Evaluate packet in given rules to determine policy action @@ -478,7 +503,7 @@ function policy.evaluate(rules, req, query, state) local action = rule.cb(req, query) if action ~= nil then rule.count = rule.count + 1 - local next_state = policy.enforce(state, req, action) + local next_state = action(state, req) if next_state then -- Not a chain rule, return next_state -- stop on first match end @@ -488,35 +513,6 @@ function policy.evaluate(rules, req, query, state) return end --- Enforce policy action -function policy.enforce(state, req, action) - if action == policy.DENY then - -- Write authority information - local answer = req.answer - ffi.C.kr_pkt_make_auth_header(answer) - answer:rcode(kres.rcode.NXDOMAIN) - answer:begin(kres.section.AUTHORITY) - mkauth_soa(answer, '\7blocked\0') - return kres.DONE - elseif action == policy.DROP then - return kres.FAIL - elseif action == policy.TC then - local answer = req.answer - if answer.max_size ~= 65535 then - answer:tc(1) -- ^ Only UDP queries - return kres.DONE - end - elseif action == policy.QTRACE then - local qry = req:current() - req.options.TRACE = true - qry.flags.TRACE = true - return -- this allows to continue iterating over policy list - elseif type(action) == 'function' then - return action(state, req) - end - return state -end - -- Top-down policy list walk until we hit a match -- the caller is responsible for reordering policy list -- from most specific to least specific. diff --git a/modules/view/view.lua b/modules/view/view.lua index f2837fc59..dad097aff 100644 --- a/modules/view/view.lua +++ b/modules/view/view.lua @@ -1,5 +1,4 @@ local kres = require('kres') -local policy = require('policy') local ffi = require('ffi') local C = ffi.C @@ -91,7 +90,12 @@ view.layer = { local match_cb = evaluate(view, req) if match_cb ~= nil then local action = match_cb(req, req:current()) - return policy.enforce(state, req, action) or state + if action then + local next_state = action(state, req) + if next_state then -- Not a chain rule, + return next_state -- stop on first match + end + end end return state end