]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] kann: add multi-head attention pooling operator
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Jun 2026 15:38:19 +0000 (16:38 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 12 Jun 2026 15:38:19 +0000 (16:38 +0100)
New kad operator attn_pool (op 38, appended to preserve model
serialization compatibility): multi-head dot-product attention pooling
over a zero-padded sequence of word vectors with learned query vectors.
All-zero positions are treated as padding and masked out of the
softmax; attention weights are stashed in gtmp between the forward and
backward passes. Exposed as kann_layer_attn_pool() and
rspamd_kann.layer.attn_pool(node, n_words[, n_heads]).

Verified: converges on a needle-in-haystack task unsolvable by a flat
dense net (0.985 vs 0.715 accuracy), exact word-order invariance of the
pooled output, padding determinism and save/load roundtrip.

contrib/kann/kann.c
contrib/kann/kann.h
contrib/kann/kautodiff.c
contrib/kann/kautodiff.h
src/lua/lua_kann.c

index 86723bd9d350f455721a70dbba4801ca39495ae2..37df99c8445eca583c9132806fe2956b180929a9 100644 (file)
@@ -832,6 +832,20 @@ kad_node_t *kann_layer_layernorm(kad_node_t *in)
        return kann_layer_layernorm2(0, 0, in);
 }
 
+/* Multi-head attention pooling over a sequence of n_words zero-padded word
+ * vectors flattened into the input; the per-word dimension is derived from
+ * the input size. Output: n_heads * dim. */
+kad_node_t *kann_layer_attn_pool(kad_node_t *in, int n_words, int n_heads)
+{
+       kad_node_t *q;
+       int dim;
+       if (in->n_d != 2 || n_words <= 0 || n_heads <= 0) return 0;
+       if (in->d[1] % n_words != 0) return 0;
+       dim = in->d[1] / n_words;
+       q = kann_new_weight(n_heads, dim);
+       return kad_attn_pool(in, q, n_words);
+}
+
 kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag)
 {
        kad_node_t *h0;
index 313c9e9df10aa48702bea736abe8029a4a1ee320..c81814b40e6e6da929c8d9d073270d2cbd812379 100644 (file)
@@ -196,6 +196,7 @@ kad_node_t *kann_layer_input(int n1);
 kad_node_t *kann_layer_dense(kad_node_t *in, int n1);
 kad_node_t *kann_layer_dropout(kad_node_t *t, float r);
 kad_node_t *kann_layer_layernorm(kad_node_t *in);
+kad_node_t *kann_layer_attn_pool(kad_node_t *in, int n_words, int n_heads);
 kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag);
 kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag);
 kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag);
index 8e20675ad6891245e2a7138a38401761e4a738f2..ccb51f662101bce21dbc0e6c337cac818bd2d733 100644 (file)
@@ -260,6 +260,20 @@ kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int left_pad)
        return kad_finalize_node(s);
 }
 
+kad_node_t *kad_attn_pool(kad_node_t *x, kad_node_t *q, int n_words)
+{
+       kad_node_t *s;
+       int32_t *aux;
+       if (x->n_d != 2 || q->n_d != 2 || n_words <= 0) return 0;
+       if (x->d[1] != n_words * q->d[1]) return 0;
+       s = kad_new_core(0, 38, 2);
+       s->child[0] = x, s->child[1] = q;
+       aux = (int32_t *) g_malloc0_n(1, sizeof(int32_t));
+       aux[0] = n_words;
+       s->ptr = aux, s->ptr_size = sizeof(int32_t);
+       return kad_finalize_node(s);
+}
+
 /********** Multi-node pooling **********/
 
 static kad_node_t *kad_pooling_general(int op, int n, kad_node_t **x)
@@ -2569,6 +2583,111 @@ int kad_op_avg1d(kad_node_t *p, int action)
        return 0;
 }
 
