]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Feature] Allow to add upstream watchers to Lua API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Dec 2018 14:31:54 +0000 (14:31 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 5 Dec 2018 14:31:54 +0000 (14:31 +0000)
src/libutil/upstream.c
src/libutil/upstream.h
src/lua/lua_upstream.c

index 90f792bbeca94161b25271b631dfee362f65e8ce..eb88e501a1ca1deb15fea66db138c3936a340d59 100644 (file)
@@ -36,6 +36,7 @@ struct upstream_addr_elt {
 
 struct upstream_list_watcher {
        rspamd_upstream_watch_func func;
+       GFreeFunc dtor;
        gpointer ud;
        enum rspamd_upstreams_watch_event events_mask;
        struct upstream_list_watcher *next, *prev;
@@ -879,6 +880,9 @@ rspamd_upstreams_destroy (struct upstream_list *ups)
                }
 
                DL_FOREACH_SAFE (ups->watchers, w, tmp) {
+                       if (w->dtor) {
+                               w->dtor (w->ud);
+                       }
                        g_free (w);
                }
 
@@ -1178,6 +1182,7 @@ rspamd_upstreams_set_limits (struct upstream_list *ups,
 void rspamd_upstreams_add_watch_callback (struct upstream_list *ups,
                                                                                  enum rspamd_upstreams_watch_event events,
                                                                                  rspamd_upstream_watch_func func,
+                                                                                 GFreeFunc dtor,
                                                                                  gpointer ud)
 {
        struct upstream_list_watcher *nw;
@@ -1188,6 +1193,7 @@ void rspamd_upstreams_add_watch_callback (struct upstream_list *ups,
        nw->func = func;
        nw->events_mask = events;
        nw->ud = ud;
+       nw->dtor = dtor;
 
        DL_APPEND (ups->watchers, nw);
 }
index 56d6fa6c569c859510d735fe4c3b4ad435ea9a12..5c0c92afc1dc6cfe1e8c5fa3b6b42426c4ac928d 100644 (file)
@@ -204,6 +204,7 @@ typedef void (*rspamd_upstream_watch_func) (struct upstream *up,
 void rspamd_upstreams_add_watch_callback (struct upstream_list *ups,
                                                                                  enum rspamd_upstreams_watch_event events,
                                                                                  rspamd_upstream_watch_func func,
+                                                                                 GFreeFunc free_func,
                                                                                  gpointer ud);
 
 /**
index 854bfafd928c75c6ba582bd3ce9e0c10e5975173..1a4d6b12899b1650abb43803da7aeb54d5d1284d 100644 (file)
@@ -56,6 +56,7 @@ LUA_FUNCTION_DEF (upstream_list, all_upstreams);
 LUA_FUNCTION_DEF (upstream_list, get_upstream_by_hash);
 LUA_FUNCTION_DEF (upstream_list, get_upstream_round_robin);
 LUA_FUNCTION_DEF (upstream_list, get_upstream_master_slave);
+LUA_FUNCTION_DEF (upstream_list, add_watcher);
 
 static const struct luaL_reg upstream_list_m[] = {
 
@@ -63,6 +64,7 @@ static const struct luaL_reg upstream_list_m[] = {
        LUA_INTERFACE_DEF (upstream_list, get_upstream_round_robin),
        LUA_INTERFACE_DEF (upstream_list, get_upstream_master_slave),
        LUA_INTERFACE_DEF (upstream_list, all_upstreams),
+       LUA_INTERFACE_DEF (upstream_list, add_watcher),
        {"__tostring", rspamd_lua_class_tostring},
        {"__gc", lua_upstream_list_destroy},
        {NULL, NULL}
@@ -290,7 +292,7 @@ lua_upstream_list_get_upstream_by_hash (lua_State *L)
                }
        }
        else {
-               lua_pushnil (L);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
@@ -322,7 +324,7 @@ lua_upstream_list_get_upstream_round_robin (lua_State *L)
                }
        }
        else {
-               lua_pushnil (L);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
@@ -356,7 +358,7 @@ lua_upstream_list_get_upstream_master_slave (lua_State *L)
                }
        }
        else {
-               lua_pushnil (L);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
@@ -390,12 +392,173 @@ lua_upstream_list_all_upstreams (lua_State *L)
                rspamd_upstreams_foreach (upl, lua_upstream_inserter, L);
        }
        else {
-               lua_pushnil (L);
+               return luaL_error (L, "invalid arguments");
        }
 
        return 1;
 }
 
+static inline enum rspamd_upstreams_watch_event
+lua_str_to_upstream_flag (const gchar *str)
+{
+       enum rspamd_upstreams_watch_event fl = 0;
+
+       if (strcmp (str, "success") == 0) {
+               fl = RSPAMD_UPSTREAM_WATCH_SUCCESS;
+       }
+       else if (strcmp (str, "failure") == 0) {
+               fl = RSPAMD_UPSTREAM_WATCH_FAILURE;
+       }
+       else if (strcmp (str, "online") == 0) {
+               fl = RSPAMD_UPSTREAM_WATCH_ONLINE;
+       }
+       else if (strcmp (str, "offline") == 0) {
+               fl = RSPAMD_UPSTREAM_WATCH_OFFLINE;
+       }
+       else {
+               msg_err ("invalid flag: %s", str);
+       }
+
+       return fl;
+}
+
+static inline const gchar *
+lua_upstream_flag_to_str (enum rspamd_upstreams_watch_event fl)
+{
+       const gchar *res = "unknown";
+
+       /* Works with single flags, not combinations */
+       if (fl & RSPAMD_UPSTREAM_WATCH_SUCCESS) {
+               res = "success";
+       }
+       else if (fl & RSPAMD_UPSTREAM_WATCH_FAILURE) {
+               res = "failure";
+       }
+       else if (fl & RSPAMD_UPSTREAM_WATCH_ONLINE) {
+               res = "online";
+       }
+       else if (fl & RSPAMD_UPSTREAM_WATCH_OFFLINE) {
+               res = "offline";
+       }
+       else {
+               msg_err ("invalid flag: %d", fl);
+       }
+
+       return res;
+}
+
+struct rspamd_lua_upstream_watcher_cbdata {
+       lua_State *L;
+       gint cbref;
+       struct upstream_list *upl;
+};
+
+static void
+lua_upstream_watch_func (struct upstream *up,
+                                                enum rspamd_upstreams_watch_event event,
+                                                guint cur_errors,
+                                                void *ud)
+{
+       struct rspamd_lua_upstream_watcher_cbdata *cdata =
+                       (struct rspamd_lua_upstream_watcher_cbdata *)ud;
+       lua_State *L;
+       struct upstream **pup;
+       const gchar *what;
+       gint err_idx;
+
+       L = cdata->L;
+       what = lua_upstream_flag_to_str (event);
+       lua_pushcfunction (L, &rspamd_lua_traceback);
+       err_idx = lua_gettop (L);
+
+       lua_rawgeti (L, LUA_REGISTRYINDEX, cdata->cbref);
+       lua_pushstring (L, what);
+       pup = lua_newuserdata (L, sizeof (*pup));
+       *pup = up;
+       rspamd_lua_setclass (L, "rspamd{upstream}", -1);
+       lua_pushinteger (L, cur_errors);
+
+       if (lua_pcall (L, 3, 0, err_idx) != 0) {
+               GString *tb = lua_touserdata (L, -1);
+               msg_err ("cannot call watch function for upstream: %s", tb->str);
+               g_string_free (tb, TRUE);
+               lua_settop (L, 0);
+
+               return;
+       }
+
+       lua_settop (L, 0);
+}
+
+static void
+lua_upstream_watch_dtor (gpointer ud)
+{
+       struct rspamd_lua_upstream_watcher_cbdata *cdata =
+                       (struct rspamd_lua_upstream_watcher_cbdata *)ud;
+
+       luaL_unref (cdata->L, LUA_REGISTRYINDEX, cdata->cbref);
+       g_free (cdata);
+}
+
+/***
+ * @method upstream_list:add_watcher(what, cb)
+ * Add new watcher to the upstream lists events (table or a string):
+ *   - `success` - called whenever upstream successfully used
+ *   - `failure` - called on upstream error
+ *   - `online` - called when upstream is being taken online from offline
+ *   - `offline` - called when upstream is being taken offline from online
+ * Callback is a function: function(what, upstream, cur_errors) ... end
+ * @example
+ups:add_watcher('success', function(what, up, cur_errors) ... end)
+ups:add_watcher({'online', 'offline'}, function(what, up, cur_errors) ... end)
+ * @return nothing
+ */
+static gint
+lua_upstream_list_add_watcher (lua_State *L)
+{
+       LUA_TRACE_POINT;
+       struct upstream_list *upl;
+
+       upl = lua_check_upstream_list (L);
+       if (upl &&
+               (lua_type (L, 2) == LUA_TTABLE ||  lua_type (L, 2) == LUA_TSTRING) &&
+               lua_type (L, 3) == LUA_TFUNCTION) {
+
+               enum rspamd_upstreams_watch_event flags = 0;
+               struct rspamd_lua_upstream_watcher_cbdata *cdata;
+
+               if (lua_type (L, 2) == LUA_TSTRING) {
+                       flags = lua_str_to_upstream_flag (lua_tostring (L, 2));
+               }
+               else {
+                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                               if (lua_isstring (L, -1)) {
+                                       flags |= lua_str_to_upstream_flag (lua_tostring (L, -1));
+                               }
+                               else {
+                                       lua_pop (L, 1);
+
+                                       return luaL_error (L, "invalid arguments");
+                               }
+                       }
+               }
+
+               cdata = g_malloc0 (sizeof (*cdata));
+               lua_pushvalue (L, 3); /* callback */
+               cdata->cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+               cdata->L = L;
+               cdata->upl = upl;
+
+               rspamd_upstreams_add_watch_callback (upl, flags,
+                               lua_upstream_watch_func, lua_upstream_watch_dtor, cdata);
+       }
+       else {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       return 0;
+}
+
 static gint
 lua_load_upstream_list (lua_State * L)
 {