]> git.ipfire.org Git - thirdparty/nftables.git/commitdiff
src: add eval_proto_ctx()
authorPablo Neira Ayuso <pablo@netfilter.org>
Mon, 2 Jan 2023 14:36:20 +0000 (15:36 +0100)
committerPablo Neira Ayuso <pablo@netfilter.org>
Mon, 2 Jan 2023 14:36:20 +0000 (15:36 +0100)
Add eval_proto_ctx() to access protocol context (struct proto_ctx).
Rename struct proto_ctx field to _pctx to highlight that this field
is internal and the helper function should be used.

This patch comes in preparation for supporting outer and inner
protocol context.

Signed-off-by: Pablo Neira Ayuso <pablo@netfilter.org>
include/proto.h
include/rule.h
src/evaluate.c
src/payload.c

index 35e760c7e16e9b3752936946d703d13d81fbcda7..6a9289b17f0575877541e4e2bec167cc5a5bd734 100644 (file)
@@ -413,4 +413,7 @@ extern const struct datatype icmp6_type_type;
 extern const struct datatype dscp_type;
 extern const struct datatype ecn_type;
 
+struct eval_ctx;
+struct proto_ctx *eval_proto_ctx(struct eval_ctx *ctx);
+
 #endif /* NFTABLES_PROTO_H */
index 00a1bac5a7737b8b432e6b1d39432c60e1680fec..c1b46414880b83db46d590b68c6d39fb37aff0eb 100644 (file)
@@ -769,7 +769,7 @@ struct eval_ctx {
        struct set              *set;
        struct stmt             *stmt;
        struct expr_ctx         ectx;
-       struct proto_ctx        pctx;
+       struct proto_ctx        _pctx;
 };
 
 extern int cmd_evaluate(struct eval_ctx *ctx, struct cmd *cmd);
index 70adb847d47577f0b5559d3ef0823e4d02b3ef27..40e26467fca21705d8089ce3b174f81784cccd6f 100644 (file)
 #include <utils.h>
 #include <xt.h>
 
