]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
datasets: factorize dataset creation
authorEric Leblond <el@stamus-networks.com>
Fri, 2 May 2025 18:48:59 +0000 (20:48 +0200)
committerVictor Julien <victor@inliniac.net>
Wed, 11 Jun 2025 18:49:18 +0000 (20:49 +0200)
Factorize DatasetGet and DatajsonGet to only have the difference
between the two in the respective function.

src/datasets-context-json.c
src/datasets.c
src/datasets.h

index 11b0035ab60d585b568884ba9ddae07a405c8d10..0317fd830471038c610af947fa22f5f25d828263 100644 (file)
@@ -360,6 +360,7 @@ static uint32_t DatajsonAddStringElement(Dataset *set, json_t *value, char *json
     if (set->remove_key) {
         json_object_del(value, json_key);
     }
+
     elt.value = json_dumps(value, JSON_COMPACT);
     elt.len = strlen(elt.value);
 
@@ -642,74 +643,34 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
         uint32_t hashsize, char *json_key_value, char *json_array_key, DatasetFormats format,
         bool remove_key)
 {
-    uint64_t default_memcap = 0;
-    uint32_t default_hashsize = 0;
-    if (strlen(name) > DATASET_NAME_MAX_LEN) {
-        SCLogError("dataset name too long");
+    Dataset *set = NULL;
+
+    int ret = DatasetCreateOrGet(name, type, NULL, load, &memcap, &hashsize, &set);
+    if (ret < 0) {
+        SCLogError("dataset with JSON %s creation failed", name);
         return NULL;
     }
-
-    DatasetLock();
-    Dataset *set = DatasetSearchByName(name);
-    if (set) {
-        if (type != DATASET_TYPE_NOTSET && set->type != type) {
-            SCLogError("dataset %s already "
-                       "exists and is of type %u",
-                    set->name, set->type);
+    if (ret == 1) {
+        SCLogDebug("dataset %s already exists", name);
+        if (set->remove_key != remove_key) {
+            SCLogError("dataset %s remove_key mismatch: %b != %b", set->name, set->remove_key,
+                    remove_key);
             DatasetUnlock();
             return NULL;
         }
-
-        if (load == NULL || strlen(load) == 0) {
-            // OK, rule keyword doesn't have to set state/load,
-            // even when yaml set has set it.
-        } else {
-            if ((load == NULL && strlen(set->load) > 0) ||
-                    (load != NULL && strcmp(set->load, load) != 0)) {
-                SCLogError("dataset %s load mismatch: %s != %s", set->name, set->load, load);
-                DatasetUnlock();
-                return NULL;
-            }
-        }
-
         DatasetUnlock();
         return set;
     }
 
-    if (type == DATASET_TYPE_NOTSET) {
-        SCLogError("dataset %s not defined", name);
-        goto out_err;
-    }
-
-    set = DatasetAlloc(name);
-    if (set == NULL) {
-        SCLogError("dataset %s allocation failed", name);
-        goto out_err;
-    }
-
-    strlcpy(set->name, name, sizeof(set->name));
-    set->type = type;
     set->remove_key = remove_key;
-    if (load && strlen(load)) {
-        strlcpy(set->load, load, sizeof(set->load));
-        SCLogDebug("set \'%s\' loading \'%s\' from \'%s\'", set->name, load, set->load);
-    }
 
-    static const char conf_format_str[] = "datasets.%s.hash";
-    char cnf_name[DATASET_NAME_MAX_LEN + (sizeof(conf_format_str) / sizeof(char))];
-    int p_ret = snprintf(cnf_name, sizeof(cnf_name), conf_format_str, name);
-    if (p_ret == 0) {
-        SCLogError("Can't build configuration variable for set: '%s'", name);
-        goto out_err;
-    }
-
-    DatasetGetDefaultMemcap(&default_memcap, &default_hashsize);
+    char cnf_name[128];
+    snprintf(cnf_name, sizeof(cnf_name), "datasets.%s.hash", name);
     switch (type) {
         case DATASET_TYPE_MD5:
             set->hash = THashInit(cnf_name, sizeof(Md5Type), Md5StrJsonSet, Md5StrJsonFree,
                     Md5StrHash, Md5StrCompare, NULL, Md5StrJsonGetLength, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap,
-                    hashsize > 0 ? hashsize : default_hashsize);
+                    memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatajsonLoadMd5(set, json_key_value, json_array_key, format) < 0)
@@ -718,8 +679,7 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
         case DATASET_TYPE_STRING:
             set->hash = THashInit(cnf_name, sizeof(StringType), StringJsonSet, StringJsonFree,
                     StringHash, StringCompare, NULL, StringJsonGetLength, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap,
-                    hashsize > 0 ? hashsize : default_hashsize);
+                    memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatajsonLoadString(set, json_key_value, json_array_key, format) < 0) {
@@ -730,8 +690,7 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
         case DATASET_TYPE_SHA256:
             set->hash = THashInit(cnf_name, sizeof(Sha256Type), Sha256StrJsonSet, Sha256StrJsonFree,
                     Sha256StrHash, Sha256StrCompare, NULL, Sha256StrJsonGetLength,
-                    load != NULL ? 1 : 0, memcap > 0 ? memcap : default_memcap,
-                    hashsize > 0 ? hashsize : default_hashsize);
+                    load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatajsonLoadSha256(set, json_key_value, json_array_key, format) < 0)
@@ -739,9 +698,7 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
             break;
         case DATASET_TYPE_IPV4:
             set->hash = THashInit(cnf_name, sizeof(IPv4Type), IPv4JsonSet, IPv4JsonFree, IPv4Hash,
-                    IPv4Compare, NULL, IPv4JsonGetLength, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap,
-                    hashsize > 0 ? hashsize : default_hashsize);
+                    IPv4Compare, NULL, IPv4JsonGetLength, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatajsonLoadIPv4(set, json_key_value, json_array_key, format) < 0)
@@ -749,9 +706,7 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
             break;
         case DATASET_TYPE_IPV6:
             set->hash = THashInit(cnf_name, sizeof(IPv6Type), IPv6JsonSet, IPv6JsonFree, IPv6Hash,
-                    IPv6Compare, NULL, IPv6JsonGetLength, load != NULL ? 1 : 0,
-                    memcap > 0 ? memcap : default_memcap,
-                    hashsize > 0 ? hashsize : default_hashsize);
+                    IPv6Compare, NULL, IPv6JsonGetLength, load != NULL ? 1 : 0, memcap, hashsize);
             if (set->hash == NULL)
                 goto out_err;
             if (DatajsonLoadIPv6(set, json_key_value, json_array_key, format) < 0)
@@ -762,7 +717,10 @@ Dataset *DatajsonGet(const char *name, enum DatasetTypes type, const char *load,
     SCLogDebug(
             "set %p/%s type %u save %s load %s", set, set->name, set->type, set->save, set->load);
 
-    DatasetAppendSet(set);
+    if (DatasetAppendSet(set) < 0) {
+        SCLogError("dataset %s append failed", name);
+        goto out_err;
+    }
 
     DatasetUnlock();
     return set;
index 566d2198e51f5dbcc4f0efbf0f2322a9d9aaf32a..1d6bafeb0eb64786ebab9325f748e801609e4ac1 100644 (file)
@@ -52,6 +52,7 @@ uint32_t dataset_max_total_hashsize = 16777216;
 uint32_t dataset_used_hashsize = 0;
 
 int DatasetAddwRep(Dataset *set, const uint8_t *data, const uint32_t data_len, DataRepType *rep);
+static void DatasetUpdateHashsize(const char *name, uint32_t hash_size);
 
 static inline void DatasetUnlockData(THashData *d)
 {
@@ -75,10 +76,27 @@ enum DatasetTypes DatasetGetTypeFromString(const char *s)
     return DATASET_TYPE_NOTSET;
 }
 
-void DatasetAppendSet(Dataset *set)
+int DatasetAppendSet(Dataset *set)
 {
+
+    if (set->hash == NULL) {
+        return -1;
+    }
+
+    if (SC_ATOMIC_GET(set->hash->memcap_reached)) {
+        SCLogError("dataset too large for set memcap");
+        return -1;
+    }
+
+    SCLogDebug(
+            "set %p/%s type %u save %s load %s", set, set->name, set->type, set->save, set->load);
+
     set->next = sets;
     sets = set;
+
+    /* hash size accounting */
+    DatasetUpdateHashsize(set->name, set->hash->config.hash_size);
+    return 0;
 }
 
 void DatasetLock(void)
@@ -341,24 +359,31 @@ static void DatasetUpdateHashsize(const char *name, uint32_t hash_size)
     }
 }
 
-Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save,
-        const char *load, uint64_t *memcap, uint32_t *hashsize)
+/**
+ * \return -1 on error
+ * \return 0 on successful creation
+ * \return 1 if the dataset already exists
+ *
+ * dataset global lock is held after return if set is found or created
+ */
+int DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
+        uint64_t *memcap, uint32_t *hashsize, Dataset **ret_set)
 {
     uint64_t default_memcap = 0;
     uint32_t default_hashsize = 0;
     if (strlen(name) > DATASET_NAME_MAX_LEN) {
-        return NULL;
+        return -1;
     }
 
-    SCMutexLock(&sets_lock);
+    DatasetLock();
     Dataset *set = DatasetSearchByName(name);
     if (set) {
         if (type != DATASET_TYPE_NOTSET && set->type != type) {
             SCLogError("dataset %s already "
                        "exists and is of type %u",
                     set->name, set->type);
-            SCMutexUnlock(&sets_lock);
-            return NULL;
+            DatasetUnlock();
+            return -1;
         }
 
         if ((save == NULL || strlen(save) == 0) &&
@@ -369,24 +394,24 @@ Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char
             if ((save == NULL && strlen(set->save) > 0) ||
                     (save != NULL && strcmp(set->save, save) != 0)) {
                 SCLogError("dataset %s save mismatch: %s != %s", set->name, set->save, save);
-                SCMutexUnlock(&sets_lock);
-                return NULL;
+                DatasetUnlock();
+                return -1;
             }
             if ((load == NULL && strlen(set->load) > 0) ||
                     (load != NULL && strcmp(set->load, load) != 0)) {
                 SCLogError("dataset %s load mismatch: %s != %s", set->name, set->load, load);
-                SCMutexUnlock(&sets_lock);
-                return NULL;
+                DatasetUnlock();
+                return -1;
             }
         }
 
-        SCMutexUnlock(&sets_lock);
-        return set;
-    } else {
-        if (type == DATASET_TYPE_NOTSET) {
-            SCLogError("dataset %s not defined", name);
-            goto out_err;
-        }
+        *ret_set = set;
+        return 1;
+    }
+
+    if (type == DATASET_TYPE_NOTSET) {
+        SCLogError("dataset %s not defined", name);
+        goto out_err;
     }
 
     DatasetGetDefaultMemcap(&default_memcap, &default_hashsize);
@@ -416,7 +441,9 @@ Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char
         strlcpy(set->load, load, sizeof(set->load));
         SCLogDebug("set \'%s\' loading \'%s\' from \'%s\'", set->name, load, set->load);
     }
