return kad_finalize_node(s);
}
+kad_node_t *kad_attn_pool(kad_node_t *x, kad_node_t *q, int n_words)
+{
+ kad_node_t *s;
+ int32_t *aux;
+ if (x->n_d != 2 || q->n_d != 2 || n_words <= 0) return 0;
+ if (x->d[1] != n_words * q->d[1]) return 0;
+ s = kad_new_core(0, 38, 2);
+ s->child[0] = x, s->child[1] = q;
+ aux = (int32_t *) g_malloc0_n(1, sizeof(int32_t));
+ aux[0] = n_words;
+ s->ptr = aux, s->ptr_size = sizeof(int32_t);
+ return kad_finalize_node(s);
+}
+
/********** Multi-node pooling **********/
static kad_node_t *kad_pooling_general(int op, int n, kad_node_t **x)
return 0;
}
+/* Multi-head dot-product attention pooling.
+ * child[0]: x, (batch, n_words * dim) -- zero-padded sequence of word vectors
+ * child[1]: q, (n_heads, dim) -- learned query vectors
+ * output: (batch, n_heads * dim) -- per-head attention-weighted sums
+ *
+ * Per head h and sample b:
+ * s_l = dot(q_h, x_l) / sqrt(dim) (all-zero x_l is padding: s_l = -inf)
+ * alpha = softmax(s)
+ * out_h = sum_l alpha_l * x_l
+ * Attention weights are saved in gtmp at the forward pass for the backward
+ * pass. Padded positions get alpha = 0 and contribute no gradient. */
+int kad_op_attn_pool(kad_node_t *p, int action)
+{
+ kad_node_t *xn = p->child[0], *qn = p->child[1];
+ int L = *(int32_t *) p->ptr;
+ int H = qn->d[0], D = qn->d[1];
+ int B = xn->d[0];
+ float scale = 1.0f / sqrtf((float) D);
+ int b, h, l, d;
+
+ if (action == KAD_SYNC_DIM) {
+ if (xn->n_d != 2 || qn->n_d != 2) return -1;
+ if (xn->d[1] != L * D) return -1;
+ p->n_d = 2, p->d[0] = xn->d[0], p->d[1] = H * D;
+ }
+ else if (action == KAD_ALLOC) {
+ if (kad_is_back(xn) || kad_is_back(qn))
+ p->gtmp = g_realloc(p->gtmp, (size_t) xn->d[0] * H * L * sizeof(float));
+ }
+ else if (action == KAD_FORWARD) {
+ float *alpha_all = (float *) p->gtmp;
+ float *s = (float *) g_malloc(L * sizeof(float));
+ for (b = 0; b < B; ++b) {
+ const float *xb = &xn->x[(size_t) b * L * D];
+ for (h = 0; h < H; ++h) {
+ const float *qh = &qn->x[(size_t) h * D];
+ float *ob = &p->x[(size_t) b * H * D + (size_t) h * D];
+ float max, sum;
+ for (l = 0; l < L; ++l) {
+ const float *xl = &xb[(size_t) l * D];
+ int nonzero = 0;
+ for (d = 0; d < D; ++d)
+ if (xl[d] != 0.0f) {
+ nonzero = 1;
+ break;
+ }
+ s[l] = nonzero ? scale * kad_sdot(D, qh, xl) : -FLT_MAX;
+ }
+ for (l = 0, max = -FLT_MAX; l < L; ++l)
+ max = max > s[l] ? max : s[l];
+ for (l = 0, sum = 0.0f; l < L; ++l) {
+ /* fully padded sample: keep alpha uniform, output stays 0 */
+ s[l] = expf(s[l] - (max == -FLT_MAX ? 0.0f : max));
+ sum += s[l];
+ }
+ for (l = 0, sum = 1.0f / sum; l < L; ++l) s[l] *= sum;
+ memset(ob, 0, D * sizeof(float));
+ for (l = 0; l < L; ++l)
+ if (s[l] != 0.0f)
+ kad_saxpy(D, s[l], &xb[(size_t) l * D], ob);
+ if (alpha_all)
+ memcpy(&alpha_all[((size_t) b * H + h) * L], s, L * sizeof(float));
+ }
+ }
+ g_free(s);
+ }
+ else if (action == KAD_BACKWARD && (kad_is_back(xn) || kad_is_back(qn))) {
+ const float *alpha_all = (const float *) p->gtmp;
+ float *dalpha = (float *) g_malloc(2 * L * sizeof(float));
+ float *ds = dalpha + L;
+ for (b = 0; b < B; ++b) {
+ const float *xb = &xn->x[(size_t) b * L * D];
+ for (h = 0; h < H; ++h) {
+ const float *qh = &qn->x[(size_t) h * D];
+ const float *alpha = &alpha_all[((size_t) b * H + h) * L];
+ const float *gb = &p->g[(size_t) b * H * D + (size_t) h * D];
+ float sdot = 0.0f;
+ for (l = 0; l < L; ++l)
+ dalpha[l] = kad_sdot(D, gb, &xb[(size_t) l * D]);
+ for (l = 0; l < L; ++l)
+ sdot += alpha[l] * dalpha[l];
+ for (l = 0; l < L; ++l)
+ ds[l] = alpha[l] * (dalpha[l] - sdot);
+ if (kad_is_back(qn)) {
+ float *qg = &qn->g[(size_t) h * D];
+ for (l = 0; l < L; ++l)
+ if (ds[l] != 0.0f)
+ kad_saxpy(D, scale * ds[l], &xb[(size_t) l * D], qg);
+ }
+ if (kad_is_back(xn)) {
+ float *xg = &xn->g[(size_t) b * L * D];
+ for (l = 0; l < L; ++l) {
+ if (alpha[l] != 0.0f)
+ kad_saxpy(D, alpha[l], gb, &xg[(size_t) l * D]);
+ if (ds[l] != 0.0f)
+ kad_saxpy(D, scale * ds[l], qh, &xg[(size_t) l * D]);
+ }
+ }
+ }
+ }
+ g_free(dalpha);
+ }
+ return 0;
+}
+
/********** List of operators **********/
kad_op_f kad_op_list[KAD_MAX_OP] = {
kad_op_sin, /* 34: sin() */
kad_op_stack, /* 35: tf.stack, but on the first axis only */
kad_op_reverse, /* 36: tf.reverse, but on one axis only */
- kad_op_gelu /* 37: GELU activation */
+ kad_op_gelu, /* 37: GELU activation */
+ kad_op_attn_pool /* 38: multi-head attention pooling */
};
char *kad_op_name[KAD_MAX_OP] = {
0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax",
"dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
- "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"};
+ "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu", "attn_pool"};
/**************************
*** Debugging routines ***
/* Simple macros to define behaviour */
#define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L)
-#define KANN_LAYER_INTERFACE(name) \
- { \
- #name, lua_kann_layer_##name \
- }
+#define KANN_LAYER_INTERFACE(name) \
+ { \
+ #name, lua_kann_layer_##name}
#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L)
-#define KANN_TRANSFORM_INTERFACE(name) \
- { \
- #name, lua_kann_transform_##name \
- }
+#define KANN_TRANSFORM_INTERFACE(name) \
+ { \
+ #name, lua_kann_transform_##name}
#define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L)
-#define KANN_LOSS_INTERFACE(name) \
- { \
- #name, lua_kann_loss_##name \
- }
+#define KANN_LOSS_INTERFACE(name) \
+ { \
+ #name, lua_kann_loss_##name}
#define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L)
-#define KANN_NEW_INTERFACE(name) \
- { \
- #name, lua_kann_new_##name \
- }
+#define KANN_NEW_INTERFACE(name) \
+ { \
+ #name, lua_kann_new_##name}
/*
KANN_LAYER_DEF(cost);
static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */
+static int lua_kann_layer_attn_pool(lua_State *L); /* forward declaration */
static luaL_reg rspamd_kann_layers_f[] = {
KANN_LAYER_INTERFACE(input),
KANN_LAYER_INTERFACE(dense),
KANN_LAYER_INTERFACE(layernorm),
+ {"attn_pool", lua_kann_layer_attn_pool}, /* manually registered - extra args */
{"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */
KANN_LAYER_INTERFACE(rnn),
KANN_LAYER_INTERFACE(lstm),
return 1;
}
+/***
+ * @function kann.layer.attn_pool(in, n_words[, n_heads[, flags]])
+ * Creates a multi-head attention pooling layer over a flattened sequence of
+ * zero-padded word vectors. The input dimension must be a multiple of
+ * n_words; the per-word dimension is derived as input_dim / n_words.
+ * Output dimension is n_heads * per-word dimension.
+ * @param {kann_node} in kann node, (batch, n_words * dim)
+ * @param {int} n_words number of word positions in the sequence
+ * @param {int} n_heads number of learned attention queries (default 4)
+ * @param {table|int} flags optional flags
+ * @return {kann_node} kann node object (should be used to combine ANN)
+*/
+static int
+lua_kann_layer_attn_pool(lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node(L, 1);
+ int n_words = luaL_checkinteger(L, 2);
+ int n_heads = luaL_optinteger(L, 3, 4);
+
+ if (in != NULL && n_words > 0 && n_heads > 0) {
+ kad_node_t *t;
+
+ t = kann_layer_attn_pool(in, n_words, n_heads);
+
+ if (t == NULL) {
+ return luaL_error(L, "invalid attn_pool: input dimension %d "
+ "is not a multiple of n_words %d",
+ in->n_d == 2 ? in->d[1] : -1, n_words);
+ }
+
+ PROCESS_KAD_FLAGS(t, 4);
+ PUSH_KAD_NODE(t);
+ }
+ else {
+ return luaL_error(L, "invalid arguments, input + n_words + n_heads required");
+ }
+
+ return 1;
+}
+
/***
* @function kann.layer.dropout(in [, flags])
* Creates a normalisation layer