]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
Add Host specific wrapper to StorageRegister()
authorVictor Julien <victor@inliniac.net>
Fri, 1 Mar 2013 13:46:47 +0000 (14:46 +0100)
committerVictor Julien <victor@inliniac.net>
Sun, 28 Jul 2013 21:41:11 +0000 (23:41 +0200)
src/host-storage.c
src/host-storage.h
src/suricata.c
src/util-storage.c
src/util-storage.h

index 0e028a739a79cd3d983c19b815886f8894b71b09..3b391adfb5ec911595073ecf71503edbbcbd387b 100644 (file)
@@ -32,29 +32,246 @@ unsigned int HostStorageSize(void) {
 }
 
 void *HostGetStorageById(Host *h, int id) {
-    return StorageGetById((Storage *)(h + HostStorageSize()), STORAGE_HOST, id);
+    return StorageGetById((Storage *)((void *)h + sizeof(Host)), STORAGE_HOST, id);
+}
+
+int HostSetStorageById(Host *h, int id, void *ptr) {
+    return StorageSetById((Storage *)((void *)h + sizeof(Host)), STORAGE_HOST, id, ptr);
 }
 
 void *HostAllocStorageById(Host *h, int id) {
-    return StorageAllocById((Storage **)&h + HostStorageSize(), STORAGE_HOST, id);
+    return StorageAllocByIdPrealloc((Storage *)((void *)h + sizeof(Host)), STORAGE_HOST, id);
 }
 
 void HostFreeStorageById(Host *h, int id) {
-    StorageFreeById((Storage *)(h + HostStorageSize()), STORAGE_HOST, id);
+    StorageFreeById((Storage *)((void *)h + sizeof(Host)), STORAGE_HOST, id);
 }
 
 void HostFreeStorage(Host *h) {
-    StorageFreeAll((Storage *)(h + HostStorageSize()), STORAGE_HOST);
+    StorageFreeAll((Storage *)((void *)h + sizeof(Host)), STORAGE_HOST);
+}
+
+int HostStorageRegister(const char *name, const unsigned int size, void *(*Init)(unsigned int), void (*Free)(void *)) {
+    return StorageRegister(STORAGE_HOST, name, size, Init, Free);
 }
 
 #ifdef UNITTESTS