+/* Multi-head dot-product attention pooling.
+ * child[0]: x, (batch, n_words * dim) -- zero-padded sequence of word vectors
+ * child[1]: q, (n_heads, dim)         -- learned query vectors
+ * output:   (batch, n_heads * dim)    -- per-head attention-weighted sums
+ *
+ * Per head h and sample b:
+ *   s_l   = dot(q_h, x_l) / sqrt(dim)   (all-zero x_l is padding: s_l = -inf)
+ *   alpha = softmax(s)
+ *   out_h = sum_l alpha_l * x_l
+ * Attention weights are saved in gtmp at the forward pass for the backward
+ * pass. Padded positions get alpha = 0 and contribute no gradient. */
+int kad_op_attn_pool(kad_node_t *p, int action)
+{
+       kad_node_t *xn = p->child[0], *qn = p->child[1];
+       int L = *(int32_t *) p->ptr;
+       int H = qn->d[0], D = qn->d[1];
+       int B = xn->d[0];
+       float scale = 1.0f / sqrtf((float) D);
+       int b, h, l, d;
+
+       if (action == KAD_SYNC_DIM) {
+               if (xn->n_d != 2 || qn->n_d != 2) return -1;
+               if (xn->d[1] != L * D) return -1;
+               p->n_d = 2, p->d[0] = xn->d[0], p->d[1] = H * D;
+       }
+       else if (action == KAD_ALLOC) {
+               if (kad_is_back(xn) || kad_is_back(qn))
+                       p->gtmp = g_realloc(p->gtmp, (size_t) xn->d[0] * H * L * sizeof(float));
+       }
+       else if (action == KAD_FORWARD) {
+               float *alpha_all = (float *) p->gtmp;
+               float *s = (float *) g_malloc(L * sizeof(float));
+               for (b = 0; b < B; ++b) {
+                       const float *xb = &xn->x[(size_t) b * L * D];
+                       for (h = 0; h < H; ++h) {
+                               const float *qh = &qn->x[(size_t) h * D];
+                               float *ob = &p->x[(size_t) b * H * D + (size_t) h * D];
+                               float max, sum;
+                               for (l = 0; l < L; ++l) {
+                                       const float *xl = &xb[(size_t) l * D];
+                                       int nonzero = 0;
+                                       for (d = 0; d < D; ++d)
+                                               if (xl[d] != 0.0f) {
+                                                       nonzero = 1;
+                                                       break;
+                                               }
+                                       s[l] = nonzero ? scale * kad_sdot(D, qh, xl) : -FLT_MAX;
+                               }
+                               for (l = 0, max = -FLT_MAX; l < L; ++l)
+                                       max = max > s[l] ? max : s[l];
+                               for (l = 0, sum = 0.0f; l < L; ++l) {
+                                       /* fully padded sample: keep alpha uniform, output stays 0 */
+                                       s[l] = expf(s[l] - (max == -FLT_MAX ? 0.0f : max));
+                                       sum += s[l];
+                               }
+                               for (l = 0, sum = 1.0f / sum; l < L; ++l) s[l] *= sum;
+                               memset(ob, 0, D * sizeof(float));
+                               for (l = 0; l < L; ++l)
+                                       if (s[l] != 0.0f)
+                                               kad_saxpy(D, s[l], &xb[(size_t) l * D], ob);
+                               if (alpha_all)
+                                       memcpy(&alpha_all[((size_t) b * H + h) * L], s, L * sizeof(float));
+                       }
+               }
+               g_free(s);
+       }
+       else if (action == KAD_BACKWARD && (kad_is_back(xn) || kad_is_back(qn))) {
+               const float *alpha_all = (const float *) p->gtmp;
+               float *dalpha = (float *) g_malloc(2 * L * sizeof(float));
+               float *ds = dalpha + L;
+               for (b = 0; b < B; ++b) {
+                       const float *xb = &xn->x[(size_t) b * L * D];
+                       for (h = 0; h < H; ++h) {
+                               const float *qh = &qn->x[(size_t) h * D];
+                               const float *alpha = &alpha_all[((size_t) b * H + h) * L];
+                               const float *gb = &p->g[(size_t) b * H * D + (size_t) h * D];
+                               float sdot = 0.0f;
+                               for (l = 0; l < L; ++l)
+                                       dalpha[l] = kad_sdot(D, gb, &xb[(size_t) l * D]);
+                               for (l = 0; l < L; ++l)
+                                       sdot += alpha[l] * dalpha[l];
+                               for (l = 0; l < L; ++l)
+                                       ds[l] = alpha[l] * (dalpha[l] - sdot);
+                               if (kad_is_back(qn)) {
+                                       float *qg = &qn->g[(size_t) h * D];
+                                       for (l = 0; l < L; ++l)
+                                               if (ds[l] != 0.0f)
+                                                       kad_saxpy(D, scale * ds[l], &xb[(size_t) l * D], qg);
+                               }
+                               if (kad_is_back(xn)) {
+                                       float *xg = &xn->g[(size_t) b * L * D];
+                                       for (l = 0; l < L; ++l) {
+                                               if (alpha[l] != 0.0f)
+                                                       kad_saxpy(D, alpha[l], gb, &xg[(size_t) l * D]);
+                                               if (ds[l] != 0.0f)
+                                                       kad_saxpy(D, scale * ds[l], qh, &xg[(size_t) l * D]);
+                                       }
+                               }
+                       }
+               }
+               g_free(dalpha);
+       }
+       return 0;
+}
+
 /********** List of operators **********/
 
 kad_op_f kad_op_list[KAD_MAX_OP] = {
@@ -2609,13 +2728,14 @@ kad_op_f kad_op_list[KAD_MAX_OP] = {
        kad_op_sin,           /* 34: sin() */
        kad_op_stack,         /* 35: tf.stack, but on the first axis only */
        kad_op_reverse,       /* 36: tf.reverse, but on one axis only */
-       kad_op_gelu           /* 37: GELU activation */
+       kad_op_gelu,          /* 37: GELU activation */
+       kad_op_attn_pool      /* 38: multi-head attention pooling */
 };
 
 char *kad_op_name[KAD_MAX_OP] = {
        0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax",
        "dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
-       "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"};
+       "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu", "attn_pool"};
 
 /**************************
  *** Debugging routines ***
index 9723201e3cc67716b03b34d13427765c822b6c60..23e96418dc7d9f36996fef48b043d8beb95292e6 100644 (file)
@@ -168,6 +168,12 @@ kad_node_t *kad_conv1d(kad_node_t *x, kad_node_t *w, int stride, int pad);  /* 1
 kad_node_t *kad_max1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D max pooling */
 kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D average pooling */
 