+struct proto_ctx *eval_proto_ctx(struct eval_ctx *ctx)
+{
+       return &ctx->_pctx;
+}
+
 static int expr_evaluate(struct eval_ctx *ctx, struct expr **expr);
 
 static const char * const byteorder_names[] = {
@@ -448,11 +453,13 @@ conflict_resolution_gen_dependency(struct eval_ctx *ctx, int protocol,
        const struct proto_hdr_template *tmpl;
        const struct proto_desc *desc = NULL;
        struct expr *dep, *left, *right;
+       struct proto_ctx *pctx;
        struct stmt *stmt;
 
        assert(expr->payload.base == PROTO_BASE_LL_HDR);
 
-       desc = ctx->pctx.protocol[base].desc;
+       pctx = eval_proto_ctx(ctx);
+       desc = pctx->protocol[base].desc;
        tmpl = &desc->templates[desc->protocol_key];
        left = payload_expr_alloc(&expr->location, desc, desc->protocol_key);
 
@@ -598,6 +605,7 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp)
        const struct proto_desc *base, *dependency = NULL;
        enum proto_bases pb = PROTO_BASE_NETWORK_HDR;
        struct expr *expr = *exprp;
+       struct proto_ctx *pctx;
        struct stmt *nstmt;
 
        switch (expr->exthdr.op) {
@@ -615,7 +623,8 @@ static int expr_evaluate_exthdr(struct eval_ctx *ctx, struct expr **exprp)
 
        assert(dependency);
 
-       base = ctx->pctx.protocol[pb].desc;
+       pctx = eval_proto_ctx(ctx);
+       base = pctx->protocol[pb].desc;
        if (base == dependency)
                return __expr_evaluate_exthdr(ctx, exprp);
 
@@ -678,8 +687,11 @@ static int resolve_protocol_conflict(struct eval_ctx *ctx,
 {
        enum proto_bases base = payload->payload.base;
        struct stmt *nstmt = NULL;
+       struct proto_ctx *pctx;
        int link, err;
 
+       pctx = eval_proto_ctx(ctx);
+
        if (payload->payload.base == PROTO_BASE_LL_HDR) {
                if (proto_is_dummy(desc)) {
                        err = meta_iiftype_gen_dependency(ctx, payload, &nstmt);
@@ -692,8 +704,8 @@ static int resolve_protocol_conflict(struct eval_ctx *ctx,
                        unsigned int i;
 
                        /* payload desc stored in the L2 header stack? No conflict. */
-                       for (i = 0; i < ctx->pctx.stacked_ll_count; i++) {
-                               if (ctx->pctx.stacked_ll[i] == payload->payload.desc)
+                       for (i = 0; i < pctx->stacked_ll_count; i++) {
+                               if (pctx->stacked_ll[i] == payload->payload.desc)
                                        return 0;
                        }
                }
@@ -701,7 +713,7 @@ static int resolve_protocol_conflict(struct eval_ctx *ctx,
 
        assert(base <= PROTO_BASE_MAX);
        /* This payload and the existing context don't match, conflict. */
-       if (ctx->pctx.protocol[base + 1].desc != NULL)
+       if (pctx->protocol[base + 1].desc != NULL)
                return 1;
 
        link = proto_find_num(desc, payload->payload.desc);
@@ -712,8 +724,8 @@ static int resolve_protocol_conflict(struct eval_ctx *ctx,
        if (base == PROTO_BASE_LL_HDR) {
                unsigned int i;
 
-               for (i = 0; i < ctx->pctx.stacked_ll_count; i++)
-                       payload->payload.offset += ctx->pctx.stacked_ll[i]->length;
+               for (i = 0; i < pctx->stacked_ll_count; i++)
+                       payload->payload.offset += pctx->stacked_ll[i]->length;
        }
 
        rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt);
@@ -731,19 +743,22 @@ static int __expr_evaluate_payload(struct eval_ctx *ctx, struct expr *expr)
        struct expr *payload = expr;
        enum proto_bases base = payload->payload.base;
        const struct proto_desc *desc;
+       struct proto_ctx *pctx;
        struct stmt *nstmt;
        int err;
 
        if (expr->etype == EXPR_PAYLOAD && expr->payload.is_raw)
                return 0;
 
-       desc = ctx->pctx.protocol[base].desc;
+       pctx = eval_proto_ctx(ctx);
+       desc = pctx->protocol[base].desc;
        if (desc == NULL) {
                if (payload_gen_dependency(ctx, payload, &nstmt) < 0)
                        return -1;
 
                rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt);
-               desc = ctx->pctx.protocol[base].desc;
+
+               desc = pctx->protocol[base].desc;
 
                if (desc == expr->payload.desc)
                        goto check_icmp;
@@ -759,15 +774,16 @@ static int __expr_evaluate_payload(struct eval_ctx *ctx, struct expr *expr)
                                                  desc->name,
                                                  payload->payload.desc->name);
 
-                       payload->payload.offset += ctx->pctx.stacked_ll[0]->length;
+                       payload->payload.offset += pctx->stacked_ll[0]->length;
                        rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt);
                        return 1;
                }
+               goto check_icmp;
        }
 
        if (payload->payload.base == desc->base &&
-           proto_ctx_is_ambiguous(&ctx->pctx, base)) {
-               desc = proto_ctx_find_conflict(&ctx->pctx, base, payload->payload.desc);
+           proto_ctx_is_ambiguous(pctx, base)) {
+               desc = proto_ctx_find_conflict(pctx, base, payload->payload.desc);
                assert(desc);
 
                return expr_error(ctx->msgs, payload,
@@ -785,8 +801,8 @@ static int __expr_evaluate_payload(struct eval_ctx *ctx, struct expr *expr)
                if (desc->base == PROTO_BASE_LL_HDR) {
                        unsigned int i;
 
-                       for (i = 0; i < ctx->pctx.stacked_ll_count; i++)
-                               payload->payload.offset += ctx->pctx.stacked_ll[i]->length;
+                       for (i = 0; i < pctx->stacked_ll_count; i++)
+                               payload->payload.offset += pctx->stacked_ll[i]->length;
                }
 check_icmp:
                if (desc != &proto_icmp && desc != &proto_icmp6)
@@ -813,13 +829,13 @@ check_icmp:
                if (err <= 0)
                        return err;
 
-               desc = ctx->pctx.protocol[base].desc;
+               desc = pctx->protocol[base].desc;
                if (desc == payload->payload.desc)
                        return 0;
        }
        return expr_error(ctx->msgs, payload,
                          "conflicting protocols specified: %s vs. %s",
-                         ctx->pctx.protocol[base].desc->name,
+                         pctx->protocol[base].desc->name,
                          payload->payload.desc->name);
 }
 
@@ -857,20 +873,22 @@ static int expr_evaluate_rt(struct eval_ctx *ctx, struct expr **expr)
 {
        static const char emsg[] = "cannot determine ip protocol version, use \"ip nexthop\" or \"ip6 nexthop\" instead";
        struct expr *rt = *expr;
+       struct proto_ctx *pctx;
 
-       rt_expr_update_type(&ctx->pctx, rt);
+       pctx = eval_proto_ctx(ctx);
+       rt_expr_update_type(pctx, rt);
 
        switch (rt->rt.key) {
        case NFT_RT_NEXTHOP4:
                if (rt->dtype != &ipaddr_type)
                        return expr_error(ctx->msgs, rt, "%s", emsg);
-               if (ctx->pctx.family == NFPROTO_IPV6)
+               if (pctx->family == NFPROTO_IPV6)
                        return expr_error(ctx->msgs, rt, "%s nexthop will not match", "ip");
                break;
        case NFT_RT_NEXTHOP6:
                if (rt->dtype != &ip6addr_type)
                        return expr_error(ctx->msgs, rt, "%s", emsg);
-               if (ctx->pctx.family == NFPROTO_IPV4)
+               if (pctx->family == NFPROTO_IPV4)
                        return expr_error(ctx->msgs, rt, "%s nexthop will not match", "ip6");
                break;
        default:
@@ -885,8 +903,10 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct)
        const struct proto_desc *base, *base_now;
        struct expr *left, *right, *dep;
        struct stmt *nstmt = NULL;
+       struct proto_ctx *pctx;
 
-       base_now = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       pctx = eval_proto_ctx(ctx);
+       base_now = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
 
        switch (ct->ct.nfproto) {
        case NFPROTO_IPV4:
@@ -896,7 +916,7 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct)
                base = &proto_ip6;
                break;
        default:
-               base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
                if (base == &proto_ip)
                        ct->ct.nfproto = NFPROTO_IPV4;
                else if (base == &proto_ip)
@@ -918,8 +938,8 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct)
                return expr_error(ctx->msgs, ct,
                                  "conflicting dependencies: %s vs. %s\n",
                                  base->name,
-                                 ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc->name);
-       switch (ctx->pctx.family) {
+                                 pctx->protocol[PROTO_BASE_NETWORK_HDR].desc->name);
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6:
                return 0;
@@ -932,7 +952,7 @@ static int ct_gen_nh_dependency(struct eval_ctx *ctx, struct expr *ct)
                                    constant_data_ptr(ct->ct.nfproto, left->len));
        dep = relational_expr_alloc(&ct->location, OP_EQ, left, right);
 
-       relational_expr_pctx_update(&ctx->pctx, dep);
+       relational_expr_pctx_update(pctx, dep);
 
        nstmt = expr_stmt_alloc(&dep->location, dep);
        rule_stmt_insert_at(ctx->rule, nstmt, ctx->stmt);
@@ -948,8 +968,10 @@ static int expr_evaluate_ct(struct eval_ctx *ctx, struct expr **expr)
 {
        const struct proto_desc *base, *error;
        struct expr *ct = *expr;
+       struct proto_ctx *pctx;
 
-       base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       pctx = eval_proto_ctx(ctx);
+       base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
 
        switch (ct->ct.key) {
        case NFT_CT_SRC:
@@ -974,13 +996,13 @@ static int expr_evaluate_ct(struct eval_ctx *ctx, struct expr **expr)
                break;
        }
 
-       ct_expr_update_type(&ctx->pctx, ct);
+       ct_expr_update_type(pctx, ct);
 
        return expr_evaluate_primary(ctx, expr);
 
 err_conflict:
        return stmt_binary_error(ctx, ct,
-                                &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                 "conflicting protocols specified: %s vs. %s",
                                 base->name, error->name);
 }
@@ -2146,6 +2168,7 @@ static bool range_needs_swap(const struct expr *range)
 static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr)
 {
        struct expr *rel = *expr, *left, *right;
+       struct proto_ctx *pctx;
        struct expr *range;
        int ret;
 
@@ -2153,6 +2176,8 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr)
                return -1;
        left = rel->left;
 
+       pctx = eval_proto_ctx(ctx);
+
        if (rel->right->etype == EXPR_RANGE && lhs_is_meta_hour(rel->left)) {
                ret = __expr_evaluate_range(ctx, &rel->right);
                if (ret)
@@ -2220,7 +2245,7 @@ static int expr_evaluate_relational(struct eval_ctx *ctx, struct expr **expr)
                 * Update protocol context for payload and meta iiftype
                 * equality expressions.
                 */
-               relational_expr_pctx_update(&ctx->pctx, rel);
+               relational_expr_pctx_update(pctx, rel);
 
                /* fall through */
        case OP_NEQ:
@@ -2332,11 +2357,12 @@ static int expr_evaluate_fib(struct eval_ctx *ctx, struct expr **exprp)
 
 static int expr_evaluate_meta(struct eval_ctx *ctx, struct expr **exprp)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        struct expr *meta = *exprp;
 
        switch (meta->meta.key) {
        case NFT_META_NFPROTO:
-               if (ctx->pctx.family != NFPROTO_INET &&
+               if (pctx->family != NFPROTO_INET &&
                    meta->flags & EXPR_F_PROTOCOL)
                        return expr_error(ctx->msgs, meta,
                                          "meta nfproto is only useful in the inet family");
@@ -2403,9 +2429,10 @@ static int expr_evaluate_variable(struct eval_ctx *ctx, struct expr **exprp)
 
 static int expr_evaluate_xfrm(struct eval_ctx *ctx, struct expr **exprp)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        struct expr *expr = *exprp;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6:
        case NFPROTO_INET:
@@ -2848,9 +2875,10 @@ static int reject_payload_gen_dependency_tcp(struct eval_ctx *ctx,
                                             struct stmt *stmt,
                                             struct expr **payload)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *desc;
 
-       desc = ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR].desc;
+       desc = pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc;
        if (desc != NULL)
                return 0;
        *payload = payload_expr_alloc(&stmt->location, &proto_tcp,
@@ -2862,9 +2890,10 @@ static int reject_payload_gen_dependency_family(struct eval_ctx *ctx,
                                                struct stmt *stmt,
                                                struct expr **payload)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *base;
 
-       base = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       base = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
        if (base != NULL)
                return 0;
 
@@ -2931,6 +2960,7 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx,
                                            struct stmt *stmt,
                                            const struct proto_desc *desc)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *base;
        int protocol;
 
@@ -2940,7 +2970,7 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx,
        case NFT_REJECT_ICMPX_UNREACH:
                break;
        case NFT_REJECT_ICMP_UNREACH:
-               base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                protocol = proto_find_num(base, desc);
                switch (protocol) {
                case NFPROTO_IPV4:
@@ -2948,14 +2978,14 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx,
                        if (stmt->reject.family == NFPROTO_IPV4)
                                break;
                        return stmt_binary_error(ctx, stmt->reject.expr,
-                                 &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                 &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                  "conflicting protocols specified: ip vs ip6");
                case NFPROTO_IPV6:
                case __constant_htons(ETH_P_IPV6):
                        if (stmt->reject.family == NFPROTO_IPV6)
                                break;
                        return stmt_binary_error(ctx, stmt->reject.expr,
-                                 &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                 &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                  "conflicting protocols specified: ip vs ip6");
                default:
                        return stmt_error(ctx, stmt,
@@ -2970,9 +3000,10 @@ static int stmt_evaluate_reject_inet_family(struct eval_ctx *ctx,
 static int stmt_evaluate_reject_inet(struct eval_ctx *ctx, struct stmt *stmt,
                                     struct expr *expr)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *desc;
 
-       desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
        if (desc != NULL &&
            stmt_evaluate_reject_inet_family(ctx, stmt, desc) < 0)
                return -1;
@@ -2987,13 +3018,14 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx,
                                              struct stmt *stmt,
                                              const struct proto_desc *desc)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *base;
        int protocol;
 
        switch (stmt->reject.type) {
        case NFT_REJECT_ICMPX_UNREACH:
        case NFT_REJECT_TCP_RST:
-               base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                protocol = proto_find_num(base, desc);
                switch (protocol) {
                case __constant_htons(ETH_P_IP):
@@ -3001,29 +3033,29 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx,
                        break;
                default:
                        return stmt_binary_error(ctx, stmt,
-                                   &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                   &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                    "cannot reject this network family");
                }
                break;
        case NFT_REJECT_ICMP_UNREACH:
-               base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                protocol = proto_find_num(base, desc);
                switch (protocol) {
                case __constant_htons(ETH_P_IP):
                        if (NFPROTO_IPV4 == stmt->reject.family)
                                break;
                        return stmt_binary_error(ctx, stmt->reject.expr,
-                                 &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                 &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                  "conflicting protocols specified: ip vs ip6");
                case __constant_htons(ETH_P_IPV6):
                        if (NFPROTO_IPV6 == stmt->reject.family)
                                break;
                        return stmt_binary_error(ctx, stmt->reject.expr,
-                                 &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                 &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                  "conflicting protocols specified: ip vs ip6");
                default:
                        return stmt_binary_error(ctx, stmt,
-                                   &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                   &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                    "cannot reject this network family");
                }
                break;
@@ -3035,14 +3067,15 @@ static int stmt_evaluate_reject_bridge_family(struct eval_ctx *ctx,
 static int stmt_evaluate_reject_bridge(struct eval_ctx *ctx, struct stmt *stmt,
                                       struct expr *expr)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *desc;
 
-       desc = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+       desc = pctx->protocol[PROTO_BASE_LL_HDR].desc;
        if (desc != &proto_eth && desc != &proto_vlan && desc != &proto_netdev)
                return __stmt_binary_error(ctx, &stmt->location, NULL,
                                           "cannot reject from this link layer protocol");
 
-       desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
        if (desc != NULL &&
            stmt_evaluate_reject_bridge_family(ctx, stmt, desc) < 0)
                return -1;
@@ -3056,7 +3089,9 @@ static int stmt_evaluate_reject_bridge(struct eval_ctx *ctx, struct stmt *stmt,
 static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt,
                                       struct expr *expr)
 {
-       switch (ctx->pctx.family) {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
+
+       switch (pctx->family) {
        case NFPROTO_ARP:
                return stmt_error(ctx, stmt, "cannot use reject with arp");
        case NFPROTO_IPV4:
@@ -3070,7 +3105,7 @@ static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt,
                        return stmt_binary_error(ctx, stmt->reject.expr, stmt,
                                   "abstracted ICMP unreachable not supported");
                case NFT_REJECT_ICMP_UNREACH:
-                       if (stmt->reject.family == ctx->pctx.family)
+                       if (stmt->reject.family == pctx->family)
                                break;
                        return stmt_binary_error(ctx, stmt->reject.expr, stmt,
                                  "conflicting protocols specified: ip vs ip6");
@@ -3094,28 +3129,29 @@ static int stmt_evaluate_reject_family(struct eval_ctx *ctx, struct stmt *stmt,
 static int stmt_evaluate_reject_default(struct eval_ctx *ctx,
                                          struct stmt *stmt)
 {
-       int protocol;
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *desc, *base;
+       int protocol;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6:
                stmt->reject.type = NFT_REJECT_ICMP_UNREACH;
-               stmt->reject.family = ctx->pctx.family;
-               if (ctx->pctx.family == NFPROTO_IPV4)
+               stmt->reject.family = pctx->family;
+               if (pctx->family == NFPROTO_IPV4)
                        stmt->reject.icmp_code = ICMP_PORT_UNREACH;
                else
                        stmt->reject.icmp_code = ICMP6_DST_UNREACH_NOPORT;
                break;
        case NFPROTO_INET:
-               desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+               desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
                if (desc == NULL) {
                        stmt->reject.type = NFT_REJECT_ICMPX_UNREACH;
                        stmt->reject.icmp_code = NFT_REJECT_ICMPX_PORT_UNREACH;
                        break;
                }
                stmt->reject.type = NFT_REJECT_ICMP_UNREACH;
-               base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                protocol = proto_find_num(base, desc);
                switch (protocol) {
                case NFPROTO_IPV4:
@@ -3132,14 +3168,14 @@ static int stmt_evaluate_reject_default(struct eval_ctx *ctx,
                break;
        case NFPROTO_BRIDGE:
        case NFPROTO_NETDEV:
-               desc = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+               desc = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
                if (desc == NULL) {
                        stmt->reject.type = NFT_REJECT_ICMPX_UNREACH;
                        stmt->reject.icmp_code = NFT_REJECT_ICMPX_PORT_UNREACH;
                        break;
                }
                stmt->reject.type = NFT_REJECT_ICMP_UNREACH;
-               base = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+               base = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                protocol = proto_find_num(base, desc);
                switch (protocol) {
                case __constant_htons(ETH_P_IP):
@@ -3175,9 +3211,9 @@ static int stmt_evaluate_reject_icmp(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_reset(struct eval_ctx *ctx, struct stmt *stmt)
 {
-       int protonum;
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *desc, *base;
-       struct proto_ctx *pctx = &ctx->pctx;
+       int protonum;
 
        desc = pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc;
        if (desc == NULL)
@@ -3194,7 +3230,7 @@ static int stmt_evaluate_reset(struct eval_ctx *ctx, struct stmt *stmt)
        default:
                if (stmt->reject.type == NFT_REJECT_TCP_RST) {
                        return stmt_binary_error(ctx, stmt,
-                                &ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR],
+                                &pctx->protocol[PROTO_BASE_TRANSPORT_HDR],
                                 "you cannot use tcp reset with this protocol");
                }
                break;
@@ -3222,13 +3258,14 @@ static int stmt_evaluate_reject(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int nat_evaluate_family(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *nproto;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6:
                if (stmt->nat.family == NFPROTO_UNSPEC)
-                       stmt->nat.family = ctx->pctx.family;
+                       stmt->nat.family = pctx->family;
                return 0;
        case NFPROTO_INET:
                if (!stmt->nat.addr) {
@@ -3238,7 +3275,7 @@ static int nat_evaluate_family(struct eval_ctx *ctx, struct stmt *stmt)
                if (stmt->nat.family != NFPROTO_UNSPEC)
                        return 0;
 
-               nproto = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+               nproto = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
 
                if (nproto == &proto_ip)
                        stmt->nat.family = NFPROTO_IPV4;
@@ -3267,7 +3304,7 @@ static const struct datatype *get_addr_dtype(uint8_t family)
 static int evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt,
                             struct expr **expr)
 {
-       struct proto_ctx *pctx = &ctx->pctx;
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct datatype *dtype;
 
        dtype = get_addr_dtype(pctx->family);
@@ -3320,7 +3357,7 @@ static bool nat_evaluate_addr_has_th_expr(const struct expr *map)
 static int nat_evaluate_transport(struct eval_ctx *ctx, struct stmt *stmt,
                                  struct expr **expr)
 {
-       struct proto_ctx *pctx = &ctx->pctx;
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
 
        if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL &&
            !nat_evaluate_addr_has_th_expr(stmt->nat.addr))
@@ -3336,16 +3373,17 @@ static int nat_evaluate_transport(struct eval_ctx *ctx, struct stmt *stmt,
 static int stmt_evaluate_l3proto(struct eval_ctx *ctx,
                                 struct stmt *stmt, uint8_t family)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_desc *nproto;
 
-       nproto = ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc;
+       nproto = pctx->protocol[PROTO_BASE_NETWORK_HDR].desc;
 
        if ((nproto == &proto_ip && family != NFPROTO_IPV4) ||
            (nproto == &proto_ip6 && family != NFPROTO_IPV6))
                return stmt_binary_error(ctx, stmt,
-                                        &ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR],
+                                        &pctx->protocol[PROTO_BASE_NETWORK_HDR],
                                         "conflicting protocols specified: %s vs. %s. You must specify ip or ip6 family in %s statement",
-                                        ctx->pctx.protocol[PROTO_BASE_NETWORK_HDR].desc->name,
+                                        pctx->protocol[PROTO_BASE_NETWORK_HDR].desc->name,
                                         family2str(family),
                                         stmt->ops->name);
        return 0;
@@ -3355,10 +3393,11 @@ static int stmt_evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt,
                              uint8_t family,
                              struct expr **addr)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct datatype *dtype;
        int err;
 
-       if (ctx->pctx.family == NFPROTO_INET) {
+       if (pctx->family == NFPROTO_INET) {
                dtype = get_addr_dtype(family);
                if (dtype->size == 0)
                        return stmt_error(ctx, stmt,
@@ -3375,7 +3414,7 @@ static int stmt_evaluate_addr(struct eval_ctx *ctx, struct stmt *stmt,
 
 static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt)
 {
-       struct proto_ctx *pctx = &ctx->pctx;
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        struct expr *one, *two, *data, *tmp;
        const struct datatype *dtype;
        int addr_type, err;
@@ -3524,13 +3563,14 @@ static int stmt_evaluate_nat(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_tproxy(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        int err;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6: /* fallthrough */
                if (stmt->tproxy.family == NFPROTO_UNSPEC)
-                       stmt->tproxy.family = ctx->pctx.family;
+                       stmt->tproxy.family = pctx->family;
                break;
        case NFPROTO_INET:
                break;
@@ -3539,7 +3579,7 @@ static int stmt_evaluate_tproxy(struct eval_ctx *ctx, struct stmt *stmt)
                                  "tproxy is only supported for IPv4/IPv6/INET");
        }
 
-       if (ctx->pctx.protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL)
+       if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL)
                return stmt_error(ctx, stmt, "Transparent proxy support requires"
                                             " transport protocol match");
 
@@ -3649,9 +3689,10 @@ static int stmt_evaluate_optstrip(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_dup(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        int err;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_IPV4:
        case NFPROTO_IPV6:
                if (stmt->dup.to == NULL)
@@ -3691,10 +3732,11 @@ static int stmt_evaluate_dup(struct eval_ctx *ctx, struct stmt *stmt)
 
 static int stmt_evaluate_fwd(struct eval_ctx *ctx, struct stmt *stmt)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct datatype *dtype;
        int err, len;
 
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_NETDEV:
                if (stmt->fwd.dev == NULL)
                        return stmt_error(ctx, stmt,
@@ -4523,7 +4565,7 @@ static int rule_evaluate(struct eval_ctx *ctx, struct rule *rule,
        struct stmt *stmt, *tstmt = NULL;
        struct error_record *erec;
 
-       proto_ctx_init(&ctx->pctx, rule->handle.family, ctx->nft->debug_mask);
+       proto_ctx_init(&ctx->_pctx, rule->handle.family, ctx->nft->debug_mask);
        memset(&ctx->ectx, 0, sizeof(ctx->ectx));
 
        ctx->rule = rule;
index 101bfbda587895b681552213e3ad1f159fdece0c..13962ef434801717f6ddf9435edc457c06612db1 100644 (file)
@@ -391,9 +391,11 @@ static int payload_add_dependency(struct eval_ctx *ctx,
 {
        const struct proto_hdr_template *tmpl;
        struct expr *dep, *left, *right;
+       struct proto_ctx *pctx;
        struct stmt *stmt;
-       int protocol = proto_find_num(desc, upper);
+       int protocol;
 
+       protocol = proto_find_num(desc, upper);
        if (protocol < 0)
                return expr_error(ctx->msgs, expr,
                                  "conflicting protocols specified: %s vs. %s",
@@ -415,15 +417,17 @@ static int payload_add_dependency(struct eval_ctx *ctx,
                return expr_error(ctx->msgs, expr,
                                          "dependency statement is invalid");
        }
-       relational_expr_pctx_update(&ctx->pctx, dep);
+
+       pctx = eval_proto_ctx(ctx);
+       relational_expr_pctx_update(pctx, dep);
        *res = stmt;
        return 0;
 }
 
 static const struct proto_desc *
-payload_get_get_ll_hdr(const struct eval_ctx *ctx)
+payload_get_get_ll_hdr(const struct proto_ctx *pctx)
 {
-       switch (ctx->pctx.family) {
+       switch (pctx->family) {
        case NFPROTO_INET:
                return &proto_inet;
        case NFPROTO_BRIDGE:
@@ -440,9 +444,11 @@ payload_get_get_ll_hdr(const struct eval_ctx *ctx)
 static const struct proto_desc *
 payload_gen_special_dependency(struct eval_ctx *ctx, const struct expr *expr)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
+
        switch (expr->payload.base) {
        case PROTO_BASE_LL_HDR:
-               return payload_get_get_ll_hdr(ctx);
+               return payload_get_get_ll_hdr(pctx);
        case PROTO_BASE_TRANSPORT_HDR:
                if (expr->payload.desc == &proto_icmp ||
                    expr->payload.desc == &proto_icmp6 ||
@@ -450,9 +456,9 @@ payload_gen_special_dependency(struct eval_ctx *ctx, const struct expr *expr)
                        const struct proto_desc *desc, *desc_upper;
                        struct stmt *nstmt;
 
-                       desc = ctx->pctx.protocol[PROTO_BASE_LL_HDR].desc;
+                       desc = pctx->protocol[PROTO_BASE_LL_HDR].desc;
                        if (!desc) {
-                               desc = payload_get_get_ll_hdr(ctx);
+                               desc = payload_get_get_ll_hdr(pctx);
                                if (!desc)
                                        break;
                        }
@@ -502,11 +508,14 @@ payload_gen_special_dependency(struct eval_ctx *ctx, const struct expr *expr)
 int payload_gen_dependency(struct eval_ctx *ctx, const struct expr *expr,
                           struct stmt **res)
 {
-       const struct hook_proto_desc *h = &hook_proto_desc[ctx->pctx.family];
+       const struct hook_proto_desc *h;
        const struct proto_desc *desc;
+       struct proto_ctx *pctx;
        struct stmt *stmt;
        uint16_t type;
 
+       pctx = eval_proto_ctx(ctx);
+       h = &hook_proto_desc[pctx->family];
        if (expr->payload.base < h->base) {
                if (expr->payload.base < h->base - 1)
                        return expr_error(ctx->msgs, expr,
@@ -527,7 +536,7 @@ int payload_gen_dependency(struct eval_ctx *ctx, const struct expr *expr,
                return 0;
        }
 
-       desc = ctx->pctx.protocol[expr->payload.base - 1].desc;
+       desc = pctx->protocol[expr->payload.base - 1].desc;
        /* Special case for mixed IPv4/IPv6 and bridge tables */
        if (desc == NULL)
                desc = payload_gen_special_dependency(ctx, expr);
@@ -538,7 +547,7 @@ int payload_gen_dependency(struct eval_ctx *ctx, const struct expr *expr,
                                  "no %s protocol specified",
                                  proto_base_names[expr->payload.base - 1]);
 
-       if (ctx->pctx.family == NFPROTO_BRIDGE && desc == &proto_eth) {
+       if (pctx->family == NFPROTO_BRIDGE && desc == &proto_eth) {
                /* prefer netdev proto, which adds dependencies based
                 * on skb->protocol.
                 *
@@ -563,11 +572,13 @@ int exthdr_gen_dependency(struct eval_ctx *ctx, const struct expr *expr,
                          enum proto_bases pb, struct stmt **res)
 {
        const struct proto_desc *desc;
+       struct proto_ctx *pctx;
 
-       desc = ctx->pctx.protocol[pb].desc;
+       pctx = eval_proto_ctx(ctx);
+       desc = pctx->protocol[pb].desc;
        if (desc == NULL) {
                if (expr->exthdr.op == NFT_EXTHDR_OP_TCPOPT) {
-                       switch (ctx->pctx.family) {
+                       switch (pctx->family) {
                        case NFPROTO_NETDEV:
                        case NFPROTO_BRIDGE:
                        case NFPROTO_INET:
@@ -1228,6 +1239,7 @@ __payload_gen_icmp_echo_dependency(struct eval_ctx *ctx, const struct expr *expr
 int payload_gen_icmp_dependency(struct eval_ctx *ctx, const struct expr *expr,
                                struct stmt **res)
 {
+       struct proto_ctx *pctx = eval_proto_ctx(ctx);
        const struct proto_hdr_template *tmpl;
        const struct proto_desc *desc;
        struct stmt *stmt = NULL;
@@ -1244,11 +1256,11 @@ int payload_gen_icmp_dependency(struct eval_ctx *ctx, const struct expr *expr,
                break;
        case PROTO_ICMP_ECHO:
                /* do not test ICMP_ECHOREPLY here: its 0 */
-               if (ctx->pctx.th_dep.icmp.type == ICMP_ECHO)
+               if (pctx->th_dep.icmp.type == ICMP_ECHO)
                        goto done;
 
                type = ICMP_ECHO;
-               if (ctx->pctx.th_dep.icmp.type)
+               if (pctx->th_dep.icmp.type)
                        goto bad_proto;
 
                stmt = __payload_gen_icmp_echo_dependency(ctx, expr,
@@ -1259,21 +1271,21 @@ int payload_gen_icmp_dependency(struct eval_ctx *ctx, const struct expr *expr,
        case PROTO_ICMP_MTU:
        case PROTO_ICMP_ADDRESS:
                type = icmp_dep_to_type(tmpl->icmp_dep);
-               if (ctx->pctx.th_dep.icmp.type == type)
+               if (pctx->th_dep.icmp.type == type)
                        goto done;
-               if (ctx->pctx.th_dep.icmp.type)
+               if (pctx->th_dep.icmp.type)
                        goto bad_proto;
                stmt = __payload_gen_icmp_simple_dependency(ctx, expr,
                                                            &icmp_type_type,
                                                            desc, type);
                break;
        case PROTO_ICMP6_ECHO:
-               if (ctx->pctx.th_dep.icmp.type == ICMP6_ECHO_REQUEST ||
-                   ctx->pctx.th_dep.icmp.type == ICMP6_ECHO_REPLY)
+               if (pctx->th_dep.icmp.type == ICMP6_ECHO_REQUEST ||
+                   pctx->th_dep.icmp.type == ICMP6_ECHO_REPLY)
                        goto done;
 
                type = ICMP6_ECHO_REQUEST;
-               if (ctx->pctx.th_dep.icmp.type)
+               if (pctx->th_dep.icmp.type)
                        goto bad_proto;
 
                stmt = __payload_gen_icmp_echo_dependency(ctx, expr,
@@ -1286,9 +1298,9 @@ int payload_gen_icmp_dependency(struct eval_ctx *ctx, const struct expr *expr,
        case PROTO_ICMP6_MGMQ:
        case PROTO_ICMP6_PPTR:
                type = icmp_dep_to_type(tmpl->icmp_dep);
-               if (ctx->pctx.th_dep.icmp.type == type)
+               if (pctx->th_dep.icmp.type == type)
                        goto done;
-               if (ctx->pctx.th_dep.icmp.type)
+               if (pctx->th_dep.icmp.type)
                        goto bad_proto;
                stmt = __payload_gen_icmp_simple_dependency(ctx, expr,
                                                            &icmp6_type_type,
@@ -1299,7 +1311,7 @@ int payload_gen_icmp_dependency(struct eval_ctx *ctx, const struct expr *expr,
                BUG("Unhandled icmp dependency code");
        }
 
-       ctx->pctx.th_dep.icmp.type = type;
+       pctx->th_dep.icmp.type = type;
 
        if (stmt_evaluate(ctx, stmt) < 0)
                return expr_error(ctx->msgs, expr,
@@ -1310,5 +1322,5 @@ done:
 
 bad_proto:
        return expr_error(ctx->msgs, expr, "incompatible icmp match: rule has %d, need %u",
-                         ctx->pctx.th_dep.icmp.type, type);
+                         pctx->th_dep.icmp.type, type);
 }