]> git.ipfire.org Git - thirdparty/nftables.git/blobdiff - src/evaluate.c
datatype: fix leak and cleanup reference counting for struct datatype
[thirdparty/nftables.git] / src / evaluate.c
index b0c6919f600a574969cfbf3c4a753f1328f1baf9..1b7e0b37b61beb040682e19cb205f6da40ebd537 100644 (file)
@@ -82,7 +82,7 @@ static void key_fix_dtype_byteorder(struct expr *key)
        if (dtype->byteorder == key->byteorder)
                return;
 
-       datatype_set(key, set_datatype_alloc(dtype, key->byteorder));
+       __datatype_set(key, set_datatype_alloc(dtype, key->byteorder));
 }
 
 static int set_evaluate(struct eval_ctx *ctx, struct set *set);
@@ -1521,8 +1521,7 @@ static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr)
                        clone = datatype_clone(i->dtype);
                        clone->size = i->len;
                        clone->byteorder = i->byteorder;
-                       clone->refcnt = 1;
-                       i->dtype = clone;
+                       __datatype_set(i, clone);
                }
 
                if (dtype == NULL && i->dtype->size == 0)
@@ -1550,7 +1549,7 @@ static int expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr)
        }
 
        (*expr)->flags |= flags;
-       datatype_set(*expr, concat_type_alloc(ntype));
+       __datatype_set(*expr, concat_type_alloc(ntype));
        (*expr)->len   = size;
 
        if (off > 0)
@@ -1888,7 +1887,6 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr)
 {
        struct expr *map = *expr, *mappings;
        struct expr_ctx ectx = ctx->ectx;
-       const struct datatype *dtype;
        struct expr *key, *data;
 
        if (map->map->etype == EXPR_CT &&
@@ -1930,12 +1928,16 @@ static int expr_evaluate_map(struct eval_ctx *ctx, struct expr **expr)
                                                  ctx->ectx.len, NULL);
                }
 
-               dtype = set_datatype_alloc(ectx.dtype, ectx.byteorder);
-               if (dtype->type == TYPE_VERDICT)
+               if (ectx.dtype->type == TYPE_VERDICT) {
                        data = verdict_expr_alloc(&netlink_location, 0, NULL);
-               else
+               } else {
+                       const struct datatype *dtype;
+
+                       dtype = set_datatype_alloc(ectx.dtype, ectx.byteorder);
                        data = constant_expr_alloc(&netlink_location, dtype,
                                                   dtype->byteorder, ectx.len, NULL);
+                       datatype_free(dtype);
+               }
 
                mappings = implicit_set_declaration(ctx, "__map%d",
                                                    key, data,
@@ -3765,8 +3767,10 @@ static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt)
 {
        struct proto_ctx *pctx = eval_proto_ctx(ctx);
        struct expr *one, *two, *data, *tmp;
-       const struct datatype *dtype;
-       int addr_type, err;
+       const struct datatype *dtype = NULL;
+       const struct datatype *dtype2;
+       int addr_type;
+       int err;
 
        if (stmt->nat.family == NFPROTO_INET)
                expr_family_infer(pctx, stmt->nat.addr, &stmt->nat.family);
@@ -3786,18 +3790,23 @@ static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt)
        dtype = concat_type_alloc((addr_type << TYPE_BITS) | TYPE_INET_SERVICE);
 
        expr_set_context(&ctx->ectx, dtype, dtype->size);
-       if (expr_evaluate(ctx, &stmt->nat.addr))
-               return -1;
+       if (expr_evaluate(ctx, &stmt->nat.addr)) {
+               err = -1;
+               goto out;
+       }
 
        if (pctx->protocol[PROTO_BASE_TRANSPORT_HDR].desc == NULL &&
            !nat_evaluate_addr_has_th_expr(stmt->nat.addr)) {
-               return stmt_binary_error(ctx, stmt->nat.addr, stmt,
+               err = stmt_binary_error(ctx, stmt->nat.addr, stmt,
                                         "transport protocol mapping is only "
                                         "valid after transport protocol match");
+               goto out;
        }
 
-       if (stmt->nat.addr->etype != EXPR_MAP)
-               return 0;
+       if (stmt->nat.addr->etype != EXPR_MAP) {
+               err = 0;
+               goto out;
+       }
 
        data = stmt->nat.addr->mappings->set->data;
        if (data->flags & EXPR_F_INTERVAL)
@@ -3805,36 +3814,42 @@ static int stmt_evaluate_nat_map(struct eval_ctx *ctx, struct stmt *stmt)
 
        datatype_set(data, dtype);
 
-       if (expr_ops(data)->type != EXPR_CONCAT)
-               return __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size,
+       if (expr_ops(data)->type != EXPR_CONCAT) {
+               err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size,
                                           BYTEORDER_BIG_ENDIAN,
                                           &stmt->nat.addr);
+               goto out;
+       }
 
        one = list_first_entry(&data->expressions, struct expr, list);
        two = list_entry(one->list.next, struct expr, list);
 
-       if (one == two || !list_is_last(&two->list, &data->expressions))
-               return __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size,
+       if (one == two || !list_is_last(&two->list, &data->expressions)) {
+               err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size,
                                           BYTEORDER_BIG_ENDIAN,
                                           &stmt->nat.addr);
+               goto out;
+       }
 
-       dtype = get_addr_dtype(stmt->nat.family);
+       dtype2 = get_addr_dtype(stmt->nat.family);
        tmp = one;
-       err = __stmt_evaluate_arg(ctx, stmt, dtype, dtype->size,
+       err = __stmt_evaluate_arg(ctx, stmt, dtype2, dtype2->size,
                                  BYTEORDER_BIG_ENDIAN,
                                  &tmp);
        if (err < 0)
-               return err;
+               goto out;
        if (tmp != one)
                BUG("Internal error: Unexpected alteration of l3 expression");
 
        tmp = two;
        err = nat_evaluate_transport(ctx, stmt, &tmp);
        if (err < 0)
-               return err;
+               goto out;
        if (tmp != two)
                BUG("Internal error: Unexpected alteration of l4 expression");
 
+out:
+       datatype_free(dtype);
        return err;
 }
 
@@ -4549,8 +4564,7 @@ static int set_expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr)
                        dtype = datatype_clone(i->dtype);
                        dtype->size = i->len;
                        dtype->byteorder = i->byteorder;
-                       dtype->refcnt = 1;
-                       i->dtype = dtype;
+                       __datatype_set(i, dtype);
                }
 
                if (i->dtype->size == 0 && i->len == 0)
@@ -4573,7 +4587,7 @@ static int set_expr_evaluate_concat(struct eval_ctx *ctx, struct expr **expr)
        }
 
        (*expr)->flags |= flags;
-       datatype_set(*expr, concat_type_alloc(ntype));
+       __datatype_set(*expr, concat_type_alloc(ntype));
        (*expr)->len   = size;
 
        expr_set_context(&ctx->ectx, (*expr)->dtype, (*expr)->len);