-    return set;
+
+    *ret_set = set;
+    return 0;
 out_err:
     if (set) {
         if (set->hash) {
@@ -424,18 +451,25 @@ out_err:
         }
         SCFree(set);
     }
-    SCMutexUnlock(&sets_lock);
-    return NULL;
+    DatasetUnlock();
+    return -1;
 }
 
 Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
         uint64_t memcap, uint32_t hashsize)
 {
-    Dataset *set = DatasetCreateOrGet(name, type, save, load, &memcap, &hashsize);
-    if (set == NULL) {
+    Dataset *set = NULL;
+
+    int ret = DatasetCreateOrGet(name, type, save, load, &memcap, &hashsize, &set);
+    if (ret < 0) {
         SCLogError("dataset %s creation failed", name);
         return NULL;
     }
+    if (ret == 1) {
+        SCLogDebug("dataset %s already exists", name);
+        DatasetUnlock();
+        return set;
+    }
 
     char cnf_name[128];
     snprintf(cnf_name, sizeof(cnf_name), "datasets.%s.hash", name);
@@ -482,26 +516,13 @@ Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save,
                 goto out_err;
             break;
     }
-    if (set->hash == NULL) {
-        goto out_err;
-    }
 
-    if (SC_ATOMIC_GET(set->hash->memcap_reached)) {
-        SCLogError("dataset too large for set memcap");
+    if (DatasetAppendSet(set) < 0) {
+        SCLogError("dataset %s append failed", name);
         goto out_err;
     }
 
-    SCLogDebug("set %p/%s type %u save %s load %s",
-            set, set->name, set->type, set->save, set->load);
-
-    set->next = sets;
-    sets = set;
-
-    /* hash size accounting */
-    DEBUG_VALIDATE_BUG_ON(set->hash->config.hash_size != hashsize);
-    DatasetUpdateHashsize(set->name, set->hash->config.hash_size);
-
-    SCMutexUnlock(&sets_lock);
+    DatasetUnlock();
     return set;
 out_err:
     if (set) {
@@ -510,7 +531,7 @@ out_err:
         }
         SCFree(set);
     }
-    SCMutexUnlock(&sets_lock);
+    DatasetUnlock();
     return NULL;
 }
 
index c3aad61c097a0bd6ccb6ba713b12aa50d177a02e..ce29ac7b3d255bf5ea5d4d63aee7da9d166b94b7 100644 (file)
@@ -60,7 +60,7 @@ typedef struct Dataset {
 } Dataset;
 
 enum DatasetTypes DatasetGetTypeFromString(const char *s);
-void DatasetAppendSet(Dataset *set);
+int DatasetAppendSet(Dataset *set);
 Dataset *DatasetAlloc(const char *name);
 void DatasetLock(void);
 void DatasetUnlock(void);
@@ -68,8 +68,8 @@ Dataset *DatasetSearchByName(const char *name);
 Dataset *DatasetFind(const char *name, enum DatasetTypes type);
 Dataset *DatasetGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
         uint64_t memcap, uint32_t hashsize);
-Dataset *DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save,
-        const char *load, uint64_t *memcap, uint32_t *hashsize);
+int DatasetCreateOrGet(const char *name, enum DatasetTypes type, const char *save, const char *load,
+        uint64_t *memcap, uint32_t *hashsize, Dataset **ret_set);
 int DatasetAdd(Dataset *set, const uint8_t *data, const uint32_t data_len);
 int DatasetRemove(Dataset *set, const uint8_t *data, const uint32_t data_len);
 int DatasetLookup(Dataset *set, const uint8_t *data, const uint32_t data_len);