From: Vsevolod Stakhov Date: Wed, 19 Nov 2025 11:26:43 +0000 (+0000) Subject: [Rework] Refactor T.transform to validate input first X-Git-Tag: 3.14.1~11^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=075a664d7cf87b1299bef6d890f9e5d4276b8dcb;p=thirdparty%2Frspamd.git [Rework] Refactor T.transform to validate input first Changed T.transform to validate input type before applying transformer. If transformer returns nil, treat as error. Output is not type-checked. Updated all usages and tests. --- diff --git a/lualib/lua_maps.lua b/lualib/lua_maps.lua index ccfa0ce088..9cdee089cc 100644 --- a/lualib/lua_maps.lua +++ b/lualib/lua_maps.lua @@ -92,14 +92,10 @@ local external_map_schema = T.table({ cdb = T.string():optional(), -- path to CDB file, required for CDB method = T.enum({ "body", "header", "query" }):optional(), -- how to pass input encode = T.enum({ "json", "messagepack" }):optional(), -- how to encode input (if relevant) - timeout = T.transform(T.number({ min = 0 }), function(val) - if type(val) == "number" then - return val - elseif type(val) == "string" then - return lua_util.parse_time_interval(val) - end - return val - end):optional(), + timeout = T.one_of({ + T.number({ min = 0 }), + T.transform(T.string(), lua_util.parse_time_interval) + }):optional(), }) -- Storage for CDB instances diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua index 088965c973..b262777baf 100644 --- a/lualib/lua_redis.lua +++ b/lualib/lua_redis.lua @@ -32,13 +32,8 @@ local db_schema = T.one_of({ local common_schema = T.table({ timeout = T.one_of({ - T.number(), - T.transform(T.number({ min = 0 }), function(val) - if type(val) == "string" then - return lutil.parse_time_interval(val) - end - return val - end) + T.number({ min = 0 }), + T.transform(T.string(), lutil.parse_time_interval) }):optional():doc({ summary = "Connection timeout (seconds)" }), db = db_schema, database = db_schema, @@ -60,34 +55,19 @@ local common_schema = T.table({ T.array(T.string()) }):optional():doc({ summary = "Sentinel servers" }), sentinel_watch_time = T.one_of({ - T.number(), - T.transform(T.number({ min = 0 }), function(val) - if type(val) == "string" then - return lutil.parse_time_interval(val) - end - return val - end) + T.number({ min = 0 }), + T.transform(T.string(), lutil.parse_time_interval) }):optional():doc({ summary = "Sentinel watch time" }), sentinel_masters_pattern = T.string():optional():doc({ summary = "Sentinel masters pattern" }), sentinel_master_maxerrors = T.one_of({ T.number(), - T.transform(T.number(), function(val) - if type(val) == "string" then - return tonumber(val) - end - return val - end) + T.transform(T.string(), tonumber) }):optional():doc({ summary = "Sentinel master max errors" }), sentinel_username = T.string():optional():doc({ summary = "Sentinel username" }), sentinel_password = T.string():optional():doc({ summary = "Sentinel password" }), redis_version = T.one_of({ T.number(), - T.transform(T.number(), function(val) - if type(val) == "string" then - return tonumber(val) - end - return val - end) + T.transform(T.string(), tonumber) }):optional():doc({ summary = "Redis server version (6 or 7)" }), }, { open = true }) diff --git a/lualib/lua_selectors/extractors.lua b/lualib/lua_selectors/extractors.lua index eb18082080..29a05854ce 100644 --- a/lualib/lua_selectors/extractors.lua +++ b/lualib/lua_selectors/extractors.lua @@ -320,18 +320,22 @@ e.g. `get_tld`]], flags = url_flags_ts, flags_mode = T.enum { 'explicit' }:optional(), prefix = T.string():optional(), - need_content = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - need_emails = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - need_images = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - ignore_redirected = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), + need_content = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + need_emails = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + need_images = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + ignore_redirected = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), } } }, ['specific_urls_filter_map'] = { @@ -362,18 +366,22 @@ e.g. `get_tld`]], flags = url_flags_ts, flags_mode = T.enum { 'explicit' }:optional(), prefix = T.string():optional(), - need_content = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - need_emails = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - need_images = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), - ignore_redirected = T.transform(T.boolean(), function(v) - return type(v) == "string" and lua_util.toboolean(v) or v - end):optional(), + need_content = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + need_emails = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + need_images = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), + ignore_redirected = T.one_of({ + T.boolean(), + T.transform(T.string(), lua_util.toboolean) + }):optional(), } } }, -- URLs filtered by flags diff --git a/lualib/lua_shape/README.md b/lualib/lua_shape/README.md index f93a87a5c7..9891d1b9ca 100644 --- a/lualib/lua_shape/README.md +++ b/lualib/lua_shape/README.md @@ -157,20 +157,32 @@ local config_schema = T.one_of({ ### Transforms +`T.transform(accepted_type, transformer)` validates input against `accepted_type`, then applies `transformer` function. + ```lua --- Parse time interval string to number -local timeout_schema = T.transform(T.number({ min = 0 }), function(val) - if type(val) == "number" then - return val - elseif type(val) == "string" then - return parse_time_interval(val) -- "5s" -> 5.0 - else - error("Expected number or time interval string") - end -end) +-- Accept string, convert to number +local num_from_string = T.transform(T.string(), tonumber) + +-- Accept number or string, convert both to number +local flexible_number = T.one_of({ + T.number(), + T.transform(T.string(), tonumber) +}) + +-- Accept string, parse time interval to number +local timeout_schema = T.one_of({ + T.number({ min = 0 }), + T.transform(T.string(), parse_time_interval) -- "5s" -> 5.0 +}) ``` -> **Note:** transform functions are evaluated only when you call `schema:transform(...)`. A plain `schema:check(...)` validates the original input without invoking the transform, matching tableshape semantics. +**Semantics:** +1. Input is validated against accepted type (first argument) +2. If valid, transformer function is called with pcall +3. If transformer returns `nil` or errors, validation fails +4. Otherwise, result is accepted without type checking + +> **Note:** Transform functions run only in `schema:transform(...)` mode. In `schema:check(...)` mode, only the input type is validated. ### Callable Defaults @@ -349,7 +361,7 @@ value does not match any alternative at : - `T.one_of(variants, opts?)` - Sum type - `T.optional(schema, opts?)` - Optional wrapper - `T.default(schema, value)` - Default value wrapper -- `T.transform(schema, fn, opts?)` - Transform wrapper +- `T.transform(accepted_type, transformer, opts?)` - Transform wrapper (validates input against accepted_type, then applies transformer) - `T.ref(id, opts?)` - Schema reference placeholder (must be resolved via the registry before validation) - `T.mixin(schema, opts?)` - Mixin definition @@ -403,7 +415,7 @@ Quick reference: | `ts.shape({...})` | `T.table({...})` | | `field:is_optional()` | `field:optional()` or `{ schema = ..., optional = true }` | | `ts.string + ts.number` | `T.one_of({ T.string(), T.number() })` | -| `ts.string / fn` | `T.string():transform_with(fn)` or `T.transform(T.number(), fn)` | +| `ts.string / fn` | `T.string():transform_with(fn)` or `T.transform(T.string(), fn)` | | `field:describe("...")` | `field:doc({ summary = "..." })` | ## Files diff --git a/lualib/lua_shape/core.lua b/lualib/lua_shape/core.lua index 1a07b25944..528f08f1be 100644 --- a/lualib/lua_shape/core.lua +++ b/lualib/lua_shape/core.lua @@ -610,21 +610,34 @@ end -- Transform wrapper local function check_transform(node, value, ctx) - if ctx.mode == "transform" then - -- Apply transformation (protect against errors in user-provided function) - local ok_transform, new_value = pcall(node.fn, value) - if not ok_transform then - return false, make_error("transform_error", ctx.path, { - error = tostring(new_value) - }) - end + -- First, validate the input value against the accepted type + local ok_input, err = node.inner:_check(value, make_context("check", ctx.path)) + if not ok_input then + return false, err + end - -- Validate transformed value against inner schema - return node.inner:_check(new_value, ctx) - else - -- In check mode, validate original value against inner schema - return node.inner:_check(value, ctx) + -- In check mode, we're done - input is valid + if ctx.mode ~= "transform" then + return true, value end + + -- In transform mode, apply the functor (protect against errors) + local ok_transform, new_value = pcall(node.fn, value) + if not ok_transform then + return false, make_error("transform_error", ctx.path, { + error = tostring(new_value) + }) + end + + -- Check if transformation returned nil (transformation failed) + if new_value == nil then + return false, make_error("transform_error", ctx.path, { + error = "transformation function returned nil" + }) + end + + -- Accept the transformed value without type checking the output + return true, new_value end function T.transform(schema, fn, opts) diff --git a/lualib/plugins/ratelimit.lua b/lualib/plugins/ratelimit.lua index 265b5fce2c..0b674f9c1e 100644 --- a/lualib/plugins/ratelimit.lua +++ b/lualib/plugins/ratelimit.lua @@ -100,21 +100,11 @@ end local bucket_schema = T.table({ burst = T.one_of({ T.number(), - T.transform(T.number(), function(val) - if type(val) == "string" then - return lua_util.dehumanize_number(val) - end - return val - end) + T.transform(T.string(), lua_util.dehumanize_number) }):doc({ summary = "Burst size (number of messages)" }), rate = T.one_of({ T.number(), - T.transform(T.number(), function(val) - if type(val) == "string" then - return str_to_rate(val) - end - return val - end) + T.transform(T.string(), str_to_rate) }):doc({ summary = "Rate limit (messages per time unit)" }), skip_recipients = T.boolean():optional():doc({ summary = "Skip per-recipient limits" }), symbol = T.string():optional():doc({ summary = "Custom symbol name" }), diff --git a/lualib/plugins/rbl.lua b/lualib/plugins/rbl.lua index b7cc8fde79..f0eb48503c 100644 --- a/lualib/plugins/rbl.lua +++ b/lualib/plugins/rbl.lua @@ -59,51 +59,33 @@ local default_options = { local return_codes_schema = T.table({}, { open = true, - key = T.transform(T.string(), function(val) - if type(val) == "string" then - return string.upper(val) - end - return val - end), + key = T.transform(T.string(), string.upper), extra = T.one_of({ T.array(T.string()), - -- Transform string to array, inner schema validates the result - T.transform(T.array(T.string()), function(val) - if type(val) == "string" then - return { val } - end - return val + -- Transform string to array + T.transform(T.string(), function(val) + return { val } end) }) }):doc({ summary = "Map of symbol names to IP patterns" }) local return_bits_schema = T.table({}, { open = true, - key = T.transform(T.string(), function(val) - if type(val) == "string" then - return string.upper(val) - end - return val - end), + key = T.transform(T.string(), string.upper), extra = T.one_of({ T.array(T.one_of({ T.number(), - T.transform(T.number(), function(val) - if type(val) == "string" then - return tonumber(val) - end - return val - end) + T.transform(T.string(), tonumber) })), - -- Transform string or number to array, inner schema validates the result - T.transform(T.array(T.number()), function(val) - if type(val) == "string" then + -- Transform string or number to array + T.one_of({ + T.transform(T.string(), function(val) return { tonumber(val) } - elseif type(val) == "number" then + end), + T.transform(T.number(), function(val) return { val } - end - return val - end) + end) + }) }) }):doc({ summary = "Map of symbol names to bit numbers" }) diff --git a/test/lua/unit/lua_shape.lua b/test/lua/unit/lua_shape.lua index e564078f91..0e2b3a1eb5 100644 --- a/test/lua/unit/lua_shape.lua +++ b/test/lua/unit/lua_shape.lua @@ -265,35 +265,46 @@ context("Lua shape validation", function() -- Transform tests context("Transform support", function() test("Transform string to number", function() - local schema = T.transform(T.number(), function(val) - if type(val) == "string" then - return tonumber(val) - end - return val - end) + -- New semantics: first arg is accepted type, second is transformer + local schema = T.transform(T.string(), tonumber) + -- Transform mode: converts string to number local val, err = schema:transform("42") assert_nil(err) assert_equal(val, 42) - end) - test("Transform with validation", function() - local schema = T.transform(T.integer({ min = 0 }), function(val) - if type(val) == "string" then - return tonumber(val) - end - return val - end) + -- Invalid string returns nil, which is caught as error + val, err = schema:transform("not a number") + assert_nil(val) + assert_not_nil(err) + assert_equal(err.kind, "transform_error") + end) - -- Valid transform - local val, err = schema:transform("10") - assert_nil(err) - assert_equal(val, 10) + test("Transform validates input type first", function() + local schema = T.transform(T.string(), tonumber) - -- Transform result fails validation - val, err = schema:transform("-5") + -- Number input fails because accepted type is string + local val, err = schema:transform(42) assert_nil(val) assert_not_nil(err) + assert_equal(err.kind, "type_mismatch") + end) + + test("Transform accepts number or string using one_of", function() + local schema = T.one_of({ + T.number(), + T.transform(T.string(), tonumber) + }) + + -- Number passes through + local val, err = schema:transform(42) + assert_nil(err) + assert_equal(val, 42) + + -- String is converted + val, err = schema:transform("123") + assert_nil(err) + assert_equal(val, 123) end) test("Transform only in transform mode", function() @@ -301,7 +312,7 @@ context("Lua shape validation", function() return val * 2 end) - -- Check mode: no transform + -- Check mode: no transform, just validates input is number local ok, val = schema:check(5) assert_true(ok) assert_equal(val, 5) @@ -321,6 +332,18 @@ context("Lua shape validation", function() assert_nil(err) assert_equal(val, "HELLO") end) + + test("Transform result is not type-checked", function() + -- Transform string to table - result type is not validated + local schema = T.transform(T.string(), function(val) + return { value = val } + end) + + local val, err = schema:transform("test") + assert_nil(err) + assert_equal(type(val), "table") + assert_equal(val.value, "test") + end) end) -- one_of tests