]> git.ipfire.org Git - thirdparty/linux.git/commitdiff
liveupdate: Use refcount_t for FLB reference counts
authorDavid Matlack <dmatlack@google.com>
Thu, 23 Apr 2026 17:40:28 +0000 (17:40 +0000)
committerMike Rapoport (Microsoft) <rppt@kernel.org>
Sun, 31 May 2026 23:31:38 +0000 (02:31 +0300)
Use refcount_t instead of a raw integer to keep track of references on
incoming and outgoing FLBs. Using refcount_t provides protection from
overflow, underflow, and other issues.

Fixes: cab056f2aae7 ("liveupdate: luo_flb: introduce File-Lifecycle-Bound global state")
Signed-off-by: David Matlack <dmatlack@google.com>
Reviewed-by: Samiullah Khawaja <skhawaja@google.com>
Reviewed-by: Pasha Tatashin <pasha.tatashin@soleen.com>
Link: https://lore.kernel.org/r/20260423174032.3140399-2-dmatlack@google.com
Signed-off-by: Pasha Tatashin <pasha.tatashin@soleen.com>
Signed-off-by: Mike Rapoport (Microsoft) <rppt@kernel.org>
include/linux/liveupdate.h
kernel/liveupdate/luo_flb.c

index 30c5a39ff9e9c29936cecf664e4f985b24c6eb21..8d3bbc35c828b822461cb9246344814ac8830c3c 100644 (file)
@@ -12,6 +12,7 @@
 #include <linux/kho/abi/luo.h>
 #include <linux/list.h>
 #include <linux/mutex.h>
+#include <linux/refcount.h>
 #include <linux/rwsem.h>
 #include <linux/types.h>
 #include <uapi/linux/liveupdate.h>
@@ -175,7 +176,7 @@ struct liveupdate_flb_ops {
  * @retrieved: True once the FLB's retrieve() callback has run.
  */
 struct luo_flb_private_state {
-       long count;
+       refcount_t count;
        u64 data;
        void *obj;
        struct mutex lock;
index 00f5494812c4ab79156b207e8323fb982df74495..59c5f31ab767408cabde721f9470499c4eeb939e 100644 (file)
@@ -111,7 +111,7 @@ static int luo_flb_file_preserve_one(struct liveupdate_flb *flb)
        struct luo_flb_private *private = luo_flb_get_private(flb);
 
        scoped_guard(mutex, &private->outgoing.lock) {
-               if (!private->outgoing.count) {
+               if (!refcount_read(&private->outgoing.count)) {
                        struct liveupdate_flb_op_args args = {0};
                        int err;
 
@@ -126,8 +126,10 @@ static int luo_flb_file_preserve_one(struct liveupdate_flb *flb)
                        }
                        private->outgoing.data = args.data;
                        private->outgoing.obj = args.obj;
+                       refcount_set(&private->outgoing.count, 1);
+               } else {
+                       refcount_inc(&private->outgoing.count);
                }
-               private->outgoing.count++;
        }
 
        return 0;
@@ -138,8 +140,7 @@ static void luo_flb_file_unpreserve_one(struct liveupdate_flb *flb)
        struct luo_flb_private *private = luo_flb_get_private(flb);
 
        scoped_guard(mutex, &private->outgoing.lock) {
-               private->outgoing.count--;
-               if (!private->outgoing.count) {
+               if (refcount_dec_and_test(&private->outgoing.count)) {
                        struct liveupdate_flb_op_args args = {0};
 
                        args.flb = flb;
@@ -178,7 +179,7 @@ static int luo_flb_retrieve_one(struct liveupdate_flb *flb)
        for (int i = 0; i < fh->header_ser->count; i++) {
                if (!strcmp(fh->ser[i].name, flb->compatible)) {
                        private->incoming.data = fh->ser[i].data;
-                       private->incoming.count = fh->ser[i].count;
+                       refcount_set(&private->incoming.count, fh->ser[i].count);
                        found = true;
                        break;
                }
@@ -208,12 +209,8 @@ static int luo_flb_retrieve_one(struct liveupdate_flb *flb)
 static void luo_flb_file_finish_one(struct liveupdate_flb *flb)
 {
        struct luo_flb_private *private = luo_flb_get_private(flb);
-       u64 count;
 
-       scoped_guard(mutex, &private->incoming.lock)
-               count = --private->incoming.count;
-
-       if (!count) {
+       if (refcount_dec_and_test(&private->incoming.count)) {
                struct liveupdate_flb_op_args args = {0};
 
                if (!private->incoming.retrieved) {
@@ -652,12 +649,13 @@ void luo_flb_serialize(void)
        guard(rwsem_read)(&luo_register_rwlock);
        list_private_for_each_entry(gflb, &luo_flb_global.list, private.list) {
                struct luo_flb_private *private = luo_flb_get_private(gflb);
+               long count = refcount_read(&private->outgoing.count);
 
-               if (private->outgoing.count > 0) {
+               if (count > 0) {
                        strscpy(fh->ser[i].name, gflb->compatible,
                                sizeof(fh->ser[i].name));
                        fh->ser[i].data = private->outgoing.data;
-                       fh->ser[i].count = private->outgoing.count;
+                       fh->ser[i].count = count;
                        i++;
                }
        }