From: Vsevolod Stakhov Date: Fri, 12 Jun 2026 15:38:19 +0000 (+0100) Subject: [Feature] kann: add multi-head attention pooling operator X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=dd7323d72bd223bac37e7efdf4191b25045e39b3;p=thirdparty%2Frspamd.git [Feature] kann: add multi-head attention pooling operator New kad operator attn_pool (op 38, appended to preserve model serialization compatibility): multi-head dot-product attention pooling over a zero-padded sequence of word vectors with learned query vectors. All-zero positions are treated as padding and masked out of the softmax; attention weights are stashed in gtmp between the forward and backward passes. Exposed as kann_layer_attn_pool() and rspamd_kann.layer.attn_pool(node, n_words[, n_heads]). Verified: converges on a needle-in-haystack task unsolvable by a flat dense net (0.985 vs 0.715 accuracy), exact word-order invariance of the pooled output, padding determinism and save/load roundtrip. --- diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c index 86723bd9d3..37df99c844 100644 --- a/contrib/kann/kann.c +++ b/contrib/kann/kann.c @@ -832,6 +832,20 @@ kad_node_t *kann_layer_layernorm(kad_node_t *in) return kann_layer_layernorm2(0, 0, in); } +/* Multi-head attention pooling over a sequence of n_words zero-padded word + * vectors flattened into the input; the per-word dimension is derived from + * the input size. Output: n_heads * dim. */ +kad_node_t *kann_layer_attn_pool(kad_node_t *in, int n_words, int n_heads) +{ + kad_node_t *q; + int dim; + if (in->n_d != 2 || n_words <= 0 || n_heads <= 0) return 0; + if (in->d[1] % n_words != 0) return 0; + dim = in->d[1] / n_words; + q = kann_new_weight(n_heads, dim); + return kad_attn_pool(in, q, n_words); +} + kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag) { kad_node_t *h0; diff --git a/contrib/kann/kann.h b/contrib/kann/kann.h index 313c9e9df1..c81814b40e 100644 --- a/contrib/kann/kann.h +++ b/contrib/kann/kann.h @@ -196,6 +196,7 @@ kad_node_t *kann_layer_input(int n1); kad_node_t *kann_layer_dense(kad_node_t *in, int n1); kad_node_t *kann_layer_dropout(kad_node_t *t, float r); kad_node_t *kann_layer_layernorm(kad_node_t *in); +kad_node_t *kann_layer_attn_pool(kad_node_t *in, int n_words, int n_heads); kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag); kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag); kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag); diff --git a/contrib/kann/kautodiff.c b/contrib/kann/kautodiff.c index 8e20675ad6..ccb51f6621 100644 --- a/contrib/kann/kautodiff.c +++ b/contrib/kann/kautodiff.c @@ -260,6 +260,20 @@ kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int left_pad) return kad_finalize_node(s); } +kad_node_t *kad_attn_pool(kad_node_t *x, kad_node_t *q, int n_words) +{ + kad_node_t *s; + int32_t *aux; + if (x->n_d != 2 || q->n_d != 2 || n_words <= 0) return 0; + if (x->d[1] != n_words * q->d[1]) return 0; + s = kad_new_core(0, 38, 2); + s->child[0] = x, s->child[1] = q; + aux = (int32_t *) g_malloc0_n(1, sizeof(int32_t)); + aux[0] = n_words; + s->ptr = aux, s->ptr_size = sizeof(int32_t); + return kad_finalize_node(s); +} + /********** Multi-node pooling **********/ static kad_node_t *kad_pooling_general(int op, int n, kad_node_t **x) @@ -2569,6 +2583,111 @@ int kad_op_avg1d(kad_node_t *p, int action) return 0; } +/* Multi-head dot-product attention pooling. + * child[0]: x, (batch, n_words * dim) -- zero-padded sequence of word vectors + * child[1]: q, (n_heads, dim) -- learned query vectors + * output: (batch, n_heads * dim) -- per-head attention-weighted sums + * + * Per head h and sample b: + * s_l = dot(q_h, x_l) / sqrt(dim) (all-zero x_l is padding: s_l = -inf) + * alpha = softmax(s) + * out_h = sum_l alpha_l * x_l + * Attention weights are saved in gtmp at the forward pass for the backward + * pass. Padded positions get alpha = 0 and contribute no gradient. */ +int kad_op_attn_pool(kad_node_t *p, int action) +{ + kad_node_t *xn = p->child[0], *qn = p->child[1]; + int L = *(int32_t *) p->ptr; + int H = qn->d[0], D = qn->d[1]; + int B = xn->d[0]; + float scale = 1.0f / sqrtf((float) D); + int b, h, l, d; + + if (action == KAD_SYNC_DIM) { + if (xn->n_d != 2 || qn->n_d != 2) return -1; + if (xn->d[1] != L * D) return -1; + p->n_d = 2, p->d[0] = xn->d[0], p->d[1] = H * D; + } + else if (action == KAD_ALLOC) { + if (kad_is_back(xn) || kad_is_back(qn)) + p->gtmp = g_realloc(p->gtmp, (size_t) xn->d[0] * H * L * sizeof(float)); + } + else if (action == KAD_FORWARD) { + float *alpha_all = (float *) p->gtmp; + float *s = (float *) g_malloc(L * sizeof(float)); + for (b = 0; b < B; ++b) { + const float *xb = &xn->x[(size_t) b * L * D]; + for (h = 0; h < H; ++h) { + const float *qh = &qn->x[(size_t) h * D]; + float *ob = &p->x[(size_t) b * H * D + (size_t) h * D]; + float max, sum; + for (l = 0; l < L; ++l) { + const float *xl = &xb[(size_t) l * D]; + int nonzero = 0; + for (d = 0; d < D; ++d) + if (xl[d] != 0.0f) { + nonzero = 1; + break; + } + s[l] = nonzero ? scale * kad_sdot(D, qh, xl) : -FLT_MAX; + } + for (l = 0, max = -FLT_MAX; l < L; ++l) + max = max > s[l] ? max : s[l]; + for (l = 0, sum = 0.0f; l < L; ++l) { + /* fully padded sample: keep alpha uniform, output stays 0 */ + s[l] = expf(s[l] - (max == -FLT_MAX ? 0.0f : max)); + sum += s[l]; + } + for (l = 0, sum = 1.0f / sum; l < L; ++l) s[l] *= sum; + memset(ob, 0, D * sizeof(float)); + for (l = 0; l < L; ++l) + if (s[l] != 0.0f) + kad_saxpy(D, s[l], &xb[(size_t) l * D], ob); + if (alpha_all) + memcpy(&alpha_all[((size_t) b * H + h) * L], s, L * sizeof(float)); + } + } + g_free(s); + } + else if (action == KAD_BACKWARD && (kad_is_back(xn) || kad_is_back(qn))) { + const float *alpha_all = (const float *) p->gtmp; + float *dalpha = (float *) g_malloc(2 * L * sizeof(float)); + float *ds = dalpha + L; + for (b = 0; b < B; ++b) { + const float *xb = &xn->x[(size_t) b * L * D]; + for (h = 0; h < H; ++h) { + const float *qh = &qn->x[(size_t) h * D]; + const float *alpha = &alpha_all[((size_t) b * H + h) * L]; + const float *gb = &p->g[(size_t) b * H * D + (size_t) h * D]; + float sdot = 0.0f; + for (l = 0; l < L; ++l) + dalpha[l] = kad_sdot(D, gb, &xb[(size_t) l * D]); + for (l = 0; l < L; ++l) + sdot += alpha[l] * dalpha[l]; + for (l = 0; l < L; ++l) + ds[l] = alpha[l] * (dalpha[l] - sdot); + if (kad_is_back(qn)) { + float *qg = &qn->g[(size_t) h * D]; + for (l = 0; l < L; ++l) + if (ds[l] != 0.0f) + kad_saxpy(D, scale * ds[l], &xb[(size_t) l * D], qg); + } + if (kad_is_back(xn)) { + float *xg = &xn->g[(size_t) b * L * D]; + for (l = 0; l < L; ++l) { + if (alpha[l] != 0.0f) + kad_saxpy(D, alpha[l], gb, &xg[(size_t) l * D]); + if (ds[l] != 0.0f) + kad_saxpy(D, scale * ds[l], qh, &xg[(size_t) l * D]); + } + } + } + } + g_free(dalpha); + } + return 0; +} + /********** List of operators **********/ kad_op_f kad_op_list[KAD_MAX_OP] = { @@ -2609,13 +2728,14 @@ kad_op_f kad_op_list[KAD_MAX_OP] = { kad_op_sin, /* 34: sin() */ kad_op_stack, /* 35: tf.stack, but on the first axis only */ kad_op_reverse, /* 36: tf.reverse, but on one axis only */ - kad_op_gelu /* 37: GELU activation */ + kad_op_gelu, /* 37: GELU activation */ + kad_op_attn_pool /* 38: multi-head attention pooling */ }; char *kad_op_name[KAD_MAX_OP] = { 0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax", "dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log", - "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"}; + "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu", "attn_pool"}; /************************** *** Debugging routines *** diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h index 9723201e3c..23e96418dc 100644 --- a/contrib/kann/kautodiff.h +++ b/contrib/kann/kautodiff.h @@ -168,6 +168,12 @@ kad_node_t *kad_conv1d(kad_node_t *x, kad_node_t *w, int stride, int pad); /* 1 kad_node_t *kad_max1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D max pooling */ kad_node_t *kad_avg1d(kad_node_t *x, int kernel_size, int stride, int pad); /* 1D average pooling */ +/* Multi-head dot-product attention pooling over a zero-padded sequence. + * x: (batch, n_words * dim) -- flattened word vectors, all-zero words are padding + * q: (n_heads, dim) -- learned query vectors + * output: (batch, n_heads * dim) -- per-head attention-weighted sums */ +kad_node_t *kad_attn_pool(kad_node_t *x, kad_node_t *q, int n_words); + kad_node_t *kad_dropout(kad_node_t *x, kad_node_t *r); /* dropout at rate r */ kad_node_t *kad_sample_normal(kad_node_t *x); /* f(x) = x * r, where r is drawn from a standard normal distribution */ diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index a4b220db1d..7edfa39aa5 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -28,28 +28,24 @@ /* Simple macros to define behaviour */ #define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L) -#define KANN_LAYER_INTERFACE(name) \ - { \ - #name, lua_kann_layer_##name \ - } +#define KANN_LAYER_INTERFACE(name) \ + { \ + #name, lua_kann_layer_##name} #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L) -#define KANN_TRANSFORM_INTERFACE(name) \ - { \ - #name, lua_kann_transform_##name \ - } +#define KANN_TRANSFORM_INTERFACE(name) \ + { \ + #name, lua_kann_transform_##name} #define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L) -#define KANN_LOSS_INTERFACE(name) \ - { \ - #name, lua_kann_loss_##name \ - } +#define KANN_LOSS_INTERFACE(name) \ + { \ + #name, lua_kann_loss_##name} #define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L) -#define KANN_NEW_INTERFACE(name) \ - { \ - #name, lua_kann_new_##name \ - } +#define KANN_NEW_INTERFACE(name) \ + { \ + #name, lua_kann_new_##name} /* @@ -71,11 +67,13 @@ KANN_LAYER_DEF(input3d); KANN_LAYER_DEF(cost); static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */ +static int lua_kann_layer_attn_pool(lua_State *L); /* forward declaration */ static luaL_reg rspamd_kann_layers_f[] = { KANN_LAYER_INTERFACE(input), KANN_LAYER_INTERFACE(dense), KANN_LAYER_INTERFACE(layernorm), + {"attn_pool", lua_kann_layer_attn_pool}, /* manually registered - extra args */ {"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */ KANN_LAYER_INTERFACE(rnn), KANN_LAYER_INTERFACE(lstm), @@ -414,6 +412,46 @@ lua_kann_layer_layerdropout(lua_State *L) return 1; } +/*** + * @function kann.layer.attn_pool(in, n_words[, n_heads[, flags]]) + * Creates a multi-head attention pooling layer over a flattened sequence of + * zero-padded word vectors. The input dimension must be a multiple of + * n_words; the per-word dimension is derived as input_dim / n_words. + * Output dimension is n_heads * per-word dimension. + * @param {kann_node} in kann node, (batch, n_words * dim) + * @param {int} n_words number of word positions in the sequence + * @param {int} n_heads number of learned attention queries (default 4) + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_attn_pool(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + int n_words = luaL_checkinteger(L, 2); + int n_heads = luaL_optinteger(L, 3, 4); + + if (in != NULL && n_words > 0 && n_heads > 0) { + kad_node_t *t; + + t = kann_layer_attn_pool(in, n_words, n_heads); + + if (t == NULL) { + return luaL_error(L, "invalid attn_pool: input dimension %d " + "is not a multiple of n_words %d", + in->n_d == 2 ? in->d[1] : -1, n_words); + } + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + n_words + n_heads required"); + } + + return 1; +} + /*** * @function kann.layer.dropout(in [, flags]) * Creates a normalisation layer