]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
threads: add storage api, based on flow storage
authorJason Ish <jason.ish@oisf.net>
Fri, 11 Oct 2024 17:41:47 +0000 (11:41 -0600)
committerVictor Julien <victor@inliniac.net>
Wed, 13 Nov 2024 09:53:59 +0000 (10:53 +0100)
src/Makefile.am
src/thread-storage.c [new file with mode: 0644]
src/thread-storage.h [new file with mode: 0644]
src/threads.c
src/threadvars.h
src/tm-threads.c
src/util-storage.c
src/util-storage.h

index f8e1e44d4affa63c8c5818f5523a609b0a1181a8..6032c1962dccf09aeb28eef56e9f7e50c3193505 100755 (executable)
@@ -438,6 +438,7 @@ noinst_HEADERS = \
        suricata-common.h \
        suricata.h \
        suricata-plugin.h \
+       thread-storage.h \
        threads-debug.h \
        threads.h \
        threads-profile.h \
@@ -990,6 +991,7 @@ libsuricata_c_a_SOURCES = \
        stream-tcp-sack.c \
        stream-tcp-util.c \
        suricata.c \
+       thread-storage.c \
        threads.c \
        tm-modules.c \
        tmqh-flow.c \
diff --git a/src/thread-storage.c b/src/thread-storage.c
new file mode 100644 (file)
index 0000000..977f4fd
--- /dev/null
@@ -0,0 +1,212 @@
+/* Copyright (C) 2024 Open Information Security Foundation
+ *
+ * You can copy, redistribute or modify this Program under the terms of
+ * the GNU General Public License version 2 as published by the Free
+ * Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * version 2 along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+ * 02110-1301, USA.
+ */
+
+#include "suricata-common.h"
+#include "thread-storage.h"
+#include "util-storage.h"
+#include "util-unittest.h"
+
+const StorageEnum storage_type = STORAGE_THREAD;
+
+unsigned int ThreadStorageSize(void)
+{
+    return StorageGetSize(storage_type);
+}
+
+void *ThreadGetStorageById(const ThreadVars *tv, ThreadStorageId id)
+{
+    return StorageGetById(tv->storage, storage_type, id.id);
+}
+
+int ThreadSetStorageById(ThreadVars *tv, ThreadStorageId id, void *ptr)
+{
+    return StorageSetById(tv->storage, storage_type, id.id, ptr);
+}
+
+void *ThreadAllocStorageById(ThreadVars *tv, ThreadStorageId id)
+{
+    return StorageAllocByIdPrealloc(tv->storage, storage_type, id.id);
+}
+
+void ThreadFreeStorageById(ThreadVars *tv, ThreadStorageId id)
+{
+    StorageFreeById(tv->storage, storage_type, id.id);
+}
+
+void ThreadFreeStorage(ThreadVars *tv)
+{
+    if (ThreadStorageSize() > 0)
+        StorageFreeAll(tv->storage, storage_type);
+}
+
+ThreadStorageId ThreadStorageRegister(const char *name, const unsigned int size,
+        void *(*Alloc)(unsigned int), void (*Free)(void *))
+{
+    int id = StorageRegister(storage_type, name, size, Alloc, Free);
+    ThreadStorageId tsi = { .id = id };
+    return tsi;
+}
+
+#ifdef UNITTESTS
+
+static void *StorageTestAlloc(unsigned int size)
+{
+    return SCCalloc(1, size);
+}
+
+static void StorageTestFree(void *x)
+{
+    SCFree(x);
+}
+
+static int ThreadStorageTest01(void)
+{
+    StorageInit();
+
+    ThreadStorageId id1 = ThreadStorageRegister("test", 8, StorageTestAlloc, StorageTestFree);
+    FAIL_IF(id1.id < 0);
+
+    ThreadStorageId id2 = ThreadStorageRegister("variable", 24, StorageTestAlloc, StorageTestFree);
+    FAIL_IF(id2.id < 0);
+
+    ThreadStorageId id3 =
+            ThreadStorageRegister("store", sizeof(void *), StorageTestAlloc, StorageTestFree);
+    FAIL_IF(id3.id < 0);
+
+    FAIL_IF(StorageFinalize() < 0);
+
+    ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
+    FAIL_IF_NULL(tv);
+
+    void *ptr = ThreadGetStorageById(tv, id1);
+    FAIL_IF_NOT_NULL(ptr);
+
+    ptr = ThreadGetStorageById(tv, id2);
+    FAIL_IF_NOT_NULL(ptr);
+
+    ptr = ThreadGetStorageById(tv, id3);
+    FAIL_IF_NOT_NULL(ptr);
+
+    void *ptr1a = ThreadAllocStorageById(tv, id1);
+    FAIL_IF_NULL(ptr1a);
+
+    void *ptr2a = ThreadAllocStorageById(tv, id2);
+    FAIL_IF_NULL(ptr2a);
+
+    void *ptr3a = ThreadAllocStorageById(tv, id3);
+    FAIL_IF_NULL(ptr3a);
+
+    void *ptr1b = ThreadGetStorageById(tv, id1);
+    FAIL_IF(ptr1a != ptr1b);
+
+    void *ptr2b = ThreadGetStorageById(tv, id2);
+    FAIL_IF(ptr2a != ptr2b);
+
+    void *ptr3b = ThreadGetStorageById(tv, id3);
+    FAIL_IF(ptr3a != ptr3b);
+
+    ThreadFreeStorage(tv);
+    StorageCleanup();
+    SCFree(tv);
+    PASS;
+}
+
+static int ThreadStorageTest02(void)
+{
+    StorageInit();
+
+    ThreadStorageId id1 = ThreadStorageRegister("test", sizeof(void *), NULL, StorageTestFree);
+    FAIL_IF(id1.id < 0);
+
+    FAIL_IF(StorageFinalize() < 0);
+
+    ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
+    FAIL_IF_NULL(tv);
+
+    void *ptr = ThreadGetStorageById(tv, id1);
+    FAIL_IF_NOT_NULL(ptr);
+
+    void *ptr1a = SCMalloc(128);
+    FAIL_IF_NULL(ptr1a);
+
+    ThreadSetStorageById(tv, id1, ptr1a);
+
+    void *ptr1b = ThreadGetStorageById(tv, id1);
+    FAIL_IF(ptr1a != ptr1b);
+
+    ThreadFreeStorage(tv);
+    StorageCleanup();
+    PASS;
+}
+
+static int ThreadStorageTest03(void)
+{
+    StorageInit();
+
+    ThreadStorageId id1 = ThreadStorageRegister("test1", sizeof(void *), NULL, StorageTestFree);
+    FAIL_IF(id1.id < 0);
+
+    ThreadStorageId id2 = ThreadStorageRegister("test2", sizeof(void *), NULL, StorageTestFree);
+    FAIL_IF(id2.id < 0);
+
+    ThreadStorageId id3 = ThreadStorageRegister("test3", 32, StorageTestAlloc, StorageTestFree);
+    FAIL_IF(id3.id < 0);
+
+    FAIL_IF(StorageFinalize() < 0);
+
+    ThreadVars *tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
+    FAIL_IF_NULL(tv);
+
+    void *ptr = ThreadGetStorageById(tv, id1);
+    FAIL_IF_NOT_NULL(ptr);
+
+    void *ptr1a = SCMalloc(128);
+    FAIL_IF_NULL(ptr1a);
+
+    ThreadSetStorageById(tv, id1, ptr1a);
+
+    void *ptr2a = SCMalloc(256);
+    FAIL_IF_NULL(ptr2a);
+
+    ThreadSetStorageById(tv, id2, ptr2a);
+
+    void *ptr3a = ThreadAllocStorageById(tv, id3);
+    FAIL_IF_NULL(ptr3a);
+
+    void *ptr1b = ThreadGetStorageById(tv, id1);
+    FAIL_IF(ptr1a != ptr1b);
+
+    void *ptr2b = ThreadGetStorageById(tv, id2);
+    FAIL_IF(ptr2a != ptr2b);
+
+    void *ptr3b = ThreadGetStorageById(tv, id3);
+    FAIL_IF(ptr3a != ptr3b);
+
+    ThreadFreeStorage(tv);
+    StorageCleanup();
+    PASS;
+}
+#endif
+
+void RegisterThreadStorageTests(void)
+{
+#ifdef UNITTESTS
+    UtRegisterTest("ThreadStorageTest01", ThreadStorageTest01);
+    UtRegisterTest("ThreadStorageTest02", ThreadStorageTest02);
+    UtRegisterTest("ThreadStorageTest03", ThreadStorageTest03);
+#endif
+}
diff --git a/src/thread-storage.h b/src/thread-storage.h
new file mode 100644 (file)
index 0000000..5dd2257
--- /dev/null
@@ -0,0 +1,45 @@
+/* Copyright (C) 2024 Open Information Security Foundation
+ *
+ * You can copy, redistribute or modify this Program under the terms of
+ * the GNU General Public License version 2 as published by the Free
+ * Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * version 2 along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+ * 02110-1301, USA.
+ */
+
+/**
+ * Thread wrapper around storage API.
+ */
+
+#ifndef SURICATA_THREAD_STORAGE_H
+#define SURICATA_THREAD_STORAGE_H
+
+#include "threadvars.h"
+
+typedef struct ThreadStorageId {
+    int id;
+} ThreadStorageId;
+
+unsigned int ThreadStorageSize(void);
+
+void *ThreadGetStorageById(const ThreadVars *tv, ThreadStorageId id);
+int ThreadSetStorageById(ThreadVars *tv, ThreadStorageId id, void *ptr);
+void *ThreadAllocStorageById(ThreadVars *tv, ThreadStorageId id);
+
+void ThreadFreeStorageById(ThreadVars *tv, ThreadStorageId id);
+void ThreadFreeStorage(ThreadVars *tv);
+
+void RegisterThreadStorageTests(void);
+
+ThreadStorageId ThreadStorageRegister(const char *name, const unsigned int size,
+        void *(*Alloc)(unsigned int), void (*Free)(void *));
+
+#endif /* SURICATA_THREAD_STORAGE_H */
index 1708a8f5cd3724e7e15dfc1675e20acde055363d..919e6422e32f015cce0ba215547b08d8b68f0e1b 100644 (file)
@@ -25,6 +25,7 @@
  */
 
 #include "suricata-common.h"
