]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
src: add support for multi-statement in dynamic sets and maps
authorPablo Neira Ayuso <pablo@netfilter.org>
Wed, 9 Dec 2020 15:46:26 +0000 (16:46 +0100)
committerPablo Neira Ayuso <pablo@netfilter.org>
Thu, 17 Dec 2020 16:24:49 +0000 (17:24 +0100)
This patch allows for two statements for dynamic set updates, e.g.

 nft rule x y add @y { ip daddr limit rate 1/second counter }

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
include/statement.h
src/evaluate.c
src/netlink.c
src/netlink_delinearize.c
src/netlink_linearize.c
src/parser_bison.y
src/statement.c

index f2fc6ade7734cd0dfbb32178459fefbe15c5852a..7637a82e4e00295dcabf9541f08d2486aaab30c8 100644 (file)
@@ -201,7 +201,7 @@ uint32_t fwd_stmt_type(const char *type);
 struct set_stmt {
        struct expr             *set;
        struct expr             *key;
-       struct stmt             *stmt;
+       struct list_head        stmt_list;
        enum nft_dynset_ops     op;
 };
 
@@ -213,7 +213,7 @@ struct map_stmt {
        struct expr             *set;
        struct expr             *key;
        struct expr             *data;
-       struct stmt             *stmt;
+       struct list_head        stmt_list;
        enum nft_dynset_ops     op;
 };
 
index e776cd01805113940a14649fa37019d11883e68e..03f060eb465a33b32d036bcb8891a4c44dd7e839 100644 (file)
@@ -3370,6 +3370,8 @@ static int stmt_evaluate_log(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct stmt *this;
+
        expr_set_context(&ctx->ectx, NULL, 0);
        if (expr_evaluate(ctx, &stmt->set.set) < 0)
                return -1;
@@ -3389,12 +3391,12 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt)
        if (stmt->set.key->comment != NULL)
                return expr_error(ctx->msgs, stmt->set.key,
                                  "Key expression comments are not supported");
