]> git.ipfire.org Git - thirdparty/knot-resolver.git/commitdiff
added more functions to packet, added tests
authorMarek Vavruša <mvavrusa@cloudflare.com>
Fri, 29 Dec 2017 21:01:56 +0000 (13:01 -0800)
committerPetr Špaček <petr.spacek@nic.cz>
Thu, 4 Jan 2018 10:04:41 +0000 (11:04 +0100)
daemon/lua/kres-gen.lua
daemon/lua/kres-gen.sh
daemon/lua/kres.lua
tests/config/basic_test.lua

index f3ba3b71cbdb228a5d4428fcb74da3ecad334bf5..ba31e5179fd6ce1b632c718b3ce88975a5651abd 100644 (file)
@@ -205,6 +205,7 @@ struct kr_context {
        struct kr_zonecut root_hints;
        char _stub[];
 };
+const char *knot_strerror(int code);
 knot_dname_t *knot_dname_from_str(uint8_t *, const char *, size_t);
 _Bool knot_dname_is_equal(const knot_dname_t *, const knot_dname_t *);
 _Bool knot_dname_is_sub(const knot_dname_t *, const knot_dname_t *);
@@ -229,6 +230,9 @@ int knot_pkt_begin(knot_pkt_t *, knot_section_t);
 int knot_pkt_put_question(knot_pkt_t *, const knot_dname_t *, uint16_t, uint16_t);
 const knot_rrset_t *knot_pkt_rr(const knot_pktsection_t *, uint16_t);
 const knot_pktsection_t *knot_pkt_section(const knot_pkt_t *, knot_section_t);
+knot_pkt_t *knot_pkt_new(void *wire, uint16_t len, knot_mm_t *mm);
+void knot_pkt_free(knot_pkt_t **pkt);
+int knot_pkt_parse(knot_pkt_t *pkt, unsigned flags);
 struct kr_rplan *kr_resolve_plan(struct kr_request *);
 knot_mm_t *kr_resolve_pool(struct kr_request *);
 struct kr_query *kr_rplan_push(struct kr_rplan *, struct kr_query *, const knot_dname_t *, uint16_t, uint16_t);
index 91f1618a2845d2e5f5fc43d4547d77ea1cea91b2..432e49e8c1f07aa585fbc83876aa6ecb45c9049f 100755 (executable)
@@ -79,6 +79,8 @@ printf "\tchar _stub[];\n};\n"
 
 ## libknot API
 ./scripts/gen-cdefs.sh libknot functions <<-EOF
+# Utils
+       knot_strerror
 # Domain names
        knot_dname_from_str
        knot_dname_is_equal
@@ -106,6 +108,9 @@ printf "\tchar _stub[];\n};\n"
        knot_pkt_put_question
        knot_pkt_rr
        knot_pkt_section
+       knot_pkt_new
+       knot_pkt_free
+       knot_pkt_parse
 EOF
 
 ## libkres API
index 6c195b9cc1db3b6dae94b236e693bb69c9f257ae..205f10ba1ed36ffead1e5785ee9eff6c94d8ef56 100644 (file)
@@ -10,6 +10,24 @@ local band = bit.band
 local C = ffi.C
 local knot = ffi.load(libknot_SONAME)
 