+#include "thread-storage.h"
 #include "util-unittest.h"
 #include "util-debug.h"
 #include "threads.h"
@@ -149,5 +150,6 @@ void ThreadMacrosRegisterTests(void)
     UtRegisterTest("ThreadMacrosTest03RWLocks", ThreadMacrosTest03RWLocks);
     UtRegisterTest("ThreadMacrosTest04RWLocks", ThreadMacrosTest04RWLocks);
 //    UtRegisterTest("ThreadMacrosTest05RWLocks", ThreadMacrosTest05RWLocks);
+    RegisterThreadStorageTests();
 #endif /* UNIT TESTS */
 }
index cebcdb4e3ac1c2b77a835e52de3e2280b6246f10..6f339e9839d5f652ce17d7cb7dcd00631f03dbbf 100644 (file)
@@ -28,6 +28,7 @@
 #include "counters.h"
 #include "packet-queue.h"
 #include "util-atomic.h"
+#include "util-storage.h"
 
 struct TmSlot_;
 
@@ -135,6 +136,7 @@ typedef struct ThreadVars_ {
     struct FlowQueue_ *flow_queue;
     bool break_loop;
 
+    Storage storage[];
 } ThreadVars;
 
 /** Thread setup flags: */
index b0d0f8686ba005ae4c05dfc4b1073ceb3b79052f..c65995ad351b765de3cff0ba19cccea137641847 100644 (file)
@@ -30,6 +30,7 @@
 #include "stream.h"
 #include "runmodes.h"
 #include "threadvars.h"