-       if (stmt->set.stmt) {
-               if (stmt_evaluate(ctx, stmt->set.stmt) < 0)
+       list_for_each_entry(this, &stmt->set.stmt_list, list) {
+               if (stmt_evaluate(ctx, this) < 0)
                        return -1;
-               if (!(stmt->set.stmt->flags & STMT_F_STATEFUL))
-                       return stmt_binary_error(ctx, stmt->set.stmt, stmt,
-                                                "meter statement must be stateful");
+               if (!(this->flags & STMT_F_STATEFUL))
+                       return stmt_error(ctx, this,
+                                         "statement must be stateful");
        }
 
        return 0;
@@ -3402,6 +3404,8 @@ static int stmt_evaluate_set(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct stmt *this;
+
        expr_set_context(&ctx->ectx, NULL, 0);
        if (expr_evaluate(ctx, &stmt->map.set) < 0)
                return -1;
@@ -3435,12 +3439,12 @@ static int stmt_evaluate_map(struct eval_ctx *ctx, struct stmt *stmt)
                return expr_error(ctx->msgs, stmt->map.data,
                                  "Data expression comments are not supported");
 
-       if (stmt->map.stmt) {
-               if (stmt_evaluate(ctx, stmt->map.stmt) < 0)
+       list_for_each_entry(this, &stmt->map.stmt_list, list) {
+               if (stmt_evaluate(ctx, this) < 0)
                        return -1;
-               if (!(stmt->map.stmt->flags & STMT_F_STATEFUL))
-                       return stmt_binary_error(ctx, stmt->map.stmt, stmt,
-                                                "meter statement must be stateful");
+               if (!(this->flags & STMT_F_STATEFUL))
+                       return stmt_error(ctx, this,
+                                         "statement must be stateful");
        }
 
        return 0;
index 8098b9746c95309fc981993c87d101e99aee72c3..ab0290926eaf3246074c3a73c84dd00cde03a9ec 100644 (file)
@@ -1132,6 +1132,7 @@ key_end:
                key = bitmask_expr_to_binops(key);
 
        expr = set_elem_expr_alloc(&netlink_location, key);
+
        if (nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_TIMEOUT))
                expr->timeout    = nftnl_set_elem_get_u64(nlse, NFTNL_SET_ELEM_TIMEOUT);
        if (nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_EXPIRATION))
index 8b06c4c0985f524103493fdab8e901d3dd07854a..73150722841193b19e460f3880bd2108cd4d9672 100644 (file)
@@ -1489,17 +1489,47 @@ static void netlink_parse_queue(struct netlink_parse_ctx *ctx,
        ctx->stmt = stmt;
 }
 
+struct dynset_parse_ctx {
+       struct netlink_parse_ctx        *nlctx;
+       const struct location           *loc;
+       struct list_head                stmt_list;
+};
+
+static int dynset_parse_expressions(struct nftnl_expr *e, void *data)
+{
+       struct dynset_parse_ctx *dynset_parse_ctx = data;
+       struct netlink_parse_ctx *ctx = dynset_parse_ctx->nlctx;
+       const struct location *loc = dynset_parse_ctx->loc;
+       struct stmt *stmt;
+
+       if (netlink_parse_expr(e, ctx) < 0 || !ctx->stmt) {
+               netlink_error(ctx, loc, "Could not parse dynset stmt");
+               return -1;
+       }
+       stmt = ctx->stmt;
+
+       list_add_tail(&stmt->list, &dynset_parse_ctx->stmt_list);
+
+       return 0;
+}
+
 static void netlink_parse_dynset(struct netlink_parse_ctx *ctx,
                                 const struct location *loc,
                                 const struct nftnl_expr *nle)
 {
+       struct dynset_parse_ctx dynset_parse_ctx = {
+               .nlctx  = ctx,
+               .loc    = loc,
+       };
        struct expr *expr, *expr_data = NULL;
        enum nft_registers sreg, sreg_data;
+       struct stmt *stmt, *dstmt, *next;
        const struct nftnl_expr *dnle;
-       struct stmt *stmt, *dstmt;
        struct set *set;
        const char *name;
 
+       init_list_head(&dynset_parse_ctx.stmt_list);
+
        name = nftnl_expr_get_str(nle, NFTNL_EXPR_DYNSET_SET_NAME);
        set  = set_lookup(ctx->table, name);
        if (set == NULL)
@@ -1523,16 +1553,25 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx,
        expr = set_elem_expr_alloc(&expr->location, expr);
        expr->timeout = nftnl_expr_get_u64(nle, NFTNL_EXPR_DYNSET_TIMEOUT);
 
-       dstmt = NULL;
-       dnle = nftnl_expr_get(nle, NFTNL_EXPR_DYNSET_EXPR, NULL);
-       if (dnle != NULL) {
-               if (netlink_parse_expr(dnle, ctx) < 0)
-                       goto out_err;
-               if (ctx->stmt == NULL) {
-                       netlink_error(ctx, loc, "Could not parse dynset stmt");
-                       goto out_err;
+       if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_EXPR)) {
+               dstmt = NULL;
+               dnle = nftnl_expr_get(nle, NFTNL_EXPR_DYNSET_EXPR, NULL);
+               if (dnle != NULL) {
+                       if (netlink_parse_expr(dnle, ctx) < 0)
+                               goto out_err;
+                       if (ctx->stmt == NULL) {
+                               netlink_error(ctx, loc,
+                                             "Could not parse dynset stmt");
+                               goto out_err;
+                       }
+                       dstmt = ctx->stmt;
+                       list_add_tail(&dstmt->list,
+                                     &dynset_parse_ctx.stmt_list);
                }
-               dstmt = ctx->stmt;
+       } else if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS)) {
+               if (nftnl_expr_expr_foreach(nle, dynset_parse_expressions,
+                                           &dynset_parse_ctx) < 0)
+                       goto out_err;
        }
 
        if (nftnl_expr_is_set(nle, NFTNL_EXPR_DYNSET_SREG_DATA)) {
@@ -1546,27 +1585,34 @@ static void netlink_parse_dynset(struct netlink_parse_ctx *ctx,
                stmt->map.set   = set_ref_expr_alloc(loc, set);
                stmt->map.key   = expr;
                stmt->map.data  = expr_data;
-               stmt->map.stmt  = dstmt;
                stmt->map.op    = nftnl_expr_get_u32(nle, NFTNL_EXPR_DYNSET_OP);
+               list_splice_tail(&dynset_parse_ctx.stmt_list,
+                                &stmt->map.stmt_list);
        } else {
-               if (dstmt != NULL && set->flags & NFT_SET_ANONYMOUS) {
+               if (!list_empty(&dynset_parse_ctx.stmt_list) &&
+                   set->flags & NFT_SET_ANONYMOUS) {
                        stmt = meter_stmt_alloc(loc);
                        stmt->meter.set  = set_ref_expr_alloc(loc, set);
                        stmt->meter.key  = expr;
-                       stmt->meter.stmt = dstmt;
+                       stmt->meter.stmt = list_first_entry(&dynset_parse_ctx.stmt_list,
+                                                           struct stmt, list);
                        stmt->meter.size = set->desc.size;
                } else {
                        stmt = set_stmt_alloc(loc);
                        stmt->set.set   = set_ref_expr_alloc(loc, set);
                        stmt->set.op    = nftnl_expr_get_u32(nle, NFTNL_EXPR_DYNSET_OP);
                        stmt->set.key   = expr;
-                       stmt->set.stmt  = dstmt;
+                       list_splice_tail(&dynset_parse_ctx.stmt_list,
+                                        &stmt->set.stmt_list);
                }
        }
 
        ctx->stmt = stmt;
        return;
 out_err:
+       list_for_each_entry_safe(dstmt, next, &dynset_parse_ctx.stmt_list, list)
+               stmt_free(dstmt);
+
        xfree(expr);
 }
 
index 05af8bb1b4854a6d2c4025141ba3f5fdf9ec7580..09d0c61cfcc066f304ad7bd9c073ad9b45b4c55e 100644 (file)
@@ -1397,8 +1397,10 @@ static void netlink_gen_set_stmt(struct netlink_linearize_ctx *ctx,
                                 const struct stmt *stmt)
 {
        struct set *set = stmt->meter.set->set;
-       struct nftnl_expr *nle;
        enum nft_registers sreg_key;
+       struct nftnl_expr *nle;
+       int num_stmts = 0;
+       struct stmt *this;
 
        sreg_key = get_register(ctx, stmt->set.key->key);
        netlink_gen_expr(ctx, stmt->set.key->key, sreg_key);
@@ -1414,9 +1416,20 @@ static void netlink_gen_set_stmt(struct netlink_linearize_ctx *ctx,
        nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_SET_ID, set->handle.set_id);
        nft_rule_add_expr(ctx, nle, &stmt->location);
 
-       if (stmt->set.stmt)
-               nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR,
-                              netlink_gen_stmt_stateful(stmt->set.stmt), 0);
+       list_for_each_entry(this, &stmt->set.stmt_list, list)
+               num_stmts++;
+
+       if (num_stmts == 1) {
+               list_for_each_entry(this, &stmt->set.stmt_list, list) {
+                       nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR,
+                                      netlink_gen_stmt_stateful(this), 0);
+               }
+       } else if (num_stmts > 1) {
+               list_for_each_entry(this, &stmt->set.stmt_list, list) {
+                       nftnl_expr_add_expr(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS,
+                                           netlink_gen_stmt_stateful(this));
+               }
+       }
 }
 
 static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx,
