]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
src: add ruleset optimization infrastructure
authorPablo Neira Ayuso <pablo@netfilter.org>
Sun, 2 Jan 2022 20:46:21 +0000 (21:46 +0100)
committerPablo Neira Ayuso <pablo@netfilter.org>
Sat, 15 Jan 2022 17:11:22 +0000 (18:11 +0100)
This patch adds a new -o/--optimize option to enable ruleset
optimization.

You can combine this option with the dry run mode (--check) to review
the proposed ruleset updates without actually loading the ruleset, e.g.

 # nft -c -o -f ruleset.test
 Merging:
 ruleset.nft:16:3-37:           ip daddr 192.168.0.1 counter accept
 ruleset.nft:17:3-37:           ip daddr 192.168.0.2 counter accept
 ruleset.nft:18:3-37:           ip daddr 192.168.0.3 counter accept
 into:
        ip daddr { 192.168.0.1, 192.168.0.2, 192.168.0.3 } counter packets 0 bytes 0 accept

This infrastructure collects the common statements that are used in
rules, then it builds a matrix of rules vs. statements. Then, it looks
for common statements in consecutive rules which allows to merge rules.

This ruleset optimization always performs an implicit dry run to
validate that the original ruleset is correct. Then, on a second pass,
it performs the ruleset optimization and add the rules into the kernel
(unless --check has been specified by the user).

From libnftables perspective, there is a new API to enable
this feature:

  uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx);
  void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags);

This patch adds support for the first optimization: Collapse a linear
list of rules matching on a single selector into a set as exposed in the
example above.

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
doc/nft.txt
include/nftables.h
include/nftables/libnftables.h
src/Makefile.am
src/libnftables.c
src/libnftables.map
src/main.c
src/optimize.c [new file with mode: 0644]
tests/shell/testcases/optimizations/dumps/merge_stmts.nft [new file with mode: 0644]
tests/shell/testcases/optimizations/merge_stmts [new file with mode: 0755]

index e4ed982410185456c714ced0995c3ce5f0729bad..7240deaa810026c6015ae7bfec67d7204960257e 100644 (file)
@@ -62,6 +62,11 @@ understanding of their meaning. You can get information about options by running
 *--check*::
        Check commands validity without actually applying the changes.
 
+*-o*::
+*--optimize*::
+       Optimize your ruleset. You can combine this option with '-c' to inspect
+        the proposed optimizations.
+
 .Ruleset list output formatting that modify the output of the list ruleset command:
 
 *-a*::
index d6d9b9cc7206e47a04babdc87eddb05186e826c7..d49eb579dc0489303f0a8540bac633cb19782dac 100644 (file)
@@ -123,6 +123,7 @@ struct nft_ctx {
        bool                    check;
        struct nft_cache        cache;
        uint32_t                flags;
+       uint32_t                optimize_flags;
        struct parser_state     *state;
        void                    *scanner;
        struct scope            *top_scope;
@@ -224,6 +225,8 @@ int nft_print(struct output_ctx *octx, const char *fmt, ...)
 int nft_gmp_print(struct output_ctx *octx, const char *fmt, ...)
        __attribute__((format(printf, 2, 0)));
 
+int nft_optimize(struct nft_ctx *nft, struct list_head *cmds);
+
 #define __NFT_OUTPUT_NOTSUPP   UINT_MAX
 
 #endif /* NFTABLES_NFTABLES_H */
index 957b5fbee243d505f72d2fb6365a0c6aa847b00b..85e08c9bc98b2a22bbce860cd6183db2199a2ab5 100644 (file)
@@ -41,6 +41,13 @@ void nft_ctx_free(struct nft_ctx *ctx);
 bool nft_ctx_get_dry_run(struct nft_ctx *ctx);
 void nft_ctx_set_dry_run(struct nft_ctx *ctx, bool dry);
 
+enum nft_optimize_flags {
+       NFT_OPTIMIZE_ENABLED            = 0x1,
+};
+
+uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx);
+void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags);
+
 enum {
        NFT_CTX_OUTPUT_REVERSEDNS       = (1 << 0),
        NFT_CTX_OUTPUT_SERVICE          = (1 << 1),
index 6ab0752337b28cd25d59426ffbe4ed29f29e3475..4cfba0af8bfa3d2aa154a62644ac637cb327a392 100644 (file)
@@ -68,6 +68,7 @@ libnftables_la_SOURCES =                      \
                mnl.c                           \
                iface.c                         \
                mergesort.c                     \
+               optimize.c                      \
                osf.c                           \
                nfnl_osf.c                      \
                tcpopt.c                        \
index e76f32eff7cac18e843fd5c3553aefc3291b9223..bd71ae9e704ff1859952d28723ce7774820ba2f0 100644 (file)
@@ -395,6 +395,18 @@ void nft_ctx_set_dry_run(struct nft_ctx *ctx, bool dry)
        ctx->check = dry;
 }
 
