]> git.ipfire.org Git - thirdparty/rspamd.git/commitdiff
[Rework] Refactor T.transform to validate input first
authorVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 19 Nov 2025 11:26:43 +0000 (11:26 +0000)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Wed, 19 Nov 2025 11:27:24 +0000 (11:27 +0000)
Changed T.transform to validate input type before applying transformer.
If transformer returns nil, treat as error. Output is not type-checked.
Updated all usages and tests.

lualib/lua_maps.lua
lualib/lua_redis.lua
lualib/lua_selectors/extractors.lua
lualib/lua_shape/README.md
lualib/lua_shape/core.lua
lualib/plugins/ratelimit.lua
lualib/plugins/rbl.lua
test/lua/unit/lua_shape.lua

index ccfa0ce088fd9caa7572a23cdeefa53a28b743d9..9cdee089cc08fbfc190f3cd26deccd0f1f8fa310 100644 (file)
@@ -92,14 +92,10 @@ local external_map_schema = T.table({
   cdb = T.string():optional(), -- path to CDB file, required for CDB
   method = T.enum({ "body", "header", "query" }):optional(), -- how to pass input
   encode = T.enum({ "json", "messagepack" }):optional(), -- how to encode input (if relevant)
-  timeout = T.transform(T.number({ min = 0 }), function(val)
-    if type(val) == "number" then
-      return val
-    elseif type(val) == "string" then
-      return lua_util.parse_time_interval(val)
-    end
-    return val
-  end):optional(),
+  timeout = T.one_of({
+    T.number({ min = 0 }),
+    T.transform(T.string(), lua_util.parse_time_interval)
+  }):optional(),
 })
 
 -- Storage for CDB instances
index 088965c9736b18a0ebf04ffbe7cf2944284756fe..b262777baf39ac6a432a1743a3e6c3bd0aab9871 100644 (file)
@@ -32,13 +32,8 @@ local db_schema = T.one_of({
 
 local common_schema = T.table({
   timeout = T.one_of({
-    T.number(),
-    T.transform(T.number({ min = 0 }), function(val)
-      if type(val) == "string" then
-        return lutil.parse_time_interval(val)
-      end
-      return val
-    end)
+    T.number({ min = 0 }),
+    T.transform(T.string(), lutil.parse_time_interval)
   }):optional():doc({ summary = "Connection timeout (seconds)" }),
   db = db_schema,
   database = db_schema,
@@ -60,34 +55,19 @@ local common_schema = T.table({
     T.array(T.string())
   }):optional():doc({ summary = "Sentinel servers" }),
   sentinel_watch_time = T.one_of({
-    T.number(),
-    T.transform(T.number({ min = 0 }), function(val)
-      if type(val) == "string" then
-        return lutil.parse_time_interval(val)
-      end
-      return val
-    end)
+    T.number({ min = 0 }),
+    T.transform(T.string(), lutil.parse_time_interval)
   }):optional():doc({ summary = "Sentinel watch time" }),
   sentinel_masters_pattern = T.string():optional():doc({ summary = "Sentinel masters pattern" }),
   sentinel_master_maxerrors = T.one_of({
     T.number(),
-    T.transform(T.number(), function(val)
-      if type(val) == "string" then
-        return tonumber(val)
-      end
-      return val
-    end)
+    T.transform(T.string(), tonumber)
   }):optional():doc({ summary = "Sentinel master max errors" }),
   sentinel_username = T.string():optional():doc({ summary = "Sentinel username" }),
   sentinel_password = T.string():optional():doc({ summary = "Sentinel password" }),
   redis_version = T.one_of({
     T.number(),
-    T.transform(T.number(), function(val)
-      if type(val) == "string" then
-        return tonumber(val)
-      end
-      return val
-    end)
+    T.transform(T.string(), tonumber)
   }):optional():doc({ summary = "Redis server version (6 or 7)" }),
 }, { open = true })
 
index eb18082080e5b4d0de803a9f19f26b5cc941572e..29a05854ce4aacd247428fefce9775512b8f2769 100644 (file)
@@ -320,18 +320,22 @@ e.g. `get_tld`]],
       flags = url_flags_ts,
       flags_mode = T.enum { 'explicit' }:optional(),
       prefix = T.string():optional(),
