]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/daf,renumber: fixed the modules and added tests
authorMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 27 Apr 2018 06:27:33 +0000 (23:27 -0700)
committerMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 7 Sep 2018 17:45:21 +0000 (10:45 -0700)
This fixes most of the rules in DAF that were broken in 2.0 and adds tests.
It also allows policy filter to evaluate policies in the checkout layer,
before the subrequest is sent to authoritative. This is used primarily for
negotiating features between resolver and authoritatives, or disabling transports.

The policy filter can now match on:
* NS suffix - to apply policies on any zone on given nameservers
* Query type

New actions:
* REFUSE - block query with an RCODE=REFUSED, fixes #337

The DAF can now toggle features between resolver and authoritatives.

fixes #322

modules/daf/README.rst
modules/daf/daf.lua
modules/daf/daf.test.lua [new file with mode: 0644]
modules/policy/README.rst
modules/policy/policy.lua
modules/renumber/renumber.lua
modules/view/view.lua
tests/config/test.cfg

index 24ca5040d0ec9c2f528d82933768ee3f3129ad51..ea58a514bddf7c4ce5c6b524adb3ee22213542f1 100644 (file)
@@ -18,6 +18,9 @@ Firewall rules are declarative and consist of filters and actions. Filters have
     -- Block all queries with QNAME = example.com
     daf.add 'qname = example.com deny'
 
+    -- Refuse all queries with QTYPE = ANY
+    daf.add 'qtype = ANY refuse'
+
     -- Filters can be combined using AND/OR...
     -- Block all queries with QNAME match regex and coming from given subnet
     daf.add 'qname ~ %w+.example.com AND src = 192.0.2.0/24 deny'
@@ -46,6 +49,21 @@ Firewall rules are declarative and consist of filters and actions. Filters have
     -- Truncate queries based on destination IPs
     daf.add 'dst = 192.0.2.51 truncate'
 
+    -- You can set features on specific zones
+    daf.add 'qname = dnssec-failed.org features -dnssec'
+
+    -- You can also set features used between the resolver and the nameservers
+    -- Each features is prefixed with either '+' to enable, or '-' to disable
+    -- The possible features are:
+    --  -edns .. disables EDNS
+    --  -tcp .. disables TCP
+    --  -0x20 .. disabled QNAME randomization (0x20)
+    --  -minimize .. disabled QNAME minimization
+    --  -throttle .. disables throttling of unresponsive NSs
+    --  -dnssec .. disables DNSSEC
+    --  +permissive .. enabled permissive mode
+    daf.add 'ns = ns1.example.com features -tcp -0x20 +dnssec'
+
     -- Disable a rule
     daf.disable 2
     -- Enable a rule
@@ -111,11 +129,11 @@ for testing.
     {}
 
     # Create new rule
-    $ curl -s -X POST -d "src = 127.0.0.1 pass" http://localhost:8053/daf | jq .
+    $ curl -s -X POST -d "src = 127.0.0.1 refuse" http://localhost:8053/daf | jq .
     {
       "count": 0,
       "active": true,
-      "info": "src = 127.0.0.1 pass",
+      "info": "src = 127.0.0.1 refuse",
       "id": 1
     }
 
@@ -128,7 +146,7 @@ for testing.
     {
       "count": 4,
       "active": true,
-      "info": "src = 127.0.0.1 pass",
+      "info": "src = 127.0.0.1 refuse",
       "id": 1
     }
 
index 28b6342bfb0b745d5732b8979b25407c02db49d2..f772693446b5db6f7eff57f4e77b173994e66e76 100644 (file)
@@ -2,9 +2,36 @@
 if not view then modules.load('view') end
 if not policy then modules.load('policy') end
 
+-- Module declaration
+local M = {
+       rules = {},
+       phases = {},
+       actions = {},
+       filters = {},
+}
+
+-- Phases for actions (e.g. when does the action execute)
+-- The default phase is 'begin'
+M.phases = {
+       reroute = 'finish',
+       rewrite = 'finish',
+       features = 'checkout',
+}
+
 -- Actions
