KAD_FUNC_OP1(kad_sigm, 6)
KAD_FUNC_OP1(kad_tanh, 7)
KAD_FUNC_OP1(kad_relu, 8)
+KAD_FUNC_OP1(kad_gelu, 37)
KAD_FUNC_OP1(kad_1minus, 11)
KAD_FUNC_OP1(kad_softmax, 14)
KAD_FUNC_OP1(kad_stdnorm, 32)
return 0;
}
+/* GELU: Gaussian Error Linear Unit
+ * Forward: GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
+ * Backward: GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi)
+ */
+#ifndef M_SQRT1_2
+#define M_SQRT1_2 0.70710678118654752440f /* 1/sqrt(2) */
+#endif
+#ifndef M_2_SQRTPI
+#define M_2_SQRTPI 1.12837916709551257390f /* 2/sqrt(pi) */
+#endif
+
+int kad_op_gelu(kad_node_t *p, int action)
+{
+ int i, n;
+ kad_node_t *q = p->child[0];
+ n = kad_len(q);
+ if (action == KAD_SYNC_DIM) {
+ kad_copy_dim1(p, q);
+ }
+ else if (action == KAD_FORWARD) {
+ for (i = 0; i < n; ++i) {
+ float x = q->x[i];
+ /* GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2))) */
+ p->x[i] = 0.5f * x * (1.0f + erff(x * (float) M_SQRT1_2));
+ }
+ }
+ else if (action == KAD_BACKWARD && kad_is_back(q)) {
+ for (i = 0; i < n; ++i) {
+ float x = q->x[i];
+ /* GELU'(x) = 0.5 * (1 + erf(x/sqrt(2))) + x * exp(-x^2/2) / sqrt(2*pi)
+ * = 0.5 * (1 + erf(x/sqrt(2))) + x * 0.5 * M_2_SQRTPI * M_SQRT1_2 * exp(-x^2/2)
+ */
+ float cdf = 0.5f * (1.0f + erff(x * (float) M_SQRT1_2));
+ float pdf = 0.5f * (float) M_2_SQRTPI * (float) M_SQRT1_2 * expf(-0.5f * x * x);
+ q->g[i] += p->g[i] * (cdf + x * pdf);
+ }
+ }
+ return 0;
+}
+
int kad_op_sin(kad_node_t *p, int action)
{
int i, n;
} \
_row_func(_xx, _ww, _yy, w->d[3], p->d[3], aux[1].stride, aux[1].pad[0], (_tmp)); \
} /* ~i */ \
- } /* ~k, c0, c1, n */ \
+ } /* ~k, c0, c1, n */ \
} while (0)
#define conv2d_loop2(_x, _w, _y, _code) \
_xx = x_padded; \
} \
for (j = 0; j < p->d[3]; ++j, _xx += j_skip, ++_yy) _code; /* output and input column */ \
- } /* ~i */ \
- } /* ~k, c1, n */ \
+ } /* ~i */ \
+ } /* ~k, c1, n */ \
} while (0)
conv_conf_t *aux = (conv_conf_t *) p->ptr;
if (p->x[u + j] < q->x[v])
p->x[u + j] = q->x[v], f[u + j] = v;
} /* ~k */
- } /* ~i */
+ } /* ~i */
}
}
else if (action == KAD_BACKWARD) {
kad_op_exp, /* 33: exp() */
kad_op_sin, /* 34: sin() */
kad_op_stack, /* 35: tf.stack, but on the first axis only */
- kad_op_reverse /* 36: tf.reverse, but on one axis only */
+ kad_op_reverse, /* 36: tf.reverse, but on one axis only */
+ kad_op_gelu /* 37: GELU activation */
};
char *kad_op_name[KAD_MAX_OP] = {
0, "add", "mul", "cmul", "ce_bin_neg", "square", "sigm", "tanh", "relu", "matmul", "avg", "1minus", "select", "ce_multi", "softmax",
"dropout", "conv2d", "max2d", "conv1d", "max1d", "slice", "max", "ce_bin", "sub", "sample_normal", "reduce_sum", "reduce_mean", "log",
- "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse"};
+ "avg1d", "mse", "reshape", "concat", "stdnorm", "exp", "sin", "stack", "reverse", "gelu"};
/**************************
*** Debugging routines ***
KANN_LAYER_DEF(conv1d);
KANN_LAYER_DEF(cost);
+static int lua_kann_layer_layerdropout(lua_State *L); /* forward declaration */
+
static luaL_reg rspamd_kann_layers_f[] = {
KANN_LAYER_INTERFACE(input),
KANN_LAYER_INTERFACE(dense),
KANN_LAYER_INTERFACE(layernorm),
+ {"dropout", lua_kann_layer_layerdropout}, /* manually registered - different naming */
KANN_LAYER_INTERFACE(rnn),
KANN_LAYER_INTERFACE(lstm),
KANN_LAYER_INTERFACE(gru),
KANN_TRANSFORM_DEF(sigm);
KANN_TRANSFORM_DEF(tanh);
KANN_TRANSFORM_DEF(relu);
+KANN_TRANSFORM_DEF(gelu);
KANN_TRANSFORM_DEF(softmax);
KANN_TRANSFORM_DEF(1minus);
KANN_TRANSFORM_DEF(exp);
KANN_TRANSFORM_INTERFACE(sigm),
KANN_TRANSFORM_INTERFACE(tanh),
KANN_TRANSFORM_INTERFACE(relu),
+ KANN_TRANSFORM_INTERFACE(gelu),
KANN_TRANSFORM_INTERFACE(softmax),
KANN_TRANSFORM_INTERFACE(1minus),
KANN_TRANSFORM_INTERFACE(exp),
LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm)
LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh)
LUA_UNARY_TRANSFORM_FUNC_IMPL(relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL(gelu)
LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax)
LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus)
LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)