@@ -1426,6 +1439,8 @@ static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx,
        enum nft_registers sreg_data;
        enum nft_registers sreg_key;
        struct nftnl_expr *nle;
+       int num_stmts = 0;
+       struct stmt *this;
 
        sreg_key = get_register(ctx, stmt->map.key);
        netlink_gen_expr(ctx, stmt->map.key, sreg_key);
@@ -1443,12 +1458,22 @@ static void netlink_gen_map_stmt(struct netlink_linearize_ctx *ctx,
        nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_OP, stmt->map.op);
        nftnl_expr_set_str(nle, NFTNL_EXPR_DYNSET_SET_NAME, set->handle.set.name);
        nftnl_expr_set_u32(nle, NFTNL_EXPR_DYNSET_SET_ID, set->handle.set_id);
+       nft_rule_add_expr(ctx, nle, &stmt->location);
 
-       if (stmt->map.stmt)
-               nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR,
-                              netlink_gen_stmt_stateful(stmt->map.stmt), 0);
+       list_for_each_entry(this, &stmt->map.stmt_list, list)
+               num_stmts++;
 
-       nft_rule_add_expr(ctx, nle, &stmt->location);
+       if (num_stmts == 1) {
+               list_for_each_entry(this, &stmt->map.stmt_list, list) {
+                       nftnl_expr_set(nle, NFTNL_EXPR_DYNSET_EXPR,
+                                      netlink_gen_stmt_stateful(this), 0);
+               }
+       } else if (num_stmts > 1) {
+               list_for_each_entry(this, &stmt->map.stmt_list, list) {
+                       nftnl_expr_add_expr(nle, NFTNL_EXPR_DYNSET_EXPRESSIONS,
+                                           netlink_gen_stmt_stateful(this));
+               }
+       }
 }
 
 static void netlink_gen_meter_stmt(struct netlink_linearize_ctx *ctx,
index 08aadaa32a866d61bf77f7636b4cf150bb48dbbd..673ce4ad40802765ede8167a6df5252375b17fc7 100644 (file)
@@ -624,8 +624,8 @@ int nft_lex(void *, void *, void *);
 %type <obj>                    obj_block_alloc counter_block quota_block ct_helper_block ct_timeout_block ct_expect_block limit_block secmark_block synproxy_block
 %destructor { obj_free($$); }  obj_block_alloc
 
-%type <list>                   stmt_list
-%destructor { stmt_list_free($$); xfree($$); } stmt_list
+%type <list>                   stmt_list stateful_stmt_list
+%destructor { stmt_list_free($$); xfree($$); } stmt_list stateful_stmt_list
 %type <stmt>                   stmt match_stmt verdict_stmt
 %destructor { stmt_free($$); } stmt match_stmt verdict_stmt
 %type <stmt>                   counter_stmt counter_stmt_alloc stateful_stmt
@@ -2656,6 +2656,19 @@ stmt_list                :       stmt
                        }
                        ;
 