+EXPORT_SYMBOL(nft_ctx_get_optimize);
+uint32_t nft_ctx_get_optimize(struct nft_ctx *ctx)
+{
+       return ctx->optimize_flags;
+}
+
+EXPORT_SYMBOL(nft_ctx_set_optimize);
+void nft_ctx_set_optimize(struct nft_ctx *ctx, uint32_t flags)
+{
+       ctx->optimize_flags = flags;
+}
+
 EXPORT_SYMBOL(nft_ctx_output_get_flags);
 unsigned int nft_ctx_output_get_flags(struct nft_ctx *ctx)
 {
@@ -626,8 +638,7 @@ retry:
        return rc;
 }
 
-EXPORT_SYMBOL(nft_run_cmd_from_filename);
-int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
+static int __nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
 {
        struct cmd *cmd, *next;
        int rc, parser_rc;
@@ -638,13 +649,6 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
        if (rc < 0)
                goto err;
 
-       if (!strcmp(filename, "-"))
-               filename = "/dev/stdin";
-
-       if (!strcmp(filename, "/dev/stdin") &&
-           !nft_output_json(&nft->output))
-               nft->stdin_buf = stdin_to_buffer();
-
        rc = -EINVAL;
        if (nft_output_json(&nft->output))
                rc = nft_parse_json_filename(nft, filename, &msgs, &cmds);
@@ -653,6 +657,9 @@ int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
 
        parser_rc = rc;
 
+       if (nft->optimize_flags)
+               nft_optimize(nft, &cmds);
+
        rc = nft_evaluate(nft, &msgs, &cmds);
        if (rc < 0)
                goto err;
@@ -694,7 +701,51 @@ err:
        if (rc)
                nft_cache_release(&nft->cache);
 
+       return rc;
+}
+
+static int nft_run_optimized_file(struct nft_ctx *nft, const char *filename)
+{
+       uint32_t optimize_flags;
+       bool check;
+       int ret;
+
+       check = nft->check;
+       nft->check = true;
+       optimize_flags = nft->optimize_flags;
+       nft->optimize_flags = 0;
+
+       /* First check the original ruleset loads fine as is. */
+       ret = __nft_run_cmd_from_filename(nft, filename);
+       if (ret < 0)
+               return ret;
+
+       nft->check = check;
+       nft->optimize_flags = optimize_flags;
+
+       return __nft_run_cmd_from_filename(nft, filename);
+}
+
+EXPORT_SYMBOL(nft_run_cmd_from_filename);
+int nft_run_cmd_from_filename(struct nft_ctx *nft, const char *filename)
+{
+       int ret;
+
+       if (!strcmp(filename, "-"))
+               filename = "/dev/stdin";
+
+       if (!strcmp(filename, "/dev/stdin") &&
+           !nft_output_json(&nft->output))
+               nft->stdin_buf = stdin_to_buffer();
+
+       if (nft->optimize_flags) {
+               ret = nft_run_optimized_file(nft, filename);
+               xfree(nft->stdin_buf);
+               return ret;
+       }
+
+       ret = __nft_run_cmd_from_filename(nft, filename);
        xfree(nft->stdin_buf);
 
-       return rc;
+       return ret;
 }
