else
table.insert(list, addr2sock(target, 53))
end
- return function(state, req)
- local qry = req:current()
+ return function(state, _, qry)
+ if not qry then return end
-- Switch mode to stub resolver, do not track origin zone cut since it's not real authority NS
qry.flags.STUB = true
qry.flags.ALWAYS_CUT = false
if not qry then return end
ffi.C.kr_qflags_set (qry.flags, kres.mk_qflags(opts_set or {}))
ffi.C.kr_qflags_clear(qry.flags, kres.mk_qflags(opts_clear or {}))
- return nil -- chain rule
+ -- chain rule
end
end
local dname_localhost = todname('localhost.')
-- Rule for localhost. zone; see RFC6303, sec. 3
-local function localhost(_, req, qry)
- local answer = req.answer
+local function localhost(_, _, qry, answer)
answer:ad(false)
answer:aa(true)
+ answer:qr(true)
local is_exact = ffi.C.knot_dname_is_equal(qry.sname, dname_localhost)
-- Answer with locally served minimal 127.in-addr.arpa domain, only having
-- a PTR record in 1.0.0.127.in-addr.arpa, and with 1.0...0.ip6.arpa. zone.
-- TODO: much of this would better be left to the hints module (or coordinated).
-local function localhost_reversed(_, req, qry)
- local answer = req.answer
-
+local function localhost_reversed(_, _, qry, answer)
-- classify qry.sname:
local is_exact -- exact dname for localhost
local is_apex -- apex of a locally-served localhost zone
answer:ad(false)
answer:aa(true)
+ answer:qr(true)
answer:rcode(kres.rcode.NOERROR)
answer:begin(kres.section.ANSWER)
if is_exact and qry.stype == kres.type.PTR then
msg_wire = string.char(#msg) .. msg
end
- return function (_, req, qry)
+ return function (_, _, qry, answer)
-- Write authority information
- local answer = req.answer
answer:ad(false)
answer:aa(true)
+ answer:qr(true)
answer:rcode(kres.rcode.NXDOMAIN)
answer:begin(kres.section.AUTHORITY)
mkauth_soa(answer, qry.sname)
answer:begin(kres.section.ADDITIONAL)
answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, msg_wire)
end
+ -- Treat the answer as cached
+ qry.flags.CACHED = true
return kres.DONE
end
end
policy.DENY = policy.DENY_MSG() -- compatibility with < 2.0
-function policy.DROP(_, _)
+function policy.DROP()
return kres.FAIL
end
-function policy.REFUSE(_, req)
- local answer = req.answer
+function policy.REFUSE(_, _, qry, answer)
+ if not qry then return end
answer:rcode(kres.rcode.REFUSED)
answer:aa(false)
answer:ad(false)
+ -- Treat the answer as cached
+ qry.flags.CACHED = true
return kres.DONE
end
-function policy.TC(_, req)
- local answer = req.answer
- answer:ad(false)
- answer:aa(false)
- answer:rcode(kres.rcode.REFUSED)
- return kres.DONE
-end
-
-function policy.TC(_, req)
- local answer = req.answer
+function policy.TC(_, req, _, answer)
if not req.qsource.tcp then
answer:aa(false)
answer:ad(false)
+ answer:qr(true)
answer:tc(1) -- ^ Only UDP queries
answer:ad(false)
return kres.DONE
policy.layer = {
begin = function(state, req)
req = kres.request_t(req)
- return evaluate(policy.rules, req, req:current(), state) or
- evaluate(policy.special_names, req, req:current(), state) or
+ return evaluate(policy.rules, req, req:current(), state, req.answer) or
+ evaluate(policy.special_names, req, req:current(), state, req.answer) or
state
end,
+ produce = function(state, req, pkt)
+ req = kres.request_t(req)
+ pkt = kres.pkt_t(pkt)
+ return evaluate(policy.produce_rules, req, req:current(), state, pkt) or state
+ end,
checkout = function (state, req, pkt, addr, stream)
req = kres.request_t(req)
pkt = kres.pkt_t(pkt)
end,
finish = function(state, req)
req = kres.request_t(req)
- return evaluate(policy.finish_rules, req, req:last(), state) or state
+ return evaluate(policy.finish_rules, req, req:last(), state, req.answer) or state
end
}
end
-- End of compatibility shim
local desc = {id=getruleid(), cb=rule, count=0}
- if phase == 'checkout' then
- table.insert(policy.checkout_rules, desc)
- elseif phase == 'finish' then
- table.insert(policy.finish_rules, desc)
- else
- table.insert(policy.rules, desc)
+ if type(phase) ~= 'table' then
+ phase = {phase}
+ end
+ -- Allow multiple phases for the same rule
+ for _, p in ipairs(phase) do
+ if p == 'checkout' then
+ table.insert(policy.checkout_rules, desc)
+ elseif p == 'produce' then
+ table.insert(policy.produce_rules, desc)
+ elseif p == 'finish' then
+ table.insert(policy.finish_rules, desc)
+ else
+ table.insert(policy.rules, desc)
+ end
end
return desc
end
-- Remove rule from a list
local function delrule(rules, id)
+ local deleted = 0
for i, r in ipairs(rules) do
if r.id == id then
table.remove(rules, i)
- return true
+ deleted = deleted + 1
end
end
- return false
+ return deleted
end
-- Delete rule from policy list
function policy.del(id)
- if not delrule(policy.rules, id) then
- if not delrule(policy.checkout_rules, id) then
- if not delrule(policy.finish_rules, id) then
- return false
- end
- end
- end
- return true
+ local deleted = delrule(policy.rules, id)
+ + delrule(policy.checkout_rules, id)
+ + delrule(policy.produce_rules, id)
+ + delrule(policy.finish_rules, id)
+ return deleted > 0
end
-- Convert list of string names to domain names
-- @var Default rules
policy.rules = {}
+policy.produce_rules = {}
policy.checkout_rules = {}
policy.finish_rules = {}
policy.special_names = {