+stateful_stmt_list     :       stateful_stmt
+                       {
+                               $$ = xmalloc(sizeof(*$$));
+                               init_list_head($$);
+                               list_add_tail(&$1->list, $$);
+                       }
+                       |       stateful_stmt_list      stateful_stmt
+                       {
+                               $$ = $1;
+                               list_add_tail(&$2->list, $1);
+                       }
+                       ;
+
 stateful_stmt          :       counter_stmt
                        |       limit_stmt
                        |       quota_stmt
@@ -3675,13 +3688,14 @@ set_stmt                :       SET     set_stmt_op     set_elem_expr_stmt      set_ref_expr
                                $$->set.key = $4;
                                $$->set.set = $2;
                        }
-                       |       set_stmt_op     set_ref_expr '{' set_elem_expr_stmt     stateful_stmt   '}'
+                       |       set_stmt_op     set_ref_expr '{' set_elem_expr_stmt     stateful_stmt_list      '}'
                        {
                                $$ = set_stmt_alloc(&@$);
                                $$->set.op  = $1;
                                $$->set.key = $4;
                                $$->set.set = $2;
-                               $$->set.stmt = $5;
+                               list_splice_tail($5, &$$->set.stmt_list);
+                               free($5);
                        }
                        ;
 
@@ -3698,14 +3712,15 @@ map_stmt                :       set_stmt_op     set_ref_expr '{' set_elem_expr_stmt     COLON   set_elem_expr_
                                $$->map.data = $6;
                                $$->map.set = $2;
                        }
