]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
optimize: merge nat rules with same selectors into map
authorPablo Neira Ayuso <pablo@netfilter.org>
Tue, 3 May 2022 15:51:36 +0000 (17:51 +0200)
committerPablo Neira Ayuso <pablo@netfilter.org>
Tue, 3 May 2022 21:45:21 +0000 (23:45 +0200)
Verdict and nat are mutually exclusive, no need to support for this
combination.

 # cat ruleset.nft
 table ip x {
        chain y {
type nat hook postrouting priority srcnat; policy drop;
                ip saddr 1.1.1.1 tcp dport 8000 snat to 4.4.4.4:80
                ip saddr 2.2.2.2 tcp dport 8001 snat to 5.5.5.5:90
        }
 }

 # nft -o -c -f ruleset.nft
 Merging:
 ruleset.nft:4:3-52:                ip saddr 1.1.1.1 tcp dport 8000 snat to 4.4.4.4:80
 ruleset.nft:5:3-52:                ip saddr 2.2.2.2 tcp dport 8001 snat to 5.5.5.5:90
 into:
        snat to ip saddr . tcp dport map { 1.1.1.1 . 8000 : 4.4.4.4 . 80, 2.2.2.2 . 8001 : 5.5.5.5 . 90 }

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
src/optimize.c
tests/shell/testcases/optimizations/dumps/merge_nat.nft [new file with mode: 0644]
tests/shell/testcases/optimizations/merge_nat [new file with mode: 0755]

index 1308e14285dc837b5c7536d712a0ed6178175984..cb3fff21b1256706d76b785f2ab8ce009198159e 100644 (file)
@@ -173,6 +173,26 @@ static bool __stmt_type_eq(const struct stmt *stmt_a, const struct stmt *stmt_b)
                    stmt_a->reject.icmp_code != stmt_b->reject.icmp_code)
                        return false;
                break;
+       case STMT_NAT:
+               if (stmt_a->nat.type != stmt_b->nat.type ||
+                   stmt_a->nat.flags != stmt_b->nat.flags ||
+                   stmt_a->nat.family != stmt_b->nat.family ||
+                   stmt_a->nat.type_flags != stmt_b->nat.type_flags ||
+                   (stmt_a->nat.addr &&
+                    stmt_a->nat.addr->etype != EXPR_SYMBOL &&
+                    stmt_a->nat.addr->etype != EXPR_RANGE) ||
+                   (stmt_b->nat.addr &&
+                    stmt_b->nat.addr->etype != EXPR_SYMBOL &&
+                    stmt_b->nat.addr->etype != EXPR_RANGE) ||
+                   (stmt_a->nat.proto &&
+                    stmt_a->nat.proto->etype != EXPR_SYMBOL &&
+                    stmt_a->nat.proto->etype != EXPR_RANGE) ||
+                   (stmt_b->nat.proto &&
+                    stmt_b->nat.proto->etype != EXPR_SYMBOL &&
+                    stmt_b->nat.proto->etype != EXPR_RANGE))
+                       return false;
+
+               return true;
        default:
                /* ... Merging anything else is yet unsupported. */
                return false;
@@ -273,6 +293,16 @@ static int rule_collect_stmts(struct optimize_ctx *ctx, struct rule *rule)
                        if (stmt->log.prefix)
                                clone->log.prefix = expr_get(stmt->log.prefix);
                        break;
+               case STMT_NAT:
+                       clone->nat.type = stmt->nat.type;
+                       clone->nat.family = stmt->nat.family;
+                       if (stmt->nat.addr)
+                               clone->nat.addr = expr_clone(stmt->nat.addr);
+                       if (stmt->nat.proto)
+                               clone->nat.proto = expr_clone(stmt->nat.proto);
+                       clone->nat.flags = stmt->nat.flags;
+                       clone->nat.type_flags = stmt->nat.type_flags;
+                       break;
                default:
                        continue;
                }
@@ -632,6 +662,129 @@ static bool stmt_verdict_cmp(const struct optimize_ctx *ctx,
        return true;
 }
 