+/* Multi-head dot-product attention pooling over a zero-padded sequence.
+ * x: (batch, n_words * dim) -- flattened word vectors, all-zero words are padding
+ * q: (n_heads, dim)         -- learned query vectors
+ * output: (batch, n_heads * dim) -- per-head attention-weighted sums */
+kad_node_t *kad_attn_pool(kad_node_t *x, kad_node_t *q, int n_words);
+
 kad_node_t *kad_dropout(kad_node_t *x, kad_node_t *r);                      /* dropout at rate r */
 kad_node_t *kad_sample_normal(kad_node_t *x);                               /* f(x) = x * r, where r is drawn from a standard normal distribution */
 
index a4b220db1d938ab40faeece4238a453285afc52c..7edfa39aa5afc60086f95cce7a13a24cd158080b 100644 (file)
 
 /* Simple macros to define behaviour */
 #define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L)
-#define KANN_LAYER_INTERFACE(name)   \
-       {                                \
-               #name, lua_kann_layer_##name \
-       }
+#define KANN_LAYER_INTERFACE(name) \
+       {                              \
+               #name, lua_kann_layer_##name}
 
 #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L)
-#define KANN_TRANSFORM_INTERFACE(name)   \
-       {                                    \
-               #name, lua_kann_transform_##name \
-       }
+#define KANN_TRANSFORM_INTERFACE(name) \
+       {                                  \
+               #name, lua_kann_transform_##name}
 
 #define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L)
-#define KANN_LOSS_INTERFACE(name)   \
-       {                               \
-               #name, lua_kann_loss_##name \
-       }
+#define KANN_LOSS_INTERFACE(name) \
+       {                             \
+               #name, lua_kann_loss_##name}
 
 #define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L)
-#define KANN_NEW_INTERFACE(name)   \
-       {                              \
-               #name, lua_kann_new_##name \
-       }
+#define KANN_NEW_INTERFACE(name) \
+       {                            \
+               #name, lua_kann_new_##name}
 
 
 /*
@@ -71,11 +67,13 @@ KANN_LAYER_DEF(input3d);
 KANN_LAYER_DEF(cost);
 
 static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */
+static int lua_kann_layer_attn_pool(lua_State *L);    /* forward declaration */
 
 static luaL_reg rspamd_kann_layers_f[] = {
        KANN_LAYER_INTERFACE(input),
        KANN_LAYER_INTERFACE(dense),
        KANN_LAYER_INTERFACE(layernorm),
+       {"attn_pool", lua_kann_layer_attn_pool},  /* manually registered - extra args */
        {"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */
        KANN_LAYER_INTERFACE(rnn),
        KANN_LAYER_INTERFACE(lstm),
@@ -414,6 +412,46 @@ lua_kann_layer_layerdropout(lua_State *L)
        return 1;
 }
 
+/***
+ * @function kann.layer.attn_pool(in, n_words[, n_heads[, flags]])
+ * Creates a multi-head attention pooling layer over a flattened sequence of
+ * zero-padded word vectors. The input dimension must be a multiple of
+ * n_words; the per-word dimension is derived as input_dim / n_words.
+ * Output dimension is n_heads * per-word dimension.
+ * @param {kann_node} in kann node, (batch, n_words * dim)
+ * @param {int} n_words number of word positions in the sequence
+ * @param {int} n_heads number of learned attention queries (default 4)
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_attn_pool(lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node(L, 1);
+       int n_words = luaL_checkinteger(L, 2);
+       int n_heads = luaL_optinteger(L, 3, 4);
+
+       if (in != NULL && n_words > 0 && n_heads > 0) {
+               kad_node_t *t;
+
+               t = kann_layer_attn_pool(in, n_words, n_heads);
+
+               if (t == NULL) {
+                       return luaL_error(L, "invalid attn_pool: input dimension %d "
+                                                                "is not a multiple of n_words %d",
+                                                         in->n_d == 2 ? in->d[1] : -1, n_words);
+               }
+
+               PROCESS_KAD_FLAGS(t, 4);
+               PUSH_KAD_NODE(t);
+       }
+       else {
+               return luaL_error(L, "invalid arguments, input + n_words + n_heads required");
+       }
+
+       return 1;
+}
+
 /***
  * @function kann.layer.dropout(in [, flags])
  * Creates a normalisation layer