]> git.ipfire.org Git - thirdparty/ldns.git/commitdiff
Added functionality to have multiple packet answers to a query.
authorWouter Wijngaards <wouter@NLnetLabs.nl>
Mon, 31 Jul 2006 13:42:33 +0000 (13:42 +0000)
committerWouter Wijngaards <wouter@NLnetLabs.nl>
Mon, 31 Jul 2006 13:42:33 +0000 (13:42 +0000)
examples/nsd-test/ldns-testns.c

index ca6d699daa9ce218b995ce8a0659d3ea99af7b83..1a186060217674b0314c75540caa870261b7fb0b 100644 (file)
@@ -47,6 +47,7 @@
        <RRs, one per line>
        SECTION ADDITIONAL
        <RRs, one per line>
+       EXTRA_PACKET            ; follow with SECTION, REPLY for more packets.
        ENTRY_END
 */
 
@@ -73,6 +74,12 @@ static int verbose = 0;
 
 enum transport_type {transport_any = 0, transport_udp, transport_tcp };
 
+/* struct to keep a linked list of reply packets for a query */
+struct reply_packet {
+       struct reply_packet* next;
+       ldns_pkt* reply;
+};
+
 /* data structure to keep the canned queries in */
 /* format is the 'matching query' and the 'canned answer' */
 struct entry {
@@ -86,7 +93,7 @@ struct entry {
        enum transport_type match_transport; /* match on UDP/TCP */
 
        /* pre canned reply */
-       ldns_pkt *reply;
+       struct reply_packet *reply_list;
 
        /* how to adjust the reply packet */
        bool copy_id; /* copy over the ID from the query into the answer */
@@ -156,6 +163,21 @@ static bool str_keyword(const char** str, const char* keyword)
        return true;
 }
 
+static struct reply_packet*
+entry_add_reply(struct entry* entry) 
+{
+       struct reply_packet* pkt = (struct reply_packet*)malloc(
+               sizeof(struct reply_packet));
+       struct reply_packet ** p = &entry->reply_list;
+       pkt->next = NULL;
+       pkt->reply = ldns_pkt_new();
+       /* link at end */
+       while(*p)
+               p = &((*p)->next);
+       *p = pkt;
+       return pkt;
+}
+
 static void matchline(const char* line, struct entry* e)
 {
        const char* parse = line;
@@ -186,7 +208,7 @@ static void matchline(const char* line, struct entry* e)
        }
 }
 
-static void replyline(const char* line, struct entry* e)
+static void replyline(const char* line, ldns_pkt *reply)
 {
        const char* parse = line;
        while(*parse) {
@@ -194,51 +216,51 @@ static void replyline(const char* line, struct entry* e)
                        return;
                        /* opcodes */
                if(str_keyword(&parse, "QUERY")) {
-                       ldns_pkt_set_opcode(e->reply, LDNS_PACKET_QUERY);
+                       ldns_pkt_set_opcode(reply, LDNS_PACKET_QUERY);
                } else if(str_keyword(&parse, "IQUERY")) {
-                       ldns_pkt_set_opcode(e->reply, LDNS_PACKET_IQUERY);
+                       ldns_pkt_set_opcode(reply, LDNS_PACKET_IQUERY);
                } else if(str_keyword(&parse, "STATUS")) {
-                       ldns_pkt_set_opcode(e->reply, LDNS_PACKET_STATUS);
+                       ldns_pkt_set_opcode(reply, LDNS_PACKET_STATUS);
                } else if(str_keyword(&parse, "NOTIFY")) {
-                       ldns_pkt_set_opcode(e->reply, LDNS_PACKET_NOTIFY);
+                       ldns_pkt_set_opcode(reply, LDNS_PACKET_NOTIFY);
                } else if(str_keyword(&parse, "UPDATE")) {
-                       ldns_pkt_set_opcode(e->reply, LDNS_PACKET_UPDATE);
+                       ldns_pkt_set_opcode(reply, LDNS_PACKET_UPDATE);
                        /* rcodes */
                } else if(str_keyword(&parse, "NOERROR")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NOERROR);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NOERROR);
                } else if(str_keyword(&parse, "FORMERR")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_FORMERR);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_FORMERR);
                } else if(str_keyword(&parse, "SERVFAIL")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_SERVFAIL);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_SERVFAIL);
                } else if(str_keyword(&parse, "NXDOMAIN")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NXDOMAIN);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NXDOMAIN);
                } else if(str_keyword(&parse, "NOTIMPL")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NOTIMPL);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NOTIMPL);
                } else if(str_keyword(&parse, "YXDOMAIN")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_YXDOMAIN);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_YXDOMAIN);
                } else if(str_keyword(&parse, "YXRRSET")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_YXRRSET);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_YXRRSET);
                } else if(str_keyword(&parse, "NXRRSET")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NXRRSET);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NXRRSET);
                } else if(str_keyword(&parse, "NOTAUTH")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NOTAUTH);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NOTAUTH);
                } else if(str_keyword(&parse, "NOTZONE")) {
-                       ldns_pkt_set_rcode(e->reply, LDNS_RCODE_NOTZONE);
+                       ldns_pkt_set_rcode(reply, LDNS_RCODE_NOTZONE);
                        /* flags */
                } else if(str_keyword(&parse, "QR")) {
-                       ldns_pkt_set_qr(e->reply, true);
+                       ldns_pkt_set_qr(reply, true);
                } else if(str_keyword(&parse, "AA")) {
-                       ldns_pkt_set_aa(e->reply, true);
+                       ldns_pkt_set_aa(reply, true);
                } else if(str_keyword(&parse, "TC")) {
-                       ldns_pkt_set_tc(e->reply, true);
+                       ldns_pkt_set_tc(reply, true);
                } else if(str_keyword(&parse, "RD")) {
-                       ldns_pkt_set_rd(e->reply, true);
+                       ldns_pkt_set_rd(reply, true);
                } else if(str_keyword(&parse, "CD")) {
-                       ldns_pkt_set_cd(e->reply, true);
+                       ldns_pkt_set_cd(reply, true);
                } else if(str_keyword(&parse, "RA")) {
-                       ldns_pkt_set_ra(e->reply, true);
+                       ldns_pkt_set_ra(reply, true);
                } else if(str_keyword(&parse, "AD")) {
-                       ldns_pkt_set_ad(e->reply, true);
+                       ldns_pkt_set_ad(reply, true);
                } else {
                        error("could not parse REPLY: '%s'", parse);
                }
