From: Vsevolod Stakhov Date: Tue, 20 Jan 2026 14:20:51 +0000 (+0000) Subject: [Feature] Add GELU activation and expose dropout in KANN bindings X-Git-Tag: 4.0.0~179^2~12 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=24fcf1f82beb35638e302aac080a1db3e43b14cf;p=thirdparty%2Frspamd.git [Feature] Add GELU activation and expose dropout in KANN bindings - Implement GELU (Gaussian Error Linear Unit) activation function using erf: GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) - Add proper forward and backward passes for GELU - Register GELU as operation #37 in kad_op_list - Expose dropout layer to Lua (function existed but wasn't registered) - Add Lua bindings for rspamd_kann.transform.gelu GELU is often better than ReLU for transformer-like architectures and high-dimensional embedding inputs. --- diff --git a/contrib/kann/kautodiff.c b/contrib/kann/kautodiff.c index 551d548616..8e20675ad6 100644 --- a/contrib/kann/kautodiff.c +++ b/contrib/kann/kautodiff.c @@ -152,6 +152,7 @@ KAD_FUNC_OP1(kad_square, 5) KAD_FUNC_OP1(kad_sigm, 6) KAD_FUNC_OP1(kad_tanh, 7) KAD_FUNC_OP1(kad_relu, 8) +KAD_FUNC_OP1(kad_gelu, 37) KAD_FUNC_OP1(kad_1minus, 11) KAD_FUNC_OP1(kad_softmax, 14) KAD_FUNC_OP1(kad_stdnorm, 32) @@ -1918,6 +1919,46 @@ int kad_op_relu(kad_node_t *p, int action) return 0; } +/* GELU: Gaussian Error Linear Unit + * Forward: GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) + * Backward: GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi) + */ +#ifndef M_SQRT1_2 +#define M_SQRT1_2 0.70710678118654752440f /* 1/sqrt(2) */ +#endif +#ifndef M_2_SQRTPI +#define M_2_SQRTPI 1.12837916709551257390f /* 2/sqrt(pi) */ +#endif + +int kad_op_gelu(kad_node_t *p, int action) +{ + int i, n; + kad_node_t *q = p->child[0]; + n = kad_len(q); + if (action == KAD_SYNC_DIM) { + kad_copy_dim1(p, q); + } + else if (action == KAD_FORWARD) { + for (i = 0; i < n; ++i) { + float x = q->x[i]; + /* GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) */ + p->x[i] = 0.5f * x * (1.0f + erff(x * (float) M_SQRT1_2)); + } + } + else if (action == KAD_BACKWARD && kad_is_back(q)) { + for (i = 0; i < n; ++i) { + float x = q->x[i]; + /* GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi) + * = 0.5 * (1 + erf(x/sqrt(2))) + x * 0.5 * M_2_SQRTPI * M_SQRT1_2 * exp(-x^2/2) + */ + float cdf = 0.5f * (1.0f + erff(x * (float) M_SQRT1_2)); + float pdf = 0.5f * (float) M_2_SQRTPI * (float) M_SQRT1_2 * expf(-0.5f * x * x); + q->g[i] += p->g[i] * (cdf + x * pdf); + } + } + return 0; +} + int kad_op_sin(kad_node_t *p, int action) { int i, n; @@ -2189,7 +2230,7 @@ int kad_op_conv2d(kad_node_t *p, int action) /* in the number-channel-height-wid } \ _row_func(_xx, _ww, _yy, w->d[3], p->d[3], aux[1].stride, aux[1].pad[0], (_tmp)); \ } /* ~i */ \ - } /* ~k, c0, c1, n */ \ + } /* ~k, c0, c1, n */ \ } while (0) #define conv2d_loop2(_x, _w, _y, _code) \ @@ -2207,8 +2248,8 @@ int kad_op_conv2d(kad_node_t *p, int action) /* in the number-channel-height-wid _xx = x_padded; \ } \ for (j = 0; j < p->d[3]; ++j, _xx += j_skip, ++_yy) _code; /* output and input column */ \ - } /* ~i */ \ - } /* ~k, c1, n */ \ + } /* ~i */ \ + } /* ~k, c1, n */ \ } while (0) conv_conf_t *aux = (conv_conf_t *) p->ptr; @@ -2314,7 +2355,7 @@ int kad_op_max2d(kad_node_t *p, int action) if (p->x[u + j] < q->x[v]) p->x[u + j] = q->x[v], f[u + j] = v; } /* ~k */ - } /* ~i */ + } /* ~i */ } } else if (action == KAD_BACKWARD) { @@ -2567,13 +2608,14 @@ kad_op_f kad_op_list[KAD_MAX_OP] = { kad_op_exp, /* 33: exp() */ kad_op_sin, /* 34: sin() */ kad_op_stack, /* 35: tf.stack, but on the first axis only */ - kad_op_reverse /* 36: tf.reverse, but on one axis only */ + kad_op_reverse, /* 36: tf.reverse, but on one axis only */ + kad_op_gelu /* 37: GELU activation */ }; char *kad_op_name[KAD_MAX_OP] = { 0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax", "dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log", - "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse"}; + "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"}; /************************** *** Debugging routines *** diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h index d7e7133041..9723201e3c 100644 --- a/contrib/kann/kautodiff.h +++ b/contrib/kann/kautodiff.h @@ -176,6 +176,7 @@ kad_node_t *kad_square(kad_node_t *x); /* f(x) = x^2 (el kad_node_t *kad_sigm(kad_node_t *x); /* f(x) = 1/(1+exp(-x)) (element-wise sigmoid) */ kad_node_t *kad_tanh(kad_node_t *x); /* f(x) = (1-exp(-2x)) / (1+exp(-2x)) (element-wise tanh) */ kad_node_t *kad_relu(kad_node_t *x); /* f(x) = max{0,x} (element-wise rectifier, aka ReLU) */ +kad_node_t *kad_gelu(kad_node_t *x); /* f(x) = 0.5*x*(1+erf(x/sqrt(2))) (element-wise GELU) */ kad_node_t *kad_softmax(kad_node_t *x);/* f_i(x_1,...,x_n) = exp(x_i) / \sum_j exp(x_j) (softmax: tf.nn.softmax(x,dim=-1)) */ kad_node_t *kad_1minus(kad_node_t *x); /* f(x) = 1 - x */ kad_node_t *kad_exp(kad_node_t *x); /* f(x) = exp(x) */ diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c index eadc8b06cd..9772aecd39 100644 --- a/src/lua/lua_kann.c +++ b/src/lua/lua_kann.c @@ -68,10 +68,13 @@ KANN_LAYER_DEF(conv2d); KANN_LAYER_DEF(conv1d); KANN_LAYER_DEF(cost); +static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */ + static luaL_reg rspamd_kann_layers_f[] = { KANN_LAYER_INTERFACE(input), KANN_LAYER_INTERFACE(dense), KANN_LAYER_INTERFACE(layernorm), + {"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */ KANN_LAYER_INTERFACE(rnn), KANN_LAYER_INTERFACE(lstm), KANN_LAYER_INTERFACE(gru), @@ -94,6 +97,7 @@ KANN_TRANSFORM_DEF(square); KANN_TRANSFORM_DEF(sigm); KANN_TRANSFORM_DEF(tanh); KANN_TRANSFORM_DEF(relu); +KANN_TRANSFORM_DEF(gelu); KANN_TRANSFORM_DEF(softmax); KANN_TRANSFORM_DEF(1minus); KANN_TRANSFORM_DEF(exp); @@ -110,6 +114,7 @@ static luaL_reg rspamd_kann_transform_f[] = { KANN_TRANSFORM_INTERFACE(sigm), KANN_TRANSFORM_INTERFACE(tanh), KANN_TRANSFORM_INTERFACE(relu), + KANN_TRANSFORM_INTERFACE(gelu), KANN_TRANSFORM_INTERFACE(softmax), KANN_TRANSFORM_INTERFACE(1minus), KANN_TRANSFORM_INTERFACE(exp), @@ -706,6 +711,7 @@ LUA_UNARY_TRANSFORM_FUNC_IMPL(square) LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm) LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh) LUA_UNARY_TRANSFORM_FUNC_IMPL(relu) +LUA_UNARY_TRANSFORM_FUNC_IMPL(gelu) LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax) LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus) LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)