+#include "thread-storage.h"
 #include "tm-queues.h"
 #include "tm-queuehandlers.h"
 #include "tm-threads.h"
@@ -919,7 +920,7 @@ ThreadVars *TmThreadCreate(const char *name, const char *inq_name, const char *i
     SCLogDebug("creating thread \"%s\"...", name);
 
     /* XXX create separate function for this: allocate a thread container */
-    tv = SCCalloc(1, sizeof(ThreadVars));
+    tv = SCCalloc(1, sizeof(ThreadVars) + ThreadStorageSize());
     if (unlikely(tv == NULL))
         goto error;
 
@@ -1577,6 +1578,8 @@ static void TmThreadFree(ThreadVars *tv)
 
     SCLogDebug("Freeing thread '%s'.", tv->name);
 
+    ThreadFreeStorage(tv);
+
     if (tv->flow_queue) {
         BUG_ON(tv->flow_queue->qlen != 0);
         SCFree(tv->flow_queue);
index 02f69a568cd2eb0500b58c58c59bf861a25ce799..bae2514323153769de43daec279aa16819c38c6b 100644 (file)
@@ -59,6 +59,8 @@ static const char *StoragePrintType(StorageEnum type)
             return "ippair";
         case STORAGE_DEVICE:
             return "livedevice";
+        case STORAGE_THREAD:
+            return "thread";
         case STORAGE_MAX:
             return "max";
     }
index 11d64bdbecbd454445a9d36676c3f3383080140e..fce1f964eb14bcb2e73e0c0fdb1c449f2b32b768 100644 (file)
@@ -31,6 +31,7 @@ typedef enum StorageEnum_ {
     STORAGE_FLOW,
     STORAGE_IPPAIR,
     STORAGE_DEVICE,
+    STORAGE_THREAD,
 
     STORAGE_MAX,
 } StorageEnum;