+static int stmt_nat_find(const struct optimize_ctx *ctx)
+{
+       uint32_t i;
+
+       for (i = 0; i < ctx->num_stmts; i++) {
+               if (ctx->stmt[i]->ops->type != STMT_NAT)
+                       continue;
+
+               return i;
+       }
+
+       return -1;
+}
+
+static struct expr *stmt_nat_expr(struct stmt *nat_stmt)
+{
+       struct expr *nat_expr;
+
+       if (nat_stmt->nat.proto) {
+               nat_expr = concat_expr_alloc(&internal_location);
+               compound_expr_add(nat_expr, expr_get(nat_stmt->nat.addr));
+               compound_expr_add(nat_expr, expr_get(nat_stmt->nat.proto));
+               expr_free(nat_stmt->nat.proto);
+               nat_stmt->nat.proto = NULL;
+       } else {
+               nat_expr = expr_get(nat_stmt->nat.addr);
+       }
+
+       return nat_expr;
+}
+
+static void merge_nat(const struct optimize_ctx *ctx,
+                     uint32_t from, uint32_t to,
+                     const struct merge *merge)
+{
+       struct expr *expr, *set, *elem, *nat_expr, *mapping, *left;
+       struct stmt *stmt, *nat_stmt;
+       uint32_t i;
+       int k;
+
+       k = stmt_nat_find(ctx);
+       assert(k >= 0);
+
+       set = set_expr_alloc(&internal_location, NULL);
+       set->set_flags |= NFT_SET_ANONYMOUS;
+
+       for (i = from; i <= to; i++) {
+               stmt = ctx->stmt_matrix[i][merge->stmt[0]];
+               expr = stmt->expr->right;
+
+               nat_stmt = ctx->stmt_matrix[i][k];
+               nat_expr = stmt_nat_expr(nat_stmt);
+
+               elem = set_elem_expr_alloc(&internal_location, expr_get(expr));
+               mapping = mapping_expr_alloc(&internal_location, elem, nat_expr);
+               compound_expr_add(set, mapping);
+       }
+
+       stmt = ctx->stmt_matrix[from][merge->stmt[0]];
+       left = expr_get(stmt->expr->left);
+       expr = map_expr_alloc(&internal_location, left, set);
+
+       nat_stmt = ctx->stmt_matrix[from][k];
+       expr_free(nat_stmt->nat.addr);
+       nat_stmt->nat.addr = expr;
+
+       remove_counter(ctx, from);
+       list_del(&stmt->list);
+       stmt_free(stmt);
+}
+
+static void merge_concat_nat(const struct optimize_ctx *ctx,
+                            uint32_t from, uint32_t to,
+                            const struct merge *merge)
+{
+       struct expr *expr, *set, *elem, *nat_expr, *mapping, *left, *concat;
+       struct stmt *stmt, *nat_stmt;
+       uint32_t i, j;
+       int k;
+
+       k = stmt_nat_find(ctx);
+       assert(k >= 0);
+
+       set = set_expr_alloc(&internal_location, NULL);
+       set->set_flags |= NFT_SET_ANONYMOUS;
+
+       for (i = from; i <= to; i++) {
+
+               concat = concat_expr_alloc(&internal_location);
+               for (j = 0; j < merge->num_stmts; j++) {
+                       stmt = ctx->stmt_matrix[i][merge->stmt[j]];
+                       expr = stmt->expr->right;
+                       compound_expr_add(concat, expr_get(expr));
+               }
+
+               nat_stmt = ctx->stmt_matrix[i][k];
+               nat_expr = stmt_nat_expr(nat_stmt);
+
+               elem = set_elem_expr_alloc(&internal_location, concat);
+               mapping = mapping_expr_alloc(&internal_location, elem, nat_expr);
+               compound_expr_add(set, mapping);
+       }
+
+       concat = concat_expr_alloc(&internal_location);
+       for (j = 0; j < merge->num_stmts; j++) {
+               stmt = ctx->stmt_matrix[from][merge->stmt[j]];
+               left = stmt->expr->left;
+               compound_expr_add(concat, expr_get(left));
+       }
+       expr = map_expr_alloc(&internal_location, concat, set);
+
+       nat_stmt = ctx->stmt_matrix[from][k];
+       expr_free(nat_stmt->nat.addr);
+       nat_stmt->nat.addr = expr;
+
+       remove_counter(ctx, from);
+       for (j = 0; j < merge->num_stmts; j++) {
+               stmt = ctx->stmt_matrix[from][merge->stmt[j]];
+               list_del(&stmt->list);
+               stmt_free(stmt);
+       }
+}
+
 static void rule_optimize_print(struct output_ctx *octx,
                                const struct rule *rule)
 {
@@ -665,26 +818,57 @@ static void rule_optimize_print(struct output_ctx *octx,
        fprintf(octx->error_fp, "%s\n", line);
 }
 
+static enum stmt_types merge_stmt_type(const struct optimize_ctx *ctx)
+{
+       uint32_t i;
+
+       for (i = 0; i < ctx->num_stmts; i++) {
+               switch (ctx->stmt[i]->ops->type) {
+               case STMT_VERDICT:
+               case STMT_NAT:
+                       return ctx->stmt[i]->ops->type;
+               default:
+                       continue;
+               }
+       }
+
+       return STMT_INVALID;
+}
+
 static void merge_rules(const struct optimize_ctx *ctx,
                        uint32_t from, uint32_t to,
                        const struct merge *merge,
                        struct output_ctx *octx)
 {
+       enum stmt_types stmt_type;
        bool same_verdict;
        uint32_t i;
 
-       same_verdict = stmt_verdict_cmp(ctx, from, to);
+       stmt_type = merge_stmt_type(ctx);
 
-       if (merge->num_stmts > 1) {
-               if (same_verdict)
-                       merge_concat_stmts(ctx, from, to, merge);
-               else
-                       merge_concat_stmts_vmap(ctx, from, to, merge);
-       } else {
-               if (same_verdict)
-                       merge_stmts(ctx, from, to, merge);
+       switch (stmt_type) {
+       case STMT_VERDICT:
+               same_verdict = stmt_verdict_cmp(ctx, from, to);
+               if (merge->num_stmts > 1) {
+                       if (same_verdict)
+                               merge_concat_stmts(ctx, from, to, merge);
+                       else
+                               merge_concat_stmts_vmap(ctx, from, to, merge);
+               } else {
+                       if (same_verdict)
+                               merge_stmts(ctx, from, to, merge);
+                       else
+                               merge_stmts_vmap(ctx, from, to, merge);
+               }
+               break;
+       case STMT_NAT:
+               if (merge->num_stmts > 1)
+                       merge_concat_nat(ctx, from, to, merge);
                else
-                       merge_stmts_vmap(ctx, from, to, merge);
+                       merge_nat(ctx, from, to, merge);
+               break;
+       default:
+               assert(0);
        }
 
        fprintf(octx->error_fp, "Merging:\n");
diff --git a/tests/shell/testcases/optimizations/dumps/merge_nat.nft b/tests/shell/testcases/optimizations/dumps/merge_nat.nft
new file mode 100644 (file)
index 0000000..7a6ecb7
--- /dev/null
@@ -0,0 +1,20 @@
+table ip test1 {
+       chain y {
+               dnat to ip saddr map { 4.4.4.4 : 1.1.1.1, 5.5.5.5 : 2.2.2.2 }
+       }
+}
+table ip test2 {
+       chain y {
+               dnat ip to tcp dport map { 80 : 1.1.1.1 . 8001, 81 : 2.2.2.2 . 9001 }
+       }
+}
+table ip test3 {
+       chain y {
+               snat to ip saddr . tcp sport map { 1.1.1.1 . 1024-65535 : 3.3.3.3, 2.2.2.2 . 1024-65535 : 4.4.4.4 }
+       }
+}
+table ip test4 {
+       chain y {
+               dnat ip to ip daddr . tcp dport map { 1.1.1.1 . 80 : 4.4.4.4 . 8000, 2.2.2.2 . 81 : 3.3.3.3 . 9000 }
+       }
+}
diff --git a/tests/shell/testcases/optimizations/merge_nat b/tests/shell/testcases/optimizations/merge_nat
new file mode 100755 (executable)
index 0000000..290cfcf
--- /dev/null
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+set -e
+
+RULESET="table ip test1 {
+        chain y {
+                ip saddr 4.4.4.4 dnat to 1.1.1.1
+                ip saddr 5.5.5.5 dnat to 2.2.2.2
+        }
+}"
+
+$NFT -o -f - <<< $RULESET
+
+RULESET="table ip test2 {
+        chain y {
+                tcp dport 80 dnat to 1.1.1.1:8001
+                tcp dport 81 dnat to 2.2.2.2:9001
+        }
+}"
+
+$NFT -o -f - <<< $RULESET
+
+RULESET="table ip test3 {
+        chain y {
+                ip saddr 1.1.1.1 tcp sport 1024-65535 snat to 3.3.3.3
+                ip saddr 2.2.2.2 tcp sport 1024-65535 snat to 4.4.4.4
+        }
+}"
+
+$NFT -o -f - <<< $RULESET
+
+RULESET="table ip test4 {
+        chain y {
+                ip daddr 1.1.1.1 tcp dport 80 dnat to 4.4.4.4:8000
+                ip daddr 2.2.2.2 tcp dport 81 dnat to 3.3.3.3:9000
+        }
+}"
+
+$NFT -o -f - <<< $RULESET