]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Provide thread safe access to dns_xfrin_t state
authorMark Andrews <marka@isc.org>
Thu, 6 Jul 2023 04:00:48 +0000 (14:00 +1000)
committerAram Sargsyan <aram@isc.org>
Fri, 22 Sep 2023 08:52:13 +0000 (08:52 +0000)
dns_xfrin_t state may be accessed from different threads when
when reporting transfer state.  Ensure access is thread safe by
using atomics and locks where appropriate.

lib/dns/include/dns/xfrin.h
lib/dns/xfrin.c

index f140c534ac35462becc4f8f3cddc3c6c155386c9..3d10bb42fe1fd879d80bd8b44cea85384bb1b311 100644 (file)
@@ -125,8 +125,8 @@ dns_xfrin_getendserial(const dns_xfrin_t *xfr);
  */
 
 void
-dns_xfrin_getstats(const dns_xfrin_t *xfr, unsigned int *nmsgp,
-                  unsigned int *nrecsp, uint64_t *nbytesp);
+dns_xfrin_getstats(dns_xfrin_t *xfr, unsigned int *nmsgp, unsigned int *nrecsp,
+                  uint64_t *nbytesp);
 /*%<
  * Get various statistics values of the xfrin object: number of the received
  * messages, number of the received records, number of the received bytes.
index 23de072e0edaf9cdb392d77ef5842c0b20a69dd6..c830533a27dae08eb5d21ff01c8e3e2a885d79ce 100644 (file)
@@ -16,6 +16,7 @@
 #include <inttypes.h>
 #include <stdbool.h>
 
+#include <isc/atomic.h>
 #include <isc/mem.h>
 #include <isc/random.h>
 #include <isc/result.h>
@@ -140,11 +141,14 @@ struct dns_xfrin {
        dns_diff_t diff; /*%< Pending database changes */
        int difflen;     /*%< Number of pending tuples */
 
-       xfrin_state_t state;
+       _Atomic xfrin_state_t state;
        uint32_t end_serial;
        uint32_t expireopt;
-       bool edns, is_ixfr, expireoptset;
+       bool edns, expireoptset;
+       atomic_bool is_ixfr;
 
+       isc_mutex_t statslock;
+       /* Locked by statslock.  */
        unsigned int nmsg;  /*%< Number of messages recvd */
        unsigned int nrecs; /*%< Number of records recvd */
        uint64_t nbytes;    /*%< Number of bytes received */
