]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Add GELU activation and expose dropout in KANN bindings
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 14:20:51 +0000 (14:20 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 20 Jan 2026 14:20:51 +0000 (14:20 +0000)
- Implement GELU (Gaussian Error Linear Unit) activation function
  using erf: GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
- Add proper forward and backward passes for GELU
- Register GELU as operation #37 in kad_op_list
- Expose dropout layer to Lua (function existed but wasn't registered)
- Add Lua bindings for rspamd_kann.transform.gelu

GELU is often better than ReLU for transformer-like architectures
and high-dimensional embedding inputs.

contrib/kann/kautodiff.c
contrib/kann/kautodiff.h
src/lua/lua_kann.c

index 551d5486166611e6345a38d5910e423960d77722..8e20675ad6891245e2a7138a38401761e4a738f2 100644 (file)
@@ -152,6 +152,7 @@ KAD_FUNC_OP1(kad_square, 5)
 KAD_FUNC_OP1(kad_sigm, 6)
 KAD_FUNC_OP1(kad_tanh, 7)
 KAD_FUNC_OP1(kad_relu, 8)
+KAD_FUNC_OP1(kad_gelu, 37)
 KAD_FUNC_OP1(kad_1minus, 11)
 KAD_FUNC_OP1(kad_softmax, 14)
 KAD_FUNC_OP1(kad_stdnorm, 32)
@@ -1918,6 +1919,46 @@ int kad_op_relu(kad_node_t *p, int action)
        return 0;
 }
 
+/* GELU: Gaussian Error Linear Unit
+ * Forward:  GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
+ * Backward: GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi)
+ */
+#ifndef M_SQRT1_2
+#define M_SQRT1_2 0.70710678118654752440f /* 1/sqrt(2) */
+#endif
+#ifndef M_2_SQRTPI
+#define M_2_SQRTPI 1.12837916709551257390f /* 2/sqrt(pi) */
+#endif
+
+int kad_op_gelu(kad_node_t *p, int action)
+{
+       int i, n;
+       kad_node_t *q = p->child[0];
+       n = kad_len(q);
+       if (action == KAD_SYNC_DIM) {
+               kad_copy_dim1(p, q);
+       }
+       else if (action == KAD_FORWARD) {
+               for (i = 0; i < n; ++i) {
+                       float x = q->x[i];
+                       /* GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) */
+                       p->x[i] = 0.5f * x * (1.0f + erff(x * (float) M_SQRT1_2));
+               }
+       }
+       else if (action == KAD_BACKWARD && kad_is_back(q)) {
+               for (i = 0; i < n; ++i) {
+                       float x = q->x[i];
+                       /* GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi)
+                        *          = 0.5 * (1 + erf(x/sqrt(2))) + x * 0.5 * M_2_SQRTPI * M_SQRT1_2 * exp(-x^2/2)
+                        */
+                       float cdf = 0.5f * (1.0f + erff(x * (float) M_SQRT1_2));
+                       float pdf = 0.5f * (float) M_2_SQRTPI * (float) M_SQRT1_2 * expf(-0.5f * x * x);
+                       q->g[i] += p->g[i] * (cdf + x * pdf);
+               }
+       }
+       return 0;
+}
+
 int kad_op_sin(kad_node_t *p, int action)
 {
        int i, n;
@@ -2189,7 +2230,7 @@ int kad_op_conv2d(kad_node_t *p, int action) /* in the number-channel-height-wid
                                                        }                                                                                                                    \
                                                        _row_func(_xx, _ww, _yy, w->d[3], p->d[3], aux[1].stride, aux[1].pad[0], (_tmp));                                    \
                                                } /* ~i */                                                                                                               \
-                                       }     /* ~k, c0, c1, n */                                                                                                    \
+                                       } /* ~k, c0, c1, n */                                                                                                        \
        } while (0)
 
 #define conv2d_loop2(_x, _w, _y, _code)                                                                                                                \
@@ -2207,8 +2248,8 @@ int kad_op_conv2d(kad_node_t *p, int action) /* in the number-channel-height-wid
                                                        _xx = x_padded;                                                                                                            \
                                                }                                                                                                                              \
                                                for (j = 0; j < p->d[3]; ++j, _xx += j_skip, ++_yy) _code; /* output and input column */                                       \
-                                       }                                                              /* ~i */                                                            \
-                               }                                                                  /* ~k, c1, n */                                                     \
+                                       } /* ~i */                                                                                                                         \
+                               } /* ~k, c1, n */                                                                                                                      \
        } while (0)
 
        conv_conf_t *aux = (conv_conf_t *) p->ptr;