-local actions = {
-       pass = 1, deny = 2, drop = 3, tc = 4, truncate = 4,
+M.actions = {
+       deny = function (_)
+               return policy.DENY_MSG()
+       end,
+       drop = function (_)
+               return policy.DROP
+       end,
+       refuse = function (_)
+               return policy.REFUSE
+       end,
+       truncate = function (_)
+               return policy.TC
+       end,
        forward = function (g)
                local addrs = {}
                local tok = g()
@@ -38,10 +65,82 @@ local actions = {
                end
                return policy.REROUTE(rules, true)
        end,
+       features = function (g)
+               local set_flags, clear_flags = {}, {}
+               local allow_tcp = true
+               -- Parse feature flag toggles
+               -- Each feature can be prefixed with a symbol '+' or '-' (enable / disable)
+               -- e.g. -dnssec +tcp .. disable DNSSEC, enable TCP
+               local tok = g()
+               while tok do
+                       local sign, o = tok:match '([+-])(%S+)'
+                       local enable = (sign ~= '-')
+                       if o == '0x20' then
+                               table.insert(enable and clear_flags or set_flags, 'NO_0X20')
+                       elseif o == 'tcp' then
+                               allow_tcp = enable
+                       elseif o == 'minimize' then
+                               table.insert(enable and clear_flags or set_flags, 'NO_MINIMIZE')
+                       elseif o == 'throttle' then
+                               table.insert(enable and clear_flags or set_flags, 'NO_THROTTLE')
+                       elseif o == 'edns' then
+                               table.insert(enable and clear_flags or set_flags, 'SAFEMODE')
+                       elseif o == 'dnssec' then
+                               -- This is a positive flag, so the the tables are interposed
+                               table.insert(enable and set_flags or clear_flags, 'DNSSEC_WANT')
+                       elseif o == 'permissive' then
+                               -- This is a positive flag, so the the tables are interposed
+                               table.insert(enable and set_flags or clear_flags, 'PERMISSIVE')
+                       else
+                               error('unknown feature: ' .. o)
+                       end
+                       tok = g()
+               end
+               -- Construct the action
+               local set_flag_action = policy.FLAGS(set_flags, clear_flags)
+               return function(state, req, qry, pkt, _ --[[addr]], is_stream)
+                       -- Track whether the minimization or 0x20 flag changes
+                       local had_0x20 = qry.flags.NO_0X20
+                       local had_minimize = qry.flags.NO_MINIMIZE
+                       set_flag_action(state, req, qry)
+                       -- Block outgoing TCP if disabled
+                       if not allow_tcp and is_stream then
+                               return kres.FAIL
+                       end
+                       -- Update outgoing message
+                       if qry.flags.NO_0X20 ~= had_0x20 or
+                          qry.flags.NO_MINIMIZE ~= had_minimize then
+                               -- Update 0x20 secret to regenerate the QNAME randomization
+                               if qry.flags.NO_0X20 or qry.flags.SAFEMODE then
+                                       qry.secret = 0
+                               else
+                                       qry.secret = qry.secret + 1
+                               end
+                               local reserved = pkt.reserved
+                               local opt_rr = pkt.opt_rr
+                               qry:write(pkt)
+                               -- Restore space reservation and OPT
+                               pkt.reserved = reserved
+                               pkt.opt_rr = opt_rr
+                               pkt:begin(kres.section.ADDITIONAL)
+                       end
+                       return nil
+               end
+       end,
 }
 
 -- Filter rules per column
-local filters = {
+M.filters = {
+       -- Filter on QTYPE
+       qtype = function (g)
+               local op, val = g(), g()
+               local qtype = kres.type[val]
+               if not qtype then
+                       error(string.format('invalid query type "%s"', val))
+               end
+               if op == '=' then return policy.query_type(true, {qtype})
+               else error(string.format('invalid operator "%s" on qtype', op)) end
+       end,
        -- Filter on QNAME (either pattern or suffix match)
        qname = function (g)
                local op, val = g(), todname(g())
@@ -49,6 +148,12 @@ local filters = {
                elseif op == '=' then return policy.suffix(true, {val})
                else error(string.format('invalid operator "%s" on qname', op)) end
        end,
+       -- Filter on NS
+       ns = function (g)
+               local op, val = g(), todname(g())
+               if op == '=' then return policy.ns_suffix(true, {val})
+               else error(string.format('invalid operator "%s" on ns', op)) end
+       end,
        -- Filter on source address
        src = function (g)
                local op = g()
@@ -65,7 +170,7 @@ local filters = {
 
 local function parse_filter(tok, g, prev)
        if not tok then error(string.format('expected filter after "%s"', prev)) end
-       local filter = filters[tok:lower()]
+       local filter = M.filters[tok:lower()]
        if not filter then error(string.format('invalid filter "%s"', tok)) end
        return filter(g)
 end
@@ -73,7 +178,7 @@ end
 local function parse_rule(g)
        -- Allow action without filter
        local tok = g()
-       if not filters[tok:lower()] then
+       if not M.filters[tok:lower()] then
                return tok, nil
        end
        local f = parse_filter(tok, g)
@@ -99,9 +204,11 @@ local function parse_query(g)
        local ok, actid, filter = pcall(parse_rule, g)
        if not ok then return nil, actid end
        actid = actid:lower()
-       if not actions[actid] then return nil, string.format('invalid action "%s"', actid) end
+       if not M.actions[actid] then
+               return nil, string.format('invalid action "%s"', actid)
+       end
        -- Parse and interpret action
-       local action = actions[actid]
+       local action = M.actions[actid]
        if type(action) == 'function' then
                action = action(g)
        end
@@ -126,11 +233,6 @@ local function rule_info(r)
        return {info=r.info, id=r.rule.id, active=(r.rule.suspended ~= true), count=r.rule.count}
 end
 
--- Module declaration
-local M = {
-       rules = {}
-}
-
 -- @function Remove a rule
 
 -- @function Cleanup module
@@ -162,12 +264,9 @@ function M.add(rule)
                end
        end
        local desc = {info=rule, policy=p}
-       -- Enforce in policy module, special actions are postrules
-       if id == 'reroute' or id == 'rewrite' then
-               desc.rule = policy.add(p, true)
-       else
-               desc.rule = policy.add(p)
-       end
+       -- Enforce in policy module in given phase
+       local phase = M.phases[id] or 'begin'
+       desc.rule = policy.add(p, phase)
        table.insert(M.rules, desc)
        return desc
 end
diff --git a/modules/daf/daf.test.lua b/modules/daf/daf.test.lua
new file mode 100644 (file)
index 0000000..df2674b
--- /dev/null
@@ -0,0 +1,151 @@
+local ffi = require('ffi')
+local condition = require('cqueues.condition')
+
+-- setup resolver
+modules = { 'daf', 'hints' }
+
+-- mock values
+local mock_address = ffi.C.kr_straddr_socket('127.0.0.1', 0)
+local mock_src_address = ffi.C.kr_straddr_socket('127.0.0.2', 0)
+
+-- helper to wait for query resolution
+local function wait_resolve(qname, qtype, proto)
+       local waiting, done, cond = false, false, condition.new()
+       local rcode, answers, aa, tc, flags = kres.rcode.SERVFAIL, {}, false, false, {}
+       resolve {
+               name = qname,
+               type = qtype,
+               init = function (req)
+                       req = kres.request_t(req)
+                       req.qsource.dst_addr = mock_address
+                       req.qsource.addr = mock_src_address
+                       req.qsource.tcp = proto ~= 'udp'
+               end,
+               finish = function (answer, req)
+                       answer = kres.pkt_t(answer)
+                       aa = answer:aa()
+                       tc = answer:tc()
+                       rcode = answer:rcode()
+                       answers = answer:section(kres.section.ANSWER)
+                       local qry = req:last()
+                       if qry ~= nil then
+                               if qry.flags.NO_0X20 then flags.NO_0X20 = true end
+                               if qry.flags.NO_MINIMIZE then flags.NO_MINIMIZE = true end
+                               if qry.flags.NO_THROTTLE then flags.NO_THROTTLE = true end
+                               if qry.flags.SAFEMODE then flags.SAFEMODE = true end
+                               if qry.flags.DNSSEC_WANT then flags.DNSSEC_WANT = false end
+                               if qry.flags.PERMISSIVE then flags.PERMISSIVE = true end
+                       end
+                       -- Signal as completed
+                       if waiting then
+                               cond:signal()
+                       end
+                       done = true
+               end,
+       }
+       -- Wait if it didn't finish immediately
+       if not done then
+               waiting = true
+               cond:wait()
+       end
+       return rcode, answers, aa, tc, flags
+end
+
+local function wait_flags(qname, qtype, proto)
+       return select(5, wait_resolve(qname, qtype, proto))
+end
+
+-- test builtin rules
+local function test_builtin_rules()
+       -- rule for localhost name
+       local rcode, answers, aa = wait_resolve('localhost', kres.type.A)
+       same(rcode, kres.rcode.NOERROR, 'localhost returns NOERROR')
+       same(#answers, 1, 'localhost returns a result')
+       same(answers[1].rdata, '\127\0\0\1', 'localhost returns local address')
+       same(aa, true, 'localhost returns authoritative answer')
+
+       -- rule for reverse localhost name
+       rcode, _ = wait_resolve('127.in-addr.arpa', kres.type.PTR)
+       same(rcode, kres.rcode.NXDOMAIN, '127.in-addr.arpa returns NOERROR')
+       rcode, answers = wait_resolve('1.0.0.127.in-addr.arpa', kres.type.PTR)
+       same(rcode, kres.rcode.NOERROR, '1.0.0.127.in-addr.arpa returns NOERROR')
+       same(#answers, 1, '1.0.0.127.in-addr.arpa returns a result')
+       same(answers[1].rdata, '\9localhost\0', '1.0.0.127.in-addr.arpa returns localhost')
+
+       -- test blocking of invalid names
+       rcode, _ = wait_resolve('test', kres.type.A)
+       same(rcode, kres.rcode.NXDOMAIN, 'test. returns NXDOMAIN')
+
+       -- test blocking of private reverse zones
+       rcode, _ = wait_resolve('0.0.0.0.in-addr.arpa.', kres.type.PTR)
+       same(rcode, kres.rcode.NXDOMAIN, '0.0.0.0.in-addr.arpa. returns NXDOMAIN')
+       rcode, _ = wait_resolve('0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa.', kres.type.PTR)
+       same(rcode, kres.rcode.NXDOMAIN, '0..0.ip6.arpa. returns NXDOMAIN')
+end
+
+-- test filters running in begin phase
+local function test_actions()
+       local filters = {
+               'qtype = A',
+               'qname = localhost',
+               'dst = 127.0.0.1',
+               'src = 127.0.0.2',
+       }
+
+       local expect = {
+               deny = {rcode = kres.rcode.NXDOMAIN, aa = true },
+               drop = {rcode = kres.rcode.SERVFAIL },
+               refuse = {rcode = kres.rcode.REFUSED },
+               truncate = {rcode = kres.rcode.NOERROR, tc = true, proto = 'udp'},
+               ['reroute 127.0.0.1-192.168.1.1'] = {rcode = kres.rcode.NOERROR, aa = true, rdata = '\192\168\1\1'},
+               ['rewrite localhost A 192.168.1.1'] = {rcode = kres.rcode.NOERROR, aa = true, rdata = '\192\168\1\1'},
+       }
+
+       for _, filter in pairs(filters) do
+               for action, e in pairs(expect) do
+                       local desc = daf.add(filter .. ' ' .. action)
+                       same(type(desc), 'table', 'created a rule ' .. filter .. ' ' .. action)
+                       rcode, answer, aa, tc = wait_resolve('localhost', kres.type.A, e.proto)
+                       same(rcode, e.rcode, ' correct rcode for ' .. action)
+                       same(aa, e.aa or false, ' correct AA for ' .. action)
+                       same(tc, e.tc or false, ' correct TC for ' .. action)
+                       if e.rdata then
+                               same(answer[1].rdata, e.rdata, ' correct RDATA for ' .. action)
+                       end
+                       daf.del(desc.rule.id)
+               end
+       end
+end
+
+-- test filters setting features when talking to authoritative servers
+local function test_features()
+       local expect = {
+               -- note: the first query will be for root server which always has disabled throttling
+               ['-0x20']       = { NO_THROTTLE = true, NO_0X20 = true },
+               ['-minimize']   = { NO_THROTTLE = true, NO_MINIMIZE = true },
+               ['+throttle']   = { NO_THROTTLE = nil },
+               ['-edns']       = { NO_THROTTLE = true, SAFEMODE = true },
+               ['-dnssec']     = { NO_THROTTLE = true, DNSSEC_WANT = nil },
+               ['+permissive'] = { NO_THROTTLE = true, PERMISSIVE = true },
+       }
+       for features, e in pairs(expect) do
+               local desc = daf.add('features -tcp ' .. features)
+               -- add rule to block all outbound queries
+               local block = policy.add(policy.all(policy.DROP), 'checkout')
+               -- resolve the query and check flags set in the final query
+               same(type(desc), 'table', 'created a rule set features ' .. features)
+               local flags = wait_flags('example.com', kres.type.A)
+               daf.del(desc.rule.id)
+               policy.del(block.id)
+               same(flags, e, 'correct flag settings for ' .. features)
+       end
+end
+
+-- plan tests
+local tests = {
+       test_builtin_rules,
+       test_actions,
+       test_features,
+}
+
+return tests
\ No newline at end of file
index f635df9a2d8ff134267b0d1a779d3162c982f1c1..44192d25729552d049c1947375ec53ef251c4de0 100644 (file)
@@ -20,9 +20,13 @@ A *filter* selects which queries will be affected by specified *action*. There a
   - applies the action if QNAME matches a `regular expression <http://lua-users.org/wiki/PatternsTutorial>`_
 * ``suffix(action, table)``
   - applies the action if QNAME suffix matches one of suffixes in the table (useful for "is domain in zone" rules),
-  uses `Aho-Corasick`_ string matching algorithm `from CloudFlare <https://github.com/cloudflare/lua-aho-corasick>`_ (BSD 3-clause)
+  uses `Aho-Corasick`_ string matching algorithm `from Cloudflare <https://github.com/cloudflare/lua-aho-corasick>`_ (BSD 3-clause)
 * :any:`policy.suffix_common`
 * ``rpz(default_action, path)``
+* ``query_type(action, table)``
+  - applies the action if QTYPE matches one of the types in the table
+* ``ns_suffix(action, table)``
+  - applies the action if the NS name suffix matches one of suffixes in the table
   - implements a subset of RPZ_ in zonefile format.  See below for details: :any:`policy.rpz`.
 * custom filter function
 
index 3ef919a5d2a3d714d0ea6db047712318a8250a87..8d02a444d019fc341ea965bca7cb192299781646 100644 (file)
@@ -122,8 +122,7 @@ function policy.FORWARD(target)
        else
                table.insert(list, addr2sock(target, 53))
        end
-       return function(state, req)
-               local qry = req:current()
+       return function(state, req, qry)
                req.options.FORWARD = true
                req.options.NO_MINIMIZE = true
                qry.flags.FORWARD = true
@@ -245,8 +244,7 @@ function policy.TLS_FORWARD(target)
                end
        end
 
-       return function(state, req)
-               local qry = req:current()
+       return function(state, req, qry)
                req.options.FORWARD = true
                req.options.NO_MINIMIZE = true
                qry.flags.FORWARD = true
@@ -274,8 +272,7 @@ end
 
 -- Set and clear some query flags
 function policy.FLAGS(opts_set, opts_clear)
-       return function(_, req)
-               local qry = req:current()
+       return function(_, _, qry)
                ffi.C.kr_qflags_set  (qry.flags, kres.mk_qflags(opts_set   or {}))
                ffi.C.kr_qflags_clear(qry.flags, kres.mk_qflags(opts_clear or {}))
                return nil -- chain rule
@@ -325,8 +322,7 @@ local dname_rev4_localhost_apex = todname('127.in-addr.arpa');
 -- Answer with locally served minimal 127.in-addr.arpa domain, only having
 -- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone.
 -- TODO: much of this would better be left to the hints module (or coordinated).
-local function localhost_reversed(_, req)
-       local qry = req:current()
+local function localhost_reversed(_, req, qry)
        local answer = req.answer
 
        -- classify qry.sname:
@@ -388,7 +384,6 @@ function policy.suffix(action, zone_list)
                if match ~= nil then
                        return action
                end
-               return nil
        end
 end
 
@@ -409,7 +404,6 @@ function policy.suffix_common(action, suffix_list, common_suffix)
                                return action
                        end
                end
-               return nil
        end
 end
 
@@ -419,7 +413,36 @@ function policy.pattern(action, pattern)
                if string.find(query:name(), pattern) then
                        return action
                end
-               return nil
+       end
+end
+
+-- Filter on NS name
+function policy.ns_suffix(action, ns_list)
+       local AC = require('ahocorasick')
+       local tree = AC.create(ns_list)
+       return function(_, query)
+               -- Check if the current NS is set
+               local ns_name = query.ns.name
+               if ns_name == nil then
+                       return
+               end
+               -- Normalize and match
+               local dname = kres.dname2wire(ns_name):lower()
+               local match = AC.match(tree, dname, false)
+               if match ~= nil then
+                       return action
+               end
+       end
+end
+
+-- Filter query type
+function policy.query_type(action, type_list)
+       return function(_, query)
+               for _, qtype in ipairs(type_list) do
+                       if query.stype == qtype then
+                               return action
+                       end
+               end
        end
 end
 
@@ -495,44 +518,51 @@ end
 function policy.REFUSE(_, req)
        local answer = req.answer
        answer:rcode(kres.rcode.REFUSED)
+       answer:aa(false)
        answer:ad(false)
        return kres.DONE
 end
 
-function policy.TC(state, req)
+function policy.TC(_, req)
        local answer = req.answer
-       if answer.max_size ~= 65535 then
+       answer:ad(false)
+       answer:aa(false)
+       answer:rcode(kres.rcode.REFUSED)
+       return kres.DONE
+end
+
+function policy.TC(_, req)
+       local answer = req.answer
+       if not req.qsource.tcp then
+               answer:aa(false)
+               answer:ad(false)
                answer:tc(1) -- ^ Only UDP queries
                answer:ad(false)
                return kres.DONE
-       else
-               return state
        end
 end
 
-function policy.QTRACE(_, req)
-       local qry = req:current()
+function policy.QTRACE(_, req, qry)
        req.options.TRACE = true
        qry.flags.TRACE = true
-       return -- this allows to continue iterating over policy list
+       -- continue iterating over policy list
 end
 
 -- Evaluate packet in given rules to determine policy action
-function policy.evaluate(rules, req, query, state)
+local function evaluate(rules, req, query, state, ...)
        for i = 1, #rules do
                local rule = rules[i]
                if not rule.suspended then
-                       local action = rule.cb(req, query)
-                       if action ~= nil then
+                       local action = rule.cb(req, query, ...)
+                       if action then
                                rule.count = rule.count + 1
-                               local next_state = action(state, req)
+                               local next_state = action(state, req, query, ...)
                                if next_state then    -- Not a chain rule,
                                        return next_state -- stop on first match
                                end
                        end
                end
        end
-       return
 end
 
 -- Top-down policy list walk until we hit a match
@@ -543,27 +573,41 @@ end
 policy.layer = {
        begin = function(state, req)
                req = kres.request_t(req)
-               return policy.evaluate(policy.rules, req, req:current(), state) or
-                      policy.evaluate(policy.special_names, req, req:current(), state) or
+               return evaluate(policy.rules, req, req:current(), state) or
+                      evaluate(policy.special_names, req, req:current(), state) or
                       state
        end,
+       checkout = function (state, req, pkt, addr, stream)
+               req = kres.request_t(req)
+               pkt = kres.pkt_t(pkt)
+               return evaluate(policy.checkout_rules, req, req:current(), state, pkt, addr, stream) or state
+       end,
        finish = function(state, req)
                req = kres.request_t(req)
-               return policy.evaluate(policy.postrules, req, req:current(), state) or state
+               return evaluate(policy.finish_rules, req, req:last(), state) or state
        end
 }
 
 -- Add rule to policy list
-function policy.add(rule, postrule)
+function policy.add(rule, phase)
        -- Compatibility with 1.0.0 API
        -- it will be dropped in 1.2.0
        if rule == policy then
-               rule = postrule
-               postrule = nil
+               rule = phase
+               phase = nil
+       end
+       if phase == true then
+               phase = 'finish'
        end
        -- End of compatibility shim
        local desc = {id=getruleid(), cb=rule, count=0}
-       table.insert(postrule and policy.postrules or policy.rules, desc)
+       if phase == 'checkout' then
+               table.insert(policy.checkout_rules, desc)
+       elseif phase == 'finish' then
+               table.insert(policy.finish_rules, desc)
+       else
+               table.insert(policy.rules, desc)
+       end
        return desc
 end
 
@@ -581,8 +625,10 @@ end
 -- Delete rule from policy list
 function policy.del(id)
        if not delrule(policy.rules, id) then
-               if not delrule(policy.postrules, id) then
-                       return false
+               if not delrule(policy.checkout_rules, id) then
+                       if not delrule(policy.finish_rules, id) then
+                               return false
+                       end
                end
        end
        return true
@@ -705,7 +751,8 @@ policy.todnames(private_zones)
 
 -- @var Default rules
 policy.rules = {}
-policy.postrules = {}
+policy.checkout_rules = {}
+policy.finish_rules = {}
 policy.special_names = {
        {
                cb=policy.suffix_common(policy.DENY_MSG(
index 6eeca4b3ea0f3eae082df74dc3e48bc59f65d936..da69346f88e00ba5f72ba239954d4f132327ec6b 100644 (file)
@@ -1,9 +1,13 @@
 -- Module interface
 local ffi = require('ffi')
-local prefixes = {}
+
+-- Export module interface
+local M = {
+       prefixes = {},
+}
 
 -- Create subnet prefix rule
-local function matchprefix(subnet, addr)
+function M.prefix(subnet, addr)
        local target = kres.str2ip(addr)
        if target == nil then error('[renumber] invalid address: '..addr) end
        local addrtype = string.find(addr, ':', 1, true) and kres.type.AAAA or kres.type.A
@@ -17,7 +21,7 @@ local function matchprefix(subnet, addr)
 end
 
 -- Create name match rule
-local function matchname(name, addr)
+function M.name(name, addr)
        local target = kres.str2ip(addr)
        if target == nil then error('[renumber] invalid address: '..addr) end
        local owner = todname(name)
@@ -28,11 +32,11 @@ end
 
 -- Add subnet prefix rewrite rule
 local function add_prefix(subnet, addr)
-       table.insert(prefixes, matchprefix(subnet, addr))
+       table.insert(M.prefixes, M.prefix(subnet, addr))
 end
 
 -- Match IP against given subnet or record owner
-local function match_subnet(subnet, bitlen, addrtype, rr)
+function M.match_subnet(subnet, bitlen, addrtype, rr)
        local addr = rr.rdata
        return addrtype == rr.type and
               ((bitlen and (#addr >= bitlen / 8) and (ffi.C.kr_bitcmp(subnet, addr, bitlen) == 0)) or subnet == rr.owner)
@@ -45,7 +49,7 @@ local function renumber_record(tbl, rr)
                local prefix = tbl[i]
                -- Match record type to address family and record address to given subnet
                -- If provided, compare record owner to prefix name
-               if match_subnet(prefix[1], prefix[2], prefix[4], rr) then
+               if M.match_subnet(prefix[1], prefix[2], prefix[4], rr) then
                        -- Replace part or whole address
                        local to_copy = prefix[2] or (#prefix[3] * 8)
                        local chunks = to_copy / 8
@@ -61,19 +65,16 @@ local function renumber_record(tbl, rr)
 end
 
 -- Renumber addresses based on config
-local function rule()
+function M.rule(prefixes)
        return function (state, req)
                if state == kres.FAIL then return state end
                req = kres.request_t(req)
                local pkt = kres.pkt_t(req.answer)
                -- Only successful answers
                local records = pkt:section(kres.section.ANSWER)
-               local ancount = #records
-               if ancount == 0 then return state end
                -- Find renumber candidates
                local changed = false
-               for i = 1, ancount do
-                       local rr = records[i]
+               for i, rr in ipairs(records) do
                        if rr.type == kres.type.A or rr.type == kres.type.AAAA then
                                local new_rr = renumber_record(prefixes, rr)
                                if new_rr ~= nil then
@@ -96,14 +97,6 @@ local function rule()
        end
 end
 
--- Export module interface
-local M = {
-       prefix = matchprefix,
-       name = matchname,
-       rule = rule,
-       match_subnet = match_subnet,
-}
-
 -- Config
 function M.config (conf)
        if conf == nil then return end
@@ -115,7 +108,7 @@ end
 
 -- Layers
 M.layer = {
-       finish = rule(),
+       finish = M.rule(M.prefixes),
 }
 
 return M
index dad097aff25962d0d3f4333fb660fb4da26de19b..f7d0be7b189b8703d1041dff97ffbc7bbae97177 100644 (file)
@@ -88,10 +88,11 @@ view.layer = {
                if state == kres.FAIL then return state end
                req = kres.request_t(req)
                local match_cb = evaluate(view, req)
-               if match_cb ~= nil then
-                       local action = match_cb(req, req:current())
+               if match_cb then
+                       local query = req:current()
+                       local action = match_cb(req, query)
                        if action then
-                               local next_state = action(state, req)
+                               local next_state = action(state, req, query)
                                if next_state then    -- Not a chain rule,
                                        return next_state -- stop on first match
                                end
index a49c0f4d17249ed0d3493da0d977af783e186b94..c339320c666dc65c99f3815fa13db44367282acc 100644 (file)
@@ -1,6 +1,9 @@
 package.path = package.path .. ';' .. env.SOURCE_PATH .. '/?.lua'
 TEST_DIR = env.TEST_FILE:match('(.*/)')
 
+-- set line buffering
+io.stdout:setvbuf 'line'
+
 -- export testing module in globals
 local tapered = require('tapered.src.tapered')
 for k, v in pairs(tapered) do
@@ -10,6 +13,9 @@ end
 -- don't send priming queries etc.
 modules.unload 'priming'
 modules.unload 'ta_signal_query'
+modules.unload 'detect_time_skew'
+modules.unload 'detect_time_jump'
+modules.unload 'ta_sentinel'
 
 -- load test
 local tests = dofile(env.TEST_FILE)