@@ -269,7 +273,7 @@ static isc_result_t
 axfr_init(dns_xfrin_t *xfr) {
        isc_result_t result;
 
-       xfr->is_ixfr = false;
+       atomic_store(&xfr->is_ixfr, false);
 
        if (xfr->db != NULL) {
                dns_db_detach(&xfr->db);
@@ -385,7 +389,7 @@ ixfr_init(dns_xfrin_t *xfr) {
                return (DNS_R_FORMERR);
        }
 
-       xfr->is_ixfr = true;
+       atomic_store(&xfr->is_ixfr, true);
        INSIST(xfr->db != NULL);
        xfr->difflen = 0;
 
@@ -491,7 +495,9 @@ static isc_result_t
 xfr_rr(dns_xfrin_t *xfr, dns_name_t *name, uint32_t ttl, dns_rdata_t *rdata) {
        isc_result_t result;
 
+       LOCK(&xfr->statslock);
        xfr->nrecs++;
+       UNLOCK(&xfr->statslock);
 
        if (rdata->type == dns_rdatatype_none ||
            dns_rdatatype_ismeta(rdata->type))
@@ -519,7 +525,7 @@ xfr_rr(dns_xfrin_t *xfr, dns_name_t *name, uint32_t ttl, dns_rdata_t *rdata) {
        }
 
 redo:
-       switch (xfr->state) {
+       switch (atomic_load(&xfr->state)) {
        case XFRST_SOAQUERY:
                if (rdata->type != dns_rdatatype_soa) {
                        xfrin_log(xfr, ISC_LOG_NOTICE,
@@ -536,7 +542,7 @@ redo:
                                  xfr->ixfr.request_serial, xfr->end_serial);
                        FAIL(DNS_R_UPTODATE);
                }
-               xfr->state = XFRST_GOTSOA;
+               atomic_store(&xfr->state, XFRST_GOTSOA);
                break;
 
        case XFRST_GOTSOA:
@@ -578,7 +584,7 @@ redo:
                xfr->firstsoa_data = isc_mem_allocate(xfr->mctx, rdata->length);
                memcpy(xfr->firstsoa_data, rdata->data, rdata->length);
                xfr->firstsoa.data = xfr->firstsoa_data;
-               xfr->state = XFRST_FIRSTDATA;
+               atomic_store(&xfr->state, XFRST_FIRSTDATA);
                break;
 
        case XFRST_FIRSTDATA:
@@ -593,25 +599,25 @@ redo:
                        xfrin_log(xfr, ISC_LOG_DEBUG(3),
                                  "got incremental response");
                        CHECK(ixfr_init(xfr));
-                       xfr->state = XFRST_IXFR_DELSOA;
+                       atomic_store(&xfr->state, XFRST_IXFR_DELSOA);
                } else {
                        xfrin_log(xfr, ISC_LOG_DEBUG(3),
                                  "got nonincremental response");
                        CHECK(axfr_init(xfr));
-                       xfr->state = XFRST_AXFR;
+                       atomic_store(&xfr->state, XFRST_AXFR);
                }
                goto redo;
 
        case XFRST_IXFR_DELSOA:
                INSIST(rdata->type == dns_rdatatype_soa);
                CHECK(ixfr_putdata(xfr, DNS_DIFFOP_DEL, name, ttl, rdata));
-               xfr->state = XFRST_IXFR_DEL;
+               atomic_store(&xfr->state, XFRST_IXFR_DEL);
                break;
 
        case XFRST_IXFR_DEL:
                if (rdata->type == dns_rdatatype_soa) {
                        uint32_t soa_serial = dns_soa_getserial(rdata);
-                       xfr->state = XFRST_IXFR_ADDSOA;
+                       atomic_store(&xfr->state, XFRST_IXFR_ADDSOA);
                        xfr->ixfr.current_serial = soa_serial;
                        goto redo;
                }
@@ -621,7 +627,7 @@ redo:
        case XFRST_IXFR_ADDSOA:
                INSIST(rdata->type == dns_rdatatype_soa);
                CHECK(ixfr_putdata(xfr, DNS_DIFFOP_ADD, name, ttl, rdata));
-               xfr->state = XFRST_IXFR_ADD;
+               atomic_store(&xfr->state, XFRST_IXFR_ADD);
                break;
 
        case XFRST_IXFR_ADD:
@@ -629,7 +635,7 @@ redo:
                        uint32_t soa_serial = dns_soa_getserial(rdata);
                        if (soa_serial == xfr->end_serial) {
                                CHECK(ixfr_commit(xfr));
-                               xfr->state = XFRST_IXFR_END;
+                               atomic_store(&xfr->state, XFRST_IXFR_END);
                                break;
                        } else if (soa_serial != xfr->ixfr.current_serial) {
                                xfrin_log(xfr, ISC_LOG_NOTICE,
@@ -639,7 +645,7 @@ redo:
                                FAIL(DNS_R_FORMERR);
                        } else {
                                CHECK(ixfr_commit(xfr));
-                               xfr->state = XFRST_IXFR_DELSOA;
+                               atomic_store(&xfr->state, XFRST_IXFR_DELSOA);
                                goto redo;
                        }
                }
@@ -674,7 +680,7 @@ redo:
                                FAIL(DNS_R_FORMERR);
                        }
                        CHECK(axfr_commit(xfr));
-                       xfr->state = XFRST_AXFR_END;
+                       atomic_store(&xfr->state, XFRST_AXFR_END);
                        break;
                }
                break;
@@ -778,10 +784,10 @@ dns_xfrin_getstate(const dns_xfrin_t *xfr, const char **statestr,
        REQUIRE(statestr != NULL && *statestr == NULL);
        REQUIRE(is_ixfr != NULL);
 
-       state = xfr->state;
+       state = atomic_load(&xfr->state);
        *statestr = "";
        *is_first_data_received = (state > XFRST_FIRSTDATA);
-       *is_ixfr = xfr->is_ixfr;
+       *is_ixfr = atomic_load(&xfr->is_ixfr);
 
        switch (state) {
        case XFRST_SOAQUERY:
@@ -822,14 +828,16 @@ dns_xfrin_getendserial(const dns_xfrin_t *xfr) {
 }
 
 void
-dns_xfrin_getstats(const dns_xfrin_t *xfr, unsigned int *nmsgp,
-                  unsigned int *nrecsp, uint64_t *nbytesp) {
+dns_xfrin_getstats(dns_xfrin_t *xfr, unsigned int *nmsgp, unsigned int *nrecsp,
+                  uint64_t *nbytesp) {
        REQUIRE(VALID_XFRIN(xfr));
        REQUIRE(nmsgp != NULL && nrecsp != NULL && nbytesp != NULL);
 
+       LOCK(&xfr->statslock);
        *nmsgp = xfr->nmsg;
        *nrecsp = xfr->nrecs;
        *nbytesp = xfr->nbytes;
+       UNLOCK(&xfr->statslock);
 }
 
 const isc_sockaddr_t *
@@ -925,7 +933,7 @@ xfrin_fail(dns_xfrin_t *xfr, isc_result_t result, const char *msg) {
                {
                        xfrin_log(xfr, ISC_LOG_ERROR, "%s: %s", msg,
                                  isc_result_totext(result));
-                       if (xfr->is_ixfr) {
+                       if (atomic_load(&xfr->is_ixfr)) {
                                /*
                                 * Pass special result code to force AXFR retry
                                 */
@@ -979,7 +987,10 @@ xfrin_create(isc_mem_t *mctx, dns_zone_t *zone, dns_db_t *db,
        dns_view_weakattach(dns_zone_getview(zone), &xfr->view);
        dns_name_init(&xfr->name, NULL);
 
+       isc_mutex_init(&xfr->statslock);
+
        atomic_init(&xfr->shuttingdown, false);
+       atomic_init(&xfr->is_ixfr, false);
 
        if (db != NULL) {
                dns_db_attach(db, &xfr->db);
@@ -988,9 +999,9 @@ xfrin_create(isc_mem_t *mctx, dns_zone_t *zone, dns_db_t *db,
        dns_diff_init(xfr->mctx, &xfr->diff);
 
        if (reqtype == dns_rdatatype_soa) {
-               xfr->state = XFRST_SOAQUERY;
+               atomic_init(&xfr->state, XFRST_SOAQUERY);
        } else {
-               xfr->state = XFRST_INITIALSOA;
+               atomic_init(&xfr->state, XFRST_INITIALSOA);
        }
 
        xfr->start = isc_time_now();
@@ -1363,9 +1374,11 @@ xfrin_send_request(dns_xfrin_t *xfr) {
                CHECK(add_opt(msg, udpsize, reqnsid, reqexpire));
        }
 
+       LOCK(&xfr->statslock);
        xfr->nmsg = 0;
        xfr->nrecs = 0;
        xfr->nbytes = 0;
+       UNLOCK(&xfr->statslock);
        xfr->start = isc_time_now();
        msg->id = xfr->id;
        if (xfr->tsigctx != NULL) {
@@ -1497,9 +1510,11 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
 
        dns_message_setclass(msg, xfr->rdclass);
 
+       LOCK(&xfr->statslock);
        if (xfr->nmsg > 0) {
                msg->tcp_continuation = 1;
        }
+       UNLOCK(&xfr->statslock);
 
        isc_buffer_init(&buffer, region->base, region->length);
        isc_buffer_add(&buffer, region->length);
@@ -1523,8 +1538,8 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
        {
                if (result == ISC_R_SUCCESS &&
                    msg->rcode == dns_rcode_formerr && xfr->edns &&
-                   (xfr->state == XFRST_SOAQUERY ||
-                    xfr->state == XFRST_INITIALSOA))
+                   (atomic_load(&xfr->state) == XFRST_SOAQUERY ||
+                    atomic_load(&xfr->state) == XFRST_INITIALSOA))
                {
                        xfr->edns = false;
                        dns_message_detach(&msg);
@@ -1559,7 +1574,7 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
                dns_message_detach(&msg);
                xfrin_reset(xfr);
                xfr->reqtype = dns_rdatatype_soa;
-               xfr->state = XFRST_SOAQUERY;
+               atomic_store(&xfr->state, XFRST_SOAQUERY);
        try_again:
                result = xfrin_start(xfr);
                if (result != ISC_R_SUCCESS) {
@@ -1583,7 +1598,8 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
                goto failure;
        }
 
-       if ((xfr->state == XFRST_SOAQUERY || xfr->state == XFRST_INITIALSOA) &&
+       if ((atomic_load(&xfr->state) == XFRST_SOAQUERY ||
+            atomic_load(&xfr->state) == XFRST_INITIALSOA) &&
            msg->counts[DNS_SECTION_QUESTION] != 1)
        {
                xfrin_log(xfr, ISC_LOG_NOTICE, "missing question section");
@@ -1633,7 +1649,7 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
         * if the first RR in the answer section is not a SOA record.
         */
        if (xfr->reqtype == dns_rdatatype_ixfr &&
-           xfr->state == XFRST_INITIALSOA &&
+           atomic_load(&xfr->state) == XFRST_INITIALSOA &&
            msg->counts[DNS_SECTION_ANSWER] == 0)
        {
                xfrin_log(xfr, ISC_LOG_DEBUG(3),
@@ -1700,24 +1716,29 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
                CHECK(dns_message_getquerytsig(msg, xfr->mctx, &xfr->lasttsig));
        } else if (dns_message_gettsigkey(msg) != NULL) {
                xfr->sincetsig++;
+               LOCK(&xfr->statslock);
                if (xfr->sincetsig > 100 || xfr->nmsg == 0 ||
-                   xfr->state == XFRST_AXFR_END ||
-                   xfr->state == XFRST_IXFR_END)
+                   atomic_load(&xfr->state) == XFRST_AXFR_END ||
+                   atomic_load(&xfr->state) == XFRST_IXFR_END)
                {
+                       UNLOCK(&xfr->statslock);
                        result = DNS_R_EXPECTEDTSIG;
                        goto failure;
                }
+               UNLOCK(&xfr->statslock);
        }
 
        /*
         * Update the number of messages received.
         */
+       LOCK(&xfr->statslock);
        xfr->nmsg++;
 
        /*
         * Update the number of bytes received.
         */
        xfr->nbytes += buffer.used;
+       UNLOCK(&xfr->statslock);
 
        /*
         * Take the context back.
@@ -1730,10 +1751,10 @@ xfrin_recv_done(isc_result_t result, isc_region_t *region, void *arg) {
                get_edns_expire(xfr, msg);
        }
 
-       switch (xfr->state) {
+       switch (atomic_load(&xfr->state)) {
        case XFRST_GOTSOA:
                xfr->reqtype = dns_rdatatype_axfr;
-               xfr->state = XFRST_INITIALSOA;
+               atomic_store(&xfr->state, XFRST_INITIALSOA);
                CHECK(xfrin_send_request(xfr));
                break;
        case XFRST_AXFR_END:
@@ -1826,6 +1847,7 @@ xfrin_destroy(dns_xfrin_t *xfr) {
        if (msecs == 0) {
                msecs = 1;
        }
+       LOCK(&xfr->statslock);
        persec = (xfr->nbytes * 1000) / msecs;
        xfrin_log(xfr, ISC_LOG_INFO,
                  "Transfer completed: %d messages, %d records, "
@@ -1834,6 +1856,8 @@ xfrin_destroy(dns_xfrin_t *xfr) {
                  xfr->nmsg, xfr->nrecs, xfr->nbytes,
                  (unsigned int)(msecs / 1000), (unsigned int)(msecs % 1000),
                  (unsigned int)persec, xfr->end_serial);
+       UNLOCK(&xfr->statslock);
+       isc_mutex_destroy(&xfr->statslock);
 
        if (xfr->dispentry != NULL) {
                dns_dispatch_done(&xfr->dispentry);