]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
datasets: separate DatasetGet in 2 functions
authorEric Leblond <el@stamus-networks.com>
Fri, 2 May 2025 15:30:14 +0000 (17:30 +0200)
committerVictor Julien <victor@inliniac.net>
Wed, 11 Jun 2025 18:49:18 +0000 (20:49 +0200)
This will be used to factorize the code with datajson.

src/datasets.c
src/datasets.h

index d14753811387fc6de2f3d193f2466f268ad87d79..566d2198e51f5dbcc4f0efbf0f2322a9d9aaf32a 100644 (file)
@@ -341,8 +341,8 @@ static void DatasetUpdateHashsize(const char *name, uint32_t hash_size)
     }
 }
 
-Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
-        uint64_t memcap, uint32_t hashsize)
+Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save,
+        const char *load, uint64_t *memcap, uint32_t *hashsize)
 {
     uint64_t default_memcap = 0;
     uint32_t default_hashsize = 0;
@@ -390,11 +390,14 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
     }
 
     DatasetGetDefaultMemcap(&default_memcap, &default_hashsize);
-    if (hashsize == 0) {
-        hashsize = default_hashsize;
+    if (*hashsize == 0) {
+        *hashsize = default_hashsize;
+    }
+    if (*memcap == 0) {
+        *memcap = default_memcap;
     }
 
-    if (!DatasetCheckHashsize(name, hashsize)) {
+    if (!DatasetCheckHashsize(name, *hashsize)) {
         goto out_err;
     }
 
@@ -413,15 +416,33 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
         strlcpy(set->load, load, sizeof(set->load));
         SCLogDebug("set \'%s\' loading \'%s\' from \'%s\'", set->name, load, set->load);
     }
+    return set;
+out_err:
+    if (set) {
+        if (set->hash) {
+            THashShutdown(set->hash);
+        }
+        SCFree(set);
+    }
+    SCMutexUnlock(&sets_lock);
+    return NULL;
+}
+
+Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
+        uint64_t memcap, uint32_t hashsize)
+{
+    Dataset *set = DatasetCreateOrGet(name, type, save, load, &memcap, &hashsize);
+    if (set == NULL) {
+        SCLogError("dataset %s creation failed", name);
+        return NULL;
+    }
 
     char cnf_name[128];
     snprintf(cnf_name, sizeof(cnf_name), "datasets.%s.hash", name);
-
     switch (type) {
         case DATASET_TYPE_MD5:
             set->hash = THashInit(cnf_name, sizeof(Md5Type), Md5StrSet, Md5StrFree, Md5StrHash,
-                    Md5StrCompare, NULL, NULL, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap, hashsize);
+                    Md5StrCompare, NULL, NULL, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatasetLoadMd5(set) < 0)
@@ -429,8 +450,7 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
             break;
         case DATASET_TYPE_STRING:
             set->hash = THashInit(cnf_name, sizeof(StringType), StringSet, StringFree, StringHash,
-                    StringCompare, NULL, StringGetLength, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap, hashsize);
+                    StringCompare, NULL, StringGetLength, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatasetLoadString(set) < 0)
@@ -438,8 +458,8 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
             break;
         case DATASET_TYPE_SHA256:
             set->hash = THashInit(cnf_name, sizeof(Sha256Type), Sha256StrSet, Sha256StrFree,
-                    Sha256StrHash, Sha256StrCompare, NULL, NULL, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap, hashsize);
+                    Sha256StrHash, Sha256StrCompare, NULL, NULL, load != NULL ? 1 : 0, memcap,
+                    hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatasetLoadSha256(set) < 0)
@@ -447,8 +467,7 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
             break;
         case DATASET_TYPE_IPV4:
             set->hash = THashInit(cnf_name, sizeof(IPv4Type), IPv4Set, IPv4Free, IPv4Hash,
-                    IPv4Compare, NULL, NULL, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap, hashsize);
+                    IPv4Compare, NULL, NULL, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatasetLoadIPv4(set) < 0)
@@ -456,8 +475,7 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
             break;
         case DATASET_TYPE_IPV6:
             set->hash = THashInit(cnf_name, sizeof(IPv6Type), IPv6Set, IPv6Free, IPv6Hash,
-                    IPv6Compare, NULL, NULL, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap, hashsize);
+                    IPv6Compare, NULL, NULL, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatasetLoadIPv6(set) < 0)
index e2a7a4725a142e27aeff850ae502133c37824e93..c3aad61c097a0bd6ccb6ba713b12aa50d177a02e 100644 (file)
@@ -68,6 +68,8 @@ Dataset *DatasetSearchByName(const char *name);
 Dataset *DatasetFind(const char *name, enum DatasetTypes type);
 Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
         uint64_t memcap, uint32_t hashsize);
+Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save,
+        const char *load, uint64_t *memcap, uint32_t *hashsize);
 int DatasetAdd(Dataset *set, const uint8_t *data, const uint32_t data_len);
 int DatasetRemove(Dataset *set, const uint8_t *data, const uint32_t data_len);
 int DatasetLookup(Dataset *set, const uint8_t *data, const uint32_t data_len);