]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Minor] Fix tensor multiplication for the vectors case
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 19 Aug 2020 12:51:10 +0000 (13:51 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 19 Aug 2020 13:03:29 +0000 (14:03 +0100)
src/lua/lua_tensor.c

index cf91006d0a07f9d632787a50dfb64335c99840b9..16bba985bf0f8fed32dbefafbd3e047ce98bdf21 100644 (file)
@@ -351,6 +351,7 @@ lua_tensor_index (lua_State *L)
                        }
                }
                else if (lua_isstring (L, 2)) {
+                       /* Access to methods */
                        lua_getmetatable (L, 1);
                        lua_pushvalue (L, 2);
                        lua_rawget (L, -2);
@@ -392,7 +393,20 @@ lua_tensor_mul (lua_State *L)
                                        dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
                }
 
-               res = lua_newtensor (L, 2, dims, true);
+               if (dims[0] == 0) {
+                       /* Column */
+                       dims[0] = 1;
+                       res = lua_newtensor (L, 2, dims, true, true);
+               }
+               else if (dims[1] == 0) {
+                       /* Row */
+                       res = lua_newtensor (L, 1, dims, true, true);
+                       dims[1] = 1;
+               }
+               else {
+                       res = lua_newtensor (L, 2, dims, true, true);
+               }
+
                kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
                                t1->data, t2->data, res->data);
        }
@@ -438,7 +452,7 @@ lua_tensor_load (lua_State *L)
                if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
                        if (ndims == 1) {
                                if (nelts == dims[0]) {
-                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
                                        memcpy (t->data, data + sizeof (int) * 4, nelts *
                                                        sizeof (rspamd_tensor_num_t));
                                }
@@ -449,7 +463,7 @@ lua_tensor_load (lua_State *L)
                        }
                        else if (ndims == 2) {
                                if (nelts == dims[0] * dims[1]) {
-                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false);
+                                       struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
                                        memcpy (t->data, data + sizeof (int) * 4, nelts *
                                                        sizeof (rspamd_tensor_num_t));
                                }