]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/policy: fixed NYIs (vararg function call)
authorMarek Vavruša <mvavrusa@cloudflare.com>
Thu, 31 May 2018 02:06:22 +0000 (19:06 -0700)
committerMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 7 Sep 2018 17:45:21 +0000 (10:45 -0700)
* fixed NYI with vararg calls in policy filter
* fixed NYI with nil returns (incompatible with type pointer returned otherwise)
* fixed tail call returns exceeding trace loop counts

daemon/lua/kres.lua
modules/policy/policy.lua

index fdb0d771599cf669505f6ac9484eb9f81d3f9a65..7f3feeef649ff6046ea093b48e447289b87489c7 100644 (file)
@@ -241,6 +241,7 @@ local timeval_t = ffi.typeof('struct timeval')
 local addr_buf = ffi.new('char[16]')
 local str_addr_buf = ffi.new('char[46 + 1 + 6 + 1]') -- IPv6 + #port + \0
 local str_addr_buf_len = ffi.sizeof(str_addr_buf)
+local sockaddr_pt = ffi.typeof('struct sockaddr *')
 local sockaddr_t = ffi.typeof('struct sockaddr')
 ffi.metatype( sockaddr_t, {
        __index = {
@@ -463,6 +464,7 @@ ffi.metatype( knot_rrset_t, {
 })
 
 -- Destructor for packet accepts pointer to pointer
+local knot_pkt_pt = ffi.typeof('knot_pkt_t *')
 local knot_pkt_t = ffi.typeof('knot_pkt_t')
 
 -- Helpers for reading/writing 16-bit numbers from packet wire
@@ -717,13 +719,15 @@ ffi.metatype( kr_query_t, {
                end,
        },
 })
+
 -- Metatype for request
+local kr_request_pt = ffi.typeof('struct kr_request *')
 local kr_request_t = ffi.typeof('struct kr_request')
 ffi.metatype( kr_request_t, {
        __index = {
                current = function(req)
                        assert(ffi.istype(kr_request_t, req))
-                       if req.current_query == nil then return nil end
+                       if req.current_query == nil then return end
                        return req.current_query
                end,
                -- Return last query on the resolution plan
@@ -736,14 +740,14 @@ ffi.metatype( kr_request_t, {
                resolved = function(req)
                        assert(ffi.istype(kr_request_t, req))
                        local qry = C.kr_rplan_resolved(C.kr_resolve_plan(req))
-                       if qry == nil then return nil end
+                       if qry == nil then return end
                        return qry
                end,
                -- returns first resolved sub query for a request
                first_resolved = function(req)
                        assert(ffi.istype(kr_request_t, req))
                        local rplan = C.kr_resolve_plan(req)
-                       if not rplan or rplan.resolved.len < 1 then return nil end
+                       if not rplan or rplan.resolved.len < 1 then return end
                        return rplan.resolved.at[0]
                end,
                push = function(req, qname, qtype, qclass, flags, parent)
@@ -806,7 +810,7 @@ ffi.metatype(ranked_rr_array_t, {
        end,
        __index = {
                get = function (self, i)
-                       if i < 0 or i > self.len then return nil end
+                       if i < 0 or i > self.len then return end
                        return self.at[i][0]
                end,
        }
@@ -912,9 +916,10 @@ kres = {
        end,
 
        -- Metatypes.  Beware that any pointer will be cast silently...
-       pkt_t = function (udata) return ffi.cast('knot_pkt_t *', udata) end,
-       request_t = function (udata) return ffi.cast('struct kr_request *', udata) end,
-       sockaddr_t = function (udata) return ffi.cast('struct sockaddr *', udata) end,
+       pkt_t = function (udata) return ffi.cast(knot_pkt_pt, udata) end,
+       request_t = function (udata) return ffi.cast(kr_request_pt, udata) end,
+       sockaddr_t = function (udata) return ffi.cast(sockaddr_pt, udata) end,
+
        -- Global API functions
        str2dname = function(name)
                if type(name) ~= 'string' then return end
@@ -927,7 +932,7 @@ kres = {
        str2ip = function (ip)
                local family = C.kr_straddr_family(ip)
                local ret = C.inet_pton(family, ip, addr_buf)
-               if ret ~= 1 then return nil end
+               if ret ~= 1 then return end
                return ffi.string(addr_buf, C.kr_family_len(family))
        end,
        context = function () return ffi.cast('struct kr_context *', __engine) end,
index c3b2bec7389763be7adc264ddad23bfa43627394..08241460b17bf68e401e99a324887134db708299 100644 (file)
@@ -282,12 +282,16 @@ function policy.FLAGS(opts_set, opts_clear)
        end
 end
 
+-- Synthesized SOA RDATA for blocked answers
+local blocked_soa_rdata = '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48'
+local blocked_soa_rdata_mname = '\6nobody\7invalid\0' .. blocked_soa_rdata
+
+-- Synthesize SOA for blocked answers
 local function mkauth_soa(answer, dname, mname)
-       if mname == nil then
-               mname = dname
+       if mname then
+               return answer:put(dname, 10800, answer:qclass(), kres.type.SOA, mname .. blocked_soa_rdata)
        end
-       return answer:put(dname, 10800, answer:qclass(), kres.type.SOA,
-               mname .. '\6nobody\7invalid\0\0\0\0\1\0\0\14\16\0\0\4\176\0\9\58\128\0\0\42\48')
+       return answer:put(dname, 10800, answer:qclass(), kres.type.SOA, blocked_soa_rdata_mname)
 end
 
 local dname_localhost = todname('localhost.')
@@ -492,23 +496,25 @@ function policy.rpz(action, path)
 end
 
 function policy.DENY_MSG(msg)
-       if msg and (type(msg) ~= 'string' or #msg >= 255) then
-               error('DENY_MSG: optional msg must be string shorter than 256 characters')
-        end
+       local msg_wire
+       if msg then
+               if (type(msg) ~= 'string' or #msg >= 255) then
+                       error('DENY_MSG: optional msg must be string shorter than 256 characters')
+               end
+               msg_wire = string.char(#msg) .. msg
+       end
 
-       return function (_, req)
+       return function (_, req, qry)
                -- Write authority information
                local answer = req.answer
                answer:ad(false)
                answer:aa(true)
                answer:rcode(kres.rcode.NXDOMAIN)
                answer:begin(kres.section.AUTHORITY)
-               mkauth_soa(answer, answer:qname())
-               if msg then
+               mkauth_soa(answer, qry.sname)
+               if msg_wire then
                        answer:begin(kres.section.ADDITIONAL)
-                       answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT,
-                                  string.char(#msg) .. msg)
-
+                       answer:put('\11explanation\7invalid', 10800, answer:qclass(), kres.type.TXT, msg_wire)
                end
                return kres.DONE
        end
@@ -554,14 +560,14 @@ function policy.QTRACE(_, req, qry)
 end
 
 -- Evaluate packet in given rules to determine policy action
-local function evaluate(rules, req, query, state, ...)
+local function evaluate(rules, req, query, state, pkt, addr, stream)
        for i = 1, #rules do
                local rule = rules[i]
                if not rule.suspended then
-                       local action = rule.cb(req, query, ...)
+                       local action = rule.cb(req, query, state, pkt, addr, stream)
                        if action then
                                rule.count = rule.count + 1
-                               local next_state = action(state, req, query, ...)
+                               local next_state = action(state, req, query, pkt, addr, stream)
                                if next_state then    -- Not a chain rule,
                                        return next_state -- stop on first match
                                end