]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
policy: apply filters on outgoing queries as well
authorMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 15 Jun 2018 22:28:54 +0000 (15:28 -0700)
committerMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 7 Sep 2018 17:45:21 +0000 (10:45 -0700)
This allows blocking names with intermediate CNAMEs, e.g.

```
example.com CNAME invalid
```

Before, the policies were only applied on query name,
which can be circumvented by a layer of indirection like this.

daemon/lua/sandbox.lua
modules/daf/daf.lua
modules/policy/policy.lua
modules/renumber/renumber.lua
modules/view/view.lua
modules/workarounds/workarounds.lua

index 22db6aac5ecc7abd920bea66c3571cdad8df6f9c..5002baa926bc5df76fb8b53a16b3f863b2d1b2d1 100644 (file)
@@ -317,7 +317,7 @@ end
 -- Load embedded modules
 trust_anchors = require('trust_anchors')
 modules.load('ta_signal_query')
-modules.load('policy')
+modules.load('policy < cache')
 modules.load('priming')
 modules.load('detect_time_skew')
 modules.load('detect_time_jump')
index 0bb39050603bfd08e48ec74c43297e709823b940..349100387056ca4aebf2f24607bc974f0ffe1f62 100644 (file)
@@ -1,6 +1,5 @@
 -- Load dependent modules
 if not view then modules.load('view') end
-if not policy then modules.load('policy') end
 
 -- Module declaration
 local M = {
@@ -17,6 +16,7 @@ M.phases = {
        rewrite = 'finish',
        features = 'checkout',
        nsset = 'checkout',
+       deny = {'begin', 'produce'},
 }
 
 -- Actions
index 24aeb3a2bf5cf144f9122d362f1b5a192975f1df..b1f55e16a0b9a8a4809c3b524da02d26ea560275 100644 (file)
@@ -99,8 +99,8 @@ function policy.STUB(target)
        else
                table.insert(list, addr2sock(target, 53))
        end
-       return function(state, req)
-               local qry = req:current()
+       return function(state, _, qry)
+               if not qry then return end
                -- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS
                qry.flags.STUB = true
                qry.flags.ALWAYS_CUT = false
@@ -300,7 +300,7 @@ function policy.FLAGS(opts_set, opts_clear)
                if not qry then return end
                ffi.C.kr_qflags_set  (qry.flags, kres.mk_qflags(opts_set   or {}))
                ffi.C.kr_qflags_clear(qry.flags, kres.mk_qflags(opts_clear or {}))
-               return nil -- chain rule
+               -- chain rule
        end
 end
 
@@ -319,10 +319,10 @@ end
 local dname_localhost = todname('localhost.')
 
 -- Rule for localhost. zone; see RFC6303, sec. 3
-local function localhost(_, req, qry)
-       local answer = req.answer
+local function localhost(_, _, qry, answer)
        answer:ad(false)
        answer:aa(true)
+       answer:qr(true)
 
        local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost)
 
@@ -351,9 +351,7 @@ local dname_rev4_localhost_apex = todname('127.in-addr.arpa');
 -- Answer with locally served minimal 127.in-addr.arpa domain, only having
 -- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone.
 -- TODO: much of this would better be left to the hints module (or coordinated).
-local function localhost_reversed(_, req, qry)
-       local answer = req.answer
-
+local function localhost_reversed(_, _, qry, answer)
        -- classify qry.sname:
        local is_exact   -- exact dname for localhost
        local is_apex    -- apex of a locally-served localhost zone
@@ -380,6 +378,7 @@ local function localhost_reversed(_, req, qry)
 
        answer:ad(false)
        answer:aa(true)
+       answer:qr(true)
        answer:rcode(kres.rcode.NOERROR)
        answer:begin(kres.section.ANSWER)
        if is_exact and qry.stype == kres.type.PTR then
@@ -526,11 +525,11 @@ function policy.DENY_MSG(msg)
                msg_wire = string.char(#msg) .. msg
        end
 
-       return function (_, req, qry)
+       return function (_, _, qry, answer)
                -- Write authority information
-               local answer = req.answer
                answer:ad(false)
                answer:aa(true)
+               answer:qr(true)
                answer:rcode(kres.rcode.NXDOMAIN)
                answer:begin(kres.section.AUTHORITY)
                mkauth_soa(answer, qry.sname)
@@ -538,36 +537,32 @@ function policy.DENY_MSG(msg)
                        answer:begin(kres.section.ADDITIONAL)
                        answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, msg_wire)
                end
+               -- Treat the answer as cached
+               qry.flags.CACHED = true
                return kres.DONE
        end
 end
 policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0
 
-function policy.DROP(_, _)
+function policy.DROP()
        return kres.FAIL
 end
 
-function policy.REFUSE(_, req)
-       local answer = req.answer
+function policy.REFUSE(_, _, qry, answer)
+       if not qry then return end
        answer:rcode(kres.rcode.REFUSED)
        answer:aa(false)
        answer:ad(false)
+       -- Treat the answer as cached
+       qry.flags.CACHED = true
        return kres.DONE
 end
 
