local addr_buf = ffi.new('char[16]')
local str_addr_buf = ffi.new('char[46 + 1 + 6 + 1]') -- IPv6 + #port + \0
local str_addr_buf_len = ffi.sizeof(str_addr_buf)
+local sockaddr_pt = ffi.typeof('struct sockaddr *')
local sockaddr_t = ffi.typeof('struct sockaddr')
ffi.metatype( sockaddr_t, {
__index = {
})
-- Destructor for packet accepts pointer to pointer
+local knot_pkt_pt = ffi.typeof('knot_pkt_t *')
local knot_pkt_t = ffi.typeof('knot_pkt_t')
-- Helpers for reading/writing 16-bit numbers from packet wire
end,
},
})
+
-- Metatype for request
+local kr_request_pt = ffi.typeof('struct kr_request *')
local kr_request_t = ffi.typeof('struct kr_request')
ffi.metatype( kr_request_t, {
__index = {
current = function(req)
assert(ffi.istype(kr_request_t, req))
- if req.current_query == nil then return nil end
+ if req.current_query == nil then return end
return req.current_query
end,
-- Return last query on the resolution plan
resolved = function(req)
assert(ffi.istype(kr_request_t, req))
local qry = C.kr_rplan_resolved(C.kr_resolve_plan(req))
- if qry == nil then return nil end
+ if qry == nil then return end
return qry
end,
-- returns first resolved sub query for a request
first_resolved = function(req)
assert(ffi.istype(kr_request_t, req))
local rplan = C.kr_resolve_plan(req)
- if not rplan or rplan.resolved.len < 1 then return nil end
+ if not rplan or rplan.resolved.len < 1 then return end
return rplan.resolved.at[0]
end,
push = function(req, qname, qtype, qclass, flags, parent)
end,
__index = {
get = function (self, i)
- if i < 0 or i > self.len then return nil end
+ if i < 0 or i > self.len then return end
return self.at[i][0]
end,
}
end,
-- Metatypes. Beware that any pointer will be cast silently...
- pkt_t = function (udata) return ffi.cast('knot_pkt_t *', udata) end,
- request_t = function (udata) return ffi.cast('struct kr_request *', udata) end,
- sockaddr_t = function (udata) return ffi.cast('struct sockaddr *', udata) end,
+ pkt_t = function (udata) return ffi.cast(knot_pkt_pt, udata) end,
+ request_t = function (udata) return ffi.cast(kr_request_pt, udata) end,
+ sockaddr_t = function (udata) return ffi.cast(sockaddr_pt, udata) end,
+
-- Global API functions
str2dname = function(name)
if type(name) ~= 'string' then return end
str2ip = function (ip)
local family = C.kr_straddr_family(ip)
local ret = C.inet_pton(family, ip, addr_buf)
- if ret ~= 1 then return nil end
+ if ret ~= 1 then return end
return ffi.string(addr_buf, C.kr_family_len(family))
end,
context = function () return ffi.cast('struct kr_context *', __engine) end,
end
end
+-- Synthesized SOA RDATA for blocked answers
+local blocked_soa_rdata = '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48'
+local blocked_soa_rdata_mname = '\6nobody\7invalid\0' .. blocked_soa_rdata
+
+-- Synthesize SOA for blocked answers
local function mkauth_soa(answer, dname, mname)
- if mname == nil then
- mname = dname
+ if mname then
+ return answer:put(dname, 10800, answer:qclass(), kres.type.SOA, mname .. blocked_soa_rdata)
end
- return answer:put(dname, 10800, answer:qclass(), kres.type.SOA,
- mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48')
+ return answer:put(dname, 10800, answer:qclass(), kres.type.SOA, blocked_soa_rdata_mname)
end
local dname_localhost = todname('localhost.')
end
function policy.DENY_MSG(msg)
- if msg and (type(msg) ~= 'string' or #msg >= 255) then
- error('DENY_MSG: optional msg must be string shorter than 256 characters')
- end
+ local msg_wire
+ if msg then
+ if (type(msg) ~= 'string' or #msg >= 255) then
+ error('DENY_MSG: optional msg must be string shorter than 256 characters')
+ end
+ msg_wire = string.char(#msg) .. msg
+ end
- return function (_, req)
+ return function (_, req, qry)
-- Write authority information
local answer = req.answer
answer:ad(false)
answer:aa(true)
answer:rcode(kres.rcode.NXDOMAIN)
answer:begin(kres.section.AUTHORITY)
- mkauth_soa(answer, answer:qname())
- if msg then
+ mkauth_soa(answer, qry.sname)
+ if msg_wire then
answer:begin(kres.section.ADDITIONAL)
- answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT,
- string.char(#msg) .. msg)
-
+ answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, msg_wire)
end
return kres.DONE
end
end
-- Evaluate packet in given rules to determine policy action
-local function evaluate(rules, req, query, state, ...)
+local function evaluate(rules, req, query, state, pkt, addr, stream)
for i = 1, #rules do
local rule = rules[i]
if not rule.suspended then
- local action = rule.cb(req, query, ...)
+ local action = rule.cb(req, query, state, pkt, addr, stream)
if action then
rule.count = rule.count + 1
- local next_state = action(state, req, query, ...)
+ local next_state = action(state, req, query, pkt, addr, stream)
if next_state then -- Not a chain rule,
return next_state -- stop on first match
end