]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Lua_tensor: Implement non-owning tensors (slices)
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 19 Aug 2020 11:43:06 +0000 (12:43 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 19 Aug 2020 13:03:29 +0000 (14:03 +0100)
src/lua/lua_tensor.c
src/lua/lua_tensor.h

index 91fcd763ed917b9f927c6557b5faaf02ef9d5b13..cf91006d0a07f9d632787a50dfb64335c99840b9 100644 (file)
@@ -53,7 +53,7 @@ static luaL_reg rspamd_tensor_m[] = {
 };
 
 static struct rspamd_lua_tensor *
-lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill)
+lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
 {
        struct rspamd_lua_tensor *res;
 
@@ -68,10 +68,16 @@ lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill)
        }
 
        /* To avoid allocating large stuff in Lua */
-       res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
+       if (own) {
+               res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
 
-       if (zero_fill) {
-               memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
+               if (zero_fill) {
+                       memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
+               }
+       }
+       else {
+               /* Mark size negative to distinguish */
+               res->size = -(res->size);
        }
 
        rspamd_lua_setclass (L, TENSOR_CLASS, -1);
@@ -96,7 +102,7 @@ lua_tensor_new (lua_State *L)
                        dims[i] = lua_tointeger (L, i + 2);
                }
 
-               (void)lua_newtensor (L, ndims, dims, true);
+               (void)lua_newtensor (L, ndims, dims, true, true);
        }
        else {
                return luaL_error (L, "incorrect dimensions number: %d", ndims);
@@ -124,7 +130,7 @@ lua_tensor_fromtable (lua_State *L)
                        dims[1] = rspamd_lua_table_size (L, 1);
 
                        struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
-                                       dims, false);
+                                       dims, false, true);
 
                        for (guint i = 0; i < dims[1]; i ++) {
                                lua_rawgeti (L, 1, i + 1);
@@ -179,7 +185,7 @@ lua_tensor_fromtable (lua_State *L)
                        dims[1] = ncols;
 
                        struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
-                                       dims, false);
+                                       dims, false, true);
 
                        for (gint i = 0; i < nrows; i ++) {
                                lua_rawgeti (L, 1, i + 1);
@@ -219,7 +225,9 @@ lua_tensor_destroy (lua_State *L)
        struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
 
        if (t) {
-               g_free (t->data);
+               if (t->size > 0) {
+                       g_free (t->data);
+               }
        }
 
        return 0;
@@ -234,19 +242,27 @@ static gint
 lua_tensor_save (lua_State *L)
 {
        struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+       gint size;
 
        if (t) {
-               gsize sz = sizeof (gint) * 4 + t->size * sizeof (rspamd_tensor_num_t);
+               if (t->size > 0) {
+                       size = t->size;
+               }
+               else {
+                       size = -(t->size);
+               }
+
+               gsize sz = sizeof (gint) * 4 + size * sizeof (rspamd_tensor_num_t);
                guchar *data;
 
                struct rspamd_lua_text *out = lua_new_text (L, NULL, 0, TRUE);
 
                data = g_malloc (sz);
                memcpy (data, &t->ndims, sizeof (int));
-               memcpy (data + sizeof (int), &t->size, sizeof (int));
+               memcpy (data + sizeof (int), &size, sizeof (int));
                memcpy (data + 2 * sizeof (int), t->dim, sizeof (int) * 2);
                memcpy (data + 4 * sizeof (int), t->data,
-                               t->size * sizeof (rspamd_tensor_num_t));
+                               size * sizeof (rspamd_tensor_num_t));
 
                out->start = (const gchar *)data;
                out->len = sz;
@@ -324,11 +340,10 @@ lua_tensor_index (lua_State *L)
 
 
                                if (idx <= t->dim[0]) {
+                                       /* Non-owning tensor */
                                        struct rspamd_lua_tensor *res =
-                                                       lua_newtensor (L, 1, &dim, false);
-                                       for (gint i = 0; i < dim; i++) {
-                                               res->data[i] = t->data[(idx - 1) * t->dim[1] + i];
-                                       }
+                                                       lua_newtensor (L, 1, &dim, false, false);
+                                       res->data = &t->data[(idx - 1) * t->dim[1]];
                                }
                                else {
                                        lua_pushnil (L);
index 554245f0beffb2275d69a76d37157d9c2bb4d34a..e4c110011d4c0e8ff8ac64590a5d4cad1c79937a 100644 (file)
@@ -23,8 +23,8 @@ typedef float rspamd_tensor_num_t;
 struct rspamd_lua_tensor {
        int ndims;
        int size; /* overall size (product of dims) */
-       rspamd_tensor_num_t *data;
        int dim[2];
+       rspamd_tensor_num_t *data;
 };
 
 struct rspamd_lua_tensor *lua_check_tensor (lua_State *L, int pos);