]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Lua_tensor: Add deserialisation
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Aug 2020 13:43:18 +0000 (14:43 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 6 Aug 2020 13:43:18 +0000 (14:43 +0100)
src/lua/lua_tensor.c

index 9b85779d79454dbb6b98b536c4223a0b9256d7b2..91fcd763ed917b9f927c6557b5faaf02ef9d5b13 100644 (file)
@@ -396,10 +396,62 @@ lua_tensor_mul (lua_State *L)
 static gint
 lua_tensor_load (lua_State *L)
 {
-       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+       const guchar *data;
+       gsize sz;
 
-       if (t) {
+       if (lua_type (L, 1) == LUA_TUSERDATA) {
+               struct rspamd_lua_text *t = lua_check_text (L, 1);
+
+               if (!t) {
+                       return luaL_error (L, "invalid argument");
+               }
+
+               data = (const guchar *)t->start;
+               sz = t->len;
+       }
+       else {
+               data = (const guchar *)lua_tolstring (L, 1, &sz);
+       }
+
+       if (sz >= sizeof (gint) * 4) {
+               int ndims, nelts, dims[2];
 
+               memcpy (&ndims, data, sizeof (int));
+               memcpy (&nelts, data + sizeof (int), sizeof (int));
+               memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2);
+
+               if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
+                       if (ndims == 1) {
+                               if (nelts == dims[0]) {
+                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+                                       memcpy (t->data, data + sizeof (int) * 4, nelts *
+                                                       sizeof (rspamd_tensor_num_t));
+                               }
+                               else {
+                                       return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
+                                                       dims[0], 1, nelts);
+                               }
+                       }
+                       else if (ndims == 2) {
+                               if (nelts == dims[0] * dims[1]) {
+                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+                                       memcpy (t->data, data + sizeof (int) * 4, nelts *
+                                                       sizeof (rspamd_tensor_num_t));
+                               }
+                               else {
+                                       return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
+                                                       dims[0], dims[1], nelts);
+                               }
+                       }
+                       else {
+                               return luaL_error (L, "invalid argument: bad ndims: %d", ndims);
+                       }
+               }
+               else {
+                       return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz,
+                                       (int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4),
+                                       nelts);
+               }
        }
        else {
                return luaL_error (L, "invalid arguments");