]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
src: replace struct stmt_ops by type field in struct stmt
authorPablo Neira Ayuso <pablo@netfilter.org>
Mon, 17 Mar 2025 22:19:49 +0000 (23:19 +0100)
committerPablo Neira Ayuso <pablo@netfilter.org>
Tue, 18 Mar 2025 15:37:47 +0000 (16:37 +0100)
Shrink struct stmt in 8 bytes.

__stmt_ops_by_type() provides an operation for STMT_INVALID since this
is required by -o/--optimize.

There are many checks for stmt->ops->type, which is the most accessed
field, that can be trivially replaced.

BUG() uses statement type enum instead of name.

Similar to:

 68e76238749f ("src: expr: add and use expr_name helper").
 72931553828a ("src: expr: add expression etype")
 2cc91e6198e7 ("src: expr: add and use internal expr_ops helper")

Acked-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
17 files changed:
include/ct.h
include/exthdr.h
include/meta.h
include/payload.h
include/statement.h
src/ct.c
src/evaluate.c
src/exthdr.c
src/json.c
src/meta.c
src/netlink_delinearize.c
src/netlink_linearize.c
src/optimize.c
src/parser_bison.y
src/payload.c
src/rule.c
src/statement.c

index 0a705fd06ee14b0a8fc2da42ac0ddb49dd5e00b5..bb9193d8fc5070b01cf56149665e5e19fd2460aa 100644 (file)
@@ -42,4 +42,8 @@ extern const struct datatype ct_status_type;
 extern const struct datatype ct_label_type;
 extern const struct datatype ct_event_type;
 
+extern const struct stmt_ops ct_stmt_ops;
+extern const struct stmt_ops notrack_stmt_ops;
+extern const struct stmt_ops flow_offload_stmt_ops;
+
 #endif /* NFTABLES_CT_H */
index 084daba5358f7416790fd17589f7fe2c96deaef6..98494e4d5bf750d0f7813def9d1b348e80434630 100644 (file)
@@ -117,4 +117,6 @@ extern const struct exthdr_desc exthdr_dst;
 extern const struct exthdr_desc exthdr_mh;
 extern const struct datatype mh_type_type;
 
+extern const struct stmt_ops exthdr_stmt_ops;
+
 #endif /* NFTABLES_EXTHDR_H */
index af2d772bb6a0d0d9057635d94316f72d794d9af7..84937ba3a1fe70019f3cec26dc6915b4fd209697 100644 (file)
@@ -47,4 +47,6 @@ extern const struct datatype day_type;
 
 bool lhs_is_meta_hour(const struct expr *meta);
 
+extern const struct stmt_ops meta_stmt_ops;
+
 #endif /* NFTABLES_META_H */
index 20304252e3f65e063249387b5cf0654dc68dc209..e14fc0f24477f18d5596a87d42f2ad944c3269b5 100644 (file)
@@ -74,4 +74,6 @@ bool payload_expr_cmp(const struct expr *e1, const struct expr *e2);
 
 const struct proto_desc *find_proto_desc(const struct nftnl_udata *ud);
 
+extern const struct stmt_ops payload_stmt_ops;
+
 #endif /* NFTABLES_PAYLOAD_H */
