]> git.ipfire.org Git - thirdparty/bind9.git/commitdiff
Address dns_zt_asyncload races by properly using isc_reference_*.
authorMark Andrews <marka@isc.org>
Fri, 13 Dec 2019 02:58:47 +0000 (13:58 +1100)
committerOndřej Surý <ondrej@isc.org>
Mon, 13 Jan 2020 10:33:31 +0000 (11:33 +0100)
lib/dns/zt.c

index 69e1f10e6fb5b710714cc4e8aa0228f7bbcf63fd..5b410b1831636c854ebe65c122230c55fafab8ec 100644 (file)
@@ -285,36 +285,70 @@ load(dns_zone_t *zone, void *paramsv) {
        return (result);
 }
 
+static void
+call_loaddone(dns_zt_t *zt) {
+       dns_zt_allloaded_t loaddone = zt->loaddone;
+       void *loaddone_arg = zt->loaddone_arg;
+
+       /*
+        * Set zt->loaddone, zt->loaddone_arg and zt->loadparams to NULL
+        * before calling loaddone.
+        */
+       zt->loaddone = NULL;
+       zt->loaddone_arg = NULL;
+
+       isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params));
+       zt->loadparams = NULL;
+
+       /*
+        * Call the callback last.
+        */
+       if (loaddone != NULL) {
+               loaddone(loaddone_arg);
+       }
+}
+
 isc_result_t
 dns_zt_asyncload(dns_zt_t *zt, bool newonly,
-                dns_zt_allloaded_t alldone, void *arg) {
+                dns_zt_allloaded_t alldone, void *arg)
+{
        isc_result_t result;
-       int pending;
+       uint_fast32_t loads_pending;
 
        REQUIRE(VALID_ZT(zt));
+
+       /*
+        * Obtain a reference to zt->loads_pending so that asyncload can
+        * safely decrement both zt->references and zt->loads_pending
+        * without going to zero.
+        */
+       loads_pending = isc_refcount_increment0(&zt->loads_pending);
+       INSIST(loads_pending == 0);
+
+       /*
+        * Only one dns_zt_asyncload call at a time should be active so
+        * these pointers should be NULL.  They are set back to NULL
+        * before the zt->loaddone (alldone) is called in call_loaddone.
+        */
+       INSIST(zt->loadparams == NULL);
+       INSIST(zt->loaddone == NULL);
+       INSIST(zt->loaddone_arg == NULL);
+
        zt->loadparams = isc_mem_get(zt->mctx, sizeof(struct zt_load_params));
        zt->loadparams->dl = doneloading;
        zt->loadparams->newonly = newonly;
+       zt->loaddone = alldone;
+       zt->loaddone_arg = arg;
 
-       RWLOCK(&zt->rwlock, isc_rwlocktype_write);
-
-       INSIST(isc_refcount_current(&zt->loads_pending) == 0);
-
+       RWLOCK(&zt->rwlock, isc_rwlocktype_read);
        result = dns_zt_apply(zt, false, NULL, asyncload, zt);
+       RWUNLOCK(&zt->rwlock, isc_rwlocktype_read);
 
-       pending = isc_refcount_current(&zt->loads_pending);
-
-       if (pending != 0) {
-               zt->loaddone = alldone;
-               zt->loaddone_arg = arg;
-       }
-
-       RWUNLOCK(&zt->rwlock, isc_rwlocktype_write);
-
-       if (pending == 0) {
-               isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params));
-               zt->loadparams = NULL;
-               alldone(arg);
+       /*
+        * Have all the loads completed?
+        */
+       if (isc_refcount_decrement(&zt->loads_pending) == 1) {
+               call_loaddone(zt);
        }
 
        return (result);
@@ -332,14 +366,20 @@ asyncload(dns_zone_t *zone, void *zt_) {
        REQUIRE(zone != NULL);
 
        isc_refcount_increment(&zt->references);
-
        isc_refcount_increment(&zt->loads_pending);
 
-       result = dns_zone_asyncload(zone, zt->loadparams->newonly, *zt->loadparams->dl, zt);
+       result = dns_zone_asyncload(zone, zt->loadparams->newonly,
+                                   *zt->loadparams->dl, zt);
        if (result != ISC_R_SUCCESS) {
-
-               isc_refcount_decrement(&zt->references);
-               isc_refcount_decrement(&zt->loads_pending);
+               uint_fast32_t oldref;
+               /*
+                * Caller is holding a reference to zt->loads_pending
+                * and zt->references so these can't decrement to zero.
+                */
+               oldref = isc_refcount_decrement(&zt->loads_pending);
+               INSIST(oldref > 1);
+               oldref = isc_refcount_decrement(&zt->references);
+               INSIST(oldref > 1);
        }
        return (ISC_R_SUCCESS);
 }
@@ -528,8 +568,6 @@ dns_zt_apply(dns_zt_t *zt, bool stop, isc_result_t *sub,
  */
 static isc_result_t
 doneloading(dns_zt_t *zt, dns_zone_t *zone, isc_task_t *task) {
-       dns_zt_allloaded_t alldone = NULL;
-       void *arg = NULL;
 
        UNUSED(zone);
        UNUSED(task);
@@ -537,15 +575,7 @@ doneloading(dns_zt_t *zt, dns_zone_t *zone, isc_task_t *task) {
        REQUIRE(VALID_ZT(zt));
 
        if (isc_refcount_decrement(&zt->loads_pending) == 1) {
-               alldone = zt->loaddone;
-               arg = zt->loaddone_arg;
-               zt->loaddone = NULL;
-               zt->loaddone_arg = NULL;
-               isc_mem_put(zt->mctx, zt->loadparams, sizeof(struct zt_load_params));
-               zt->loadparams = NULL;
-               if (alldone != NULL) {
-                       alldone(arg);
-               }
+               call_loaddone(zt);
        }
 
        if (isc_refcount_decrement(&zt->references) == 1) {