]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/daf: allow multiple argument matching in filter
authorMarek Vavruša <mvavrusa@cloudflare.com>
Tue, 1 May 2018 03:17:42 +0000 (20:17 -0700)
committerMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 7 Sep 2018 17:45:21 +0000 (10:45 -0700)
This looks like in nftables, e.g. `src { 127.0.0.1 192.168.1.1 }`

modules/daf/daf.lua
modules/daf/daf.test.lua
modules/policy/lua-aho-corasick

index 35de8e7d09481223598aa87e542df87655ff1ea8..95c4e7f480b284942b8a4f91e27465d9abff6b7d 100644 (file)
@@ -133,47 +133,86 @@ M.actions = {
 -- Filter rules per column
 M.filters = {
        -- Filter on QTYPE
-       qtype = function (g)
-               local op, val = g(), g()
-               local qtype = kres.type[val]
-               if not qtype then
-                       error(string.format('invalid query type "%s"', val))
+       qtype = function (op, arg)
+               for i, v in ipairs(arg) do
+                       arg[i] = kres.type[v]
+                       if not arg[i] then
+                               panic('invalid query type "%s"', v)
+                       end
                end
-               if op == '=' then return policy.query_type(true, {qtype})
-               else error(string.format('invalid operator "%s" on qtype', op)) end
+               if op == '=' then return policy.query_type(true, arg)
+               else panic('invalid operator "%s" on qtype', op) end
        end,
        -- Filter on QNAME (either pattern or suffix match)
-       qname = function (g)
-               local op, val = g(), todname(g())
-               if     op == '~' then return policy.pattern(true, val:sub(2)) -- Skip leading label length
-               elseif op == '=' then return policy.suffix(true, {val})
-               else error(string.format('invalid operator "%s" on qname', op)) end
+       qname = function (op, arg)
+               if op == '~' then
+                       local name = todname(arg[1])
+                       if name == nil or #arg ~= 1 then
+                               error('operator "~"" on qname must have exactly one domain name as an argument')
+                       end
+                       return policy.pattern(true, name:sub(2)) -- Skip leading label length
+               elseif op == '=' then
+                       return policy.suffix(true, policy.todnames(arg))
+               else
+                       panic('invalid operator "%s" on qname', op)
+               end
        end,
        -- Filter on NS
-       ns = function (g)
-               local op, val = g(), todname(g())
-               if op == '=' then return policy.ns_suffix(true, {val})
-               else error(string.format('invalid operator "%s" on ns', op)) end
+       ns = function (op, arg)
+               if op == '=' then return policy.ns_suffix(true, policy.todnames(arg))
+               else panic('invalid operator "%s" on ns', op) end
        end,
        -- Filter on source address
-       src = function (g)
-               local op = g()
-               if op ~= '=' then error('address supports only "=" operator') end
-               return view.rule_src(true, g())
+       src = function (op, arg)
+               if op ~= '=' or #arg ~= 1 then error('address supports only "=" operator with single argument') end
+               return view.rule_src(true, arg[1])
        end,
        -- Filter on destination address
-       dst = function (g)
-               local op = g()
-               if op ~= '=' then error('address supports only "=" operator') end
-               return view.rule_dst(true, g())
+       dst = function (op, arg)
+               if op ~= '=' or #arg ~= 1 then error('address supports only "=" operator with single argument') end
+               return view.rule_dst(true, arg[1])
        end,
 }
 
+-- Allowed operators
+local operators = {
+       ['='] = '=',
+       ['~'] = '~',
+}
+
 local function parse_filter(tok, g, prev)
-       if not tok then error(string.format('expected filter after "%s"', prev)) end
+       if not tok then panic('expected filter after "%s"', prev) end
        local filter = M.filters[tok:lower()]
-       if not filter then error(string.format('invalid filter "%s"', tok)) end
-       return filter(g)
+       if not filter then panic('invalid filter "%s"', tok) end
+       -- Parse operator (if not exists, defaults to equality like nftables)
+       -- e.g. qname = example.com
+       --      qname example.com
+       local op = g()
+       local arg
+       if not operators[op] then
+               arg = op
+               op = '='
+       else
+               arg = g()
+       end
+       if not arg then
+               panic('expected argument after filter "%s %s"', tok, op)
+       end
+       -- Parse argument table
+       -- e.g. src {192.168.1.0 127.0.0.1}
+       local res = {}
+       if arg:find('^[{]') then
+               while arg do
+                       table.insert(res, arg:match('[^{}%s]+'))
+                       if arg:find('[}]$') then
+                               break
+                       end
+                       arg = g()
+               end
+       else
+               table.insert(res, arg)
+       end
+       return filter(op, res)
 end
 
 local function parse_rule(g)
@@ -203,8 +242,12 @@ end
 
 local function parse_query(g)
        local ok, actid, filter = pcall(parse_rule, g)
-       if not ok then return nil, actid end
-       actid = actid:lower()
+       if not ok then
+               return nil, actid
+       end
+       if actid then
+               actid = actid:lower()
+       end
        if not M.actions[actid] then
                return nil, string.format('invalid action "%s"', actid)
        end
@@ -219,7 +262,7 @@ end
 -- Compile a rule described by query language
 -- The query language is modelled by iptables/nftables
 -- conj = AND | OR
--- op = IS | NOT | LIKE | IN
+-- op = = | ~
 -- filter = <key> <op> <expr>
 -- rule = <filter> | <filter> <conj> <rule>
 -- action = PASS | DENY | DROP | TC | FORWARD
@@ -234,7 +277,8 @@ local function rule_info(r)
        return {info=r.info, id=r.rule.id, active=(r.rule.suspended ~= true), count=r.rule.count}
 end
 
--- @function Remove a rule
+-- @function Parse and compile a rule
+M.compile = compile
 
 -- @function Cleanup module
 function M.deinit()
index df2674b11fd0de0aab0f81eec30c64699e2d8d56..8aeb045b55ae9dad5bafe75dda1ef4d5e88b4444 100644 (file)
@@ -83,6 +83,49 @@ local function test_builtin_rules()
        same(rcode, kres.rcode.NXDOMAIN, '0..0.ip6.arpa. returns NXDOMAIN')
 end
 
+local function get_filter(rule)
+       local _, _, filter = daf.compile(rule)
+       return filter or function () return true end
+end
+
+-- test rules parser
+local function test_parser()
+       local a_query = {stype = kres.type.A}
+       local aaaa_query = {stype = kres.type.AAAA}
+       local txt_query = {stype = kres.type.TXT}
+
+       -- invalid rules
+       nok(daf.compile('qname'), 'rejects "qname"')
+       nok(daf.compile('qname '), 'rejects "qname "')
+       nok(daf.compile('qname {'), 'rejects "qname {"')
+       nok(daf.compile('qname {A'), 'rejects "qname {A"')
+       nok(daf.compile('qname A}'), 'rejects "qname A}"')
+       nok(daf.compile('qname @ {A AAAA} deny'), 'rejects "qname @ {A AAAA} deny"')
+       nok(daf.compile('qname ~ {A AAAA} deny'), 'rejects "qname ~ {A AAAA} deny"')
+       nok(daf.compile('qname and'), 'rejects "qname and"')
+       nok(daf.compile('qname A or'), 'rejects "qname A or"')
+
+       local filters = {
+               -- test catch all
+               ['deny'] = {true, true, true},
+               -- test explicit operator '='
+               ['qtype = A deny'] = {true, nil, nil},
+               -- test implicit operator '='
+               ['qtype A deny'] = {true, nil, nil},
+               -- test multiple arguments
+               ['qtype { A TXT } deny'] = {true, true, nil},
+               ['qtype {A TXT } deny'] = {true, true, nil},
+               ['qtype {A TXT} deny'] = {true, true, nil},
+       }
+
+       for filter, e in pairs(filters) do
+               local match = get_filter(filter)
+               same(e[1], match(nil, a_query), 'matches ' .. filter .. ' (A query)')
+               same(e[2], match(nil, txt_query), 'matches ' .. filter .. ' (TXT query)')
+               same(e[3], match(nil, aaaa_query), 'matches ' .. filter .. ' (AAAA query)')
+       end
+end
+
 -- test filters running in begin phase
 local function test_actions()
        local filters = {
@@ -144,6 +187,7 @@ end
 -- plan tests
 local tests = {
        test_builtin_rules,
+       test_parser,
        test_actions,
        test_features,
 }
index 5beaa28f4ef5ec20aa0adb75b54fabae556ec96d..47df012b383a0e79b9f40c1839b1581b00ce2989 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 5beaa28f4ef5ec20aa0adb75b54fabae556ec96d
+Subproject commit 47df012b383a0e79b9f40c1839b1581b00ce2989