@@ -269,7 +291,7 @@ static struct entry* new_entry()
        e->match_serial = false;
        e->ixfr_soa_serial = 0;
        e->match_transport = transport_any;
-       e->reply = ldns_pkt_new();
+       e->reply_list = NULL;
        e->copy_id = false;
        e->next = NULL;
        return e;
@@ -313,6 +335,7 @@ static struct entry* read_datafile(const char* name)
        ldns_rdf* origin = NULL;
        ldns_rdf* prev_rr = NULL;
        int entry_num = 0;
+       struct reply_packet *cur_reply = NULL;
 
        if((in=fopen(name, "r")) == NULL) {
                error("could not open file %s: %s", name, strerror(errno));
@@ -334,6 +357,7 @@ static struct entry* read_datafile(const char* name)
                                        name, lineno);
                        }
                        current = new_entry();
+                       cur_reply = entry_add_reply(current);
                        if(last)
                                last->next = current;
                        else    list = current;
@@ -355,9 +379,11 @@ static struct entry* read_datafile(const char* name)
                if(str_keyword(&parse, "MATCH")) {
                        matchline(parse, current);
                } else if(str_keyword(&parse, "REPLY")) {
-                       replyline(parse, current);
+                       replyline(parse, cur_reply->reply);
                } else if(str_keyword(&parse, "ADJUST")) {
                        adjustline(parse, current);
+               } else if(str_keyword(&parse, "EXTRA_PACKET")) {
+                       cur_reply = entry_add_reply(current);
                } else if(str_keyword(&parse, "SECTION")) {
                        if(str_keyword(&parse, "QUESTION"))
                                add_section = LDNS_SECTION_QUESTION;
@@ -379,10 +405,9 @@ static struct entry* read_datafile(const char* name)
                        if (status != LDNS_STATUS_OK)
                                error("%s line %d:\n\t%s: %s", name, lineno,
                                        ldns_get_errorstr_by_id(status), parse);
-                       ldns_pkt_push_rr(current->reply, add_section, n);
+                       ldns_pkt_push_rr(cur_reply->reply, add_section, n);
                }
 
-
        }
        log_msg("Read %d entries\n", entry_num);
 