index d3a795ce8567cbc1e471ad74a38965afcc832a72..a511dd789154ba78c8b883053488e21734bca874 100644 (file)
@@ -28,3 +28,8 @@ LIBNFTABLES_2 {
   nft_ctx_add_var;
   nft_ctx_clear_vars;
 } LIBNFTABLES_1;
+
+LIBNFTABLES_3 {
+  nft_set_optimize;
+  nft_get_optimize;
+} LIBNFTABLES_2;
index 5847fc4ad5146338991ff2fb99db5b3d2cb15ca2..9bd25db82343f0b7427d53691406d22c1c6b3d69 100644 (file)
@@ -36,7 +36,8 @@ enum opt_indices {
        IDX_INTERACTIVE,
         IDX_INCLUDEPATH,
        IDX_CHECK,
-#define IDX_RULESET_INPUT_END  IDX_CHECK
+       IDX_OPTIMIZE,
+#define IDX_RULESET_INPUT_END  IDX_OPTIMIZE
         /* Ruleset list formatting */
         IDX_HANDLE,
 #define IDX_RULESET_LIST_START IDX_HANDLE
@@ -80,6 +81,7 @@ enum opt_vals {
        OPT_NUMERIC_PROTO       = 'p',
        OPT_NUMERIC_TIME        = 'T',
        OPT_TERSE               = 't',
+       OPT_OPTIMIZE            = 'o',
        OPT_INVALID             = '?',
 };
 
@@ -136,6 +138,8 @@ static const struct nft_opt nft_options[] = {
                                     "Format output in JSON"),
        [IDX_DEBUG]         = NFT_OPT("debug",                  OPT_DEBUG,              "<level [,level...]>",
                                     "Specify debugging level (scanner, parser, eval, netlink, mnl, proto-ctx, segtree, all)"),
+       [IDX_OPTIMIZE]      = NFT_OPT("optimize",               OPT_OPTIMIZE,           NULL,
+                                    "Optimize ruleset"),
 };
 
 #define NR_NFT_OPTIONS (sizeof(nft_options) / sizeof(nft_options[0]))
@@ -484,6 +488,9 @@ int main(int argc, char * const *argv)
                case OPT_TERSE:
                        output_flags |= NFT_CTX_OUTPUT_TERSE;
                        break;
+               case OPT_OPTIMIZE:
+                       nft_ctx_set_optimize(nft, 0x1);
+                       break;
                case OPT_INVALID:
                        exit(EXIT_FAILURE);
                }
