]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
modules/policy: Use module 'cqueues.socket' instead 'socket'
authorLukáš Ježek <lukas.jezek@nic.cz>
Fri, 13 Dec 2019 11:07:15 +0000 (12:07 +0100)
committerPetr Špaček <petr.spacek@nic.cz>
Fri, 20 Dec 2019 09:23:40 +0000 (10:23 +0100)
modules/policy/policy.lua
modules/policy/policy.test.lua

index af958ce6982065899f48982628ff4a25514d51eb..c5775aae72635ac64a0698f6f439333a3ce6108a 100644 (file)
@@ -12,22 +12,20 @@ local function getruleid()
 end
 
 -- Support for client sockets from inside policy actions
-local socket_client = function () return error("missing luasocket, can't create socket client") end
-local has_socket, socket = pcall(require, 'socket')
+local socket_client = function ()
+       return error("missing lua-cqueues library, can't create socket client") 
+end
+local has_socket, socket = pcall(require, 'cqueues.socket')
 if has_socket then
        socket_client = function (host, port)
                local s, err, status
-               if host:find(':') then
-                       s, err = socket.udp6()
-               else
-                       s, err = socket.udp()
-               end
-               if not s then
-                       return nil, err
-               end
-               status, err = s:setpeername(host, port)
+
+               s = socket.connect({ host = host, port = port, type = socket.SOCK_DGRAM })
+               s:setmode('bn', 'bn')
+               status, err = pcall(s.connect, s)
+
                if not status then
-                       return nil, err
+                       return status, err
                end
                return s
        end
@@ -65,13 +63,13 @@ end
 -- Mirror request elsewhere, and continue solving
 function policy.MIRROR(target)
        local addr, port = addr_split_port(target, 53)
-       local sink, err = socket_client(addr, port)
+       local sink, err = socket_client(ffi.string(addr), port)
        if not sink then panic('MIRROR target %s is not a valid: %s', target, err) end
        return function(state, req)
                if state == kres.FAIL then return state end
                local query = req.qsource.packet
                if query ~= nil then
-                       sink:send(ffi.string(query.wire, query.size))
+                       sink:send(ffi.string(query.wire, query.size), 1, tonumber(query.size))
                end
                return -- Chain action to next
        end
index d6e4d1d107167402f1777f406dc068f220dc597a..e006bbf6fa6e458372010cf1eec562f1b1c7ba1e 100644 (file)
@@ -1,6 +1,12 @@
 -- setup resolver
 -- policy module should be loaded by default, do not load it explicitly
 
+-- do not attempt to contact outside world, operate only on cache
+net.ipv4 = false
+net.ipv6 = false
+-- do not listen, test is driven by config code
+env.KRESD_NO_LISTEN = true
+
 -- test for default configuration
 local function test_tls_forward()
        boom(policy.TLS_FORWARD, {}, 'TLS_FORWARD without arguments')
@@ -61,8 +67,77 @@ local function test_slice()
        ok(policy.slice, {function() end, policy.FORWARD, policy.FORWARD})
 end
 
+local function mirror_parser(srv, cv, nqueries)
+       local ffi = require('ffi')
+       local test_end = 0
+       local TIMEOUT = 5  -- seconds
+
+       while true do
+               local input = srv:xread('*a', 'bn', TIMEOUT)
+               if not input then
+                       cv:signal()
+                       return false, 'mirror: timeout'
+               end
+               --print(#input, input)
+               -- convert query to knot_pkt_t
+               local wire = ffi.cast("void *", input)
+               local pkt = ffi.gc(ffi.C.knot_pkt_new(wire, #input, nil), ffi.C.knot_pkt_free)
+               if not pkt then
+                       cv:signal()
+                       return false, 'mirror: packet allocation error'
+               end
+
+               local result = ffi.C.knot_pkt_parse(pkt, 0)
+               if result ~= 0 then
+                       cv:signal()
+                       return false, 'mirror: packet parse error'
+               end
+               --print(pkt)
+               test_end = test_end + 1
+
+               if test_end == nqueries then
+                       cv:signal()
+                       return true, 'packet mirror pass'
+               end
+
+       end
+end
+
+local function test_mirror()
+       local socket = require("cqueues.socket")
+       local cond = require("cqueues.condition")
+       local cv = cond.new()
+       local queries = {}
+       local srv = socket.listen({
+               host = "127.0.0.1",
+               port = 36659,
+               type = socket.SOCK_DGRAM,
+       })
+       -- binary mode, no buffering
+       srv:setmode('bn', 'bn')
+
+       queries["bla.mujtest.cz."] = kres.type.AAAA
+       queries["bla.mujtest2.cz."] = kres.type.AAAA
+
+       -- UDP server for test
+       worker.bg_worker.cq:wrap(function()
+               local err, msg = mirror_parser(srv, cv, kr_table_len(queries))
+
+               ok(err, msg)
+       end)
+
+       policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest.cz.'})))
+       policy.add(policy.suffix(policy.MIRROR('127.0.0.1@36659'), policy.todnames({'mujtest2.cz.'})))
+
+       for name, rtype in pairs(queries) do
+               resolve(name, rtype)
+       end
+
+       cv:wait()
+end
 
 return {
        test_tls_forward,
+       test_mirror,
        test_slice,
 }