]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Project] Add tensors index method
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Aug 2020 20:04:32 +0000 (21:04 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Aug 2020 20:05:09 +0000 (21:05 +0100)
src/lua/lua_tensor.c

index 21bdf9673428ec04d174065e2a856e075a5d6417..85aaa2e955f0829bc7909648ac3c57ce4534844f 100644 (file)
@@ -32,6 +32,7 @@ LUA_FUNCTION_DEF (tensor, fromtable);
 LUA_FUNCTION_DEF (tensor, destroy);
 LUA_FUNCTION_DEF (tensor, mul);
 LUA_FUNCTION_DEF (tensor, tostring);
+LUA_FUNCTION_DEF (tensor, index);
 
 static luaL_reg rspamd_tensor_f[] = {
                LUA_INTERFACE_DEF (tensor, load),
@@ -45,8 +46,9 @@ static luaL_reg rspamd_tensor_m[] = {
                {"__gc", lua_tensor_destroy},
                {"__mul", lua_tensor_mul},
                {"mul", lua_tensor_mul},
-               {"__tostring", lua_tensor_tostring},
                {"tostring", lua_tensor_tostring},
+               {"__tostring", lua_tensor_tostring},
+               {"__index", lua_tensor_index},
                {NULL, NULL},
 };
 
@@ -284,6 +286,52 @@ lua_tensor_tostring (lua_State *L)
        return 1;
 }
 
+static gint
+lua_tensor_index (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+       gint idx;
+
+       if (t) {
+               if (lua_isnumber (L, 2)) {
+                       idx = lua_tointeger (L, 2);
+
+                       if (t->ndims == 1) {
+                               /* Individual element */
+                               if (idx <= t->dim[0]) {
+                                       lua_pushnumber (L, t->data[idx - 1]);
+                               }
+                               else {
+                                       lua_pushnil (L);
+                               }
+                       }
+                       else {
+                               /* Push row */
+                               gint dim = t->dim[1];
+
+
+                               if (idx <= t->dim[0]) {
+                                       struct rspamd_lua_tensor *res =
+                                                       lua_newtensor (L, 1, &dim, false);
+                                       for (gint i = 0; i < dim; i++) {
+                                               res->data[i] = t->data[(idx - 1) * t->dim[1] + i];
+                                       }
+                               }
+                               else {
+                                       lua_pushnil (L);
+                               }
+                       }
+               }
+               else if (lua_isstring (L, 2)) {
+                       lua_getmetatable (L, 1);
+                       lua_pushvalue (L, 2);
+                       lua_rawget (L, -2);
+               }
+       }
+
+       return 1;
+}
+
 /***
  * @method tensor:mul(other, [transA, [transB]])
  * Multiply two tensors (optionally transposed) and return a new tensor