+
+static void *StorageTestInit(unsigned int size) {
+    void *x = SCMalloc(size);
+    return x;
+}
+static void StorageTestFree(void *x) {
+    if (x)
+        SCFree(x);
+}
+
 static int HostStorageTest01(void) {
+    StorageInit();
+
+    int id1 = HostStorageRegister("test", 8, StorageTestInit, StorageTestFree);
+    if (id1 < 0)
+        goto error;
+    int id2 = HostStorageRegister("variable", 24, StorageTestInit, StorageTestFree);
+    if (id2 < 0)
+        goto error;
+    int id3 = HostStorageRegister("store", sizeof(void *), StorageTestInit, StorageTestFree);
+    if (id3 < 0)
+        goto error;
+
+    if (StorageFinalize() < 0)
+        goto error;
+
+    HostInitConfig(1);
+
+    Address a;
+    memset(&a, 0x00, sizeof(a));
+    a.addr_data32[0] = 0x01020304;
+    a.family = AF_INET;
+    Host *h = HostGetHostFromHash(&a);
+    if (h == NULL) {
+        printf("failed to get host: ");
+        goto error;
+    }
+
+    void *ptr = HostGetStorageById(h, id1);
+    if (ptr != NULL) {
+        goto error;
+    }
+    ptr = HostGetStorageById(h, id2);
+    if (ptr != NULL) {
+        goto error;
+    }
+    ptr = HostGetStorageById(h, id3);
+    if (ptr != NULL) {
+        goto error;
+    }
+
+    void *ptr1a = HostAllocStorageById(h, id1);
+    if (ptr1a == NULL) {
+        goto error;
+    }
+    void *ptr2a = HostAllocStorageById(h, id2);
+    if (ptr2a == NULL) {
+        goto error;
+    }
+    void *ptr3a = HostAllocStorageById(h, id3);
+    if (ptr3a == NULL) {
+        goto error;
+    }
+
+    void *ptr1b = HostGetStorageById(h, id1);
+    if (ptr1a != ptr1b) {
+        goto error;
+    }
+    void *ptr2b = HostGetStorageById(h, id2);
+    if (ptr2a != ptr2b) {
+        goto error;
+    }
+    void *ptr3b = HostGetStorageById(h, id3);
+    if (ptr3a != ptr3b) {
+        goto error;
+    }
+
+    HostRelease(h);
+
+    HostShutdown();
+    StorageCleanup();
+    return 1;
+error:
+    HostShutdown();
+    StorageCleanup();
+    return 0;
+}
+
+static int HostStorageTest02(void) {
+    StorageInit();
+
+    int id1 = HostStorageRegister("test", sizeof(void *), NULL, StorageTestFree);
+    if (id1 < 0)
+        goto error;
+
+    if (StorageFinalize() < 0)
+        goto error;
+
+    HostInitConfig(1);
+
+    Address a;
+    memset(&a, 0x00, sizeof(a));
+    a.addr_data32[0] = 0x01020304;
+    a.family = AF_INET;
+    Host *h = HostGetHostFromHash(&a);
+    if (h == NULL) {
+        printf("failed to get host: ");
+        goto error;
+    }
+
+    void *ptr = HostGetStorageById(h, id1);
+    if (ptr != NULL) {
+        goto error;
+    }
+
+    void *ptr1a = SCMalloc(128);
+    if (ptr1a == NULL) {
+        goto error;
+    }
+    HostSetStorageById(h, id1, ptr1a);
+
+    void *ptr1b = HostGetStorageById(h, id1);
+    if (ptr1a != ptr1b) {
+        goto error;
+    }
+
+    HostRelease(h);
+
+    HostShutdown();
+    StorageCleanup();
+    return 1;
+error:
+    HostShutdown();
+    StorageCleanup();
+    return 0;
+}
+
+static int HostStorageTest03(void) {
+    StorageInit();
+
+    int id1 = HostStorageRegister("test1", sizeof(void *), NULL, StorageTestFree);
+    if (id1 < 0)
+        goto error;
+    int id2 = HostStorageRegister("test2", sizeof(void *), NULL, StorageTestFree);
+    if (id2 < 0)
+        goto error;
+    int id3 = HostStorageRegister("test3", 32, StorageTestInit, StorageTestFree);
+    if (id3 < 0)
+        goto error;
+
+    if (StorageFinalize() < 0)
+        goto error;
+
+    HostInitConfig(1);
+
+    Address a;
+    memset(&a, 0x00, sizeof(a));
+    a.addr_data32[0] = 0x01020304;
+    a.family = AF_INET;
+    Host *h = HostGetHostFromHash(&a);
+    if (h == NULL) {
+        printf("failed to get host: ");
+        goto error;
+    }
+
+    void *ptr = HostGetStorageById(h, id1);
+    if (ptr != NULL) {
+        goto error;
+    }
+
+    void *ptr1a = SCMalloc(128);
+    if (ptr1a == NULL) {
+        goto error;
+    }
+    HostSetStorageById(h, id1, ptr1a);
+
+    void *ptr2a = SCMalloc(256);
+    if (ptr2a == NULL) {
+        goto error;
+    }
+    HostSetStorageById(h, id2, ptr2a);
+
+    void *ptr3a = HostAllocStorageById(h, id3);
+    if (ptr3a == NULL) {
+        goto error;
+    }
+
+    void *ptr1b = HostGetStorageById(h, id1);
+    if (ptr1a != ptr1b) {
+        goto error;
+    }
+    void *ptr2b = HostGetStorageById(h, id2);
+    if (ptr2a != ptr2b) {
+        goto error;
+    }
+    void *ptr3b = HostGetStorageById(h, id3);
+    if (ptr3a != ptr3b) {
+        goto error;
+    }
+
+    HostRelease(h);
+
+    HostShutdown();
+    StorageCleanup();
     return 1;
+error:
+    HostShutdown();
+    StorageCleanup();
+    return 0;
 }
 #endif
 
 void RegisterHostStorageTests(void) {
 #ifdef UNITTESTS
     UtRegisterTest("HostStorageTest01", HostStorageTest01, 1);
+    UtRegisterTest("HostStorageTest02", HostStorageTest02, 1);
+    UtRegisterTest("HostStorageTest03", HostStorageTest03, 1);
 #endif
 }
index bc43d31566949609b12015988379ec66cdc44c63..e5fe4f4b4c6853a6307c342c49b17ba4fbe35113 100644 (file)
@@ -32,6 +32,7 @@
 unsigned int HostStorageSize(void);
 
 void *HostGetStorageById(Host *h, int id);