index 9376911bb377938e9631383b39fd80f6517d0478..e8724dde63d0428192f7397ef6110a5b08763eae 100644 (file)
@@ -372,16 +372,16 @@ enum stmt_flags {
  * struct stmt
  *
  * @list:      rule list node
- * @ops:       statement ops
  * @location:  location where the statement was defined
  * @flags:     statement flags
+ * @type:      statement type
  * @union:     type specific data
  */
 struct stmt {
        struct list_head                list;
-       const struct stmt_ops           *ops;
        struct location                 location;
        enum stmt_flags                 flags;
+       enum stmt_types                 type:8;
 
        union {
                struct expr             *expr;
@@ -420,6 +420,8 @@ int stmt_dependency_evaluate(struct eval_ctx *ctx, struct stmt *stmt);
 extern void stmt_free(struct stmt *stmt);
 extern void stmt_list_free(struct list_head *list);
 extern void stmt_print(const struct stmt *stmt, struct output_ctx *octx);
+const char *stmt_name(const struct stmt *stmt);
+const struct stmt_ops *stmt_ops(const struct stmt *stmt);
 
 const char *get_rate(uint64_t byte_rate, uint64_t *rate);
 const char *get_unit(uint64_t u);
index 6793464859cade676ec457ca79a792a49c8129ad..4d71a4b0103b9080be4e0af959364e2467164523 100644 (file)
--- a/src/ct.c
+++ b/src/ct.c
@@ -530,7 +530,7 @@ static void ct_stmt_destroy(struct stmt *stmt)
        expr_free(stmt->ct.expr);
 }
 
-static const struct stmt_ops ct_stmt_ops = {
+const struct stmt_ops ct_stmt_ops = {
        .type           = STMT_CT,
        .name           = "ct",
        .print          = ct_stmt_print,
@@ -557,7 +557,7 @@ static void notrack_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
        nft_print(octx, "notrack");
 }
 
-static const struct stmt_ops notrack_stmt_ops = {
+const struct stmt_ops notrack_stmt_ops = {
        .type           = STMT_NOTRACK,
        .name           = "notrack",
        .print          = notrack_stmt_print,
@@ -580,7 +580,7 @@ static void flow_offload_stmt_destroy(struct stmt *stmt)
        free_const(stmt->flow.table_name);
 }
 
-static const struct stmt_ops flow_offload_stmt_ops = {
+const struct stmt_ops flow_offload_stmt_ops = {
        .type           = STMT_FLOW_OFFLOAD,
        .name           = "flow_offload",
        .print          = flow_offload_stmt_print,
index f1f7ddaab991ab29bdb77352fb8f532cef4d5356..95b9b3d547d99dc41f0301a2da0933a571e8c77c 100644 (file)
@@ -74,9 +74,9 @@ static int __fmtstring(3, 4) set_error(struct eval_ctx *ctx,
        return -1;
 }
 
-static const char *stmt_name(const struct stmt *stmt)
+const char *stmt_name(const struct stmt *stmt)
 {
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_NAT:
                switch (stmt->nat.type) {
                case NFT_NAT_SNAT:
@@ -93,7 +93,7 @@ static const char *stmt_name(const struct stmt *stmt)
                break;
        }
 
-       return stmt->ops->name;
+       return stmt_ops(stmt)->name;
 }
 
 static int stmt_error_range(struct eval_ctx *ctx, const struct stmt *stmt, const struct expr *e)
@@ -573,7 +573,7 @@ static int expr_evaluate_bits(struct eval_ctx *ctx, struct expr **exprp)
         * require the transformations that are needed for payload matching,
         * skip this.
         */
-       if (ctx->stmt && ctx->stmt->ops->type == STMT_PAYLOAD)
+       if (ctx->stmt && ctx->stmt->type == STMT_PAYLOAD)
                return 0;
 
        switch (expr->etype) {
@@ -790,7 +790,7 @@ static int stmt_dep_conflict(struct eval_ctx *ctx, const struct stmt *nstmt)
                if (stmt == nstmt)
                        break;
 
-               if (stmt->ops->type != STMT_EXPRESSION ||
+               if (stmt->type != STMT_EXPRESSION ||
                    stmt->expr->etype != EXPR_RELATIONAL ||
                    stmt->expr->right->etype != EXPR_VALUE ||
                    stmt->expr->left->etype != EXPR_PAYLOAD ||
@@ -1841,13 +1841,13 @@ static int __expr_evaluate_set_elem(struct eval_ctx *ctx, struct expr *elem)
                set_stmt = list_first_entry(&set->stmt_list, struct stmt, list);
 
                list_for_each_entry(elem_stmt, &elem->stmt_list, list) {
-                       if (set_stmt->ops != elem_stmt->ops) {
+                       if (set_stmt->type != elem_stmt->type) {
                                return stmt_error(ctx, elem_stmt,
                                                  "statement mismatch, element expects %s, "
                                                  "but %s has type %s",
-                                                 elem_stmt->ops->name,
+                                                 stmt_name(elem_stmt),
                                                  set_is_map(set->flags) ? "map" : "set",
-                                                 set_stmt->ops->name);
+                                                 stmt_name(set_stmt));
                        }
                        set_stmt = list_next_entry(set_stmt, list);
                }
@@ -4126,7 +4126,7 @@ static int stmt_evaluate_l3proto(struct eval_ctx *ctx,
                                         "conflicting protocols specified: %s vs. %s. You must specify ip or ip6 family in %s statement",
                                         pctx->protocol[PROTO_BASE_NETWORK_HDR].desc->name,
                                         family2str(family),
-                                        stmt->ops->name);
+                                        stmt_name(stmt));
        return 0;
 }
 
@@ -4854,7 +4854,7 @@ int stmt_evaluate(struct eval_ctx *ctx, struct stmt *stmt)
        if (ctx->nft->debug_mask & NFT_DEBUG_EVALUATION) {
                struct error_record *erec;
                erec = erec_create(EREC_INFORMATIONAL, &stmt->location,
-                                  "Evaluate %s", stmt->ops->name);
+                                  "Evaluate %s", stmt_name(stmt));
                erec_print(&ctx->nft->output, erec, ctx->nft->debug_mask);
                stmt_print(stmt, &ctx->nft->output);
                nft_print(&ctx->nft->output, "\n\n");
@@ -4863,7 +4863,7 @@ int stmt_evaluate(struct eval_ctx *ctx, struct stmt *stmt)
 
        ctx->stmt_len = 0;
 
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_CONNLIMIT:
        case STMT_COUNTER:
        case STMT_LAST:
@@ -4913,7 +4913,7 @@ int stmt_evaluate(struct eval_ctx *ctx, struct stmt *stmt)
        case STMT_OPTSTRIP:
                return stmt_evaluate_optstrip(ctx, stmt);
        default:
-               BUG("unknown statement type %s\n", stmt->ops->name);
+               BUG("unknown statement type %d\n", stmt->type);
        }
 }
 
index 1438d7e2d2dce5282561b83afbb6de3a0e8b3e11..c7d876a45aabd6e2ac2fa5462bc8baa78222100e 100644 (file)
@@ -269,7 +269,7 @@ static void exthdr_stmt_destroy(struct stmt *stmt)
        expr_free(stmt->exthdr.val);
 }
 
-static const struct stmt_ops exthdr_stmt_ops = {
+const struct stmt_ops exthdr_stmt_ops = {
        .type           = STMT_EXTHDR,
        .name           = "exthdr",
        .print          = exthdr_stmt_print,
index 64a6888f9e0ac4bc49d2896c2c1a1fa0efcf31fe..96413d70895ae6ddbab99888f4f44606a500de7a 100644 (file)
@@ -109,19 +109,20 @@ static json_t *set_key_dtype_json(const struct set *set,
 
 static json_t *stmt_print_json(const struct stmt *stmt, struct output_ctx *octx)
 {
+       const struct stmt_ops *ops = stmt_ops(stmt);
        char buf[1024];
        FILE *fp;
 
-       if (stmt->ops->json)
-               return stmt->ops->json(stmt, octx);
+       if (ops->json)
+               return ops->json(stmt, octx);
 
        fprintf(stderr, "warning: stmt ops %s have no json callback\n",
-               stmt->ops->name);
+               ops->name);
 
        fp = octx->output_fp;
        octx->output_fp = fmemopen(buf, 1024, "w");
 
-       stmt->ops->print(stmt, octx);
+       ops->print(stmt, octx);
 
        fclose(octx->output_fp);
        octx->output_fp = fp;
index a17bacf07d0e79c06053b0152e58cb0706da782a..1010209d3152526afa529a1ce818e81ca2ccfbf4 100644 (file)
@@ -952,7 +952,7 @@ static void meta_stmt_destroy(struct stmt *stmt)
        expr_free(stmt->meta.expr);
 }
 
-static const struct stmt_ops meta_stmt_ops = {
+const struct stmt_ops meta_stmt_ops = {
        .type           = STMT_META,
        .name           = "meta",
        .print          = meta_stmt_print,
index ae14065c00d6cae9616c62c8b63d4adaec2ca1a3..ae1ee53f6e7c6a2d5452c83baebe0bae1b36dd2d 100644 (file)
@@ -3121,7 +3121,7 @@ static void stmt_expr_postprocess(struct rule_pp_ctx *ctx)
        expr_postprocess(ctx, &ctx->stmt->expr);
 
        if (dl->pdctx.prev && ctx->stmt &&
-           ctx->stmt->ops->type == dl->pdctx.prev->ops->type &&
+           ctx->stmt->type == dl->pdctx.prev->type &&
            expr_may_merge_range(ctx->stmt->expr, dl->pdctx.prev->expr, &op))
                expr_postprocess_range(ctx, op);
 }
@@ -3404,7 +3404,7 @@ static struct dl_proto_ctx *rule_update_dl_proto_ctx(struct rule_pp_ctx *rctx)
        const struct stmt *stmt = rctx->stmt;
        bool inner = false;
 
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_EXPRESSION:
                if (has_inner_desc(stmt->expr->left))
                        inner = true;
@@ -3438,7 +3438,7 @@ static void rule_parse_postprocess(struct netlink_parse_ctx *ctx, struct rule *r
        proto_ctx_init(&rctx._dl[1].pctx, NFPROTO_BRIDGE, ctx->debug_mask, true);
 
        list_for_each_entry_safe(stmt, next, &rule->stmts, list) {
-               enum stmt_types type = stmt->ops->type;
+               enum stmt_types type = stmt->type;
 
                rctx.stmt = stmt;
                dl = rule_update_dl_proto_ctx(&rctx);
index 598ddfab5827bc7baf5e1e570fd1b328cb23c73c..5f73183bf19a03f3e05a2d8e1101bd77cc2ff6ca 100644 (file)
@@ -1046,7 +1046,7 @@ static struct nftnl_expr *netlink_gen_last_stmt(const struct stmt *stmt)
 
 struct nftnl_expr *netlink_gen_stmt_stateful(const struct stmt *stmt)
 {
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_CONNLIMIT:
                return netlink_gen_connlimit_stmt(stmt);
        case STMT_COUNTER:
@@ -1058,7 +1058,7 @@ struct nftnl_expr *netlink_gen_stmt_stateful(const struct stmt *stmt)
        case STMT_LAST:
                return netlink_gen_last_stmt(stmt);
        default:
-               BUG("unknown stateful statement type %s\n", stmt->ops->name);
+               BUG("unknown stateful statement type %d\n", stmt->type);
        }
 }
 
@@ -1694,7 +1694,7 @@ static void netlink_gen_stmt(struct netlink_linearize_ctx *ctx,
 {
        struct nftnl_expr *nle;
 
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_EXPRESSION:
                return netlink_gen_expr(ctx, stmt->expr, NFT_REG_VERDICT);
        case STMT_VERDICT:
@@ -1748,7 +1748,7 @@ static void netlink_gen_stmt(struct netlink_linearize_ctx *ctx,
        case STMT_OPTSTRIP:
                return netlink_gen_optstrip_stmt(ctx, stmt);
        default:
-               BUG("unknown statement type %s\n", stmt->ops->name);
+               BUG("unknown statement type %d\n", stmt->type);
        }
 }
 
index 230fe4a23de378e90c84940d4785b8465f8f4b62..05d8084b2a47564de375dbee72e6940c7a306994 100644 (file)
@@ -164,10 +164,10 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b,
 {
        struct expr *expr_a, *expr_b;
 
-       if (stmt_a->ops->type != stmt_b->ops->type)
+       if (stmt_a->type != stmt_b->type)
                return false;
 
-       switch (stmt_a->ops->type) {
+       switch (stmt_a->type) {
        case STMT_EXPRESSION:
                expr_a = stmt_a->expr;
                expr_b = stmt_b->expr;
@@ -324,7 +324,7 @@ static bool stmt_verdict_eq(const struct stmt *stmt_a, const struct stmt *stmt_b
 {
        struct expr *expr_a, *expr_b;
 
-       assert (stmt_a->ops->type == STMT_VERDICT);
+       assert (stmt_a->type == STMT_VERDICT);
 
        expr_a = stmt_a->expr;
        expr_b = stmt_b->expr;
@@ -345,14 +345,14 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
        uint32_t i;
 
        for (i = 0; i < ctx->num_stmts; i++) {
-               if (ctx->stmt[i]->ops->type == STMT_INVALID)
+               if (ctx->stmt[i]->type == STMT_INVALID)
                        unsupported_exists = true;
 
                if (__stmt_type_eq(stmt, ctx->stmt[i], false))
                        return true;
        }
 
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_EXPRESSION:
        case STMT_VERDICT:
        case STMT_COUNTER:
@@ -371,13 +371,9 @@ static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
        return false;
 }
 
-static struct stmt_ops unsupported_stmt_ops = {
-       .type   = STMT_INVALID,
-       .name   = "unsupported",
-};
-
 static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
 {
+       const struct stmt_ops *ops;
        struct stmt *stmt, *clone;
 
        list_for_each_entry(stmt, &rule->stmts, list) {
@@ -387,16 +383,17 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
                /* No refcounter available in statement objects, clone it to
                 * to store in the array of selectors.
                 */
-               clone = stmt_alloc(&internal_location, stmt->ops);
-               switch (stmt->ops->type) {
+               ops = stmt_ops(stmt);
+               clone = stmt_alloc(&internal_location, ops);
+               switch (stmt->type) {
                case STMT_EXPRESSION:
                        if (stmt->expr->op != OP_IMPLICIT &&
                            stmt->expr->op != OP_EQ) {
-                               clone->ops = &unsupported_stmt_ops;
+                               clone->type = STMT_INVALID;
                                break;
                        }
                        if (stmt->expr->left->etype == EXPR_CONCAT) {
-                               clone->ops = &unsupported_stmt_ops;
+                               clone->type = STMT_INVALID;
                                break;
                        }
                        /* fall-through */
@@ -418,7 +415,7 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
                            (stmt->nat.proto &&
                             (stmt->nat.proto->etype == EXPR_MAP ||
                              stmt->nat.proto->etype == EXPR_VARIABLE))) {
-                               clone->ops = &unsupported_stmt_ops;
+                               clone->type = STMT_INVALID;
                                break;
                        }
                        clone->nat.type = stmt->nat.type;
@@ -438,7 +435,7 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
                        clone->reject.family = stmt->reject.family;
                        break;
                default:
-                       clone->ops = &unsupported_stmt_ops;
+                       clone->type = STMT_INVALID;
                        break;
                }
 
@@ -455,7 +452,7 @@ static int unsupported_in_stmt_matrix(const struct optimize_ctx *ctx)
        uint32_t i;
 
        for (i = 0; i < ctx->num_stmts; i++) {
-               if (ctx->stmt[i]->ops->type == STMT_INVALID)
+               if (ctx->stmt[i]->type == STMT_INVALID)
                        return i;
        }
        /* this should not happen. */
@@ -475,7 +472,7 @@ static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *s
 }
 
 static struct stmt unsupported_stmt = {
-       .ops    = &unsupported_stmt_ops,
+       .type   = STMT_INVALID,
 };
 
 static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx,
@@ -502,7 +499,7 @@ static int stmt_verdict_find(const struct optimize_ctx *ctx)
        uint32_t i;
 
        for (i = 0; i < ctx->num_stmts; i++) {
-               if (ctx->stmt[i]->ops->type != STMT_VERDICT)
+               if (ctx->stmt[i]->type != STMT_VERDICT)
                        continue;
 
                return i;
@@ -569,7 +566,7 @@ static void merge_verdict_stmts(const struct optimize_ctx *ctx,
 
        for (i = from + 1; i <= to; i++) {
                stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
-               switch (stmt_b->ops->type) {
+               switch (stmt_b->type) {
                case STMT_VERDICT:
                        switch (stmt_b->expr->etype) {
                        case EXPR_MAP:
@@ -591,7 +588,7 @@ static void merge_stmts(const struct optimize_ctx *ctx,
 {
        struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
 
-       switch (stmt_a->ops->type) {
+       switch (stmt_a->type) {
        case STMT_EXPRESSION:
                merge_expr_stmts(ctx, from, to, merge, stmt_a);
                break;
@@ -762,7 +759,7 @@ static void remove_counter(const struct optimize_ctx *ctx, uint32_t from)
                if (!stmt)
                        continue;
 
-               if (stmt->ops->type == STMT_COUNTER) {
+               if (stmt->type == STMT_COUNTER) {
                        list_del(&stmt->list);
                        stmt_free(stmt);
                }
@@ -780,7 +777,7 @@ static struct stmt *zap_counter(const struct optimize_ctx *ctx, uint32_t from)
                if (!stmt)
                        continue;
 
-               if (stmt->ops->type == STMT_COUNTER) {
+               if (stmt->type == STMT_COUNTER) {
                        list_del(&stmt->list);
                        return stmt;
                }
@@ -937,7 +934,7 @@ static int stmt_nat_type(const struct optimize_ctx *ctx, int from,
                if (!ctx->stmt_matrix[from][j])
                        continue;
 
-               if (ctx->stmt_matrix[from][j]->ops->type == STMT_NAT) {
+               if (ctx->stmt_matrix[from][j]->type == STMT_NAT) {
                        *nat_type = ctx->stmt_matrix[from][j]->nat.type;
                        return 0;
                }
@@ -955,7 +952,7 @@ static int stmt_nat_find(const struct optimize_ctx *ctx, int from)
                return -1;
 
        for (i = 0; i < ctx->num_stmts; i++) {
-               if (ctx->stmt[i]->ops->type != STMT_NAT ||
+               if (ctx->stmt[i]->type != STMT_NAT ||
                    ctx->stmt[i]->nat.type != nat_type)
                        continue;
 
@@ -969,7 +966,7 @@ static struct expr *stmt_nat_expr(struct stmt *nat_stmt)
 {
        struct expr *nat_expr;
 
-       assert(nat_stmt->ops->type == STMT_NAT);
+       assert(nat_stmt->type == STMT_NAT);
 
        if (nat_stmt->nat.proto) {
                if (nat_stmt->nat.addr) {
@@ -1153,7 +1150,7 @@ static uint32_t merge_stmt_type(const struct optimize_ctx *ctx,
                        stmt = ctx->stmt_matrix[i][j];
                        if (!stmt)
                                continue;
-                       if (stmt->ops->type == STMT_NAT) {
+                       if (stmt->type == STMT_NAT) {
                                if ((stmt->nat.type == NFT_NAT_REDIR &&
                                     !stmt->nat.proto) ||
                                    stmt->nat.type == NFT_NAT_MASQ)
@@ -1250,7 +1247,7 @@ static bool stmt_is_mergeable(const struct stmt *stmt)
        if (!stmt)
                return false;
 
-       switch (stmt->ops->type) {
+       switch (stmt->type) {
        case STMT_VERDICT:
                if (stmt->expr->etype == EXPR_MAP)
                        return true;
@@ -1346,7 +1343,7 @@ static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
                for (m = 0; m < ctx->num_stmts; m++) {
                        if (!ctx->stmt_matrix[i][m])
                                continue;
-                       switch (ctx->stmt_matrix[i][m]->ops->type) {
+                       switch (ctx->stmt_matrix[i][m]->type) {
                        case STMT_EXPRESSION:
                                merge[k].stmt[merge[k].num_stmts++] = m;
                                break;
index e494079d63732dcb423a780d7766088c19e77959..4d4d39342bf75d1c08aa6047337e7ccb820bebc9 100644 (file)
@@ -3306,12 +3306,12 @@ counter_args            :       counter_arg
 
 counter_arg            :       PACKETS                 NUM
                        {
-                               assert($<stmt>0->ops->type == STMT_COUNTER);
+                               assert($<stmt>0->type == STMT_COUNTER);
                                $<stmt>0->counter.packets = $2;
                        }
                        |       BYTES                   NUM
                        {
-                               assert($<stmt>0->ops->type == STMT_COUNTER);
+                               assert($<stmt>0->type == STMT_COUNTER);
                                $<stmt>0->counter.bytes  = $2;
                        }
                        ;
index 50b5acc9a9271d9c9669c0b4b26f554739305393..a38f5bf730d156ec2a4c51fa26f3a16f0b5b46c8 100644 (file)
@@ -378,7 +378,7 @@ static void payload_stmt_destroy(struct stmt *stmt)
        expr_free(stmt->payload.val);
 }
 
-static const struct stmt_ops payload_stmt_ops = {
+const struct stmt_ops payload_stmt_ops = {
        .type           = STMT_PAYLOAD,
        .name           = "payload",
        .print          = payload_stmt_print,
@@ -1198,7 +1198,7 @@ bool stmt_payload_expr_trim(struct stmt *stmt, const struct proto_ctx *pctx)
        mpz_t bitmask, tmp, tmp2;
        unsigned long n;
 
-       assert(stmt->ops->type == STMT_PAYLOAD);
+       assert(stmt->type == STMT_PAYLOAD);
        assert(expr->etype == EXPR_BINOP);
 
        payload = expr->left;
index 9c317934139c3eb3c72c6484bbf13d811d409020..3edfa4715853d9fd4d25dac19b28c0e3f1dc1300 100644 (file)
@@ -494,10 +494,12 @@ void rule_free(struct rule *rule)
 
 void rule_print(const struct rule *rule, struct output_ctx *octx)
 {
+       const struct stmt_ops *ops;
        const struct stmt *stmt;
 
        list_for_each_entry(stmt, &rule->stmts, list) {
-               stmt->ops->print(stmt, octx);
+               ops = stmt_ops(stmt);
+               ops->print(stmt, octx);
                if (!list_is_last(&stmt->list, &rule->stmts))
                        nft_print(octx, " ");
        }
@@ -2741,7 +2743,7 @@ static void stmt_reduce(const struct rule *rule)
                }
 
                /* Must not merge across other statements */
-               if (stmt->ops->type != STMT_EXPRESSION) {
+               if (stmt->type != STMT_EXPRESSION) {
                        if (idx >= 2)
                                payload_do_merge(sa, idx);
                        idx = 0;
index 551cd13fa04ebd4ee03694cb3b29f7ba464c8d10..695b57a6cc650aa57172f725279e8908c4b74e91 100644 (file)
 #include <linux/netfilter/nf_log.h>
 #include <linux/netfilter/nf_synproxy.h>
 
-struct stmt *stmt_alloc(const struct location *loc,
-                       const struct stmt_ops *ops)
+struct stmt *stmt_alloc(const struct location *loc, const struct stmt_ops *ops)
 {
        struct stmt *stmt;
 
        stmt = xzalloc(sizeof(*stmt));
        init_list_head(&stmt->list);
        stmt->location = *loc;
-       stmt->ops      = ops;
+       stmt->type = ops->type;
        return stmt;
 }
 
 void stmt_free(struct stmt *stmt)
 {
+       const struct stmt_ops *ops;
+
        if (stmt == NULL)
                return;
-       if (stmt->ops->destroy)
-               stmt->ops->destroy(stmt);
+
+       ops = stmt_ops(stmt);
+       if (ops->destroy)
+               ops->destroy(stmt);
        free(stmt);
 }
 
@@ -66,7 +69,9 @@ void stmt_list_free(struct list_head *list)
 
 void stmt_print(const struct stmt *stmt, struct output_ctx *octx)
 {
-       stmt->ops->print(stmt, octx);
+       const struct stmt_ops *ops = stmt_ops(stmt);
+
+       ops->print(stmt, octx);
 }
 
 static void expr_stmt_print(const struct stmt *stmt, struct output_ctx *octx)
@@ -1079,3 +1084,59 @@ struct stmt *synproxy_stmt_alloc(const struct location *loc)
 {
        return stmt_alloc(loc, &synproxy_stmt_ops);
 }
+
+/* For src/optimize.c */
+static struct stmt_ops invalid_stmt_ops = {
+       .type   = STMT_INVALID,
+       .name   = "unsupported",
+};
+
+static const struct stmt_ops *__stmt_ops_by_type(enum stmt_types type)
+{
+       switch (type) {
+       case STMT_INVALID: return &invalid_stmt_ops;
+       case STMT_EXPRESSION: return &expr_stmt_ops;
+       case STMT_VERDICT: return &verdict_stmt_ops;
+       case STMT_METER: return &meter_stmt_ops;
+       case STMT_COUNTER: return &counter_stmt_ops;
+       case STMT_PAYLOAD: return &payload_stmt_ops;
+       case STMT_META: return &meta_stmt_ops;
+       case STMT_LIMIT: return &limit_stmt_ops;
+       case STMT_LOG: return &log_stmt_ops;
+       case STMT_REJECT: return &reject_stmt_ops;
+       case STMT_NAT: return &nat_stmt_ops;
+       case STMT_TPROXY: return &tproxy_stmt_ops;
+       case STMT_QUEUE: return &queue_stmt_ops;
+       case STMT_CT: return &ct_stmt_ops;
+       case STMT_SET: return &set_stmt_ops;
+       case STMT_DUP: return &dup_stmt_ops;
+       case STMT_FWD: return &fwd_stmt_ops;
+       case STMT_XT: return &xt_stmt_ops;
+       case STMT_QUOTA: return &quota_stmt_ops;
+       case STMT_NOTRACK: return &notrack_stmt_ops;
+       case STMT_OBJREF: return &objref_stmt_ops;
+       case STMT_EXTHDR: return &exthdr_stmt_ops;
+       case STMT_FLOW_OFFLOAD: return &flow_offload_stmt_ops;
+       case STMT_CONNLIMIT: return &connlimit_stmt_ops;
+       case STMT_MAP: return &map_stmt_ops;
+       case STMT_SYNPROXY: return &synproxy_stmt_ops;
+       case STMT_CHAIN: return &chain_stmt_ops;
+       case STMT_OPTSTRIP: return &optstrip_stmt_ops;
+       case STMT_LAST: return &last_stmt_ops;
+       default:
+               break;
+       }
+
+       return NULL;
+}
+
+const struct stmt_ops *stmt_ops(const struct stmt *stmt)
+{
+       const struct stmt_ops *ops;
+
+       ops = __stmt_ops_by_type(stmt->type);
+       if (!ops)
+               BUG("Unknown statement type %d\n", stmt->type);
+
+       return ops;
+}