]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
src: add refcount asserts
authorFlorian Westphal <fw@strlen.de>
Wed, 2 Apr 2025 22:32:29 +0000 (00:32 +0200)
committerFlorian Westphal <fw@strlen.de>
Wed, 29 Oct 2025 12:44:11 +0000 (13:44 +0100)
_get() functions must not be used when refcnt is 0, as expr_free()
releases expressions on 1 -> 0 transition.

Also, check that a refcount would not overflow from UINT_MAX to 0.
Use INT_MAX to also catch refcount leaks sooner, we don't expect
2**31 get()s on same object.

This helps catching use-after-free refcounting bugs even when nft
is built without ASAN support.

v3: use a macro + BUG to get more info without a coredump.

Signed-off-by: Florian Westphal <fw@strlen.de>
include/rule.h
include/utils.h
src/expression.c
src/rule.c

index 8d2f29d09337e52aa8976004a0098ea15baa348b..bcdc50cad59db050896bbcae7b77a022b0b0f687 100644 (file)
@@ -115,7 +115,7 @@ struct symbol {
        struct list_head        list;
        const char              *identifier;
        struct expr             *expr;
-       int                     refcnt;
+       unsigned int            refcnt;
 };
 
 extern void symbol_bind(struct scope *scope, const char *identifier,
index 474c7595f7cdc2a3fcd55511176c2af047814207..16b9426455ea033ac79c29a13cce5926b0b020cd 100644 (file)
@@ -6,6 +6,7 @@
 #include <stdio.h>
 #include <unistd.h>
 #include <assert.h>
+#include <limits.h>
 #include <list.h>
 #include <gmputil.h>
 
 #define __must_be_array(a) \
        BUILD_BUG_ON_ZERO(__builtin_types_compatible_p(typeof(a), typeof(&a[0])))
 
+#define assert_refcount_safe(refcnt) do {                                              \
+       if ((refcnt) == 0)                                                              \
+               BUG("refcount was 0");                                                  \
+       if ((refcnt) >= INT_MAX)                                                        \
+               BUG("refcount saturated");                                              \
+} while (0)
+
 #define container_of(ptr, type, member) ({                     \
        typeof( ((type *)0)->member ) *__mptr = (ptr);          \
        (type *)( (void *)__mptr - offsetof(type,member) );})
index 019c263f187b834a8de7341ecc2f6ff91c242f58..6c7bebe0a3d1262a6660000eeff4d7fc52da3783 100644 (file)
@@ -68,6 +68,7 @@ struct expr *expr_clone(const struct expr *expr)
 
 struct expr *expr_get(struct expr *expr)
 {
+       assert_refcount_safe(expr->refcnt);
        expr->refcnt++;
        return expr;
 }
@@ -84,6 +85,8 @@ void expr_free(struct expr *expr)
 {
        if (expr == NULL)
                return;
+
+       assert_refcount_safe(expr->refcnt);
        if (--expr->refcnt > 0)
                return;
 
@@ -343,11 +346,13 @@ static void variable_expr_clone(struct expr *new, const struct expr *expr)
        new->scope      = expr->scope;
        new->sym        = expr->sym;
 
+       assert_refcount_safe(expr->sym->refcnt);
        expr->sym->refcnt++;
 }
 
 static void variable_expr_destroy(struct expr *expr)
 {
+       assert_refcount_safe(expr->sym->refcnt);
        expr->sym->refcnt--;
 }
 
index d0a62a3ee002d3b199c1ec263f68649f49fa2860..f51d605cc1ad968a240af0107be3900c07df45cb 100644 (file)
@@ -181,6 +181,7 @@ struct set *set_clone(const struct set *set)
 
 struct set *set_get(struct set *set)
 {
+       assert_refcount_safe(set->refcnt);
        set->refcnt++;
        return set;
 }
@@ -189,6 +190,7 @@ void set_free(struct set *set)
 {
        struct stmt *stmt, *next;
 
+       assert_refcount_safe(set->refcnt);
        if (--set->refcnt > 0)
                return;
 
@@ -484,12 +486,14 @@ struct rule *rule_alloc(const struct location *loc, const struct handle *h)
 
 struct rule *rule_get(struct rule *rule)
 {
+       assert_refcount_safe(rule->refcnt);
        rule->refcnt++;
        return rule;
 }
 
 void rule_free(struct rule *rule)
 {
+       assert_refcount_safe(rule->refcnt);
        if (--rule->refcnt > 0)
                return;
        stmt_list_free(&rule->stmts);
@@ -606,6 +610,7 @@ struct symbol *symbol_get(const struct scope *scope, const char *identifier)
        if (!sym)
                return NULL;
 
+       assert_refcount_safe(sym->refcnt);
        sym->refcnt++;
 
        return sym;
@@ -613,6 +618,7 @@ struct symbol *symbol_get(const struct scope *scope, const char *identifier)
 
 static void symbol_put(struct symbol *sym)
 {
+       assert_refcount_safe(sym->refcnt);
        if (--sym->refcnt == 0) {
                free_const(sym->identifier);
                expr_free(sym->expr);
@@ -732,6 +738,7 @@ struct chain *chain_alloc(void)
 
 struct chain *chain_get(struct chain *chain)
 {
+       assert_refcount_safe(chain->refcnt);
        chain->refcnt++;
        return chain;
 }
@@ -741,6 +748,7 @@ void chain_free(struct chain *chain)
        struct rule *rule, *next;
        int i;
 
+       assert_refcount_safe(chain->refcnt);
        if (--chain->refcnt > 0)
                return;
        list_for_each_entry_safe(rule, next, &chain->rules, list)
@@ -1176,6 +1184,7 @@ void table_free(struct table *table)
        struct set *set, *nset;
        struct obj *obj, *nobj;
 
+       assert_refcount_safe(table->refcnt);
        if (--table->refcnt > 0)
                return;
        if (table->comment)
@@ -1214,6 +1223,7 @@ void table_free(struct table *table)
 
 struct table *table_get(struct table *table)
 {
+       assert_refcount_safe(table->refcnt);
        table->refcnt++;
        return table;
 }
@@ -1687,12 +1697,14 @@ struct obj *obj_alloc(const struct location *loc)
 
 struct obj *obj_get(struct obj *obj)
 {
+       assert_refcount_safe(obj->refcnt);
        obj->refcnt++;
        return obj;
 }
 
 void obj_free(struct obj *obj)
 {
+       assert_refcount_safe(obj->refcnt);
        if (--obj->refcnt > 0)
                return;
        free_const(obj->comment);
@@ -2270,6 +2282,7 @@ struct flowtable *flowtable_alloc(const struct location *loc)
 
 struct flowtable *flowtable_get(struct flowtable *flowtable)
 {
+       assert_refcount_safe(flowtable->refcnt);
        flowtable->refcnt++;
        return flowtable;
 }
@@ -2278,6 +2291,7 @@ void flowtable_free(struct flowtable *flowtable)
 {
        int i;
 
+       assert_refcount_safe(flowtable->refcnt);
        if (--flowtable->refcnt > 0)
                return;
        handle_free(&flowtable->handle);