]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/policy: chain rules, postrules, mirror, doc
authorMarek Vavrusa <marek@vavrusa.com>
Thu, 16 Jun 2016 17:50:41 +0000 (10:50 -0700)
committerMarek Vavrusa <marek@vavrusa.com>
Wed, 6 Jul 2016 06:33:38 +0000 (23:33 -0700)
* rules may now be chained if the rule action
  doesn't return next state. in this case, next
  matching rule will be executed. this is useful
  for snooping actions
* rules now may be paused/deleted
* implemented a new action for query mirroring to
  given destination

modules/policy/README.rst
modules/policy/policy.lua
modules/renumber/renumber.lua
modules/view/view.lua

index ca2675a4836481d7d462627189a3ba8bd8536634..b5a83e1aae0883f41de5808409b569a3bd68049b 100644 (file)
@@ -25,6 +25,7 @@ There are several defined actions:
 * ``DROP`` - terminate query resolution, returns SERVFAIL to requestor
 * ``TC`` - set TC=1 if the request came through UDP, forcing client to retry with TCP
 * ``FORWARD(ip)`` - forward query to given IP and proxy back response (stub mode)
+* ``MIRROR(ip)`` - mirror query to given IP and continue solving it (useful for partial snooping)
 * ``REROUTE({{subnet,target}, ...})`` - reroute addresses in response matching given subnet to given target, e.g. ``{'192.0.2.0/24', '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55', see :ref:`renumber module <mod-renumber>` for more information.
 
 .. note:: The module (and ``kres``) expects domain names in wire format, not textual representation. So each label in name is prefixed with its length, e.g. "example.com" equals to ``"\7example\3com"``. You can use convenience function ``todname('example.com')`` for automatic conversion.
@@ -60,6 +61,15 @@ Example configuration
        policy:add(policy.pattern(policy.FORWARD('2001:DB8::1'), '\4bad[0-9]\2cz'))
        -- Forward all queries (complete stub mode)
        policy:add(policy.all(policy.FORWARD('2001:DB8::1')))
+  -- Mirror all queries and retrieve information
+  local rule = policy:add(policy.all(policy.MIRROR('127.0.0.2')))
+  -- Print information about the rule
+  print(string.format('id: %d, matched queries: %d', rule.id, rule.count)
+  -- Reroute all addresses found in answer from 192.0.2.0/24 to 127.0.0.x
+  -- this policy is enforced on answers, therefore 'postrule'
+  local rule = policy:add(policy.REROUTE({'192.0.2.0/24', '127.0.0.0'}), true)
+  -- Delete rule that we just created
+  policy:del(rule.id)
 
 Properties
 ^^^^^^^^^^
@@ -84,12 +94,28 @@ Properties
 
    Forward query to given IP address.
 
-.. function:: policy:add(rule)
+.. envvar:: policy.MIRROR (address)
+
+   Forward query to given IP address.
+
+.. envvar:: policy.REROUTE({{subnet,target}, ...})
+
+   Reroute addresses in response matching given subnet to given target, e.g. ``{'192.0.2.0/24', '127.0.0.0'}`` will rewrite '192.0.2.55' to '127.0.0.55'.
+
+.. function:: policy:add(rule, postrule)
 
   :param rule: added rule, i.e. ``policy.pattern(policy.DENY, '[0-9]+\2cz')``
-  :param pattern: regular expression
+  :param postrule: boolean, if true the rule will be evaluated on answer instead of query
+  :return: rule description
   
-  Policy to block queries based on the QNAME regex matching.
+  Add a new policy rule that is executed either or queries or answers, depending on the ``postrule`` parameter. You can then use the returned rule description to get information and unique identifier for the rule, as well as match count.
+
+.. function:: policy:del(id)
+
+  :param id: identifier of a given rule
+  :return: boolean
+  
+  Remove a rule from policy list.
 
 .. function:: policy.all(action)
 
index 18c9ee038eab3de1262650c8d47ea18130cb8402..81df616dd798bb2d62ff425898eb3ebef98c0c50 100644 (file)
@@ -8,6 +8,49 @@ local function getruleid()
        return newid
 end
 
+-- Support for client sockets from inside policy actions
+local socket_client = function () return error("missing luasocket, can't create socket client") end
+local has_socket, socket = pcall(require, 'socket')
+if has_socket then
+       socket_client = function (host, port)
+               local s, err, status
+               if host:find(':') then
+                       s, err = socket.udp6()
+               else
+                       s, err = socket.udp()
+               end
+               if not s then
+                       return nil, err
+               end
+               status, err = s:setpeername(host, port)
+               if not status then
+                       return nil, err
+               end
+               return s
+       end
+end
+local has_ffi, ffi = pcall(require, 'ffi')
+if not has_ffi then
+       socket_client = function () return error("missing ffi library, required for this policy") end
+end
+
+-- Mirror request elsewhere, and continue solving
+local function mirror(target)
+       local addr, port = target:match '([^@]*)@?(.*)'
+       if not port or #port == 0 then port = 53 end
+       local sink, err = socket_client(addr, port)
+       if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
+       return function(state, req)
+               if state == kres.FAIL then return state end
+               req = kres.request_t(req)
+               local query = req.qsource.packet
+               if query ~= nil then
+                       sink:send(ffi.string(query.wire, query.size))
+               end
+               return -- Chain action to next
+       end
+end
+
 -- Forward request, and solve as stub query
 local function forward(target)
        local dst_ip = kres.str2ip(target)
@@ -36,7 +79,7 @@ end
 
 local policy = {
        -- Policies
-       PASS = 1, DENY = 2, DROP = 3, TC = 4, FORWARD = forward, REROUTE = reroute,
+       PASS = 1, DENY = 2, DROP = 3, TC = 4, FORWARD = forward, REROUTE = reroute, MIRROR = mirror,
        -- Special values
        ANY = 0,
 }
@@ -141,16 +184,21 @@ function policy.rpz(action, path, format)
 end
 
 -- Evaluate packet in given rules to determine policy action
-function policy.evaluate(rules, req, query)
+function policy.evaluate(rules, req, query, state)
        for i = 1, #rules do
                local rule = rules[i]
-               local action = rule.cb(req, query)
-               if action ~= nil then
-                       rule.count = rule.count + 1
-                       return action
+               if not rule.suspended then
+                       local action = rule.cb(req, query)
+                       if action ~= nil then
+                               rule.count = rule.count + 1
+                               local next_state = policy.enforce(state, req, action)
+                               if next_state then    -- Not a chain rule,
+                                       return next_state -- stop on first match
+                               end
+                       end
                end
        end
-       return policy.PASS
+       return state
 end
 
 -- Enforce policy action
@@ -177,17 +225,19 @@ function policy.enforce(state, req, action)
        return state
 end
 
--- Capture queries before processing
+-- Top-down policy list walk until we hit a match
+-- the caller is responsible for reordering policy list
+-- from most specific to least specific.
+-- Some rules may be chained, in this case they are evaluated
+-- as a dependency chain, e.g. r1,r2,r3 -> r3(r2(r1(state)))
 policy.layer = {
        begin = function(state, req)
                req = kres.request_t(req)
-               local action = policy.evaluate(policy.rules, req, req:current())
-               return policy.enforce(state, req, action)
+               return policy.evaluate(policy.rules, req, req:current(), state)
        end,
        finish = function(state, req)
                req = kres.request_t(req)
-               local action = policy.evaluate(policy.postrules, req, req:current())
-               return policy.enforce(state, req, action)
+               return policy.evaluate(policy.postrules, req, req:current(), state)
        end
 }
 
@@ -198,6 +248,27 @@ function policy.add(policy, rule, postrule)
        return desc
 end
 
+-- Remove rule from a list
+local function delrule(rules, id)
+       for i, r in ipairs(rules) do
+               if r.id == id then
+                       table.remove(rules, i)
+                       return true
+               end
+       end
+       return false
+end
+
+-- Delete rule from policy list
+function policy.del(policy, id)
+       if not delrule(policy.rules, id) then
+               if not delrule(policy.postrules, id) then
+                       return false
+               end
+       end
+       return true
+end
+
 -- Convert list of string names to domain names
 function policy.todnames(names)
        for i, v in ipairs(names) do
index c4afb7366eb565f431cbaccd7fa86a88d91b65be..e005e5648e380c8386fadb3244b2bbb21991d863 100644 (file)
@@ -86,7 +86,8 @@ local function rule(prefixes)
                                end
                        end
                end
-               if not changed then return state end
+               -- If not rewritten, chain action
+               if not changed then return end
                -- Replace section if renumbering
                local qname = pkt:qname()
                local qclass = pkt:qclass()
index 1e9c8f7cf202cb0a74e751d7c302424bd0eb64f8..dbe4d93da980b5ae62d0a578501737cce1bc12ca 100644 (file)
@@ -6,7 +6,8 @@ local C = ffi.C
 -- Module declaration
 local view = {
        key = {},
-       subnet = {},
+       src = {},
+       dst = {},
 }
 
 -- @function View based on TSIG key name.
@@ -15,12 +16,12 @@ function view.tsig(view, tsig, policy)
 end
 
 -- @function View based on source IP subnet.
-function view.addr(view, subnet, policy)
+function view.addr(view, subnet, policy, dst)
        local subnet_cd = ffi.new('char[16]')
        local family = C.kr_straddr_family(subnet)
        local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
        local t = {family, subnet_cd, bitlen, policy}
-       table.insert(view.subnet, t)
+       table.insert(dst and view.dst or view.src, t)
        return t
 end
 
@@ -34,26 +35,49 @@ local function evaluate(view, req)
        local client_key = req.qsource.key
        local match_cb = (client_key ~= nil) and view.key[client_key:owner()] or nil
        -- Search subnets otherwise
-       if match_cb == nil and req.qsource.addr ~= nil then
-               for i = 1, #view.subnet do
-                       local pair = view.subnet[i]
-                       if match_subnet(pair[1], pair[2], pair[3], req.qsource.addr) then
-                               match_cb = pair[4]
-                               break
+       if match_cb == nil then
+               if req.qsource.addr ~= nil then
+                       for i = 1, #view.src do
+                               local pair = view.src[i]
+                               if match_subnet(pair[1], pair[2], pair[3], req.qsource.addr) then
+                                       match_cb = pair[4]
+                                       break
+                               end
+                       end
+               elseif req.qsource.dst_addr ~= nil then
+                       for i = 1, #view.dst do
+                               local pair = view.dst[i]
+                               if match_subnet(pair[1], pair[2], pair[3], req.qsource.dst_addr) then
+                                       match_cb = pair[4]
+                                       break
+                               end
                        end
                end
        end
        return match_cb
 end
 
--- @function Return view policy rule
-function view.rule(action, subnet)
+-- @function Return policy based on source address
+function view.rule_src(action, subnet)
+       local subnet_cd = ffi.new('char[16]')
+       local family = C.kr_straddr_family(subnet)
+       local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
+       return function(req, _)
+               local addr = req.qsource.addr
+               if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
+                       return action
+               end
+       end
+end
+
+-- @function Return policy based on destination address
+function view.rule_dst(action, subnet)
        local subnet_cd = ffi.new('char[16]')
        local family = C.kr_straddr_family(subnet)
        local bitlen = C.kr_straddr_subnet(subnet_cd, subnet)
        return function(req, _)
-               local src_addr = req.qsource.addr
-               if src_addr ~= nil and match_subnet(family, subnet_cd, bitlen, src_addr) then
+               local addr = req.qsource.dst_addr
+               if addr ~= nil and match_subnet(family, subnet_cd, bitlen, addr) then
                        return action
                end
        end
@@ -67,7 +91,7 @@ view.layer = {
                local match_cb = evaluate(view, req)
                if match_cb ~= nil then
                        local action = match_cb(req, req:current())
-                       return policy.enforce(state, req, action)
+                       return policy.enforce(state, req, action) or state
                end
                return state
        end