]> git.ipfire.org Git - thirdparty/ipxe.git/commitdiff
[refcnt] Check reference validity on each use of ref_get() and ref_put()
authorMichael Brown <mcb30@ipxe.org>
Fri, 5 Nov 2010 22:24:46 +0000 (22:24 +0000)
committerMichael Brown <mcb30@ipxe.org>
Mon, 8 Nov 2010 03:35:35 +0000 (03:35 +0000)
Check that the reference count is valid (i.e. non-negative) on each
call to ref_get() and ref_put(), using an assert() at the point of
use.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
src/core/refcnt.c
src/include/ipxe/refcnt.h

index e540240fff15c16d4356cd84bc2425c08dd060d0..6117d741d6deb7817fdc175400ba57946e920520 100644 (file)
@@ -31,18 +31,16 @@ FILE_LICENCE ( GPL2_OR_LATER );
  * Increment reference count
  *
  * @v refcnt           Reference counter, or NULL
- * @ret refcnt         Reference counter
  *
  * If @c refcnt is NULL, no action is taken.
  */
-struct refcnt * ref_get ( struct refcnt *refcnt ) {
+void ref_increment ( struct refcnt *refcnt ) {
 
        if ( refcnt ) {
-               refcnt->refcnt++;
+               refcnt->count++;
                DBGC2 ( refcnt, "REFCNT %p incremented to %d\n",
-                       refcnt, refcnt->refcnt );
+                       refcnt, refcnt->count );
        }
-       return refcnt;
 }
 
 /**
@@ -55,18 +53,28 @@ struct refcnt * ref_get ( struct refcnt *refcnt ) {
  *
  * If @c refcnt is NULL, no action is taken.
  */
-void ref_put ( struct refcnt *refcnt ) {
+void ref_decrement ( struct refcnt *refcnt ) {
 
        if ( ! refcnt )
                return;
 
-       refcnt->refcnt--;
+       refcnt->count--;
        DBGC2 ( refcnt, "REFCNT %p decremented to %d\n",
-               refcnt, refcnt->refcnt );
+               refcnt, refcnt->count );
 
-       if ( refcnt->refcnt >= 0 )
+       if ( refcnt->count >= 0 )
                return;
 
+       if ( refcnt->count < -1 ) {
+               DBGC ( refcnt, "REFCNT %p decremented too far (%d)!\n",
+                      refcnt, refcnt->count );
+               /* Avoid multiple calls to free(), which typically
+                * result in memory corruption that is very hard to
+                * track down.
+                */
+               return;
+       }
+
        if ( refcnt->free ) {
                DBGC ( refcnt, "REFCNT %p being freed via method %p\n",
                       refcnt, refcnt->free );
index 49fce5044c268b0c2521455081c85fa303dd1f3e..0e8b8658c5074afcad18fcd30c242956677ac15a 100644 (file)
@@ -9,6 +9,9 @@
 
 FILE_LICENCE ( GPL2_OR_LATER );
 
+#include <stddef.h>
+#include <assert.h>
+
 /**
  * A reference counter
  *
@@ -26,7 +29,7 @@ struct refcnt {
         * When this count is decremented below zero, the free()
         * method will be called.
         */
-       int refcnt;
+       int count;
        /** Free containing object
         *
         * This method is called when the reference count is
@@ -75,8 +78,37 @@ ref_init ( struct refcnt *refcnt,
                .free = free_fn,                                        \
        }
 
-extern struct refcnt * ref_get ( struct refcnt *refcnt );
-extern void ref_put ( struct refcnt *refcnt );
+extern void ref_increment ( struct refcnt *refcnt );
+extern void ref_decrement ( struct refcnt *refcnt );
+
+/**
+ * Get additional reference to object
+ *
+ * @v refcnt           Reference counter, or NULL
+ * @ret refcnt         Reference counter
+ *
+ * If @c refcnt is NULL, no action is taken.
+ */
+#define ref_get( refcnt ) ( {                                          \
+       if ( refcnt )                                                   \
+               assert ( (refcnt)->count >= 0 );                        \
+       ref_increment ( refcnt );                                       \
+       (refcnt); } )
+
+/**
+ * Drop reference to object
+ *
+ * @v refcnt           Reference counter, or NULL
+ * @ret refcnt         Reference counter
+ *
+ * If @c refcnt is NULL, no action is taken.
+ */
+#define ref_put( refcnt ) do {                                         \
+       if ( refcnt )                                                   \
+               assert ( (refcnt)->count >= 0 );                        \
+       ref_decrement ( refcnt );                                       \
+       } while ( 0 )
+
 extern void ref_no_free ( struct refcnt *refcnt );
 
 #endif /* _IPXE_REFCNT_H */