@@ -2314,7 +2355,7 @@ int kad_op_max2d(kad_node_t *p, int action)
                                                        if (p->x[u + j] < q->x[v])
                                                                p->x[u + j] = q->x[v], f[u + j] = v;
                                } /* ~k */
-                       }     /* ~i */
+                       } /* ~i */
                }
        }
        else if (action == KAD_BACKWARD) {
@@ -2567,13 +2608,14 @@ kad_op_f kad_op_list[KAD_MAX_OP] = {
        kad_op_exp,           /* 33: exp() */
        kad_op_sin,           /* 34: sin() */
        kad_op_stack,         /* 35: tf.stack, but on the first axis only */
-       kad_op_reverse        /* 36: tf.reverse, but on one axis only */
+       kad_op_reverse,       /* 36: tf.reverse, but on one axis only */
+       kad_op_gelu           /* 37: GELU activation */
 };
 
 char *kad_op_name[KAD_MAX_OP] = {
        0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax",
        "dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
-       "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse"};
+       "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"};
 
 /**************************
  *** Debugging routines ***
index d7e713304149f69af0c9fb605994afbc08d14a2b..9723201e3cc67716b03b34d13427765c822b6c60 100644 (file)
@@ -176,6 +176,7 @@ kad_node_t *kad_square(kad_node_t *x); /* f(x) = x^2                         (el
 kad_node_t *kad_sigm(kad_node_t *x);   /* f(x) = 1/(1+exp(-x))               (element-wise sigmoid) */
 kad_node_t *kad_tanh(kad_node_t *x);   /* f(x) = (1-exp(-2x)) / (1+exp(-2x)) (element-wise tanh) */
 kad_node_t *kad_relu(kad_node_t *x);   /* f(x) = max{0,x}                    (element-wise rectifier, aka ReLU) */
+kad_node_t *kad_gelu(kad_node_t *x);   /* f(x) = 0.5*x*(1+erf(x/sqrt(2)))    (element-wise GELU) */
 kad_node_t *kad_softmax(kad_node_t *x);/* f_i(x_1,...,x_n) = exp(x_i) / \sum_j exp(x_j) (softmax: tf.nn.softmax(x,dim=-1)) */
 kad_node_t *kad_1minus(kad_node_t *x); /* f(x) = 1 - x */
 kad_node_t *kad_exp(kad_node_t *x);    /* f(x) = exp(x) */
index eadc8b06cdcd8ad2be6849c60ccf528d11d9e3f1..9772aecd39164cc91b6e24db3cd4192a31ebdb55 100644 (file)
@@ -68,10 +68,13 @@ KANN_LAYER_DEF(conv2d);
 KANN_LAYER_DEF(conv1d);
 KANN_LAYER_DEF(cost);
 
+static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */
+
 static luaL_reg rspamd_kann_layers_f[] = {
        KANN_LAYER_INTERFACE(input),
        KANN_LAYER_INTERFACE(dense),
        KANN_LAYER_INTERFACE(layernorm),
+       {"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */
        KANN_LAYER_INTERFACE(rnn),
        KANN_LAYER_INTERFACE(lstm),
        KANN_LAYER_INTERFACE(gru),
@@ -94,6 +97,7 @@ KANN_TRANSFORM_DEF(square);
 KANN_TRANSFORM_DEF(sigm);
 KANN_TRANSFORM_DEF(tanh);
 KANN_TRANSFORM_DEF(relu);
+KANN_TRANSFORM_DEF(gelu);
 KANN_TRANSFORM_DEF(softmax);
 KANN_TRANSFORM_DEF(1minus);
 KANN_TRANSFORM_DEF(exp);
@@ -110,6 +114,7 @@ static luaL_reg rspamd_kann_transform_f[] = {
        KANN_TRANSFORM_INTERFACE(sigm),
        KANN_TRANSFORM_INTERFACE(tanh),
        KANN_TRANSFORM_INTERFACE(relu),
+       KANN_TRANSFORM_INTERFACE(gelu),
        KANN_TRANSFORM_INTERFACE(softmax),
        KANN_TRANSFORM_INTERFACE(1minus),
        KANN_TRANSFORM_INTERFACE(exp),
@@ -706,6 +711,7 @@ LUA_UNARY_TRANSFORM_FUNC_IMPL(square)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(gelu)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus)
 LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)