-function policy.TC(_, req)
-       local answer = req.answer
-       answer:ad(false)
-       answer:aa(false)
-       answer:rcode(kres.rcode.REFUSED)
-       return kres.DONE
-end
-
-function policy.TC(_, req)
-       local answer = req.answer
+function policy.TC(_, req, _, answer)
        if not req.qsource.tcp then
                answer:aa(false)
                answer:ad(false)
+               answer:qr(true)
                answer:tc(1) -- ^ Only UDP queries
                answer:ad(false)
                return kres.DONE
@@ -606,10 +601,15 @@ end
 policy.layer = {
        begin = function(state, req)
                req = kres.request_t(req)
-               return evaluate(policy.rules, req, req:current(), state) or
-                      evaluate(policy.special_names, req, req:current(), state) or
+               return evaluate(policy.rules, req, req:current(), state, req.answer) or
+                      evaluate(policy.special_names, req, req:current(), state, req.answer) or
                       state
        end,
+       produce = function(state, req, pkt)
+               req = kres.request_t(req)
+               pkt = kres.pkt_t(pkt)
+               return evaluate(policy.produce_rules, req, req:current(), state, pkt) or state
+       end,
        checkout = function (state, req, pkt, addr, stream)
                req = kres.request_t(req)
                pkt = kres.pkt_t(pkt)
@@ -617,7 +617,7 @@ policy.layer = {
        end,
        finish = function(state, req)
                req = kres.request_t(req)
-               return evaluate(policy.finish_rules, req, req:last(), state) or state
+               return evaluate(policy.finish_rules, req, req:last(), state, req.answer) or state
        end
 }
 
@@ -634,37 +634,43 @@ function policy.add(rule, phase)
        end
        -- End of compatibility shim
        local desc = {id=getruleid(), cb=rule, count=0}
-       if phase == 'checkout' then
-               table.insert(policy.checkout_rules, desc)
-       elseif phase == 'finish' then
-               table.insert(policy.finish_rules, desc)
-       else
-               table.insert(policy.rules, desc)
+       if type(phase) ~= 'table' then
+               phase = {phase}
+       end
+       -- Allow multiple phases for the same rule
+       for _, p in ipairs(phase) do
+               if p == 'checkout' then
+                       table.insert(policy.checkout_rules, desc)
+               elseif p == 'produce' then
+                       table.insert(policy.produce_rules, desc)
+               elseif p == 'finish' then
+                       table.insert(policy.finish_rules, desc)
+               else
+                       table.insert(policy.rules, desc)
+               end
        end
        return desc
 end
 
 -- Remove rule from a list
 local function delrule(rules, id)
+       local deleted = 0
        for i, r in ipairs(rules) do
                if r.id == id then
                        table.remove(rules, i)
-                       return true
+                       deleted = deleted + 1
                end
        end
-       return false
+       return deleted
 end
 
 -- Delete rule from policy list
 function policy.del(id)
-       if not delrule(policy.rules, id) then
-               if not delrule(policy.checkout_rules, id) then
-                       if not delrule(policy.finish_rules, id) then
-                               return false
-                       end
-               end
-       end
-       return true
+       local deleted = delrule(policy.rules, id)
+               + delrule(policy.checkout_rules, id)
+               + delrule(policy.produce_rules, id)
+               + delrule(policy.finish_rules, id)
+       return deleted > 0
 end
 
 -- Convert list of string names to domain names
@@ -784,6 +790,7 @@ policy.todnames(private_zones)
 
 -- @var Default rules
 policy.rules = {}
+policy.produce_rules = {}
 policy.checkout_rules = {}
 policy.finish_rules = {}
 policy.special_names = {
index da69346f88e00ba5f72ba239954d4f132327ec6b..afea1e239b97e3164a8549910c93f2764445765a 100644 (file)
@@ -66,10 +66,8 @@ end
 
 -- Renumber addresses based on config
 function M.rule(prefixes)
-       return function (state, req)
+       return function (state, _, _, pkt)
                if state == kres.FAIL then return state end
-               req = kres.request_t(req)
-               local pkt = kres.pkt_t(req.answer)
                -- Only successful answers
                local records = pkt:section(kres.section.ANSWER)
                -- Find renumber candidates
index f7d0be7b189b8703d1041dff97ffbc7bbae97177..d71ccc777f1b1858af8dd39540e6bba9f276a3f6 100644 (file)
@@ -92,7 +92,7 @@ view.layer = {
                        local query = req:current()
                        local action = match_cb(req, query)
                        if action then
-                               local next_state = action(state, req, query)
+                               local next_state = action(state, req, query, req.answer)
                                if next_state then    -- Not a chain rule,
                                        return next_state -- stop on first match
                                end
index 97667821998ee1f490a8f3fef1e8515b1e650e71..3e5208425e8004c7f9f6466f6cd7855519de78ca 100644 (file)
@@ -1,6 +1,3 @@
--- Load dependent module
-if not policy then modules.load('policy') end
-
 local M = {} -- the module
 
 function M.config()