@@ -423,21 +448,23 @@ static struct entry* find_match(struct entry* entries, ldns_pkt* query_pkt,
        enum transport_type transport)
 {
        struct entry* p = entries;
+       ldns_pkt* reply = NULL;
        for(p=entries; p; p=p->next) {
                if(verbose) log_msg("comparepkt: ");
+               reply = p->reply_list->reply;
                if(p->match_opcode && ldns_pkt_get_opcode(query_pkt) != 
-                       ldns_pkt_get_opcode(p->reply)) {
+                       ldns_pkt_get_opcode(reply)) {
                        if(verbose) log_msg("bad opcode\n");
                        continue;
                }
-               if(p->match_qtype && get_qtype(query_pkt) != get_qtype(p->reply)) {
+               if(p->match_qtype && get_qtype(query_pkt) != get_qtype(reply)) {
                        if(verbose) log_msg("bad qtype\n");
                        continue;
                }
                if(p->match_qname) {
-                       if(!get_owner(query_pkt) || !get_owner(p->reply) ||
+                       if(!get_owner(query_pkt) || !get_owner(reply) ||
                                ldns_dname_compare(
-                               get_owner(query_pkt), get_owner(p->reply)) != 0) {
+                               get_owner(query_pkt), get_owner(reply)) != 0) {
                                if(verbose) log_msg("bad qname\n");
                                continue;
                        }
@@ -456,38 +483,36 @@ static struct entry* find_match(struct entry* entries, ldns_pkt* query_pkt,
        return NULL;
 }
 
-static ldns_pkt* 
-get_answer(struct entry* entries, ldns_pkt* query_pkt, enum transport_type transport)
+static void
+adjust_packet(struct entry* match, ldns_pkt* answer_pkt, ldns_pkt* query_pkt)
 {
-       ldns_pkt* answer_pkt = NULL;
-       struct entry* match = find_match(entries, query_pkt, transport);
-       if(!match) 
-               return NULL;
        /* copy & adjust packet */
-       answer_pkt = ldns_pkt_clone(match->reply);
        if(match->copy_id)
                ldns_pkt_set_id(answer_pkt, ldns_pkt_id(query_pkt));
-       return answer_pkt;
 }
 
 /*
- * Parses data buffer to a query, finds the correct answer and returns
- * a buffer to data to send, or NULL. (LDNS_FREE the buffer when done).
+ * Parses data buffer to a query, finds the correct answer 
+ * and calls the given function for every packet to send.
  */
-static uint8_t*
+static void
 handle_query(uint8_t* inbuf, ssize_t inlen, struct entry* entries, int* count,
-       size_t* answer_size, enum transport_type transport)
+       enum transport_type transport, void (*sendfunc)(uint8_t*, size_t, void*),
+       void* userdata)
 {
        ldns_status status;
        ldns_pkt *query_pkt = NULL;
        ldns_pkt *answer_pkt = NULL;
+       struct reply_packet *p;
        ldns_rr *query_rr = NULL;
        uint8_t *outbuf = NULL;
+       size_t answer_size = 0;
+       struct entry* entry = NULL;
 
        status = ldns_wire2pkt(&query_pkt, inbuf, (size_t)inlen);
        if (status != LDNS_STATUS_OK) {
                log_msg("Got bad packet: %s\n", ldns_get_errorstr_by_id(status));
-               return NULL;
+               return;
        }
        
        query_rr = ldns_rr_list_rr(ldns_pkt_question(query_pkt), 0);
@@ -497,57 +522,70 @@ handle_query(uint8_t* inbuf, ssize_t inlen, struct entry* entries, int* count,
        if(verbose) ldns_pkt_print(logfile, query_pkt);
        
        /* fill up answer packet */
-       answer_pkt = get_answer(entries, query_pkt, transport);
-       if(answer_pkt) {
+       entry = find_match(entries, query_pkt, transport);
+       if(!entry || !entry->reply_list) {
+               log_msg("no answer packet for this query, no reply.\n");
+               return;
+       }
+       for(p = entry->reply_list; p; p = p->next)
+       {
                if(verbose) log_msg("Answer pkt:\n");
                if(verbose) ldns_pkt_print(logfile, answer_pkt);
-               status = ldns_pkt2wire(&outbuf, answer_pkt, answer_size);
-               log_msg("Answer packet size: %u bytes.\n", (unsigned int)*answer_size);
+               answer_pkt = ldns_pkt_clone(p->reply);
+               adjust_packet(entry, answer_pkt, query_pkt);
+               status = ldns_pkt2wire(&outbuf, answer_pkt, &answer_size);
+               log_msg("Answer packet size: %u bytes.\n", (unsigned int)answer_size);
                if (status != LDNS_STATUS_OK) {
                        log_msg("Error creating answer: %s\n", ldns_get_errorstr_by_id(status));
                        outbuf = NULL;
                }
-       } else {
-               log_msg("no answer packet for this query, no reply.\n");
-               outbuf = NULL;
+               ldns_pkt_free(query_pkt);
+               ldns_pkt_free(answer_pkt);
+
+               sendfunc(outbuf, answer_size, userdata);
+               LDNS_FREE(outbuf);
+               outbuf = 0;
+               answer_size = 0;
        }
-       ldns_pkt_free(query_pkt);
-       ldns_pkt_free(answer_pkt);
-       return outbuf;
+}
+
+struct handle_udp_userdata {
+       int udp_sock;
+       struct sockaddr_storage addr_him;
+       socklen_t hislen;
+};
+static void
+send_udp(uint8_t* buf, size_t len, void* data)
+{
+       struct handle_udp_userdata *userdata = (struct handle_udp_userdata*)data;
+       /* udp send reply */
+       ssize_t nb;
+       nb = sendto(userdata->udp_sock, buf, len, 0, 
+               (struct sockaddr*)&userdata->addr_him, userdata->hislen);
+       if(nb == -1)
+               log_msg("sendto(): %s\n", strerror(errno));
+       else if((size_t)nb != len)
+               log_msg("sendto(): only sent %d of %d octets.\n", 
+                       (int)nb, (int)len);
 }
 
 static void
 handle_udp(int udp_sock, struct entry* entries, int *count)
 {
        ssize_t nb;
-       struct sockaddr_storage addr_him;
-       socklen_t hislen;
        uint8_t inbuf[INBUF_SIZE];
-       uint8_t *outbuf;
-       size_t answer_size = 0;
+       struct handle_udp_userdata userdata;
+       userdata.udp_sock = udp_sock;
 
-       hislen = (socklen_t)sizeof(addr_him);
+       userdata.hislen = (socklen_t)sizeof(userdata.addr_him);
        /* udp recv */
        nb = recvfrom(udp_sock, inbuf, INBUF_SIZE, 0, 
-               (struct sockaddr*)&addr_him, &hislen);
+               (struct sockaddr*)&userdata.addr_him, &userdata.hislen);
        if (nb < 1) {
                log_msg("recvfrom(): %s\n", strerror(errno));
                return;
        }
-       outbuf = handle_query(inbuf, nb, entries, count, &answer_size,
-               transport_udp);
-       if(!outbuf)
-               return;
-
-       /* udp send reply */
-       nb = sendto(udp_sock, outbuf, answer_size, 0, 
-               (struct sockaddr*)&addr_him, hislen);
-       if(nb == -1)
-               log_msg("sendto(): %s\n", strerror(errno));
-       else if((size_t)nb != answer_size)
-               log_msg("sendto(): only sent %d of %d octets.\n", 
-                       (int)nb, (int)answer_size);
-       LDNS_FREE(outbuf);
+       handle_query(inbuf, nb, entries, count, transport_udp, send_udp, &userdata);
 }
 
 static void
@@ -578,6 +616,20 @@ write_n_bytes(int sock, uint8_t* buf, size_t sz)
        }
 }
 
+struct handle_tcp_userdata {
+       int s;
+};
+static void
+send_tcp(uint8_t* buf, size_t len, void* data)
+{
+       struct handle_tcp_userdata *userdata = (struct handle_tcp_userdata*)data;
+       uint16_t tcplen;
+       /* tcp send reply */
+       tcplen = htons(len);
+       write_n_bytes(userdata->s, (uint8_t*)&tcplen, sizeof(tcplen));
+       write_n_bytes(userdata->s, buf, len);
+}
+
 static void
 handle_tcp(int tcp_sock, struct entry* entries, int *count)
 {
@@ -585,9 +637,8 @@ handle_tcp(int tcp_sock, struct entry* entries, int *count)
        struct sockaddr_storage addr_him;
        socklen_t hislen;
        uint8_t inbuf[INBUF_SIZE];
-       uint8_t *outbuf;
-       size_t answer_size = 0;
        uint16_t tcplen;
+       struct handle_tcp_userdata userdata;
 
        /* accept */
        hislen = (socklen_t)sizeof(addr_him);
@@ -595,6 +646,7 @@ handle_tcp(int tcp_sock, struct entry* entries, int *count)
                log_msg("accept(): %s\n", strerror(errno));
                return;
        }
+       userdata.s = s;
 
        /* tcp recv */
        read_n_bytes(s, (uint8_t*)&tcplen, sizeof(tcplen));
@@ -607,19 +659,9 @@ handle_tcp(int tcp_sock, struct entry* entries, int *count)
        }
        read_n_bytes(s, inbuf, tcplen);
 
-       outbuf = handle_query(inbuf, tcplen, entries, count, &answer_size,
-               transport_tcp);
-       if(!outbuf) {
-               close(s);
-               return;
-       }
-
-       /* tcp send reply */
-       tcplen = htons(answer_size);
-       write_n_bytes(s, (uint8_t*)&tcplen, sizeof(tcplen));
-       write_n_bytes(s, outbuf, answer_size);
-       LDNS_FREE(outbuf);
+       handle_query(inbuf, tcplen, entries, count, transport_tcp, send_tcp, &userdata);
        close(s);
+
 }
 
 int