+-- Inverse table
+local function itable(t, tolower)
+       local it = {}
+       for k,v in pairs(t) do it[v] = tolower and string.lower(k) or k end
+       return it
+end
+
+-- Byte order conversions
+local function htonl(x) return x end
+local htons = htonl
+if ffi.abi('le') then
+       htonl = bit.bswap
+       function htons(x) return bit.rshift(htonl(x), 16) end
+end
+
+-- Basic types
+local u16_p = ffi.typeof('uint16_t *')
+
 -- Various declarations that are very stable.
 ffi.cdef[[
 /*
@@ -37,6 +55,11 @@ int inet_pton(int af, const char *src, void *dst);
 
 require('kres-gen')
 
+-- Convert libknot error strings
+local function knot_strerror(r)
+       return ffi.string(knot.knot_strerror(r))
+end
+
 -- Constant tables
 local const_class = {
        IN         =   1,
@@ -136,6 +159,13 @@ local const_section = {
        AUTHORITY  = 1,
        ADDITIONAL = 2,
 }
+local const_opcode = {
+       QUERY      = 0,
+       IQUERY     = 1,
+       STATUS     = 2,
+       NOTIFY     = 4,
+       UPDATE     = 5,
+}
 local const_rcode = {
        NOERROR    =  0,
        FORMERR    =  1,
@@ -152,6 +182,13 @@ local const_rcode = {
        BADCOOKIE  = 23,
 }
 
+-- Constant tables
+local const_class_str = itable(const_class)
+local const_type_str = itable(const_type)
+local const_rcode_str = itable(const_rcode)
+local const_opcode_str = itable(const_opcode)
+local const_section_str = itable(const_section)
+
 -- Metatype for RR types to allow anonymous types
 setmetatable(const_type, {
        __index = function (t, k)
@@ -178,6 +215,17 @@ ffi.metatype( sockaddr_t, {
        }
 })
 
+-- Pretty print for domain name
+local function dname2str(dname)
+       return ffi.string(ffi.gc(C.knot_dname_to_str(nil, dname, 0), C.free))
+end
+
+-- Convert dname pointer to wireformat string
+local function dname2wire(name)
+       if name == nil then return end
+       return ffi.string(name, knot.knot_dname_size(name))
+end
+
 -- Metatype for RR set.  Beware, the indexing is 0-based (rdata, get, tostring).
 local rrset_buflen = (64 + 1) * 1024
 local rrset_buf = ffi.new('char[?]', rrset_buflen)
@@ -186,7 +234,7 @@ ffi.metatype( knot_rrset_t, {
        -- beware: `owner` and `rdata` are typed as a plain lua strings
        --         and not the real types they represent.
        __index = {
-               owner = function(rr) return ffi.string(rr._owner, knot.knot_dname_size(rr._owner)) end,
+               owner = function(rr) return dname2wire(rr._owner) end,
                ttl = function(rr) return tonumber(knot.knot_rrset_ttl(rr)) end,
                rdata = function(rr, i)
                        local rdata = knot.knot_rdataset_at(rr.rrs, i)
@@ -231,32 +279,138 @@ ffi.metatype( knot_rrset_t, {
        },
 })
 
+-- Destructor for packet accepts pointer to pointer
+local packet_ptr = ffi.new('knot_pkt_t *[1]')
+local function pkt_free(pkt)
+       packet_ptr[0] = pkt
+       knot.knot_pkt_free(packet_ptr)
+end
+
+-- Helpers for reading/writing 16-bit numbers from packet wire
+local function pkt_u16(pkt, off, val)
+       local ptr = ffi.cast(u16_p, pkt.wire + off)
+       if val ~= nil then ptr[0] = htons(val) end
+       return (htons(ptr[0]))
+end
+
+-- Helpers for reading/writing message header flags
+local function pkt_bit(pkt, byteoff, bitmask, val)
+       -- If the value argument is passed, set/clear the desired bit
+       if val ~= nil then
+               if val then pkt.wire[byteoff] = bit.bor(pkt.wire[byteoff], bitmask)
+               else pkt.wire[byteoff] = bit.band(pkt.wire[byteoff], bit.bnot(bitmask)) end
+               return true
+       end
+       return (bit.band(pkt.wire[byteoff], bitmask) ~= 0)
+end
+
+-- Helpers for converting packet to text
+local function section_tostring(pkt, section_id)
+       local data = {}
+       local section = knot.knot_pkt_section(pkt, section_id)
+       if section.count > 0 then
+               table.insert(data, string.format('\n;; %s\n', const_section_str[section_id]))
+               for j = 0, section.count - 1 do
+                       local rrset = knot.knot_pkt_rr(section, j)
+                       local rrtype = rrset.type
+                       if rrtype ~= const_type.OPT and rrtype ~= const_type.TSIG then
+                               table.insert(data, rrset:txt_dump())
+                       end
+               end
+       end
+       return table.concat(data, '')
+end
+
+local function packet_tostring(pkt)
+       local hdr = string.format(';; ->>HEADER<<- opcode: %s; status: %s; id: %d\n',
+               const_opcode_str[pkt:opcode()], const_rcode_str[pkt:rcode()], pkt:id())
+       local flags = {}
+       for _,v in ipairs({'rd', 'tc', 'aa', 'qr', 'cd', 'ad', 'ra'}) do
+               if(pkt[v](pkt)) then table.insert(flags, v) end
+       end
+       local info = string.format(';; Flags: %s; QUERY: %d; ANSWER: %d; AUTHORITY: %d; ADDITIONAL: %d\n',
+               table.concat(flags, ' '), pkt:qdcount(), pkt:ancount(), pkt:nscount(), pkt:arcount())
+       local data = '\n'
+       if pkt.opt_rr ~= nil then
+               data = data..string.format(';; OPT PSEUDOSECTION:\n%s', pkt.opt_rr:tostring())
+       end
+       if pkt.tsig_rr ~= nil then
+               data = data..string.format(';; TSIG PSEUDOSECTION:\n%s', pkt.tsig_rr:tostring())
+       end
+       -- Zone transfer answers may omit question
+       if pkt:qdcount() > 0 then
+               data = data..string.format(';; QUESTION\n;; %s\t%s\t%s\n',
+                       dname2str(pkt:qname()), const_type_str[pkt:qtype()], const_class_str[pkt:qclass()])
+       end
+       local data_sec = {}
+       for i = const_section.ANSWER, const_section.ADDITIONAL do
+               table.insert(data_sec, section_tostring(pkt, i))
+       end
+       return hdr..info..data..table.concat(data_sec, '')
+end
+
 -- Metatype for packet
 local knot_pkt_t = ffi.typeof('knot_pkt_t')
 ffi.metatype( knot_pkt_t, {
+       __new = function (_, size, wire)
+               if size < 12 or size > 65535 then
+                       error('packet size must be <12, 65535>')
+               end
+
+               local pkt = knot.knot_pkt_new(nil, size, nil)
+               if pkt == nil then
+                       error(string.format('failed to allocate a packet of size %d', size))
+               end
+               if wire == nil then
+                       pkt:id(tonumber(C.kr_rand_uint(65535)))
+               else
+                       assert(size <= #wire)
+                       ffi.copy(pkt.wire, wire, size)
+                       pkt.size = size
+                       pkt.parsed = 0
+               end
+
+               return ffi.gc(pkt[0], pkt_free)
+       end,
+       __tostring = function(pkt)
+               return pkt:tostring()
+       end,
+       __len = function(pkt)
+               assert(pkt ~= nil) return pkt.size
+       end,
+       __ipairs = function(self)
+               return ipairs(self:section(const_section.ANSWER))
+       end,
        __index = {
-               qname = function(pkt)
-                       local qname = knot.knot_pkt_qname(pkt)
-                       return ffi.string(qname, knot.knot_dname_size(qname))
+               -- Header
+               id      = function(pkt, val) return pkt_u16(pkt, 0,  val) end,
+               qdcount = function(pkt, val) return pkt_u16(pkt, 4,  val) end,
+               ancount = function(pkt, val) return pkt_u16(pkt, 6,  val) end,
+               nscount = function(pkt, val) return pkt_u16(pkt, 8,  val) end,
+               arcount = function(pkt, val) return pkt_u16(pkt, 10, val) end,
+               opcode = function (pkt, val)
+                       assert(pkt ~= nil)
+                       pkt.wire[2] = (val) and bit.bor(bit.band(pkt.wire[2], 0x78), 8 * val) or pkt.wire[2]
+                       return (bit.band(pkt.wire[2], 0x78) / 8)
                end,
-               qclass = function(pkt) return knot.knot_pkt_qclass(pkt) end,
-               qtype  = function(pkt) return knot.knot_pkt_qtype(pkt) end,
                rcode = function (pkt, val)
                        pkt.wire[3] = (val) and bor(band(pkt.wire[3], 0xf0), val) or pkt.wire[3]
                        return band(pkt.wire[3], 0x0f)
                end,
-               tc = function (pkt, val)
-                       pkt.wire[2] = bor(pkt.wire[2], (val) and 0x02 or 0x00)
-                       return band(pkt.wire[2], 0x02)
-               end,
-               rd = function (pkt, val)
-                       pkt.wire[2] = bor(pkt.wire[2], (val) and 0x01 or 0x00)
-                       return band(pkt.wire[2],0x01)
-               end,
-               ad = function (pkt, val)
-                       pkt.wire[3] = bor(pkt.wire[3], (val) and 0x20 or 0x00)
-                       return band(pkt.wire[3],0x20)
+               rd = function (pkt, val) return pkt_bit(pkt, 2, 0x01, val) end,
+               tc = function (pkt, val) return pkt_bit(pkt, 2, 0x02, val) end,
+               aa = function (pkt, val) return pkt_bit(pkt, 2, 0x04, val) end,
+               qr = function (pkt, val) return pkt_bit(pkt, 2, 0x80, val) end,
+               cd = function (pkt, val) return pkt_bit(pkt, 3, 0x10, val) end,
+               ad = function (pkt, val) return pkt_bit(pkt, 3, 0x20, val) end,
+               ra = function (pkt, val) return pkt_bit(pkt, 3, 0x80, val) end,
+               -- Question
+               qname = function(pkt)
+                       local qname = knot.knot_pkt_qname(pkt)
+                       return dname2wire(qname)
                end,
+               qclass = function(pkt) return knot.knot_pkt_qclass(pkt) end,
+               qtype  = function(pkt) return knot.knot_pkt_qtype(pkt) end,
                rrsets = function (pkt, section_id)
                        local records = {}
                        local section = knot.knot_pkt_section(pkt, section_id)
@@ -277,13 +431,41 @@ ffi.metatype( knot_pkt_t, {
                        end
                        return records
                end,
-               begin = function (pkt, section) return knot.knot_pkt_begin(pkt, section) end,
+               begin = function (pkt, section)
+                       assert(pkt ~= nil)
+                       assert(section >= pkt.current, 'cannot rewind to already written section')
+                       assert(const_section_str[section], string.format('invalid section: %s', section))
+                       local ret = knot.knot_pkt_begin(pkt, section)
+                       if ret ~= 0 then return nil, knot_strerror(ret) end
+                       return true
+               end,
                put = function (pkt, owner, ttl, rclass, rtype, rdata)
-                       return C.kr_pkt_put(pkt, owner, ttl, rclass, rtype, rdata, #rdata)
+                       assert(pkt ~= nil)
+                       local ret = C.kr_pkt_put(pkt, owner, ttl, rclass, rtype, rdata, #rdata)
+                       if ret ~= 0 then return nil, knot_strerror(ret) end
+                       return true
                end,
                clear = function (pkt) return C.kr_pkt_recycle(pkt) end,
                question = function(pkt, qname, qclass, qtype)
-                       return C.knot_pkt_put_question(pkt, qname, qclass, qtype)
+                       assert(pkt ~= nil)
+                       assert(qclass ~= nil, string.format('invalid class: %s', qclass))
+                       assert(qtype ~= nil, string.format('invalid type: %s', qtype))
+                       local ret = C.knot_pkt_put_question(pkt, qname, qclass, qtype)
+                       if ret ~= 0 then return nil, knot_strerror(ret) end
+                       return true
+               end,
+               towire = function (pkt)
+                       return ffi.string(pkt.wire, pkt.size)
+               end,
+               tostring = function(pkt)
+                       return packet_tostring(pkt)
+               end,
+               -- Packet manipulation
+               parse = function (pkt)
+                       assert(pkt ~= nil)
+                       local ret = knot.knot_pkt_parse(pkt, 0)
+                       if ret ~= 0 then return nil, knot_strerror(ret) end
+                       return true
                end,
        },
 })
@@ -291,7 +473,7 @@ ffi.metatype( knot_pkt_t, {
 local kr_query_t = ffi.typeof('struct kr_query')
 ffi.metatype( kr_query_t, {
        __index = {
-               name = function(qry) return ffi.string(qry.sname, knot.knot_dname_size(qry.sname)) end,
+               name = function(qry) return dname2wire(qry.sname) end,
        },
 })
 -- Metatype for request
@@ -333,11 +515,6 @@ ffi.metatype( kr_request_t, {
        },
 })
 
--- Pretty print for domain name
-local function dname2str(dname)
-       return ffi.string(ffi.gc(C.knot_dname_to_str(nil, dname, 0), C.free))
-end
-
 -- Pretty-print a single RR (which is a table with .owner .ttl .type .rdata)
 -- Extension: append .comment if exists.
 local function rr2str(rr, style)
@@ -371,6 +548,16 @@ kres = {
        type = const_type,
        section = const_section,
        rcode = const_rcode,
+       opcode = const_opcode,
+
+       -- Constants to strings
+       tostring = {
+               class = const_class_str,
+               type = const_type_str,
+               section = const_section_str,
+               rcode = const_rcode_str,
+               opcode = const_opcode_str,
+       },
 
        -- Create a struct kr_qflags from a single flag name or a list of names.
        mk_qflags = function (names)
@@ -391,15 +578,21 @@ kres = {
        end,
 
        CONSUME = 1, PRODUCE = 2, DONE = 4, FAIL = 8, YIELD = 16,
+
+       -- Export types
+       rrset = knot_rrset_t,
+       packet = knot_pkt_t,
+
        -- Metatypes.  Beware that any pointer will be cast silently...
        pkt_t = function (udata) return ffi.cast('knot_pkt_t *', udata) end,
        request_t = function (udata) return ffi.cast('struct kr_request *', udata) end,
        -- Global API functions
        str2dname = function(name)
                local dname = ffi.gc(C.knot_dname_from_str(nil, name, 0), C.free)
-               return ffi.string(dname, knot.knot_dname_size(dname))
+               return dname2wire(dname)
        end,
        dname2str = dname2str,
+       dname2wire = dname2wire,
        rr2str = rr2str,
        str2ip = function (ip)
                local family = C.kr_straddr_family(ip)
index 4dae147d0d65a23e215eedcc4bde3dab12ea3c3a..b929bfe5dcd206cc68439819bbfd844a11475758 100644 (file)
@@ -4,7 +4,15 @@ local function test_constants()
        same(kres.type.NS, 2, 'record type constants work')
        same(kres.type.TYPE2, 2, 'unnamed record type constants work')
        same(kres.type.BADTYPE, nil, 'non-existent type constants are checked')
+       same(kres.section.ANSWER, 0, 'section constants work')
        same(kres.rcode.SERVFAIL, 2, 'rcode constants work')
+       same(kres.opcode.UPDATE, 5, 'opcode constants work')
+       -- Test inverset tables to convert constants to text
+       same(kres.tostring.class[1], 'IN', 'text class constants work')
+       same(kres.tostring.type[2], 'NS', 'text record type constants work')
+       same(kres.tostring.section[0], 'ANSWER', 'text section constants work')
+       same(kres.tostring.rcode[2], 'SERVFAIL', 'text rcode constants work')
+       same(kres.tostring.opcode[5], 'UPDATE', 'text opcode constants work')
 end
 
 -- test globals
@@ -20,15 +28,69 @@ local function test_globals()
 end
 
 -- test if dns library functions work
-local function test_kres_functions()
+local function test_rrset_functions()
        local rr = {owner = '\3com', ttl = 1, type = kres.type.TXT, rdata = '\5hello'}
        local rr_text = tostring(kres.rr2str(rr))
        same(rr_text:gsub('%s+', ' '), 'com. 1 TXT "hello"', 'rrset to text works')
        same(kres.dname2str(todname('com.')), 'com.', 'domain name conversion works')
 end
 
+-- test dns library packet interface
+local function test_packet_functions()
+       local pkt = kres.packet(512)
+       isnt(pkt, nil, 'creating packets works')
+       -- Test manipulating header
+       ok(pkt:rcode(kres.rcode.NOERROR), 'setting rcode works')
+       same(pkt:rcode(), 0, 'getting rcode works')
+       same(pkt:opcode(), 0, 'getting opcode works')
+       is(pkt:aa(), false, 'packet is created without AA')
+       is(pkt:ra(), false, 'packet is created without RA')
+       is(pkt:ad(), false, 'packet is created without AD')
+       ok(pkt:rd(true), 'setting RD bit works')
+       is(pkt:rd(), true, 'getting RD bit works')
+       ok(pkt:tc(true), 'setting TC bit works')
+       is(pkt:tc(), true, 'getting TC bit works')
+       ok(pkt:tc(false), 'disabling TC bit works')
+       is(pkt:tc(), false, 'getting TC bit after disable works')
+       is(pkt:cd(), false, 'getting CD bit works')
+       is(pkt:id(1234), 1234, 'setting MSGID works')
+       is(pkt:id(), 1234, 'getting MSGID works')
+       -- Test manipulating question
+       is(pkt:qname(), nil, 'reading name from empty question')
+       is(pkt:qtype(), 0, 'reading type from empty question')
+       is(pkt:qclass(), 0, 'reading class from empty question')
+       ok(pkt:question(todname('hello'), kres.class.IN, kres.type.A), 'setting question section works')
+       same(pkt:qname(), todname('hello'), 'reading QNAME works')
+       same(pkt:qtype(), kres.type.A, 'reading QTYPE works')
+       same(pkt:qclass(), kres.class.IN, 'reading QCLASS works')
+       -- Test manipulating sections
+       ok(pkt:begin(kres.section.ANSWER), 'switching sections works')
+       ok(pkt:put(pkt:qname(), 900, pkt:qclass(), kres.type.A, '\1\2\3\4'), 'adding rrsets works')
+       boom(pkt.begin, {pkt, 10}, 'switching to invalid section doesnt work')
+       ok(pkt:begin(kres.section.ADDITIONAL), 'switching to different section works')
+       boom(pkt.begin, {pkt, 0}, 'rewinding sections doesnt work')
+       ok(pkt:put(pkt:qname(), 900, pkt:qclass(), kres.type.A, '\4\3\2\1'), 'adding rrsets to different section works')
+       -- Test conversions to text
+       like(pkt:tostring(), '->>HEADER<<-', 'packet to text works')
+       -- Test deserialization
+       local wire = pkt:towire()
+       same(#wire, 55, 'packet serialization works')
+       local parsed = kres.packet(#wire, wire)
+       isnt(parsed, nil, 'creating packet from wire works')
+       ok(parsed:parse(), 'parsing packet from wire works')
+       same(parsed:qname(), pkt:qname(), 'parsed packet has same QNAME')
+       same(parsed:qtype(), pkt:qtype(), 'parsed packet has same QTYPE')
+       same(parsed:qclass(), pkt:qclass(), 'parsed packet has same QCLASS')
+       same(parsed:rcode(), pkt:rcode(), 'parsed packet has same rcode')
+       same(parsed:rd(), pkt:rd(), 'parsed packet has same RD')
+       same(parsed:id(), pkt:id(), 'parsed packet has same MSGID')
+       same(parsed:ancount(), pkt:ancount(), 'parsed packet has same answer count')
+       same(parsed:tostring(), pkt:tostring(), 'parsed packet is equal to source packet')
+end
+
 return {
        test_constants,
        test_globals,
-       test_kres_functions,
+       test_rrset_functions,
+       test_packet_functions,
 }
\ No newline at end of file