#include "logger.h"
#include "contrib/librdns/rdns.h"
#include "contrib/mumhash/mum.h"
+#include "heap.h"
#include <math.h>
#include <netdb.h>
struct upstream_list_watcher *next, *prev;
};
+/* Heap element for token bucket selection */
+struct upstream_token_heap_entry {
+ unsigned int pri; /* Priority = inflight_tokens (lower = better) */
+ unsigned int idx; /* Heap index (managed by heap) */
+ struct upstream *up; /* Pointer to upstream */
+};
+
+RSPAMD_HEAP_DECLARE(upstream_token_heap, struct upstream_token_heap_entry);
+
struct upstream {
unsigned int weight;
unsigned int cur_weight;
gpointer data;
char uid[8];
ref_entry_t ref;
+
+ /* Token bucket fields for weighted load balancing */
+ gsize max_tokens; /* Maximum token capacity */
+ gsize available_tokens; /* Current available tokens */
+ gsize inflight_tokens; /* Tokens reserved by in-flight requests */
+ unsigned int heap_idx; /* Index in token heap (UINT_MAX if not in heap) */
#ifdef UPSTREAMS_THREAD_SAFE
rspamd_mutex_t *lock;
#endif
double probe_jitter;
unsigned int max_errors;
unsigned int dns_retransmits;
+
+ /* Token bucket configuration */
+ gsize token_bucket_max; /* Max tokens per upstream (default: 10000) */
+ gsize token_bucket_scale; /* Bytes per token (default: 1024) */
+ gsize token_bucket_min; /* Min tokens for selection (default: 1) */
+ gsize token_bucket_base_cost; /* Base cost per request (default: 10) */
};
struct upstream_list {
enum rspamd_upstream_flag flags;
unsigned int cur_elt;
enum rspamd_upstream_rotation rot_alg;
+
+ /* Token bucket heap for weighted selection */
+ upstream_token_heap_t token_heap;
+ gboolean token_bucket_initialized;
#ifdef UPSTREAMS_THREAD_SAFE
rspamd_mutex_t *lock;
#endif
#define DEFAULT_PROBE_JITTER 0.3
static const double default_probe_jitter = DEFAULT_PROBE_JITTER;
+/* Token bucket defaults */
+#define DEFAULT_TOKEN_BUCKET_MAX 10000
+#define DEFAULT_TOKEN_BUCKET_SCALE 1024
+#define DEFAULT_TOKEN_BUCKET_MIN 1
+#define DEFAULT_TOKEN_BUCKET_BASE_COST 10
+
static const struct upstream_limits default_limits = {
.revive_time = DEFAULT_REVIVE_TIME,
.revive_jitter = DEFAULT_REVIVE_JITTER,
.resolve_min_interval = DEFAULT_RESOLVE_MIN_INTERVAL,
.probe_max_backoff = DEFAULT_PROBE_MAX_BACKOFF,
.probe_jitter = DEFAULT_PROBE_JITTER,
+ .token_bucket_max = DEFAULT_TOKEN_BUCKET_MAX,
+ .token_bucket_scale = DEFAULT_TOKEN_BUCKET_SCALE,
+ .token_bucket_min = DEFAULT_TOKEN_BUCKET_MIN,
+ .token_bucket_base_cost = DEFAULT_TOKEN_BUCKET_BASE_COST,
};
static void rspamd_upstream_lazy_resolve_cb(struct ev_loop *, ev_timer *, int);
g_ptr_array_add(ls->alive, upstream);
upstream->active_idx = ls->alive->len - 1;
+ /* Initialize token bucket state */
+ upstream->heap_idx = UINT_MAX;
+ if (ls->rot_alg == RSPAMD_UPSTREAM_TOKEN_BUCKET) {
+ upstream->max_tokens = ls->limits->token_bucket_max;
+ upstream->available_tokens = upstream->max_tokens;
+ upstream->inflight_tokens = 0;
+
+ /* Add to token heap if already initialized */
+ if (ls->token_bucket_initialized) {
+ struct upstream_token_heap_entry entry;
+ entry.pri = 0;
+ entry.idx = 0;
+ entry.up = upstream;
+ rspamd_heap_push_safe(upstream_token_heap, &ls->token_heap, &entry, skip_heap);
+ upstream->heap_idx = rspamd_heap_size(upstream_token_heap, &ls->token_heap) - 1;
+ skip_heap:;
+ }
+ }
+
if (upstream->ctx && upstream->ctx->configured &&
!((upstream->flags & RSPAMD_UPSTREAM_FLAG_NORESOLVE) ||
(upstream->flags & RSPAMD_UPSTREAM_FLAG_DNS))) {
g_ptr_array_remove_index(ls->alive, upstream->active_idx);
upstream->active_idx = -1;
+ /* Remove from token bucket heap if present */
+ if (ls->token_bucket_initialized && upstream->heap_idx != UINT_MAX) {
+ struct upstream_token_heap_entry *entry;
+
+ RSPAMD_UPSTREAM_LOCK(upstream);
+
+ if (upstream->heap_idx < rspamd_heap_size(upstream_token_heap, &ls->token_heap)) {
+ entry = rspamd_heap_index(upstream_token_heap, &ls->token_heap, upstream->heap_idx);
+ if (entry && entry->up == upstream) {
+ rspamd_heap_remove(upstream_token_heap, &ls->token_heap, entry);
+ }
+ }
+ upstream->heap_idx = UINT_MAX;
+
+ /*
+ * Return inflight tokens to available pool - these represent
+ * requests that were in-flight when upstream failed. The tokens
+ * should be restored so they're available when upstream comes back.
+ */
+ if (upstream->inflight_tokens > 0) {
+ upstream->available_tokens += upstream->inflight_tokens;
+ if (upstream->available_tokens > upstream->max_tokens) {
+ upstream->available_tokens = upstream->max_tokens;
+ }
+ upstream->inflight_tokens = 0;
+ }
+
+ RSPAMD_UPSTREAM_UNLOCK(upstream);
+ }
+
/* We need to update all indices */
for (i = 0; i < ls->alive->len; i++) {
cur = g_ptr_array_index(ls->alive, i);
struct upstream_list_watcher *w, *tmp;
if (ups != NULL) {
+ /* Clean up token bucket heap */
+ if (ups->token_bucket_initialized) {
+ rspamd_heap_destroy(upstream_token_heap, &ups->token_heap);
+ ups->token_bucket_initialized = FALSE;
+ }
+
g_ptr_array_free(ups->alive, TRUE);
for (i = 0; i < ups->ups->len; i++) {
case RSPAMD_UPSTREAM_MASTER_SLAVE:
up = rspamd_upstream_get_round_robin(ups, except, FALSE);
break;
+ case RSPAMD_UPSTREAM_TOKEN_BUCKET:
+ /*
+ * Token bucket requires message size, which isn't available here.
+ * Fall back to round robin. Use rspamd_upstream_get_token_bucket()
+ * for proper token bucket selection.
+ */
+ up = rspamd_upstream_get_round_robin(ups, except, TRUE);
+ break;
case RSPAMD_UPSTREAM_SEQUENTIAL:
if (ups->cur_elt >= ups->alive->len) {
ups->cur_elt = 0;
DL_APPEND(ups->watchers, nw);
}
+enum rspamd_upstream_rotation
+rspamd_upstreams_get_rotation(struct upstream_list *ups)
+{
+ if (ups == NULL) {
+ return RSPAMD_UPSTREAM_UNDEF;
+ }
+ return ups->rot_alg;
+}
+
+void rspamd_upstreams_set_token_bucket(struct upstream_list *ups,
+ gsize max_tokens,
+ gsize scale_factor,
+ gsize min_tokens,
+ gsize base_cost)
+{
+ struct upstream_limits *nlimits;
+ g_assert(ups != NULL);
+
+ /* Allocate new limits if we have a pool, otherwise modify in place */
+ if (ups->ctx && ups->ctx->pool) {
+ nlimits = rspamd_mempool_alloc(ups->ctx->pool, sizeof(*nlimits));
+ memcpy(nlimits, ups->limits, sizeof(*nlimits));
+ }
+ else {
+ /* No pool, we need to be careful here */
+ nlimits = g_malloc(sizeof(*nlimits));
+ memcpy(nlimits, ups->limits, sizeof(*nlimits));
+ }
+
+ if (max_tokens > 0) {
+ nlimits->token_bucket_max = max_tokens;
+ }
+ if (scale_factor > 0) {
+ nlimits->token_bucket_scale = scale_factor;
+ }
+ if (min_tokens > 0) {
+ nlimits->token_bucket_min = min_tokens;
+ }
+ if (base_cost > 0) {
+ nlimits->token_bucket_base_cost = base_cost;
+ }
+
+ ups->limits = nlimits;
+}
+
+/*
+ * Calculate token cost for a message of given size
+ */
+static inline gsize
+rspamd_upstream_calculate_tokens(const struct upstream_limits *limits,
+ gsize message_size)
+{
+ return limits->token_bucket_base_cost +
+ (message_size / limits->token_bucket_scale);
+}
+
+/*
+ * Initialize token bucket heap for an upstream list (lazy initialization)
+ */
+static gboolean
+rspamd_upstream_token_bucket_init(struct upstream_list *ups)
+{
+ unsigned int i;
+ struct upstream *up;
+ struct upstream_token_heap_entry entry;
+
+ if (ups->token_bucket_initialized) {
+ return TRUE;
+ }
+
+ rspamd_heap_init(upstream_token_heap, &ups->token_heap);
+
+ /* Add all alive upstreams to the heap */
+ for (i = 0; i < ups->alive->len; i++) {
+ up = g_ptr_array_index(ups->alive, i);
+
+ /* Initialize token bucket state for this upstream */
+ up->max_tokens = ups->limits->token_bucket_max;
+ up->available_tokens = up->max_tokens;
+ up->inflight_tokens = 0;
+
+ /* Add to heap with priority = inflight_tokens (0 initially) */
+ entry.pri = 0;
+ entry.idx = 0;
+ entry.up = up;
+
+ rspamd_heap_push_safe(upstream_token_heap, &ups->token_heap, &entry, init_error);
+ up->heap_idx = rspamd_heap_size(upstream_token_heap, &ups->token_heap) - 1;
+ }
+
+ ups->token_bucket_initialized = TRUE;
+ return TRUE;
+
+init_error:
+ /* Heap allocation failed, destroy what we have */
+ rspamd_heap_destroy(upstream_token_heap, &ups->token_heap);
+ return FALSE;
+}
+
+/*
+ * Update heap position after changing inflight_tokens
+ */
+static void
+rspamd_upstream_token_heap_update(struct upstream_list *ups, struct upstream *up)
+{
+ struct upstream_token_heap_entry *entry;
+
+ if (!ups->token_bucket_initialized || up->heap_idx == UINT_MAX) {
+ return;
+ }
+
+ if (up->heap_idx >= rspamd_heap_size(upstream_token_heap, &ups->token_heap)) {
+ return;
+ }
+
+ entry = rspamd_heap_index(upstream_token_heap, &ups->token_heap, up->heap_idx);
+ if (entry && entry->up == up) {
+ /* Use rspamd_heap_update to adjust position based on new priority */
+ unsigned int new_pri = (unsigned int) MIN(up->inflight_tokens, UINT_MAX);
+ rspamd_heap_update(upstream_token_heap, &ups->token_heap, entry, new_pri);
+ }
+}
+
+/*
+ * Find upstream in heap by pointer (for removal or update after finding mismatch)
+ */
+static struct upstream_token_heap_entry *
+rspamd_upstream_find_in_heap(struct upstream_list *ups, struct upstream *up)
+{
+ unsigned int i;
+ struct upstream_token_heap_entry *entry;
+
+ for (i = 0; i < rspamd_heap_size(upstream_token_heap, &ups->token_heap); i++) {
+ entry = rspamd_heap_index(upstream_token_heap, &ups->token_heap, i);
+ if (entry && entry->up == up) {
+ up->heap_idx = i;
+ return entry;
+ }
+ }
+ return NULL;
+}
+
+struct upstream *
+rspamd_upstream_get_token_bucket(struct upstream_list *ups,
+ struct upstream *except,
+ gsize message_size,
+ gsize *reserved_tokens)
+{
+ struct upstream *selected = NULL;
+ struct upstream_token_heap_entry *entry;
+ gsize token_cost;
+ unsigned int i;
+ gsize min_inflight = G_MAXSIZE;
+ struct upstream *fallback = NULL;
+
+ if (ups == NULL || reserved_tokens == NULL) {
+ return NULL;
+ }
+
+ *reserved_tokens = 0;
+
+ RSPAMD_UPSTREAM_LOCK(ups);
+
+ /* Handle empty alive list same as other algorithms */
+ if (ups->alive->len == 0) {
+ RSPAMD_UPSTREAM_UNLOCK(ups);
+ return NULL;
+ }
+
+ /* Initialize token bucket if not done yet */
+ if (!ups->token_bucket_initialized) {
+ if (!rspamd_upstream_token_bucket_init(ups)) {
+ /* Fall back to round robin on init failure */
+ RSPAMD_UPSTREAM_UNLOCK(ups);
+ return rspamd_upstream_get_round_robin(ups, except, TRUE);
+ }
+ }
+
+ /* Calculate token cost for this message */
+ token_cost = rspamd_upstream_calculate_tokens(ups->limits, message_size);
+
+ /*
+ * Use heap property: the root (index 0) has minimum inflight_tokens.
+ * Check a few candidates from the top of the heap rather than scanning all.
+ */
+ unsigned int heap_size = rspamd_heap_size(upstream_token_heap, &ups->token_heap);
+ unsigned int candidates_checked = 0;
+ const unsigned int max_candidates = 8; /* Check up to 8 lowest-loaded upstreams */
+
+ for (i = 0; i < heap_size && candidates_checked < max_candidates; i++) {
+ entry = rspamd_heap_index(upstream_token_heap, &ups->token_heap, i);
+
+ if (entry == NULL || entry->up == NULL) {
+ continue;
+ }
+
+ struct upstream *up = entry->up;
+
+ /* Skip inactive upstreams */
+ if (up->active_idx < 0) {
+ continue;
+ }
+
+ /* Skip excluded upstream */
+ if (except && up == except) {
+ continue;
+ }
+
+ candidates_checked++;
+
+ /* Track upstream with minimum inflight for fallback */
+ if (up->inflight_tokens < min_inflight) {
+ min_inflight = up->inflight_tokens;
+ fallback = up;
+ }
+
+ /* Check if upstream has sufficient tokens */
+ if (up->available_tokens >= token_cost) {
+ selected = up;
+ break;
+ }
+ }
+
+ /* If no upstream has sufficient tokens, use the least loaded one */
+ if (selected == NULL && fallback != NULL) {
+ selected = fallback;
+ }
+
+ if (selected != NULL) {
+ /* Reserve tokens */
+ if (selected->available_tokens >= token_cost) {
+ selected->available_tokens -= token_cost;
+ }
+ else {
+ /* Clamp to 0 if we don't have enough */
+ selected->available_tokens = 0;
+ }
+ selected->inflight_tokens += token_cost;
+ *reserved_tokens = token_cost;
+
+ /* Update heap position */
+ rspamd_upstream_token_heap_update(ups, selected);
+
+ selected->checked++;
+ }
+
+ RSPAMD_UPSTREAM_UNLOCK(ups);
+
+ return selected;
+}
+
+void rspamd_upstream_return_tokens(struct upstream *up, gsize tokens, gboolean success)
+{
+ struct upstream_list *ls;
+
+ if (up == NULL || tokens == 0) {
+ return;
+ }
+
+ ls = up->ls;
+
+ /*
+ * Lock ordering: always lock list before upstream to prevent deadlocks.
+ * This is consistent with rspamd_upstream_get_token_bucket.
+ */
+ if (ls) {
+ RSPAMD_UPSTREAM_LOCK(ls);
+ }
+ RSPAMD_UPSTREAM_LOCK(up);
+
+ /* Return tokens from inflight */
+ if (up->inflight_tokens >= tokens) {
+ up->inflight_tokens -= tokens;
+ }
+ else {
+ up->inflight_tokens = 0;
+ }
+
+ /* Only restore available tokens on success */
+ if (success) {
+ up->available_tokens += tokens;
+ /* Cap at max tokens */
+ if (up->available_tokens > up->max_tokens) {
+ up->available_tokens = up->max_tokens;
+ }
+ }
+
+ /* Update heap position if we have a list */
+ if (ls && ls->token_bucket_initialized) {
+ rspamd_upstream_token_heap_update(ls, up);
+ }
+
+ RSPAMD_UPSTREAM_UNLOCK(up);
+ if (ls) {
+ RSPAMD_UPSTREAM_UNLOCK(ls);
+ }
+}
+
struct upstream *
rspamd_upstream_ref(struct upstream *up)
{
int parser_from_ref;
int parser_to_ref;
struct rspamd_task *task;
+ gsize reserved_tokens; /* Tokens reserved for this request (token bucket) */
};
enum rspamd_proxy_legacy_support {
up->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(elt));
}
+ /* Parse token_bucket configuration for weighted load balancing */
+ elt = ucl_object_lookup(obj, "token_bucket");
+ if (elt != NULL && ucl_object_type(elt) == UCL_OBJECT && up->u != NULL) {
+ gsize max_tokens = 10000, scale = 1024, min_tokens = 1, base_cost = 10;
+ const ucl_object_t *tb_elt;
+
+ if ((tb_elt = ucl_object_lookup(elt, "max_tokens")) != NULL) {
+ max_tokens = ucl_object_toint(tb_elt);
+ if (max_tokens == 0) {
+ msg_warn_pool("token_bucket.max_tokens must be > 0, using default 10000");
+ max_tokens = 10000;
+ }
+ }
+ if ((tb_elt = ucl_object_lookup(elt, "scale")) != NULL) {
+ scale = ucl_object_toint(tb_elt);
+ if (scale == 0) {
+ msg_warn_pool("token_bucket.scale cannot be 0 (division by zero), using default 1024");
+ scale = 1024;
+ }
+ }
+ if ((tb_elt = ucl_object_lookup(elt, "min_tokens")) != NULL) {
+ min_tokens = ucl_object_toint(tb_elt);
+ }
+ if ((tb_elt = ucl_object_lookup(elt, "base_cost")) != NULL) {
+ base_cost = ucl_object_toint(tb_elt);
+ }
+
+ /* Validate relationships */
+ if (min_tokens > max_tokens) {
+ msg_warn_pool("token_bucket.min_tokens (%zu) > max_tokens (%zu), clamping",
+ min_tokens, max_tokens);
+ min_tokens = max_tokens;
+ }
+ if (base_cost >= max_tokens) {
+ msg_warn_pool("token_bucket.base_cost (%zu) >= max_tokens (%zu), reducing to max/2",
+ base_cost, max_tokens);
+ base_cost = max_tokens / 2;
+ }
+
+ /* Enable token bucket rotation and configure parameters */
+ rspamd_upstreams_set_rotation(up->u, RSPAMD_UPSTREAM_TOKEN_BUCKET);
+ rspamd_upstreams_set_token_bucket(up->u, max_tokens, scale, min_tokens, base_cost);
+
+ msg_info_pool_check("upstream %s: token_bucket enabled (max=%zu, scale=%zu, min=%zu, base=%zu)",
+ up->name, max_tokens, scale, min_tokens, base_cost);
+ }
+
/*
* Accept lua function here in form
* fun :: String -> UCL
proxy_backend_close_connection(struct rspamd_proxy_backend_connection *conn)
{
if (conn && !(conn->flags & RSPAMD_BACKEND_CLOSED)) {
+ /* Return any reserved tokens if not already returned (safety net) */
+ if (conn->reserved_tokens > 0 && conn->up) {
+ rspamd_upstream_return_tokens(conn->up, conn->reserved_tokens, FALSE);
+ conn->reserved_tokens = 0;
+ }
+
if (conn->backend_conn) {
rspamd_http_connection_reset(conn->backend_conn);
rspamd_http_connection_unref(conn->backend_conn);
: "self-scan",
err,
session->ctx->max_retries - session->retries);
+
+ /* Return reserved tokens on error (token bucket load balancing) */
+ if (bk_conn->reserved_tokens > 0 && bk_conn->up) {
+ rspamd_upstream_return_tokens(bk_conn->up, bk_conn->reserved_tokens, FALSE);
+ bk_conn->reserved_tokens = 0;
+ }
+
rspamd_upstream_fail(bk_conn->up, FALSE, err ? err->message : "unknown");
proxy_backend_close_connection(session->master_conn);
}
}
+ /* Return reserved tokens on success (token bucket load balancing) */
+ if (bk_conn->reserved_tokens > 0 && bk_conn->up) {
+ rspamd_upstream_return_tokens(bk_conn->up, bk_conn->reserved_tokens, TRUE);
+ bk_conn->reserved_tokens = 0;
+ }
+
rspamd_upstream_ok(bk_conn->up);
/* Handle keepalive for master connection */
gpointer hash_key = rspamd_inet_address_get_hash_key(session->client_addr,
&hash_len);
- if (session->ctx->max_retries > 1 &&
- session->retries == session->ctx->max_retries) {
+ /* Initialize reserved_tokens to 0 */
+ session->master_conn->reserved_tokens = 0;
+
+ /* Check if token bucket algorithm is configured */
+ if (rspamd_upstreams_get_rotation(backend->u) == RSPAMD_UPSTREAM_TOKEN_BUCKET) {
+ /* Calculate message size for token bucket */
+ gsize message_size = 0;
+
+ if (session->map && session->map_len) {
+ message_size = session->map_len;
+ }
+ else if (session->client_message && session->client_message->body_buf.len > 0) {
+ message_size = session->client_message->body_buf.len;
+ }
+
+ /* Use token bucket selection */
+ session->master_conn->up = rspamd_upstream_get_token_bucket(
+ backend->u,
+ (session->retries > 0) ? session->master_conn->up : NULL,
+ message_size,
+ &session->master_conn->reserved_tokens);
+ }
+ else if (session->ctx->max_retries > 1 &&
+ session->retries == session->ctx->max_retries) {
session->master_conn->up = rspamd_upstream_get_except(backend->u,
session->master_conn->up,