-      need_content = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      need_emails = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      need_images = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      ignore_redirected = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
+      need_content = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      need_emails = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      need_images = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      ignore_redirected = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
     } }
   },
   ['specific_urls_filter_map'] = {
@@ -362,18 +366,22 @@ e.g. `get_tld`]],
       flags = url_flags_ts,
       flags_mode = T.enum { 'explicit' }:optional(),
       prefix = T.string():optional(),
-      need_content = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      need_emails = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      need_images = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
-      ignore_redirected = T.transform(T.boolean(), function(v)
-        return type(v) == "string" and lua_util.toboolean(v) or v
-      end):optional(),
+      need_content = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      need_emails = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      need_images = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
+      ignore_redirected = T.one_of({
+        T.boolean(),
+        T.transform(T.string(), lua_util.toboolean)
+      }):optional(),
     } }
   },
   -- URLs filtered by flags
index f93a87a5c75210f13d4893660a449cb26d17f982..9891d1b9cac9c620da14b3c59f061128ac8ff666 100644 (file)
@@ -157,20 +157,32 @@ local config_schema = T.one_of({
 
 ### Transforms
 
+`T.transform(accepted_type, transformer)` validates input against `accepted_type`, then applies `transformer` function.
+
 ```lua
--- Parse time interval string to number
-local timeout_schema = T.transform(T.number({ min = 0 }), function(val)
-  if type(val) == "number" then
-    return val
-  elseif type(val) == "string" then
-    return parse_time_interval(val)  -- "5s" -> 5.0
-  else
-    error("Expected number or time interval string")
-  end
-end)
+-- Accept string, convert to number
+local num_from_string = T.transform(T.string(), tonumber)
+
+-- Accept number or string, convert both to number
+local flexible_number = T.one_of({
+  T.number(),
+  T.transform(T.string(), tonumber)
+})
+
+-- Accept string, parse time interval to number
+local timeout_schema = T.one_of({
+  T.number({ min = 0 }),
+  T.transform(T.string(), parse_time_interval)  -- "5s" -> 5.0
+})
 ```
 
-> **Note:** transform functions are evaluated only when you call `schema:transform(...)`. A plain `schema:check(...)` validates the original input without invoking the transform, matching tableshape semantics.
+**Semantics:**
+1. Input is validated against accepted type (first argument)
+2. If valid, transformer function is called with pcall
+3. If transformer returns `nil` or errors, validation fails
+4. Otherwise, result is accepted without type checking
+
+> **Note:** Transform functions run only in `schema:transform(...)` mode. In `schema:check(...)` mode, only the input type is validated.
 
 ### Callable Defaults
 
@@ -349,7 +361,7 @@ value does not match any alternative at :
 - `T.one_of(variants, opts?)` - Sum type
 - `T.optional(schema, opts?)` - Optional wrapper
 - `T.default(schema, value)` - Default value wrapper
-- `T.transform(schema, fn, opts?)` - Transform wrapper
+- `T.transform(accepted_type, transformer, opts?)` - Transform wrapper (validates input against accepted_type, then applies transformer)
 - `T.ref(id, opts?)` - Schema reference placeholder (must be resolved via the registry before validation)
 - `T.mixin(schema, opts?)` - Mixin definition
 
@@ -403,7 +415,7 @@ Quick reference:
 | `ts.shape({...})` | `T.table({...})` |
 | `field:is_optional()` | `field:optional()` or `{ schema = ..., optional = true }` |
 | `ts.string + ts.number` | `T.one_of({ T.string(), T.number() })` |
-| `ts.string / fn` | `T.string():transform_with(fn)` or `T.transform(T.number(), fn)` |
+| `ts.string / fn` | `T.string():transform_with(fn)` or `T.transform(T.string(), fn)` |
 | `field:describe("...")` | `field:doc({ summary = "..." })` |
 
 ## Files
index 1a07b25944c9d33d3d2585dd8fd1704eaf17e1c0..528f08f1be4183def2b0a9c1f9aaeb3842016ac5 100644 (file)
@@ -610,21 +610,34 @@ end
 -- Transform wrapper
 
 local function check_transform(node, value, ctx)
-  if ctx.mode == "transform" then
-    -- Apply transformation (protect against errors in user-provided function)
-    local ok_transform, new_value = pcall(node.fn, value)
-    if not ok_transform then
-      return false, make_error("transform_error", ctx.path, {
-        error = tostring(new_value)
-      })
-    end
+  -- First, validate the input value against the accepted type
+  local ok_input, err = node.inner:_check(value, make_context("check", ctx.path))
+  if not ok_input then
+    return false, err
+  end
 
-    -- Validate transformed value against inner schema
-    return node.inner:_check(new_value, ctx)
-  else
-    -- In check mode, validate original value against inner schema
-    return node.inner:_check(value, ctx)
+  -- In check mode, we're done - input is valid
+  if ctx.mode ~= "transform" then
+    return true, value
   end
+
+  -- In transform mode, apply the functor (protect against errors)
+  local ok_transform, new_value = pcall(node.fn, value)
+  if not ok_transform then
+    return false, make_error("transform_error", ctx.path, {
+      error = tostring(new_value)
+    })
+  end
+
+  -- Check if transformation returned nil (transformation failed)
+  if new_value == nil then
+    return false, make_error("transform_error", ctx.path, {
+      error = "transformation function returned nil"
+    })
+  end
+
+  -- Accept the transformed value without type checking the output
+  return true, new_value
 end
 
 function T.transform(schema, fn, opts)
index 265b5fce2c2570798f19b3c8be332ea303337c26..0b674f9c1e0e385a38d503774feea0e70a98ecf5 100644 (file)
@@ -100,21 +100,11 @@ end
 local bucket_schema = T.table({
   burst = T.one_of({
     T.number(),
-    T.transform(T.number(), function(val)
-      if type(val) == "string" then
-        return lua_util.dehumanize_number(val)
-      end
-      return val
-    end)
+    T.transform(T.string(), lua_util.dehumanize_number)
   }):doc({ summary = "Burst size (number of messages)" }),
   rate = T.one_of({
     T.number(),
-    T.transform(T.number(), function(val)
-      if type(val) == "string" then
-        return str_to_rate(val)
-      end
-      return val
-    end)
+    T.transform(T.string(), str_to_rate)
   }):doc({ summary = "Rate limit (messages per time unit)" }),
   skip_recipients = T.boolean():optional():doc({ summary = "Skip per-recipient limits" }),
   symbol = T.string():optional():doc({ summary = "Custom symbol name" }),
index b7cc8fde7943a5db89ce96bd1dc2c892d6d749f0..f0eb48503c9996b78cfee459476854a8501e2a02 100644 (file)
@@ -59,51 +59,33 @@ local default_options = {
 
 local return_codes_schema = T.table({}, {
   open = true,
-  key = T.transform(T.string(), function(val)
-    if type(val) == "string" then
-      return string.upper(val)
-    end
-    return val
-  end),
+  key = T.transform(T.string(), string.upper),
   extra = T.one_of({
     T.array(T.string()),
-    -- Transform string to array, inner schema validates the result
-    T.transform(T.array(T.string()), function(val)
-      if type(val) == "string" then
-        return { val }
-      end
-      return val
+    -- Transform string to array
+    T.transform(T.string(), function(val)
+      return { val }
     end)
   })
 }):doc({ summary = "Map of symbol names to IP patterns" })
 
 local return_bits_schema = T.table({}, {
   open = true,
-  key = T.transform(T.string(), function(val)
-    if type(val) == "string" then
-      return string.upper(val)
-    end
-    return val
-  end),
+  key = T.transform(T.string(), string.upper),
   extra = T.one_of({
     T.array(T.one_of({
       T.number(),
-      T.transform(T.number(), function(val)
-        if type(val) == "string" then
-          return tonumber(val)
-        end
-        return val
-      end)
+      T.transform(T.string(), tonumber)
     })),
-    -- Transform string or number to array, inner schema validates the result
-    T.transform(T.array(T.number()), function(val)
-      if type(val) == "string" then
+    -- Transform string or number to array
+    T.one_of({
+      T.transform(T.string(), function(val)
         return { tonumber(val) }
-      elseif type(val) == "number" then
+      end),
+      T.transform(T.number(), function(val)
         return { val }
-      end
-      return val
-    end)
+      end)
+    })
   })
 }):doc({ summary = "Map of symbol names to bit numbers" })
 
index e564078f91c0fdb38a95c7f9139c9a42631b8f80..0e2b3a1eb5884c2652057d15d0a963d60f05e6c3 100644 (file)
@@ -265,35 +265,46 @@ context("Lua shape validation", function()
   -- Transform tests
   context("Transform support", function()
     test("Transform string to number", function()
-      local schema = T.transform(T.number(), function(val)
-        if type(val) == "string" then
-          return tonumber(val)
-        end
-        return val
-      end)
+      -- New semantics: first arg is accepted type, second is transformer
+      local schema = T.transform(T.string(), tonumber)
 
+      -- Transform mode: converts string to number
       local val, err = schema:transform("42")
       assert_nil(err)
       assert_equal(val, 42)
-    end)
 
-    test("Transform with validation", function()
-      local schema = T.transform(T.integer({ min = 0 }), function(val)
-        if type(val) == "string" then
-          return tonumber(val)
-        end
-        return val
-      end)
+      -- Invalid string returns nil, which is caught as error
+      val, err = schema:transform("not a number")
+      assert_nil(val)
+      assert_not_nil(err)
+      assert_equal(err.kind, "transform_error")
+    end)
 
-      -- Valid transform
-      local val, err = schema:transform("10")
-      assert_nil(err)
-      assert_equal(val, 10)
+    test("Transform validates input type first", function()
+      local schema = T.transform(T.string(), tonumber)
 
-      -- Transform result fails validation
-      val, err = schema:transform("-5")
+      -- Number input fails because accepted type is string
+      local val, err = schema:transform(42)
       assert_nil(val)
       assert_not_nil(err)
+      assert_equal(err.kind, "type_mismatch")
+    end)
+
+    test("Transform accepts number or string using one_of", function()
+      local schema = T.one_of({
+        T.number(),
+        T.transform(T.string(), tonumber)
+      })
+
+      -- Number passes through
+      local val, err = schema:transform(42)
+      assert_nil(err)
+      assert_equal(val, 42)
+
+      -- String is converted
+      val, err = schema:transform("123")
+      assert_nil(err)
+      assert_equal(val, 123)
     end)
 
     test("Transform only in transform mode", function()
@@ -301,7 +312,7 @@ context("Lua shape validation", function()
         return val * 2
       end)
 
-      -- Check mode: no transform
+      -- Check mode: no transform, just validates input is number
       local ok, val = schema:check(5)
       assert_true(ok)
       assert_equal(val, 5)
@@ -321,6 +332,18 @@ context("Lua shape validation", function()
       assert_nil(err)
       assert_equal(val, "HELLO")
     end)
+
+    test("Transform result is not type-checked", function()
+      -- Transform string to table - result type is not validated
+      local schema = T.transform(T.string(), function(val)
+        return { value = val }
+      end)
+
+      local val, err = schema:transform("test")
+      assert_nil(err)
+      assert_equal(type(val), "table")
+      assert_equal(val.value, "test")
+    end)
   end)
 
   -- one_of tests