+int HostSetStorageById(Host *h, int id, void *ptr);
 void *HostAllocStorageById(Host *h, int id);
 
 void HostFreeStorageById(Host *h, int id);
@@ -39,4 +40,6 @@ void HostFreeStorage(Host *h);
 
 void RegisterHostStorageTests(void);
 
+int HostStorageRegister(const char *name, const unsigned int size, void *(*Init)(unsigned int), void (*Free)(void *));
+
 #endif /* __HOST_STORAGE_H__ */
index d01330583f609f6ec21d18ebd36510c926168865..0599f4c7a267f8f1468a9c3a2b2c39d1bce1304b 100644 (file)
 #include "util-mpm-ac.h"
 #endif
 #include "util-storage.h"
+#include "host-storage.h"
 
 /*
  * we put this here, because we only use it here in main.
@@ -1748,6 +1749,7 @@ int main(int argc, char **argv)
         CudaBufferRegisterUnittests();
 #endif
         StorageRegisterTests();
+        RegisterHostStorageTests();
 
         if (list_unittests) {
             UtListTests(regex_arg);
index 934ff799ee2022b26b258db213b4ced112c6113c..31ba17fdbb7d0b872f9b165c2aacbced1150cb46 100644 (file)
@@ -94,7 +94,7 @@ int StorageRegister(const StorageEnum type, const char *name, const unsigned int
         return -1;
 
     if (type >= STORAGE_MAX || name == NULL || strlen(name) == 0 ||
-            size == 0 || Init == NULL || Free == NULL)
+            size == 0 || (size != sizeof(void *) && Init == NULL) || Free == NULL)
         return -1;
 
     StorageList *list = storage_list;
@@ -206,6 +206,28 @@ void *StorageGetById(const Storage *storage, const StorageEnum type, const int i
     return storage[id];
 }
 
+int StorageSetById(Storage *storage, const StorageEnum type, const int id, void *ptr) {
+    SCLogDebug("storage %p id %d", storage, id);
+    if (storage == NULL)
+        return -1;
+    storage[id] = ptr;
+    return 0;
+}
+
+void *StorageAllocByIdPrealloc(Storage *storage, StorageEnum type, int id) {
+    SCLogDebug("storage %p id %d", storage, id);
+
+    StorageMapping *map = &storage_map[type][id];
+    if (storage[id] == NULL && map->Init != NULL) {
+        storage[id] = map->Init(map->size);
+        if (storage[id] == NULL) {
+            return NULL;
+        }
+    }
+
+    return storage[id];
+}
+
 void *StorageAllocById(Storage **storage, StorageEnum type, int id) {
     SCLogDebug("storage %p id %d", storage, id);
 
@@ -219,7 +241,7 @@ void *StorageAllocById(Storage **storage, StorageEnum type, int id) {
     }
     SCLogDebug("store %p", store);
 
-    if (store[id] == NULL) {
+    if (store[id] == NULL && map->Init != NULL) {
         store[id] = map->Init(map->size);
         if (store[id] == NULL) {
             SCFree(store);
@@ -284,7 +306,7 @@ static void *StorageTestInit(unsigned int size) {
     void *x = SCMalloc(size);
     return x;
 }
-void StorageTestFree(void *x) {
+static void StorageTestFree(void *x) {
     if (x)
         SCFree(x);
 }
@@ -403,7 +425,7 @@ static int StorageTest03(void) {
         goto error;
     }
 
-    id = StorageRegister(STORAGE_HOST, "test1", 8, NULL, StorageTestFree);
+    id = StorageRegister(STORAGE_HOST, "test1", 6, NULL, StorageTestFree);
     if (id != -1) {
         printf("duplicate registration should have failed (2): ");
         goto error;
index 2e4061beaaecc9a9a0c0ff0f0e90ff9254117c1a..880cb129f373a4065e941b51b9522326e631e483 100644 (file)
@@ -45,6 +45,8 @@ unsigned int StorageGetCnt(const StorageEnum type);
 unsigned int StorageGetSize(const StorageEnum type);
 
 void *StorageGetById(const Storage *storage, const StorageEnum type, const int id);
+int StorageSetById(Storage *storage, const StorageEnum type, const int id, void *ptr);
+void *StorageAllocByIdPrealloc(Storage *storage, StorageEnum type, int id);
 void *StorageAllocById(Storage **storage, const StorageEnum type, const int id);
 void StorageFreeById(Storage *storage, const StorageEnum type, const int id);
 void StorageFreeAll(Storage *storage, const StorageEnum type);