]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Add printing and fix multiplication
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Aug 2020 15:05:40 +0000 (16:05 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Aug 2020 20:05:09 +0000 (21:05 +0100)
src/lua/lua_tensor.c

index e8aebd180c9990b9b00a42731cbe6075e7a7601d..21bdf9673428ec04d174065e2a856e075a5d6417 100644 (file)
@@ -31,6 +31,7 @@ LUA_FUNCTION_DEF (tensor, new);
 LUA_FUNCTION_DEF (tensor, fromtable);
 LUA_FUNCTION_DEF (tensor, destroy);
 LUA_FUNCTION_DEF (tensor, mul);
+LUA_FUNCTION_DEF (tensor, tostring);
 
 static luaL_reg rspamd_tensor_f[] = {
                LUA_INTERFACE_DEF (tensor, load),
@@ -44,6 +45,8 @@ static luaL_reg rspamd_tensor_m[] = {
                {"__gc", lua_tensor_destroy},
                {"__mul", lua_tensor_mul},
                {"mul", lua_tensor_mul},
+               {"__tostring", lua_tensor_tostring},
+               {"tostring", lua_tensor_tostring},
                {NULL, NULL},
 };
 
@@ -114,12 +117,14 @@ lua_tensor_fromtable (lua_State *L)
                if (lua_isnumber (L, -1)) {
                        lua_pop (L, 1);
                        /* Input vector */
-                       gint dim = rspamd_lua_table_size (L, 1);
+                       gint dims[2];
+                       dims[0] = 1;
+                       dims[1] = rspamd_lua_table_size (L, 1);
 
-                       struct rspamd_lua_tensor *res = lua_newtensor (L, 1,
-                                       &dim, false);
+                       struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
+                                       dims, false);
 
-                       for (guint i = 0; i < dim; i ++) {
+                       for (guint i = 0; i < dims[1]; i ++) {
                                lua_rawgeti (L, 1, i + 1);
                                res->data[i] = lua_tonumber (L, -1);
                                lua_pop (L, 1);
@@ -168,8 +173,8 @@ lua_tensor_fromtable (lua_State *L)
                        }
 
                        gint dims[2];
-                       dims[0] = ncols;
-                       dims[1] = nrows;
+                       dims[0] = nrows;
+                       dims[1] = ncols;
 
                        struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
                                        dims, false);
@@ -238,6 +243,47 @@ lua_tensor_save (lua_State *L)
        return 1;
 }
 
+static gint
+lua_tensor_tostring (lua_State *L)
+{
+       struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
+
+       if (t) {
+               GString *out = g_string_sized_new (128);
+
+               if (t->ndims == 1) {
+                       /* Print as a vector */
+                       for (gint i = 0; i < t->dim[0]; i ++) {
+                               rspamd_printf_gstring (out, "%.4f ", t->data[i]);
+                       }
+                       /* Trim last space */
+                       out->len --;
+               }
+               else {
+                       for (gint i = 0; i < t->dim[0]; i ++) {
+                               for (gint j = 0; j < t->dim[1]; j ++) {
+                                       rspamd_printf_gstring (out, "%.4f ",
+                                                       t->data[i * t->dim[1] + j]);
+                               }
+                               /* Trim last space */
+                               out->len --;
+                               rspamd_printf_gstring (out, "\n");
+                       }
+                       /* Trim last ; */
+                       out->len --;
+               }
+
+               lua_pushlstring (L, out->str, out->len);
+
+               g_string_free (out, TRUE);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 1;
+}
+
 /***
  * @method tensor:mul(other, [transA, [transB]])
  * Multiply two tensors (optionally transposed) and return a new tensor
@@ -259,12 +305,19 @@ lua_tensor_mul (lua_State *L)
        }
 
        if (t1 && t2) {
-               gint dims[2];
+               gint dims[2], shadow_dims[2];
                dims[0] = transA ? t1->dim[1] : t1->dim[0];
+               shadow_dims[0] = transB ? t2->dim[1] : t2->dim[0];
                dims[1] = transB ? t2->dim[0] : t2->dim[1];
+               shadow_dims[1] = transA ? t1->dim[0] : t1->dim[1];
+
+               if (shadow_dims[0] != shadow_dims[1]) {
+                       return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
+                                       dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
+               }
 
-               res = lua_newtensor (L, 2, dims, false);
-               kad_sgemm_simple (transA, transB, t1->dim[1], t2->dim[0], t1->dim[0],
+               res = lua_newtensor (L, 2, dims, true);
+               kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
                                t1->data, t2->data, res->data);
        }
        else {