diff --git a/src/optimize.c b/src/optimize.c
new file mode 100644 (file)
index 0000000..bae36d7
--- /dev/null
@@ -0,0 +1,478 @@
+/*
+ * Copyright (c) 2021 Pablo Neira Ayuso <pablo@netfilter.org>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License version 2 (or any
+ * later) as published by the Free Software Foundation.
+ */
+
+/* Funded through the NGI0 PET Fund established by NLnet (https://nlnet.nl)
+ * with support from the European Commission's Next Generation Internet
+ * programme.
+ */
+
+#define _GNU_SOURCE
+#include <string.h>
+#include <errno.h>
+#include <inttypes.h>
+#include <nftables.h>
+#include <parser.h>
+#include <expression.h>
+#include <statement.h>
+#include <utils.h>
+#include <erec.h>
+
+#define MAX_STMTS      32
+
+struct optimize_ctx {
+       struct stmt *stmt[MAX_STMTS];
+       uint32_t num_stmts;
+
+       struct stmt ***stmt_matrix;
+       struct rule **rule;
+       uint32_t num_rules;
+};
+
+static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+       struct expr *expr_a, *expr_b;
+
+       if (stmt_a->ops->type != stmt_b->ops->type)
+               return false;
+
+       switch (stmt_a->ops->type) {
+       case STMT_EXPRESSION:
+               expr_a = stmt_a->expr;
+               expr_b = stmt_b->expr;
+
+               if (expr_a->left->etype != expr_b->left->etype)
+                       return false;
+
+               switch (expr_a->left->etype) {
+               case EXPR_PAYLOAD:
+                       if (expr_a->left->payload.desc != expr_b->left->payload.desc)
+                               return false;
+                       if (expr_a->left->payload.tmpl != expr_b->left->payload.tmpl)
+                               return false;
+                       break;
+               case EXPR_EXTHDR:
+                       if (expr_a->left->exthdr.desc != expr_b->left->exthdr.desc)
+                               return false;
+                       if (expr_a->left->exthdr.tmpl != expr_b->left->exthdr.tmpl)
+                               return false;
+                       break;
+               case EXPR_META:
+                       if (expr_a->left->meta.key != expr_b->left->meta.key)
+                               return false;
+                       if (expr_a->left->meta.base != expr_b->left->meta.base)
+                               return false;
+                       break;
+               case EXPR_CT:
+                       if (expr_a->left->ct.key != expr_b->left->ct.key)
+                               return false;
+                       if (expr_a->left->ct.base != expr_b->left->ct.base)
+                               return false;
+                       if (expr_a->left->ct.direction != expr_b->left->ct.direction)
+                               return false;
+                       if (expr_a->left->ct.nfproto != expr_b->left->ct.nfproto)
+                               return false;
+                       break;
+               case EXPR_RT:
+                       if (expr_a->left->rt.key != expr_b->left->rt.key)
+                               return false;
+                       break;
+               case EXPR_SOCKET:
+                       if (expr_a->left->socket.key != expr_b->left->socket.key)
+                               return false;
+                       if (expr_a->left->socket.level != expr_b->left->socket.level)
+                               return false;
+                       break;
+               default:
+                       return false;
+               }
+               break;
+       case STMT_COUNTER:
+       case STMT_NOTRACK:
+               break;
+       case STMT_VERDICT:
+               expr_a = stmt_a->expr;
+               expr_b = stmt_b->expr;
+               if (expr_a->verdict != expr_b->verdict)
+                       return false;
+               if (expr_a->chain && expr_b->chain) {
+                       if (expr_a->chain->etype != expr_b->chain->etype)
+                               return false;
+                       if (expr_a->chain->etype == EXPR_VALUE &&
+                           strcmp(expr_a->chain->identifier, expr_b->chain->identifier))
+                               return false;
+               } else if (expr_a->chain || expr_b->chain) {
+                       return false;
+               }
+               break;
+       case STMT_LIMIT:
+               if (stmt_a->limit.rate != stmt_b->limit.rate ||
+                   stmt_a->limit.unit != stmt_b->limit.unit ||
+                   stmt_a->limit.burst != stmt_b->limit.burst ||
+                   stmt_a->limit.type != stmt_b->limit.type ||
+                   stmt_a->limit.flags != stmt_b->limit.flags)
+                       return false;
+               break;
+       case STMT_LOG:
+               if (stmt_a->log.snaplen != stmt_b->log.snaplen ||
+                   stmt_a->log.group != stmt_b->log.group ||
+                   stmt_a->log.qthreshold != stmt_b->log.qthreshold ||
+                   stmt_a->log.level != stmt_b->log.level ||
+                   stmt_a->log.logflags != stmt_b->log.logflags ||
+                   stmt_a->log.flags != stmt_b->log.flags ||
+                   stmt_a->log.prefix->etype != EXPR_VALUE ||
+                   stmt_b->log.prefix->etype != EXPR_VALUE ||
+                   mpz_cmp(stmt_a->log.prefix->value, stmt_b->log.prefix->value))
+                       return false;
+               break;
+       case STMT_REJECT:
+               if (stmt_a->reject.expr || stmt_b->reject.expr)
+                       return false;
+
+               if (stmt_a->reject.family != stmt_b->reject.family ||
+                   stmt_a->reject.type != stmt_b->reject.type ||
+                   stmt_a->reject.icmp_code != stmt_b->reject.icmp_code)
+                       return false;
+               break;
+       default:
+               /* ... Merging anything else is yet unsupported. */
+               return false;
+       }
+
+       return true;
+}
+
+static bool stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
+{
+       if (!stmt_a && !stmt_b)
+               return true;
+       else if (!stmt_a)
+               return false;
+       else if (!stmt_b)
+               return false;
+
+       return __stmt_type_eq(stmt_a, stmt_b);
+}
+
+static bool stmt_type_find(struct optimize_ctx *ctx, const struct stmt *stmt)
+{
+       uint32_t i;
+
+       for (i = 0; i < ctx->num_stmts; i++) {
+               if (__stmt_type_eq(stmt, ctx->stmt[i]))
+                       return true;
+       }
+
+       return false;
+}
+
+static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
+{
+       struct stmt *stmt, *clone;
+
+       list_for_each_entry(stmt, &rule->stmts, list) {
+               if (stmt_type_find(ctx, stmt))
+                       continue;
+
+               /* No refcounter available in statement objects, clone it to
+                * to store in the array of selectors.
+                */
+               clone = stmt_alloc(&internal_location, stmt->ops);
+               switch (stmt->ops->type) {
+               case STMT_EXPRESSION:
+               case STMT_VERDICT:
+                       clone->expr = expr_get(stmt->expr);
+                       break;
+               case STMT_COUNTER:
+               case STMT_NOTRACK:
+                       break;
+               case STMT_LIMIT:
+                       memcpy(&clone->limit, &stmt->limit, sizeof(clone->limit));
+                       break;
+               case STMT_LOG:
+                       memcpy(&clone->log, &stmt->log, sizeof(clone->log));
+                       clone->log.prefix = expr_get(stmt->log.prefix);
+                       break;
+               default:
+                       break;
+               }
+
+               ctx->stmt[ctx->num_stmts++] = clone;
+               if (ctx->num_stmts >= MAX_STMTS)
+                       return -1;
+       }
+
+       return 0;
+}
+
+static int cmd_stmt_find_in_stmt_matrix(struct optimize_ctx *ctx, struct stmt *stmt)
+{
+       uint32_t i;
+
+       for (i = 0; i < ctx->num_stmts; i++) {
+               if (__stmt_type_eq(stmt, ctx->stmt[i]))
+                       return i;
+       }
+       /* should not ever happen. */
+       return 0;
+}
+
+static void rule_build_stmt_matrix_stmts(struct optimize_ctx *ctx,
+                                        struct rule *rule, uint32_t *i)
+{
+       struct stmt *stmt;
+       int k;
+
+       list_for_each_entry(stmt, &rule->stmts, list) {
+               k = cmd_stmt_find_in_stmt_matrix(ctx, stmt);
+               ctx->stmt_matrix[*i][k] = stmt;
+       }
+       ctx->rule[(*i)++] = rule;
+}
+
+struct merge {
+       /* interval of rules to be merged */
+       uint32_t        rule_from;
+       uint32_t        num_rules;
+       /* statements to be merged (index relative to statement matrix) */
+       uint32_t        stmt[MAX_STMTS];
+       uint32_t        num_stmts;
+};
+
+static void merge_stmts(const struct optimize_ctx *ctx,
+                       uint32_t from, uint32_t to, const struct merge *merge)
+{
+       struct stmt *stmt_a = ctx->stmt_matrix[from][merge->stmt[0]];
+       struct expr *expr_a, *expr_b, *set, *elem;
+       struct stmt *stmt_b;
+       uint32_t i;
+
+       assert (stmt_a->ops->type == STMT_EXPRESSION);
+
+       set = set_expr_alloc(&internal_location, NULL);
+       set->set_flags |= NFT_SET_ANONYMOUS;
+
+       expr_a = stmt_a->expr->right;
+       elem = set_elem_expr_alloc(&internal_location, expr_get(expr_a));
+       compound_expr_add(set, elem);
+
+       for (i = from + 1; i <= to; i++) {
+               stmt_b = ctx->stmt_matrix[i][merge->stmt[0]];
+               expr_b = stmt_b->expr->right;
+               elem = set_elem_expr_alloc(&internal_location, expr_get(expr_b));
+               compound_expr_add(set, elem);
+       }
+
+       expr_free(stmt_a->expr->right);
+       stmt_a->expr->right = set;
+}
+
+static void rule_optimize_print(struct output_ctx *octx,
+                               const struct rule *rule)
+{
+       const struct location *loc = &rule->location;
+       const struct input_descriptor *indesc = loc->indesc;
+       const char *line;
+       char buf[1024];
+
+       switch (indesc->type) {
+       case INDESC_BUFFER:
+       case INDESC_CLI:
+               line = indesc->data;
+               *strchrnul(line, '\n') = '\0';
+               break;
+       case INDESC_STDIN:
+               line = indesc->data;
+               line += loc->line_offset;
+               *strchrnul(line, '\n') = '\0';
+               break;
+       case INDESC_FILE:
+               line = line_location(indesc, loc, buf, sizeof(buf));
+               break;
+       case INDESC_INTERNAL:
+       case INDESC_NETLINK:
+               break;
+       default:
+               BUG("invalid input descriptor type %u\n", indesc->type);
+       }
+
+       print_location(octx->error_fp, indesc, loc);
+       fprintf(octx->error_fp, "%s\n", line);
+}
+
+static void merge_rules(const struct optimize_ctx *ctx,
+                       uint32_t from, uint32_t to,
+                       const struct merge *merge,
+                       struct output_ctx *octx)
+{
+       uint32_t i;
+
+       if (merge->num_stmts > 1) {
+               return;
+       } else {
+               merge_stmts(ctx, from, to, merge);
+       }
+
+       fprintf(octx->error_fp, "Merging:\n");
+       rule_optimize_print(octx, ctx->rule[from]);
+
+       for (i = from + 1; i <= to; i++) {
+               rule_optimize_print(octx, ctx->rule[i]);
+               list_del(&ctx->rule[i]->list);
+               rule_free(ctx->rule[i]);
+       }
+
+       fprintf(octx->error_fp, "into:\n\t");
+       rule_print(ctx->rule[from], octx);
+       fprintf(octx->error_fp, "\n");
+}
+
+static bool rules_eq(const struct optimize_ctx *ctx, int i, int j)
+{
+       uint32_t k;
+
+       for (k = 0; k < ctx->num_stmts; k++) {
+               if (!stmt_type_eq(ctx->stmt_matrix[i][k], ctx->stmt_matrix[j][k]))
+                       return false;
+       }
+
+       return true;
+}
+
+static int chain_optimize(struct nft_ctx *nft, struct list_head *rules)
+{
+       struct optimize_ctx *ctx;
+       uint32_t num_merges = 0;
+       struct merge *merge;
+       uint32_t i, j, m, k;
+       struct rule *rule;
+       int ret;
+
+       ctx = xzalloc(sizeof(*ctx));
+
+       /* Step 1: collect statements in rules */
+       list_for_each_entry(rule, rules, list) {
+               ret = rule_collect_stmts(ctx, rule);
+               if (ret < 0)
+                       goto err;
+
+               ctx->num_rules++;
+       }
+
+       ctx->rule = xzalloc(sizeof(ctx->rule) * ctx->num_rules);
+       ctx->stmt_matrix = xzalloc(sizeof(struct stmt *) * ctx->num_rules);
+       for (i = 0; i < ctx->num_rules; i++)
+               ctx->stmt_matrix[i] = xzalloc(sizeof(struct stmt *) * MAX_STMTS);
+
+       merge = xzalloc(sizeof(*merge) * ctx->num_rules);
+
+       /* Step 2: Build matrix of statements */
+       i = 0;
+       list_for_each_entry(rule, rules, list)
+               rule_build_stmt_matrix_stmts(ctx, rule, &i);
+
+       /* Step 3: Look for common selectors for possible rule mergers */
+       for (i = 0; i < ctx->num_rules; i++) {
+               for (j = i + 1; j < ctx->num_rules; j++) {
+                       if (!rules_eq(ctx, i, j)) {
+                               if (merge[num_merges].num_rules > 0)
+                                       num_merges++;
+
+                               i = j - 1;
+                               break;
+                       }
+                       if (merge[num_merges].num_rules > 0) {
+                               merge[num_merges].num_rules++;
+                       } else {
+                               merge[num_merges].rule_from = i;
+                               merge[num_merges].num_rules = 2;
+                       }
+               }
+               if (j == ctx->num_rules && merge[num_merges].num_rules > 0) {
+                       num_merges++;
+                       break;
+               }
+       }
+
+       /* Step 4: Infer how to merge the candidate rules */
+       for (k = 0; k < num_merges; k++) {
+               i = merge[k].rule_from;
+
+               for (m = 0; m < ctx->num_stmts; m++) {
+                       if (!ctx->stmt_matrix[i][m])
+                               continue;
+                       switch (ctx->stmt_matrix[i][m]->ops->type) {
+                       case STMT_EXPRESSION:
+                               merge[k].stmt[merge[k].num_stmts++] = m;
+                               break;
+                       default:
+                               break;
+                       }
+               }
+
+               j = merge[k].num_rules - 1;
+               merge_rules(ctx, i, i + j, &merge[k], &nft->output);
+       }
+       ret = 0;
+       for (i = 0; i < ctx->num_rules; i++)
+               xfree(ctx->stmt_matrix[i]);
+
+       xfree(ctx->stmt_matrix);
+       xfree(merge);
+err:
+       for (i = 0; i < ctx->num_stmts; i++)
+               stmt_free(ctx->stmt[i]);
+
+       xfree(ctx->rule);
+       xfree(ctx);
+
+       return ret;
+}
+
+static int cmd_optimize(struct nft_ctx *nft, struct cmd *cmd)
+{
+       struct table *table;
+       struct chain *chain;
+       int ret = 0;
+
+       switch (cmd->obj) {
+       case CMD_OBJ_TABLE:
+               table = cmd->table;
+               if (!table)
+                       break;
+
+               list_for_each_entry(chain, &table->chains, list) {
+                       if (chain->flags & CHAIN_F_HW_OFFLOAD)
+                               continue;
+
+                       chain_optimize(nft, &chain->rules);
+               }
+               break;
+       default:
+               break;
+       }
+
+       return ret;
+}
+
+int nft_optimize(struct nft_ctx *nft, struct list_head *cmds)
+{
+       struct cmd *cmd;
+       int ret;
+
+       list_for_each_entry(cmd, cmds, list) {
+               switch (cmd->op) {
+               case CMD_ADD:
+                       ret = cmd_optimize(nft, cmd);
+                       break;
+               default:
+                       break;
+               }
+       }
+
+       return ret;
+}
diff --git a/tests/shell/testcases/optimizations/dumps/merge_stmts.nft b/tests/shell/testcases/optimizations/dumps/merge_stmts.nft
new file mode 100644 (file)
index 0000000..b56ea3e
--- /dev/null
@@ -0,0 +1,5 @@
+table ip x {
+       chain y {
+               ip daddr { 192.168.0.1, 192.168.0.2, 192.168.0.3 } counter packets 0 bytes 0 accept
+       }
+}
diff --git a/tests/shell/testcases/optimizations/merge_stmts b/tests/shell/testcases/optimizations/merge_stmts
new file mode 100755 (executable)
index 0000000..0c35636
--- /dev/null
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+set -e
+
+RULESET="table ip x {
+       chain y {
+               ip daddr 192.168.0.1 counter accept
+               ip daddr 192.168.0.2 counter accept
+               ip daddr 192.168.0.3 counter accept
+       }
+}"
+
+$NFT -o -f - <<< $RULESET