-                       |       set_stmt_op     set_ref_expr '{' set_elem_expr_stmt     stateful_stmt COLON     set_elem_expr_stmt      '}'
+                       |       set_stmt_op     set_ref_expr '{' set_elem_expr_stmt     stateful_stmt_list      COLON   set_elem_expr_stmt      '}'
                        {
                                $$ = map_stmt_alloc(&@$);
                                $$->map.op  = $1;
                                $$->map.key = $4;
                                $$->map.data = $7;
-                               $$->map.stmt = $5;
                                $$->map.set = $2;
+                               list_splice_tail($5, &$$->map.stmt_list);
+                               free($5);
                        }
                        ;
 
index 6fe8e9d9beb4665b964a62e2ddf6231b9a87b0f5..39020857ae9c3e8df24e31230653dbc37fe67450 100644 (file)
@@ -732,15 +732,16 @@ const char * const set_stmt_op_names[] = {
 static void set_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
 {
        unsigned int flags = octx->flags;
+       struct stmt *this;
 
        nft_print(octx, "%s ", set_stmt_op_names[stmt->set.op]);
        expr_print(stmt->set.set, octx);
        nft_print(octx, " { ");
        expr_print(stmt->set.key, octx);
-       if (stmt->set.stmt) {
+       list_for_each_entry(this, &stmt->set.stmt_list, list) {
                nft_print(octx, " ");
                octx->flags |= NFT_CTX_OUTPUT_STATELESS;
-               stmt_print(stmt->set.stmt, octx);
+               stmt_print(this, octx);
                octx->flags = flags;
        }
        nft_print(octx, " }");
@@ -748,9 +749,12 @@ static void set_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
 
 static void set_stmt_destroy(struct stmt *stmt)
 {
+       struct stmt *this, *next;
+
        expr_free(stmt->set.key);
        expr_free(stmt->set.set);
-       stmt_free(stmt->set.stmt);
+       list_for_each_entry_safe(this, next, &stmt->set.stmt_list, list)
+               stmt_free(this);
 }
 
 static const struct stmt_ops set_stmt_ops = {
@@ -763,21 +767,27 @@ static const struct stmt_ops set_stmt_ops = {
 
 struct stmt *set_stmt_alloc(const struct location *loc)
 {
-       return stmt_alloc(loc, &set_stmt_ops);
+       struct stmt *stmt;
+
+       stmt = stmt_alloc(loc, &set_stmt_ops);
+       init_list_head(&stmt->set.stmt_list);
+
+       return stmt;
 }
 
 static void map_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
 {
        unsigned int flags = octx->flags;
+       struct stmt *this;
 
        nft_print(octx, "%s ", set_stmt_op_names[stmt->map.op]);
        expr_print(stmt->map.set, octx);
        nft_print(octx, " { ");
        expr_print(stmt->map.key, octx);
-       if (stmt->map.stmt) {
+       list_for_each_entry(this, &stmt->map.stmt_list, list) {
                nft_print(octx, " ");
                octx->flags |= NFT_CTX_OUTPUT_STATELESS;
-               stmt_print(stmt->map.stmt, octx);
+               stmt_print(this, octx);
                octx->flags = flags;
        }
        nft_print(octx, " : ");
@@ -787,10 +797,13 @@ static void map_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
 
 static void map_stmt_destroy(struct stmt *stmt)
 {
+       struct stmt *this, *next;
+
        expr_free(stmt->map.key);
        expr_free(stmt->map.data);
        expr_free(stmt->map.set);
-       stmt_free(stmt->map.stmt);
+       list_for_each_entry_safe(this, next, &stmt->map.stmt_list, list)
+               stmt_free(this);
 }
 
 static const struct stmt_ops map_stmt_ops = {
@@ -802,7 +815,12 @@ static const struct stmt_ops map_stmt_ops = {
 
 struct stmt *map_stmt_alloc(const struct location *loc)
 {
-       return stmt_alloc(loc, &map_stmt_ops);
+       struct stmt *stmt;
+
+       stmt = stmt_alloc(loc, &map_stmt_ops);
+       init_list_head(&stmt->map.stmt_list);
+
+       return stmt;
 }
 
 static void dup_stmt_print(const struct stmt *stmt, struct output_ctx *octx)