From: Vsevolod Stakhov Date: Mon, 17 Nov 2025 14:43:36 +0000 (+0000) Subject: [Feature] Add lua_shape validation library as tableshape replacement X-Git-Tag: 3.14.1~11^2~46 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5411b251f610c390b358e892a77ea381b6735f5b;p=thirdparty%2Frspamd.git [Feature] Add lua_shape validation library as tableshape replacement Implement comprehensive schema validation library with improved features: * Better one_of error reporting with intersection analysis * Schema-driven documentation generation with mixin tracking * Rich type constraints (ranges, lengths, Lua patterns) * First-class mixins with origin tracking for composition * JSON Schema Draft 7 export for UCL validation * Transform support with immutable semantics * Pure Lua implementation with optional lpeg support The library provides 4 core modules: - core.lua: All type constructors, validation, and utilities - registry.lua: Schema registration and reference resolution - jsonschema.lua: JSON Schema export - docs.lua: Documentation IR generation Includes comprehensive test suite (44 tests, 119 assertions). Designed to gradually replace tableshape across 22 modules. --- diff --git a/lualib/lua_shape/MIGRATION.md b/lualib/lua_shape/MIGRATION.md new file mode 100644 index 0000000000..a97b1bfe38 --- /dev/null +++ b/lualib/lua_shape/MIGRATION.md @@ -0,0 +1,394 @@ +# Migration Guide: tableshape to lua_shape + +This guide helps migrate from tableshape to the new lua_shape library. + +## Basic Concepts + +### Module Import + +**tableshape:** +```lua +local ts = require("tableshape").types +``` + +**lua_shape:** +```lua +local T = require "lua_shape.core" + +-- Only need Registry if using schema registration/refs: +local Registry = require "lua_shape.registry" -- optional +``` + +Note: All utility functions (like `format_error`) are included in the core module, so you only need one require statement for most use cases. + +## Type Constructors + +### Scalar Types + +| tableshape | rspamd_schema | Notes | +|------------|---------------|-------| +| `ts.string` | `T.string()` | | +| `ts.number` | `T.number()` | | +| `ts.integer` | `T.integer()` | | +| `ts.boolean` | `T.boolean()` | | +| `ts.literal("foo")` | `T.literal("foo")` | | +| `ts.one_of{"a","b"}` | `T.enum({"a","b"})` | For simple value enums | + +### Constraints + +**tableshape:** +```lua +ts.string:length(3, 10) -- min 3, max 10 +ts.number:range(0, 100) +``` + +**rspamd_schema:** +```lua +T.string({ min_len = 3, max_len = 10 }) +T.number({ min = 0, max = 100 }) +T.integer({ min = 0, max = 100 }) +``` + +### Arrays + +**tableshape:** +```lua +ts.array_of(ts.string) +``` + +**rspamd_schema:** +```lua +T.array(T.string()) +``` + +### Tables (Shapes) + +**tableshape:** +```lua +ts.shape({ + name = ts.string, + age = ts.number, + email = ts.string:is_optional() +}) +``` + +**rspamd_schema:** +```lua +T.table({ + name = T.string(), + age = T.number(), + email = { schema = T.string(), optional = true } +}) + +-- Or using :optional() method: +T.table({ + name = T.string(), + age = T.number(), + email = T.string():optional() +}) +``` + +### Optional Fields + +**tableshape:** +```lua +field = ts.string:is_optional() +``` + +**rspamd_schema:** +```lua +-- Method 1: inline +field = { schema = T.string(), optional = true } + +-- Method 2: using :optional() +field = T.string():optional() +``` + +### Default Values + +**tableshape:** +```lua +field = ts.string:is_optional() -- then handle defaults manually +``` + +**rspamd_schema:** +```lua +field = { schema = T.string(), optional = true, default = "default_value" } + +-- Or using :with_default() +field = T.string():with_default("default_value") +``` + +## Operators + +### Union (one_of) + +**tableshape:** +```lua +ts.string + ts.number +``` + +**rspamd_schema:** +```lua +T.one_of({ T.string(), T.number() }) +``` + +### Transform + +**tableshape:** +```lua +ts.string / tonumber +ts.string / function(v) return v:upper() end +``` + +**rspamd_schema:** +```lua +T.string():transform_with(tonumber) +T.string():transform_with(function(v) return v:upper() end) + +-- Or using T.transform: +T.transform(T.string(), tonumber) +``` + +### Chained Transforms + +**tableshape:** +```lua +(ts.string / tonumber) * ts.number +``` + +**rspamd_schema:** +```lua +T.string():transform_with(tonumber):transform_with(function(v) + return T.number():check(v) and v or error("not a number") +end) + +-- Better: validate after transform +T.transform(T.number(), function(v) + return tonumber(v) or 0 +end) +``` + +## one_of with Multiple Shapes + +**tableshape:** +```lua +ts.one_of { + ts.shape({ type = ts.literal("file"), path = ts.string }), + ts.shape({ type = ts.literal("redis"), host = ts.string }), +} +``` + +**rspamd_schema:** +```lua +T.one_of({ + { + name = "file_config", + schema = T.table({ + type = T.literal("file"), + path = T.string() + }) + }, + { + name = "redis_config", + schema = T.table({ + type = T.literal("redis"), + host = T.string() + }) + } +}) +``` + +## Documentation + +**tableshape:** +```lua +ts.string:describe("User name") +``` + +**rspamd_schema:** +```lua +T.string():doc({ + summary = "User name", + description = "Full description here", + examples = {"alice", "bob"} +}) +``` + +## Complex Example: Redis Options + +**tableshape:** +```lua +local ts = require("tableshape").types + +local db_schema = (ts.number / tostring + ts.string):is_optional() + +local common_schema = { + timeout = (ts.number + ts.string / parse_time):is_optional(), + db = db_schema, + password = ts.string:is_optional(), +} + +local servers_schema = table_merge({ + servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local redis_schema = ts.one_of { + ts.shape(common_schema), + ts.shape(servers_schema), +} +``` + +**lua_shape:** +```lua +local T = require "lua_shape.core" + +-- Accept string or number for db +local db_schema = T.one_of({ + T.number(), + T.string() +}):optional():doc({ summary = "Database number" }) + +-- Accept number or time string for timeout +local timeout_schema = T.transform(T.number({ min = 0 }), function(val) + if type(val) == "number" then return val end + if type(val) == "string" then return parse_time(val) end + error("Expected number or time string") +end):optional():doc({ summary = "Connection timeout" }) + +-- Common fields +local common_fields = { + timeout = timeout_schema, + db = db_schema, + password = T.string():optional() +} + +-- Servers field accepts string or array +local servers_field = T.one_of({ + T.string(), + T.array(T.string()) +}) + +-- Define variants +local redis_schema = T.one_of({ + { + name = "no_servers", + schema = T.table(common_fields) + }, + { + name = "with_servers", + schema = T.table(table_merge({ + servers = servers_field + }, common_fields)) + } +}) +``` + +Key improvements: +- Better error messages with intersection ("all alternatives require: db, timeout") +- Named variants for clarity +- Transform semantics explicit +- Documentation embedded in schema + +## Validation + +### Check Mode + +**tableshape:** +```lua +local ok, err = schema:transform(config) +if not ok then + logger.errx("Invalid config: %s", err) +end +``` + +**lua_shape:** +```lua +local T = require "lua_shape.core" + +local ok, val_or_err = schema:check(config) +if not ok then + logger.errx("Invalid config:\n%s", T.format_error(val_or_err)) +end +``` + +### Transform Mode + +**tableshape:** +```lua +local ok, result = schema:transform(config) +``` + +**rspamd_schema:** +```lua +local ok, result = schema:transform(config) +-- result will have defaults applied and transforms executed +``` + +## Key Differences + +1. **Explicit vs Operator-based:** + - tableshape uses operators (`+`, `/`, `*`) for composition + - rspamd_schema uses explicit methods and constructors + +2. **Error Reporting:** + - tableshape returns string errors + - rspamd_schema returns structured error trees with better messages + +3. **one_of Intersection:** + - rspamd_schema computes intersection of table variants for better error messages + - Shows "all alternatives require field X" instead of listing every variant error + +4. **Mixins:** + - rspamd_schema has first-class mixin support with origin tracking + - Can show "field from mixin redis" in docs and errors + +5. **Export:** + - rspamd_schema can export to JSON Schema + - Can generate documentation IR from schemas + +## Migration Strategy + +1. Start with standalone schemas (not referenced by other code yet) +2. Test validation and error messages +3. Gradually replace tableshape imports +4. Update schema definitions +5. Update validation call sites +6. Remove tableshape dependency when complete + +## Helper Patterns + +### Common Transform Pattern + +For fields that accept multiple types with normalization: + +**tableshape:** +```lua +(ts.string + ts.number) / normalize_fn +``` + +**rspamd_schema:** +```lua +T.one_of({ + T.string():transform_with(normalize_fn), + T.number():transform_with(normalize_fn) +}) + +-- Or apply transform after one_of: +T.one_of({ T.string(), T.number() }):transform_with(normalize_fn) +``` + +### Optional with Transform + +**tableshape:** +```lua +(ts.string / tonumber):is_optional() +``` + +**rspamd_schema:** +```lua +T.string():transform_with(tonumber):optional() + +-- Or with default: +T.string():transform_with(tonumber):with_default(0) +``` diff --git a/lualib/lua_shape/README.md b/lualib/lua_shape/README.md new file mode 100644 index 0000000000..a5944e8fd2 --- /dev/null +++ b/lualib/lua_shape/README.md @@ -0,0 +1,401 @@ +# lua_shape + +A comprehensive schema validation and transformation library for Rspamd, designed to replace tableshape with improved error reporting, documentation generation, and export capabilities. + +## Features + +1. **Better Error Reporting**: Structured error trees with intersection analysis for `one_of` types +2. **Documentation Generation**: Extract structured documentation from schemas +3. **Type Constraints**: Numeric ranges, string lengths, patterns, and more +4. **First-class Mixins**: Field composition with origin tracking +5. **JSON Schema Export**: Export schemas for UCL validation +6. **Transform Support**: Immutable transformations with validation +7. **Pure Lua**: No dependencies on external modules (except optional lpeg for patterns) + +## Quick Start + +```lua +local T = require "lua_shape.core" + +-- Define a schema +local config_schema = T.table({ + host = T.string({ min_len = 1 }), + port = T.integer({ min = 1, max = 65535 }):with_default(8080), + timeout = T.number({ min = 0 }):optional(), + ssl = T.boolean():with_default(false) +}) + +-- Validate configuration +local ok, result = config_schema:check({ + host = "localhost", + port = 3000 +}) + +if not ok then + print("Validation error:") + print(T.format_error(result)) +end + +-- Transform with defaults applied +local ok, config = config_schema:transform({ + host = "example.com" +}) +-- config.port == 8080 (default applied) +-- config.ssl == false (default applied) +``` + +## Core Types + +### Scalars + +- `T.string(opts)` - String with optional constraints + - `min_len`, `max_len` - Length constraints + - `pattern` - Lua pattern for validation (e.g., `"^%d+$"` for digits only) + - `lpeg` - Optional lpeg pattern for complex parsing +- `T.number(opts)` - Number with optional range constraints (min, max) +- `T.integer(opts)` - Integer (number with integer constraint) +- `T.boolean()` - Boolean value +- `T.enum(values)` - One of a fixed set of values +- `T.literal(value)` - Exact value match + +### Structured Types + +- `T.array(item_schema, opts)` - Array with item validation + - `min_items`, `max_items` - Size constraints +- `T.table(fields, opts)` - Table/object with field schemas + - `open = true` - Allow additional fields not defined in schema + - `open = false` (default) - Reject unknown fields + - `extra = schema` - Schema for validating extra fields + - `mixins` - Array of mixin schemas for composition +- `T.one_of(variants)` - Sum type (match exactly one alternative) + +### Composition + +- `schema:optional()` - Make schema optional +- `schema:with_default(value)` - Add default value +- `schema:doc(doc_table)` - Add documentation +- `schema:transform_with(fn)` - Apply transformation +- `T.transform(schema, fn)` - Transform wrapper +- `T.ref(id)` - Reference to registered schema +- `T.mixin(schema, opts)` - Mixin for table composition + +## Examples + +### Basic Types with Constraints + +```lua +-- String with length constraint +local name_schema = T.string({ min_len = 3, max_len = 50 }) + +-- String with Lua pattern (validates format) +local email_schema = T.string({ pattern = "^[%w%.]+@[%w%.]+$" }) +local ipv4_schema = T.string({ pattern = "^%d+%.%d+%.%d+%.%d+$" }) + +-- Integer with range +local age_schema = T.integer({ min = 0, max = 150 }) + +-- Enum +local level_schema = T.enum({"debug", "info", "warning", "error"}) +``` + +### Arrays and Tables + +```lua +-- Array of strings +local tags_schema = T.array(T.string()) + +-- Table with required and optional fields +local user_schema = T.table({ + name = T.string(), + email = T.string(), + age = T.integer():optional(), + role = T.enum({"admin", "user"}):with_default("user") +}) + +-- Closed table (default): rejects unknown fields +local strict_config = T.table({ + host = T.string(), + port = T.integer() +}, { open = false }) + +-- Open table: allows additional fields not in schema +local flexible_config = T.table({ + host = T.string(), + port = T.integer() +}, { open = true }) +-- Accepts: { host = "localhost", port = 8080, custom_field = "value" } +``` + +### one_of with Intersection + +```lua +-- Multiple config variants +local config_schema = T.one_of({ + { + name = "file_config", + schema = T.table({ + type = T.literal("file"), + path = T.string() + }) + }, + { + name = "redis_config", + schema = T.table({ + type = T.literal("redis"), + host = T.string(), + port = T.integer():with_default(6379) + }) + } +}) + +-- Error messages show intersection: +-- "all alternatives require: type (string)" +``` + +### Transforms + +```lua +-- Parse time interval string to number +local timeout_schema = T.transform(T.number({ min = 0 }), function(val) + if type(val) == "number" then + return val + elseif type(val) == "string" then + return parse_time_interval(val) -- "5s" -> 5.0 + else + error("Expected number or time interval string") + end +end) +``` + +### Schema Registry + +```lua +local Registry = require "lua_shape.registry" +local reg = Registry.global() + +-- Define reusable schemas +local redis_schema = reg:define("redis.options", T.table({ + servers = T.array(T.string()), + db = T.integer({ min = 0, max = 15 }):with_default(0) +})) + +-- Reference in other schemas +local app_schema = T.table({ + cache = T.ref("redis.options") +}) + +-- Resolve references +local resolved = reg:resolve_schema(app_schema) +``` + +### Mixins with Origin Tracking + +```lua +-- Base mixin +local redis_mixin = T.table({ + redis_host = T.string(), + redis_port = T.integer():with_default(6379) +}) + +-- Use mixin in another schema +local plugin_schema = T.table({ + enabled = T.boolean(), + plugin_option = T.string() +}, { + mixins = { + T.mixin(redis_mixin, { as = "redis" }) + } +}) + +-- Documentation will show: +-- Direct fields: enabled, plugin_option +-- Mixin "redis": redis_host, redis_port +``` + +### JSON Schema Export + +```lua +local jsonschema = require "lua_shape.jsonschema" + +-- Export single schema +local json = jsonschema.from_schema(config_schema, { + id = "https://rspamd.com/schema/config", + title = "Application Config" +}) + +-- Export all schemas from registry +local all_schemas = jsonschema.export_registry(Registry.global()) +``` + +### Documentation Generation + +```lua +local docs = require "lua_shape.docs" + +-- Generate documentation IR +local doc_tree = docs.for_schema(config_schema) + +-- Render as markdown +local markdown_lines = docs.render_markdown(doc_tree.schema_doc) +for _, line in ipairs(markdown_lines) do + print(line) +end +``` + +## Error Reporting + +### Structured Errors + +Errors are represented as trees: + +```lua +{ + kind = "table_invalid", + path = "config", + details = { + errors = { + port = { + kind = "constraint_violation", + path = "config.port", + details = { constraint = "max", expected = 65535, got = 99999 } + } + } + } +} +``` + +### Human-Readable Formatting + +```lua +local T = require "lua_shape.core" +print(T.format_error(error_tree)) +``` + +Output: +``` +table validation failed at config: + constraint violation at config.port: max (expected: 65535, got: 99999) +``` + +### one_of Intersection Errors + +When all variants of a one_of fail, the error shows common requirements: + +``` +value does not match any alternative at : + all alternatives require: + - name: string + - type: string + some alternatives also expect: + - path: string (in file_config variant) + - host: string (in redis_config variant) + tried alternatives: + - file_config: ... + - redis_config: ... +``` + +## API Reference + +### Core Module (`rspamd_schema.core`) + +#### Type Constructors + +- `T.string(opts?)` - String type + - opts: `min_len`, `max_len`, `pattern`, `lpeg`, `doc` +- `T.number(opts?)` - Number type + - opts: `min`, `max`, `doc` +- `T.integer(opts?)` - Integer type (number with integer=true) + - opts: `min`, `max`, `doc` +- `T.boolean(opts?)` - Boolean type +- `T.enum(values, opts?)` - Enum type +- `T.literal(value, opts?)` - Literal value type +- `T.array(item_schema, opts?)` - Array type + - opts: `min_items`, `max_items`, `doc` +- `T.table(fields, opts?)` - Table type + - opts: `open`, `extra`, `mixins`, `doc` +- `T.one_of(variants, opts?)` - Sum type +- `T.optional(schema, opts?)` - Optional wrapper +- `T.default(schema, value)` - Default value wrapper +- `T.transform(schema, fn, opts?)` - Transform wrapper +- `T.ref(id, opts?)` - Schema reference +- `T.mixin(schema, opts?)` - Mixin definition + +#### Schema Methods + +- `schema:check(value, ctx?)` - Validate value +- `schema:transform(value, ctx?)` - Transform and validate +- `schema:optional(opts?)` - Make optional +- `schema:with_default(value)` - Add default +- `schema:doc(doc_table)` - Add documentation +- `schema:transform_with(fn, opts?)` - Add transformation + +### Registry Module (`rspamd_schema.registry`) + +- `Registry.global()` - Get/create global registry +- `registry:define(id, schema)` - Register schema with ID +- `registry:get(id)` - Get schema by ID +- `registry:resolve_schema(schema)` - Resolve references and mixins +- `registry:list()` - List all schema IDs +- `registry:export_all()` - Export all schemas + +### Core Utilities + +The core module also includes utility functions: + +- `T.format_error(err)` - Format error tree as human-readable string +- `T.deep_clone(value)` - Deep clone value for immutable transformations + +### JSON Schema Module (`rspamd_schema.jsonschema`) + +- `jsonschema.from_schema(schema, opts?)` - Convert to JSON Schema +- `jsonschema.export_registry(registry, opts?)` - Export all schemas + +### Docs Module (`rspamd_schema.docs`) + +- `docs.for_schema(schema, opts?)` - Generate documentation IR +- `docs.for_registry(registry, opts?)` - Generate docs for all schemas +- `docs.render_markdown(doc_tree, indent?)` - Render as markdown + +## Migration from tableshape + +See [MIGRATION.md](MIGRATION.md) for detailed migration guide. + +Quick reference: + +| tableshape | rspamd_schema | +|------------|---------------| +| `ts.string` | `T.string()` | +| `ts.number` | `T.number()` | +| `ts.array_of(ts.string)` | `T.array(T.string())` | +| `ts.shape({...})` | `T.table({...})` | +| `field:is_optional()` | `field:optional()` or `{ schema = ..., optional = true }` | +| `ts.string + ts.number` | `T.one_of({ T.string(), T.number() })` | +| `ts.string / fn` | `T.string():transform_with(fn)` or `T.transform(T.number(), fn)` | +| `field:describe("...")` | `field:doc({ summary = "..." })` | + +## Files + +- `core.lua` - Core type system, validation, and utilities +- `registry.lua` - Schema registration and reference resolution +- `jsonschema.lua` - JSON Schema export +- `docs.lua` - Documentation generation +- `MIGRATION.md` - Migration guide from tableshape +- `README.md` - This file + +## Testing + +Test files are in the repository root: +- `test_rspamd_schema.lua` - Basic functionality tests +- `test_one_of_intersection.lua` - Intersection logic tests +- `test_export_and_docs.lua` - Export and documentation tests + +Run tests: +```bash +lua test_rspamd_schema.lua +lua test_one_of_intersection.lua +lua test_export_and_docs.lua +``` + +## License + +Apache License 2.0 - Same as Rspamd diff --git a/lualib/lua_shape/core.lua b/lualib/lua_shape/core.lua new file mode 100644 index 0000000000..c8b353a474 --- /dev/null +++ b/lualib/lua_shape/core.lua @@ -0,0 +1,915 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Lua shape validation library - Core module +-- Provides type constructors and validation logic + +local T = {} + +-- Simple utility functions +local function shallowcopy(t) + local result = {} + for k, v in pairs(t) do + result[k] = v + end + return result +end + +-- Check if table is array-like +local function is_array(t) + if type(t) ~= "table" then + return false + end + local count = 0 + for k, _ in pairs(t) do + count = count + 1 + if type(k) ~= "number" or k < 1 or k ~= math.floor(k) or k > count then + return false + end + end + return count == #t +end + +-- Error tree node constructor +local function make_error(kind, path, details) + return { + kind = kind, + path = table.concat(path or {}, "."), + details = details or {} + } +end + +-- Context for validation +local function make_context(mode, path) + return { + mode = mode or "check", + path = path or {} + } +end + +-- Clone path for nested validation +local function clone_path(path) + local result = {} + for i, v in ipairs(path) do + result[i] = v + end + return result +end + +-- Forward declare schema_mt +local schema_mt + +-- Schema node methods +local schema_methods = { + -- Check if value matches schema + check = function(self, value, ctx) + ctx = ctx or make_context("check") + return self._check(self, value, ctx) + end, + + -- Transform value according to schema + transform = function(self, value, ctx) + ctx = ctx or make_context("transform") + return self._check(self, value, ctx) + end, + + -- Make schema optional + optional = function(self, opts) + opts = opts or {} + return T.optional(self, opts) + end, + + -- Add default value + with_default = function(self, value) + return T.default(self, value) + end, + + -- Add documentation + doc = function(self, doc_table) + local new_opts = shallowcopy(self.opts or {}) + new_opts.doc = doc_table + local result = shallowcopy(self) + result.opts = new_opts + return setmetatable(result, schema_mt) + end, + + -- Transform with function + transform_with = function(self, fn, opts) + return T.transform(self, fn, opts) + end +} + +-- Schema node metatable +schema_mt = { + __index = schema_methods +} + +-- Create a new schema node +local function make_node(tag, data) + local node = shallowcopy(data) + node.tag = tag + node.opts = node.opts or {} + return setmetatable(node, schema_mt) +end + +-- Scalar type validators + +local function check_string(node, value, ctx) + if type(value) ~= "string" then + return false, make_error("type_mismatch", ctx.path, { + expected = "string", + got = type(value) + }) + end + + local opts = node.opts or {} + + -- Length constraints + if opts.min_len and #value < opts.min_len then + return false, make_error("constraint_violation", ctx.path, { + constraint = "min_len", + expected = opts.min_len, + got = #value + }) + end + + if opts.max_len and #value > opts.max_len then + return false, make_error("constraint_violation", ctx.path, { + constraint = "max_len", + expected = opts.max_len, + got = #value + }) + end + + -- Pattern matching + if opts.pattern then + if not string.match(value, opts.pattern) then + return false, make_error("constraint_violation", ctx.path, { + constraint = "pattern", + pattern = opts.pattern + }) + end + end + + -- lpeg pattern (optional) + if opts.lpeg then + local lpeg = require "lpeg" + if not lpeg.match(opts.lpeg, value) then + return false, make_error("constraint_violation", ctx.path, { + constraint = "lpeg_pattern" + }) + end + end + + return true, value +end + +local function check_number(node, value, ctx) + local num = tonumber(value) + if not num then + return false, make_error("type_mismatch", ctx.path, { + expected = "number", + got = type(value) + }) + end + + local opts = node.opts or {} + + -- Integer constraint + if opts.integer and num ~= math.floor(num) then + return false, make_error("constraint_violation", ctx.path, { + constraint = "integer", + got = num + }) + end + + -- Range constraints + if opts.min and num < opts.min then + return false, make_error("constraint_violation", ctx.path, { + constraint = "min", + expected = opts.min, + got = num + }) + end + + if opts.max and num > opts.max then + return false, make_error("constraint_violation", ctx.path, { + constraint = "max", + expected = opts.max, + got = num + }) + end + + return true, num +end + +local function check_boolean(node, value, ctx) + if type(value) ~= "boolean" then + return false, make_error("type_mismatch", ctx.path, { + expected = "boolean", + got = type(value) + }) + end + + return true, value +end + +local function check_enum(node, value, ctx) + local opts = node.opts or {} + local values = opts.enum or {} + + for _, v in ipairs(values) do + if v == value then + return true, value + end + end + + return false, make_error("enum_mismatch", ctx.path, { + expected = values, + got = value + }) +end + +local function check_literal(node, value, ctx) + local opts = node.opts or {} + local expected = opts.literal + + if value ~= expected then + return false, make_error("literal_mismatch", ctx.path, { + expected = expected, + got = value + }) + end + + return true, value +end + +-- Scalar type constructors + +function T.string(opts) + return make_node("scalar", { + kind = "string", + opts = opts or {}, + _check = check_string + }) +end + +function T.number(opts) + return make_node("scalar", { + kind = "number", + opts = opts or {}, + _check = check_number + }) +end + +function T.integer(opts) + opts = opts or {} + opts.integer = true + return make_node("scalar", { + kind = "integer", + opts = opts, + _check = check_number + }) +end + +function T.boolean(opts) + return make_node("scalar", { + kind = "boolean", + opts = opts or {}, + _check = check_boolean + }) +end + +function T.enum(values, opts) + opts = opts or {} + opts.enum = values + return make_node("scalar", { + kind = "enum", + opts = opts, + _check = check_enum + }) +end + +function T.literal(value, opts) + opts = opts or {} + opts.literal = value + return make_node("scalar", { + kind = "literal", + opts = opts, + _check = check_literal + }) +end + +-- Array type + +local function check_array(node, value, ctx) + if type(value) ~= "table" then + return false, make_error("type_mismatch", ctx.path, { + expected = "array", + got = type(value) + }) + end + + -- Check if it's array-like (no string keys, sequential numeric keys) + if not is_array(value) then + return false, make_error("type_mismatch", ctx.path, { + expected = "array", + got = "table with non-array keys" + }) + end + + local opts = node.opts or {} + local item_schema = node.item_schema + local len = #value + + -- Length constraints + if opts.min_items and len < opts.min_items then + return false, make_error("constraint_violation", ctx.path, { + constraint = "min_items", + expected = opts.min_items, + got = len + }) + end + + if opts.max_items and len > opts.max_items then + return false, make_error("constraint_violation", ctx.path, { + constraint = "max_items", + expected = opts.max_items, + got = len + }) + end + + -- Check each item + local result = {} + local errors = {} + local has_errors = false + + for i, item in ipairs(value) do + local item_ctx = make_context(ctx.mode, clone_path(ctx.path)) + table.insert(item_ctx.path, "[" .. i .. "]") + + local ok, val_or_err = item_schema:_check(item, item_ctx) + if ok then + result[i] = val_or_err + else + has_errors = true + errors[i] = val_or_err + end + end + + if has_errors then + return false, make_error("array_items_invalid", ctx.path, { + errors = errors + }) + end + + return true, result +end + +function T.array(item_schema, opts) + return make_node("array", { + item_schema = item_schema, + opts = opts or {}, + _check = check_array + }) +end + +-- Table type + +local function check_table(node, value, ctx) + if type(value) ~= "table" then + return false, make_error("type_mismatch", ctx.path, { + expected = "table", + got = type(value) + }) + end + + local opts = node.opts or {} + local fields = node.fields or {} + local result = {} + local errors = {} + local has_errors = false + + -- Check declared fields + for field_name, field_spec in pairs(fields) do + local field_value = value[field_name] + local field_ctx = make_context(ctx.mode, clone_path(ctx.path)) + table.insert(field_ctx.path, field_name) + + if field_value == nil then + -- Missing field + if field_spec.optional then + -- Apply default in transform mode + if ctx.mode == "transform" and field_spec.default ~= nil then + result[field_name] = field_spec.default + end + else + has_errors = true + errors[field_name] = make_error("required_field_missing", field_ctx.path, { + field = field_name + }) + end + else + -- Validate field + local ok, val_or_err = field_spec.schema:_check(field_value, field_ctx) + if ok then + result[field_name] = val_or_err + else + has_errors = true + errors[field_name] = val_or_err + end + end + end + + -- Check for unknown fields + if not opts.open then + for key, val in pairs(value) do + if not fields[key] then + if opts.extra then + -- Validate extra field + local extra_ctx = make_context(ctx.mode, clone_path(ctx.path)) + table.insert(extra_ctx.path, key) + local ok, val_or_err = opts.extra:_check(val, extra_ctx) + if ok then + result[key] = val_or_err + else + has_errors = true + errors[key] = val_or_err + end + else + has_errors = true + local extra_ctx = make_context(ctx.mode, clone_path(ctx.path)) + table.insert(extra_ctx.path, key) + errors[key] = make_error("unknown_field", extra_ctx.path, { + field = key + }) + end + end + end + else + -- Open table: copy unknown fields as-is + for key, val in pairs(value) do + if not fields[key] then + result[key] = val + end + end + end + + if has_errors then + return false, make_error("table_invalid", ctx.path, { + errors = errors + }) + end + + return true, result +end + +function T.table(fields, opts) + opts = opts or {} + + -- Normalize fields: convert {key = schema} to {key = {schema = schema}} + local normalized_fields = {} + for key, val in pairs(fields) do + if val.schema then + -- Already normalized + normalized_fields[key] = val + else + -- Assume val is a schema + -- Check if schema is an optional wrapper + local is_optional = val.tag == "optional" + local inner_schema = is_optional and val.inner or val + local default_value = is_optional and val.default or nil + + normalized_fields[key] = { + schema = inner_schema, + optional = is_optional, + default = default_value + } + end + end + + return make_node("table", { + fields = normalized_fields, + opts = opts, + _check = check_table + }) +end + +-- Optional wrapper + +local function check_optional(node, value, ctx) + if value == nil then + if ctx.mode == "transform" and node.default ~= nil then + return true, node.default + end + return true, nil + end + + return node.inner:_check(value, ctx) +end + +function T.optional(schema, opts) + opts = opts or {} + return make_node("optional", { + inner = schema, + default = opts.default, + opts = opts, + _check = check_optional + }) +end + +function T.default(schema, value) + return T.optional(schema, { default = value }) +end + +-- Transform wrapper + +local function check_transform(node, value, ctx) + if ctx.mode == "transform" then + -- Apply transformation + local new_value = node.fn(value, ctx) + -- Validate transformed value + return node.inner:_check(new_value, ctx) + else + -- In check mode, just validate original value + return node.inner:_check(value, ctx) + end +end + +function T.transform(schema, fn, opts) + return make_node("transform", { + inner = schema, + fn = fn, + opts = opts or {}, + _check = check_transform + }) +end + +-- one_of type with intersection logic + +-- Extract constraints from a schema for intersection computation +local function extract_constraints(schema) + if not schema or not schema.tag then + return nil + end + + local tag = schema.tag + + if tag == "scalar" then + return { + type_name = schema.kind, + constraints = schema.opts + } + elseif tag == "table" then + local fields = {} + for field_name, field_spec in pairs(schema.fields or {}) do + fields[field_name] = { + required = not field_spec.optional, + type_name = field_spec.schema.tag == "scalar" and field_spec.schema.kind or field_spec.schema.tag, + constraints = field_spec.schema.opts + } + end + return { + type_name = "table", + fields = fields + } + elseif tag == "array" then + return { + type_name = "array", + item_constraints = extract_constraints(schema.item_schema) + } + end + + return { type_name = tag } +end + +-- Compute intersection of table-like variants +local function compute_intersection(variants) + if not variants or #variants == 0 then + return nil + end + + -- Check if all variants are table-like + local all_tables = true + local constraints_list = {} + + for _, variant in ipairs(variants) do + local constraints = extract_constraints(variant.schema) + if not constraints or constraints.type_name ~= "table" then + all_tables = false + break + end + table.insert(constraints_list, constraints) + end + + if not all_tables or #constraints_list == 0 then + return nil + end + + -- Find common required fields + local result = { + required_fields = {}, + optional_fields = {}, + conflicting_fields = {} + } + + -- Collect all field names + local all_fields = {} + for _, c in ipairs(constraints_list) do + for field_name, _ in pairs(c.fields or {}) do + all_fields[field_name] = (all_fields[field_name] or 0) + 1 + end + end + + -- Analyze each field + for field_name, count in pairs(all_fields) do + if count == #constraints_list then + -- Field present in all variants + local field_types = {} + local all_required = true + + for _, c in ipairs(constraints_list) do + local field = c.fields[field_name] + if field then + table.insert(field_types, field.type_name) + if not field.required then + all_required = false + end + end + end + + -- Check if types are compatible + local first_type = field_types[1] + local compatible = true + for _, ftype in ipairs(field_types) do + if ftype ~= first_type then + compatible = false + break + end + end + + if compatible and all_required then + result.required_fields[field_name] = first_type + elseif compatible then + result.optional_fields[field_name] = first_type + else + result.conflicting_fields[field_name] = field_types + end + end + end + + return result +end + +local function check_one_of(node, value, ctx) + local variants = node.variants or {} + local matching = {} + local errors = {} + + for i, variant in ipairs(variants) do + local variant_ctx = make_context(ctx.mode, clone_path(ctx.path)) + local ok, val_or_err = variant.schema:_check(value, variant_ctx) + + if ok then + table.insert(matching, { + index = i, + name = variant.name or ("variant_" .. i), + value = val_or_err + }) + else + errors[i] = { + name = variant.name or ("variant_" .. i), + error = val_or_err + } + end + end + + if #matching == 0 then + -- No variant matched - compute intersection for better error + local intersection = compute_intersection(variants) + return false, make_error("one_of_mismatch", ctx.path, { + variants = errors, + intersection = intersection + }) + elseif #matching == 1 then + -- Exactly one match - success + return true, matching[1].value + else + -- Multiple matches - take first by default + -- Could make this configurable (first wins vs ambiguity error) + return true, matching[1].value + end +end + +function T.one_of(variants, opts) + opts = opts or {} + + -- Normalize variants: allow bare schemas or {name=..., schema=...} + local normalized_variants = {} + for i, variant in ipairs(variants) do + if variant.schema then + normalized_variants[i] = variant + else + normalized_variants[i] = { + name = opts.names and opts.names[i] or ("variant_" .. i), + schema = variant + } + end + end + + return make_node("one_of", { + variants = normalized_variants, + opts = opts, + _check = check_one_of + }) +end + +-- Reference placeholder (will be resolved by registry) + +function T.ref(id, opts) + return make_node("ref", { + ref_id = id, + opts = opts or {}, + _check = function(node, value, ctx) + error("Unresolved reference: " .. id .. ". Use registry to resolve references.") + end + }) +end + +-- Mixin constructor + +function T.mixin(schema, opts) + opts = opts or {} + return { + _is_mixin = true, + schema = schema, + as = opts.as, + doc = opts.doc + } +end + +-- Utility functions + +-- Format error tree for human-readable output +local function format_error_impl(err, indent, lines) + indent = indent or 0 + lines = lines or {} + + local prefix = string.rep(" ", indent) + + if err.kind == "type_mismatch" then + local msg = string.format("%stype mismatch at %s: expected %s, got %s", + prefix, err.path or "(root)", + err.details.expected or "?", + err.details.got or "?") + table.insert(lines, msg) + + elseif err.kind == "constraint_violation" then + local constraint = err.details.constraint or "?" + local msg = string.format("%sconstraint violation at %s: %s", + prefix, err.path or "(root)", constraint) + if err.details.expected then + msg = msg .. string.format(" (expected: %s, got: %s)", + tostring(err.details.expected), + tostring(err.details.got)) + end + table.insert(lines, msg) + + elseif err.kind == "required_field_missing" then + local msg = string.format("%srequired field missing: %s", + prefix, err.path or err.details.field or "?") + table.insert(lines, msg) + + elseif err.kind == "unknown_field" then + local msg = string.format("%sunknown field: %s", + prefix, err.path or err.details.field or "?") + table.insert(lines, msg) + + elseif err.kind == "enum_mismatch" then + local expected_str = table.concat(err.details.expected or {}, ", ") + local msg = string.format("%senum mismatch at %s: expected one of [%s], got %s", + prefix, err.path or "(root)", + expected_str, tostring(err.details.got)) + table.insert(lines, msg) + + elseif err.kind == "literal_mismatch" then + local msg = string.format("%sliteral mismatch at %s: expected %s, got %s", + prefix, err.path or "(root)", + tostring(err.details.expected), + tostring(err.details.got)) + table.insert(lines, msg) + + elseif err.kind == "array_items_invalid" then + local msg = string.format("%sarray items invalid at %s:", prefix, err.path or "(root)") + table.insert(lines, msg) + for _, item_err in pairs(err.details.errors or {}) do + format_error_impl(item_err, indent + 1, lines) + end + + elseif err.kind == "table_invalid" then + local msg = string.format("%stable validation failed at %s:", prefix, err.path or "(root)") + table.insert(lines, msg) + for _, field_err in pairs(err.details.errors or {}) do + format_error_impl(field_err, indent + 1, lines) + end + + elseif err.kind == "one_of_mismatch" then + local msg = string.format("%svalue does not match any alternative at %s:", + prefix, err.path or "(root)") + table.insert(lines, msg) + + -- Add intersection summary + if err.details.intersection then + local inter = err.details.intersection + + -- Show common required fields + local req_fields = {} + for field_name, field_type in pairs(inter.required_fields or {}) do + table.insert(req_fields, string.format("%s: %s", field_name, field_type)) + end + if #req_fields > 0 then + table.insert(lines, prefix .. " all alternatives require:") + for _, field_desc in ipairs(req_fields) do + table.insert(lines, prefix .. " - " .. field_desc) + end + end + + -- Show optional common fields + local opt_fields = {} + for field_name, field_type in pairs(inter.optional_fields or {}) do + table.insert(opt_fields, string.format("%s: %s", field_name, field_type)) + end + if #opt_fields > 0 then + table.insert(lines, prefix .. " some alternatives also expect:") + for _, field_desc in ipairs(opt_fields) do + table.insert(lines, prefix .. " - " .. field_desc) + end + end + + -- Show conflicting fields + local conflicts = {} + for field_name, field_types in pairs(inter.conflicting_fields or {}) do + table.insert(conflicts, string.format("%s (conflicting types: %s)", + field_name, table.concat(field_types, ", "))) + end + if #conflicts > 0 then + table.insert(lines, prefix .. " conflicting field requirements:") + for _, conflict_desc in ipairs(conflicts) do + table.insert(lines, prefix .. " - " .. conflict_desc) + end + end + end + + table.insert(lines, prefix .. " tried alternatives:") + for idx, variant_err in ipairs(err.details.variants or {}) do + local variant_name = variant_err.name or ("variant " .. idx) + table.insert(lines, string.format("%s - %s:", prefix, variant_name)) + format_error_impl(variant_err.error, indent + 3, lines) + end + + else + -- Unknown error kind + local msg = string.format("%sunknown error at %s: %s", + prefix, err.path or "(root)", err.kind or "?") + table.insert(lines, msg) + end + + return lines +end + +function T.format_error(err) + if not err then + return "no error" + end + + local lines = format_error_impl(err, 0, {}) + return table.concat(lines, "\n") +end + +-- Deep clone a value (for immutable transformations) +function T.deep_clone(value) + if type(value) ~= "table" then + return value + end + + local result = {} + for k, v in pairs(value) do + result[k] = T.deep_clone(v) + end + return result +end + +return T diff --git a/lualib/lua_shape/docs.lua b/lualib/lua_shape/docs.lua new file mode 100644 index 0000000000..04071cc3a7 --- /dev/null +++ b/lualib/lua_shape/docs.lua @@ -0,0 +1,282 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Lua shape validation library - Documentation IR generator +-- Generates structured documentation from schemas + +local exports = {} + +-- Extract documentation from opts +local function get_doc(opts) + if not opts or not opts.doc then + return {} + end + return opts.doc +end + +-- Generate doc IR for a schema node +local function generate_doc_impl(schema, path) + path = path or "(root)" + + if not schema or not schema.tag then + return { + type = "unknown", + path = path + } + end + + local doc = get_doc(schema.opts) + local result = { + type = schema.tag, + path = path, + summary = doc.summary, + description = doc.description, + examples = doc.examples + } + + local tag = schema.tag + + -- Scalar types + if tag == "scalar" then + result.kind = schema.kind + result.constraints = {} + + local opts = schema.opts or {} + + if schema.kind == "string" then + if opts.min_len then result.constraints.min_length = opts.min_len end + if opts.max_len then result.constraints.max_length = opts.max_len end + if opts.pattern then result.constraints.pattern = opts.pattern end + + elseif schema.kind == "number" or schema.kind == "integer" then + if opts.min then result.constraints.minimum = opts.min end + if opts.max then result.constraints.maximum = opts.max end + if opts.integer then result.constraints.integer = true end + + elseif schema.kind == "enum" then + if opts.enum then result.constraints.values = opts.enum end + + elseif schema.kind == "literal" then + result.constraints.value = opts.literal + end + + -- Array type + elseif tag == "array" then + result.item_schema = generate_doc_impl(schema.item_schema, path .. "[]") + + local opts = schema.opts or {} + result.constraints = {} + if opts.min_items then result.constraints.min_items = opts.min_items end + if opts.max_items then result.constraints.max_items = opts.max_items end + + -- Table type + elseif tag == "table" then + result.fields = {} + result.mixin_groups = {} + + local opts = schema.opts or {} + result.open = opts.open ~= false + result.extra_schema = opts.extra and generate_doc_impl(opts.extra, path .. ".*") or nil + + -- Group fields by origin (mixins) + local origin_groups = {} + local no_origin_fields = {} + + for field_name, field_spec in pairs(schema.fields or {}) do + local field_doc = { + name = field_name, + optional = field_spec.optional or false, + default = field_spec.default, + schema = generate_doc_impl(field_spec.schema, path .. "." .. field_name) + } + + if field_spec.origin then + local origin_key = field_spec.origin.mixin_name or "unknown" + if not origin_groups[origin_key] then + origin_groups[origin_key] = { + mixin_name = field_spec.origin.mixin_name, + schema_id = field_spec.origin.schema_id, + fields = {} + } + end + table.insert(origin_groups[origin_key].fields, field_doc) + else + table.insert(no_origin_fields, field_doc) + end + end + + -- Add direct fields first + result.fields = no_origin_fields + + -- Add mixin groups + for _, group in pairs(origin_groups) do + table.insert(result.mixin_groups, group) + end + + -- one_of type + elseif tag == "one_of" then + result.variants = {} + + for i, variant in ipairs(schema.variants or {}) do + local variant_doc = generate_doc_impl(variant.schema, path .. "::variant" .. i) + variant_doc.name = variant.name or ("variant_" .. i) + table.insert(result.variants, variant_doc) + end + + -- Optional wrapper + elseif tag == "optional" then + result = generate_doc_impl(schema.inner, path) + result.optional = true + if schema.default ~= nil then + result.default = schema.default + end + + -- Transform wrapper + elseif tag == "transform" then + result = generate_doc_impl(schema.inner, path) + result.has_transform = true + + -- Reference + elseif tag == "ref" then + result.ref_id = schema.ref_id + end + + return result +end + +-- Generate documentation IR for a schema +function exports.for_schema(schema, opts) + opts = opts or {} + + local doc_tree = generate_doc_impl(schema, opts.root_path or "(root)") + + return { + schema_doc = doc_tree, + metadata = { + generated_at = os.date("%Y-%m-%d %H:%M:%S"), + generator = "rspamd_schema v1.0" + } + } +end + +-- Generate documentation for all schemas in a registry +function exports.for_registry(registry, opts) + opts = opts or {} + + local schemas = registry:export_all() + local result = { + schemas = {}, + metadata = { + generated_at = os.date("%Y-%m-%d %H:%M:%S"), + generator = "rspamd_schema v1.0" + } + } + + for id, schema in pairs(schemas) do + result.schemas[id] = generate_doc_impl(schema, id) + end + + return result +end + +-- Simple markdown renderer (optional helper) +function exports.render_markdown(doc_tree, indent) + indent = indent or 0 + local lines = {} + local prefix = string.rep(" ", indent) + + if doc_tree.summary then + table.insert(lines, prefix .. "**" .. doc_tree.summary .. "**") + end + + if doc_tree.description then + table.insert(lines, prefix .. doc_tree.description) + end + + if doc_tree.type == "scalar" then + local type_str = doc_tree.kind or "unknown" + local constraint_strs = {} + + for k, v in pairs(doc_tree.constraints or {}) do + table.insert(constraint_strs, k .. "=" .. tostring(v)) + end + + if #constraint_strs > 0 then + type_str = type_str .. " (" .. table.concat(constraint_strs, ", ") .. ")" + end + + table.insert(lines, prefix .. "Type: `" .. type_str .. "`") + + elseif doc_tree.type == "array" then + table.insert(lines, prefix .. "Type: `array`") + table.insert(lines, prefix .. "Items:") + local item_lines = exports.render_markdown(doc_tree.item_schema, indent + 1) + for _, line in ipairs(item_lines) do + table.insert(lines, line) + end + + elseif doc_tree.type == "table" then + table.insert(lines, prefix .. "Type: `table`") + + if #doc_tree.fields > 0 then + table.insert(lines, prefix .. "Fields:") + for _, field in ipairs(doc_tree.fields) do + local opt_str = field.optional and " (optional)" or " (required)" + if field.default ~= nil then + opt_str = opt_str .. ", default: " .. tostring(field.default) + end + table.insert(lines, prefix .. " - **" .. field.name .. "**" .. opt_str) + local field_lines = exports.render_markdown(field.schema, indent + 2) + for _, line in ipairs(field_lines) do + table.insert(lines, line) + end + end + end + + if #doc_tree.mixin_groups > 0 then + table.insert(lines, prefix .. "Mixins:") + for _, group in ipairs(doc_tree.mixin_groups) do + table.insert(lines, prefix .. " - **" .. (group.mixin_name or "unknown") .. "**") + for _, field in ipairs(group.fields) do + local opt_str = field.optional and " (optional)" or " (required)" + table.insert(lines, prefix .. " - **" .. field.name .. "**" .. opt_str) + end + end + end + + elseif doc_tree.type == "one_of" then + table.insert(lines, prefix .. "Type: `one_of` (must match exactly one alternative)") + table.insert(lines, prefix .. "Alternatives:") + for _, variant in ipairs(doc_tree.variants or {}) do + table.insert(lines, prefix .. " - **" .. variant.name .. "**") + local variant_lines = exports.render_markdown(variant, indent + 2) + for _, line in ipairs(variant_lines) do + table.insert(lines, line) + end + end + end + + if doc_tree.examples then + table.insert(lines, prefix .. "Examples:") + for _, example in ipairs(doc_tree.examples) do + table.insert(lines, prefix .. " - `" .. tostring(example) .. "`") + end + end + + return lines +end + +return exports diff --git a/lualib/lua_shape/jsonschema.lua b/lualib/lua_shape/jsonschema.lua new file mode 100644 index 0000000000..f7fa2f9798 --- /dev/null +++ b/lualib/lua_shape/jsonschema.lua @@ -0,0 +1,230 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Lua shape validation library - JSON Schema exporter +-- Converts lua_shape schemas to JSON Schema format + +local exports = {} + +-- Convert a schema node to JSON Schema +local function to_jsonschema_impl(schema, opts) + opts = opts or {} + + if not schema or not schema.tag then + return {} + end + + local result = {} + local schema_opts = schema.opts or {} + + -- Add description from doc + if schema_opts.doc and schema_opts.doc.summary then + result.description = schema_opts.doc.summary + end + + local tag = schema.tag + + -- Scalar types + if tag == "scalar" then + local kind = schema.kind + + if kind == "string" then + result.type = "string" + + if schema_opts.min_len then + result.minLength = schema_opts.min_len + end + if schema_opts.max_len then + result.maxLength = schema_opts.max_len + end + if schema_opts.pattern then + result.pattern = schema_opts.pattern + end + + elseif kind == "number" or kind == "integer" then + result.type = kind == "integer" and "integer" or "number" + + if schema_opts.min then + result.minimum = schema_opts.min + end + if schema_opts.max then + result.maximum = schema_opts.max + end + + elseif kind == "boolean" then + result.type = "boolean" + + elseif kind == "enum" then + if schema_opts.enum then + result.enum = schema_opts.enum + end + + elseif kind == "literal" then + result.const = schema_opts.literal + end + + -- Array type + elseif tag == "array" then + result.type = "array" + + if schema.item_schema then + result.items = to_jsonschema_impl(schema.item_schema, opts) + end + + if schema_opts.min_items then + result.minItems = schema_opts.min_items + end + if schema_opts.max_items then + result.maxItems = schema_opts.max_items + end + + -- Table type + elseif tag == "table" then + result.type = "object" + result.properties = {} + result.required = {} + + -- Process fields + for field_name, field_spec in pairs(schema.fields or {}) do + result.properties[field_name] = to_jsonschema_impl(field_spec.schema, opts) + + -- Add to required if not optional + if not field_spec.optional then + table.insert(result.required, field_name) + end + + -- Add default if present + if field_spec.default ~= nil then + result.properties[field_name].default = field_spec.default + end + + -- Add origin metadata if present (for mixin tracking) + if field_spec.origin and opts.include_origin then + result.properties[field_name]["x-rspamd-origin"] = field_spec.origin + end + end + + -- Handle open/closed table + if schema_opts.open == false then + if schema_opts.extra then + -- Allow additional properties matching extra schema + result.additionalProperties = to_jsonschema_impl(schema_opts.extra, opts) + else + result.additionalProperties = false + end + else + result.additionalProperties = true + end + + -- Remove empty required array + if #result.required == 0 then + result.required = nil + end + + -- one_of type + elseif tag == "one_of" then + result.oneOf = {} + + for _, variant in ipairs(schema.variants or {}) do + local variant_schema = to_jsonschema_impl(variant.schema, opts) + + -- Add title if variant has a name + if variant.name and opts.include_variant_names then + variant_schema.title = variant.name + end + + table.insert(result.oneOf, variant_schema) + end + + -- Optional wrapper + elseif tag == "optional" then + result = to_jsonschema_impl(schema.inner, opts) + + -- Add null as allowed type + if result.type then + if type(result.type) == "string" then + result.type = { result.type, "null" } + else + table.insert(result.type, "null") + end + end + + if schema.default ~= nil then + result.default = schema.default + end + + -- Transform wrapper + elseif tag == "transform" then + -- For JSON Schema, just export the inner schema + -- Transform semantics don't apply to JSON Schema validation + result = to_jsonschema_impl(schema.inner, opts) + + -- Reference + elseif tag == "ref" then + local ref_id = schema.ref_id + result["$ref"] = "#/definitions/" .. ref_id + end + + return result +end + +-- Convert a schema to JSON Schema +function exports.from_schema(schema, opts) + opts = opts or {} + + local result = { + ["$schema"] = "http://json-schema.org/draft-07/schema#" + } + + -- Add schema ID if provided + if opts.id then + result["$id"] = opts.id + end + + -- Add title if provided + if opts.title then + result.title = opts.title + end + + -- Convert schema + local schema_json = to_jsonschema_impl(schema, opts) + for k, v in pairs(schema_json) do + result[k] = v + end + + return result +end + +-- Export all schemas from a registry +function exports.export_registry(registry, opts) + opts = opts or {} + -- local base_id = opts.base_id or "https://rspamd.com/schema/" + + local result = { + ["$schema"] = "http://json-schema.org/draft-07/schema#", + definitions = {} + } + + local schemas = registry:export_all() + + for id, schema in pairs(schemas) do + result.definitions[id] = to_jsonschema_impl(schema, opts) + end + + return result +end + +return exports diff --git a/lualib/lua_shape/registry.lua b/lualib/lua_shape/registry.lua new file mode 100644 index 0000000000..88eb57c82b --- /dev/null +++ b/lualib/lua_shape/registry.lua @@ -0,0 +1,224 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Lua shape validation library - Registry module +-- Provides schema registration and reference resolution + +local Registry = {} +Registry.__index = Registry + +-- Simple utility functions +local function shallowcopy(t) + local result = {} + for k, v in pairs(t) do + result[k] = v + end + return result +end + +-- Global registry instance +local global_registry = nil + +-- Create a new registry +local function new() + return setmetatable({ + schemas = {}, + resolved_cache = {} + }, Registry) +end + +-- Get or create global registry +function Registry.global() + if not global_registry then + global_registry = new() + end + return global_registry +end + +-- Define a schema with an ID +function Registry:define(id, schema) + if self.schemas[id] then + error("Schema already defined: " .. id) + end + + -- Resolve mixins if this is a table schema + local resolved = self:resolve_schema(schema) + + self.schemas[id] = { + id = id, + original = schema, + resolved = resolved + } + + return resolved +end + +-- Get a schema by ID +function Registry:get(id) + local entry = self.schemas[id] + if not entry then + return nil + end + return entry.resolved +end + +-- Resolve references and mixins in a schema +function Registry:resolve_schema(schema) + if not schema then + return nil + end + + local tag = schema.tag + + -- If already resolved, return from cache + local cache_key = tostring(schema) + if self.resolved_cache[cache_key] then + return self.resolved_cache[cache_key] + end + + -- Handle reference nodes + if tag == "ref" then + local ref_id = schema.ref_id + local target = self.schemas[ref_id] + if not target then + error("Unresolved reference: " .. ref_id) + end + return target.resolved + end + + -- Handle table nodes with mixins + if tag == "table" then + local opts = schema.opts or {} + local mixins = opts.mixins or {} + + if #mixins > 0 then + -- Merge mixin fields into table + local merged_fields = shallowcopy(schema.fields or {}) + + for _, mixin_def in ipairs(mixins) do + if mixin_def._is_mixin then + local mixin_schema = mixin_def.schema + + -- Resolve mixin schema if it's a reference + if mixin_schema.tag == "ref" then + mixin_schema = self:resolve_schema(mixin_schema) + end + + -- Extract fields from mixin + if mixin_schema.tag == "table" then + local mixin_fields = mixin_schema.fields or {} + local mixin_name = mixin_def.as or mixin_schema.opts.doc and + mixin_schema.opts.doc.summary or + "unknown" + + for field_name, field_spec in pairs(mixin_fields) do + if merged_fields[field_name] then + -- Conflict: host field overrides mixin + merged_fields[field_name] = merged_fields[field_name] -- Keep host field + -- TODO: Add warning/logging + else + -- Add field from mixin with origin tracking + local field_copy = shallowcopy(field_spec) + field_copy.origin = { + mixin_name = mixin_name, + schema_id = mixin_schema.opts.schema_id + } + merged_fields[field_name] = field_copy + end + end + end + end + end + + -- Create new table schema with merged fields + local resolved = shallowcopy(schema) + resolved.fields = merged_fields + self.resolved_cache[cache_key] = resolved + return resolved + end + end + + -- Handle array nodes - resolve item schema + if tag == "array" then + local resolved_item = self:resolve_schema(schema.item_schema) + if resolved_item ~= schema.item_schema then + local resolved = shallowcopy(schema) + resolved.item_schema = resolved_item + self.resolved_cache[cache_key] = resolved + return resolved + end + end + + -- Handle one_of nodes - resolve variant schemas + if tag == "one_of" then + local variants = schema.variants or {} + local resolved_variants = {} + local changed = false + + for i, variant in ipairs(variants) do + local resolved_variant_schema = self:resolve_schema(variant.schema) + if resolved_variant_schema ~= variant.schema then + changed = true + end + resolved_variants[i] = { + name = variant.name, + schema = resolved_variant_schema + } + end + + if changed then + local resolved = shallowcopy(schema) + resolved.variants = resolved_variants + self.resolved_cache[cache_key] = resolved + return resolved + end + end + + -- Handle optional/transform wrappers - resolve inner schema + if tag == "optional" or tag == "transform" then + local resolved_inner = self:resolve_schema(schema.inner) + if resolved_inner ~= schema.inner then + local resolved = shallowcopy(schema) + resolved.inner = resolved_inner + self.resolved_cache[cache_key] = resolved + return resolved + end + end + + -- No changes needed + return schema +end + +-- List all registered schema IDs +function Registry:list() + local ids = {} + for id, _ in pairs(self.schemas) do + table.insert(ids, id) + end + table.sort(ids) + return ids +end + +-- Export all schemas (for documentation or JSON Schema generation) +function Registry:export_all() + local result = {} + for id, entry in pairs(self.schemas) do + result[id] = entry.resolved + end + return result +end + +return Registry diff --git a/test/lua/unit/lua_shape.lua b/test/lua/unit/lua_shape.lua new file mode 100644 index 0000000000..6c048f98f5 --- /dev/null +++ b/test/lua/unit/lua_shape.lua @@ -0,0 +1,552 @@ +context("Lua shape validation", function() + local T = require "lua_shape.core" + local Registry = require "lua_shape.registry" + + -- Scalar type tests + context("Scalar types", function() + test("String type - valid", function() + local schema = T.string() + local ok, val = schema:check("hello") + assert_true(ok) + assert_equal(val, "hello") + end) + + test("String type - invalid", function() + local schema = T.string() + local ok, err = schema:check(123) + assert_false(ok) + assert_equal(err.kind, "type_mismatch") + assert_equal(err.details.expected, "string") + assert_equal(err.details.got, "number") + end) + + test("String with length constraints", function() + local schema = T.string({ min_len = 3, max_len = 10 }) + + local ok, val = schema:check("hello") + assert_true(ok) + + ok = schema:check("hi") + assert_false(ok) + + ok = schema:check("this is too long") + assert_false(ok) + end) + + test("String with pattern", function() + local schema = T.string({ pattern = "^%d+$" }) + + local ok = schema:check("123") + assert_true(ok) + + ok = schema:check("abc") + assert_false(ok) + end) + + test("Integer type with range", function() + local schema = T.integer({ min = 0, max = 100 }) + + local ok, val = schema:check(50) + assert_true(ok) + assert_equal(val, 50) + + ok = schema:check(150) + assert_false(ok) + + ok = schema:check(-10) + assert_false(ok) + end) + + test("Integer rejects non-integer", function() + local schema = T.integer() + local ok, err = schema:check(3.14) + assert_false(ok) + assert_equal(err.kind, "constraint_violation") + assert_equal(err.details.constraint, "integer") + end) + + test("Number accepts integer and float", function() + local schema = T.number({ min = 0, max = 10 }) + + local ok = schema:check(5) + assert_true(ok) + + ok = schema:check(5.5) + assert_true(ok) + + ok = schema:check(15) + assert_false(ok) + end) + + test("Boolean type", function() + local schema = T.boolean() + + local ok, val = schema:check(true) + assert_true(ok) + assert_equal(val, true) + + ok, val = schema:check(false) + assert_true(ok) + assert_equal(val, false) + + ok = schema:check("true") + assert_false(ok) + end) + + test("Enum type", function() + local schema = T.enum({"debug", "info", "warning", "error"}) + + local ok = schema:check("info") + assert_true(ok) + + ok, err = schema:check("trace") + assert_false(ok) + assert_equal(err.kind, "enum_mismatch") + end) + + test("Literal type", function() + local schema = T.literal("exact_value") + + local ok = schema:check("exact_value") + assert_true(ok) + + ok = schema:check("other_value") + assert_false(ok) + end) + end) + + -- Array type tests + context("Array type", function() + test("Array of strings - valid", function() + local schema = T.array(T.string()) + local ok, val = schema:check({"foo", "bar", "baz"}) + assert_true(ok) + assert_rspamd_table_eq({expect = {"foo", "bar", "baz"}, actual = val}) + end) + + test("Array of strings - invalid item", function() + local schema = T.array(T.string()) + local ok, err = schema:check({"foo", 123, "baz"}) + assert_false(ok) + assert_equal(err.kind, "array_items_invalid") + end) + + test("Array with size constraints", function() + local schema = T.array(T.string(), { min_items = 2, max_items = 5 }) + + local ok = schema:check({"a", "b", "c"}) + assert_true(ok) + + ok = schema:check({"a"}) + assert_false(ok) + + ok = schema:check({"a", "b", "c", "d", "e", "f"}) + assert_false(ok) + end) + + test("Array rejects table with non-array keys", function() + local schema = T.array(T.string()) + local ok, err = schema:check({foo = "bar"}) + assert_false(ok) + assert_equal(err.kind, "type_mismatch") + end) + end) + + -- Table type tests + context("Table type", function() + test("Simple table - valid", function() + local schema = T.table({ + name = T.string(), + age = T.integer({ min = 0 }) + }) + + local ok, val = schema:check({ name = "Alice", age = 30 }) + assert_true(ok) + assert_equal(val.name, "Alice") + assert_equal(val.age, 30) + end) + + test("Table - missing required field", function() + local schema = T.table({ + name = T.string(), + age = T.integer() + }) + + local ok, err = schema:check({ name = "Bob" }) + assert_false(ok) + assert_equal(err.kind, "table_invalid") + end) + + test("Table - optional field", function() + local schema = T.table({ + name = T.string(), + email = T.string():optional() + }) + + local ok, val = schema:check({ name = "Charlie" }) + assert_true(ok) + assert_equal(val.name, "Charlie") + assert_nil(val.email) + end) + + test("Table - optional field with explicit syntax", function() + local schema = T.table({ + name = T.string(), + email = { schema = T.string(), optional = true } + }) + + local ok = schema:check({ name = "David" }) + assert_true(ok) + end) + + test("Table - default value in transform mode", function() + local schema = T.table({ + name = T.string(), + port = { schema = T.integer(), optional = true, default = 8080 } + }) + + local ok, val = schema:transform({ name = "server" }) + assert_true(ok) + assert_equal(val.port, 8080) + end) + + test("Table - closed table rejects unknown fields", function() + local schema = T.table({ + name = T.string() + }, { open = false }) + + local ok, err = schema:check({ name = "Eve", extra = "field" }) + assert_false(ok) + assert_equal(err.kind, "table_invalid") + end) + + test("Table - open table allows unknown fields", function() + local schema = T.table({ + name = T.string() + }, { open = true }) + + local ok, val = schema:check({ name = "Frank", extra = "field" }) + assert_true(ok) + assert_equal(val.extra, "field") + end) + end) + + -- Optional and default tests + context("Optional and default values", function() + test("Optional wrapper", function() + local schema = T.optional(T.string()) + + local ok, val = schema:check("hello") + assert_true(ok) + assert_equal(val, "hello") + + ok, val = schema:check(nil) + assert_true(ok) + assert_nil(val) + end) + + test("Optional with default in check mode", function() + local schema = T.string():with_default("default") + + local ok, val = schema:check(nil) + assert_true(ok) + assert_nil(val) -- check mode doesn't apply defaults + end) + + test("Optional with default in transform mode", function() + local schema = T.string():with_default("default") + + local ok, val = schema:transform(nil) + assert_true(ok) + assert_equal(val, "default") + end) + end) + + -- Transform tests + context("Transform support", function() + test("Transform string to number", function() + local schema = T.transform(T.number(), function(val) + if type(val) == "string" then + return tonumber(val) + end + return val + end) + + local ok, val = schema:transform("42") + assert_true(ok) + assert_equal(val, 42) + end) + + test("Transform with validation", function() + local schema = T.transform(T.integer({ min = 0 }), function(val) + if type(val) == "string" then + return tonumber(val) + end + return val + end) + + -- Valid transform + local ok, val = schema:transform("10") + assert_true(ok) + assert_equal(val, 10) + + -- Transform result fails validation + ok = schema:transform("-5") + assert_false(ok) + end) + + test("Transform only in transform mode", function() + local schema = T.transform(T.number(), function(val) + return val * 2 + end) + + -- Check mode: no transform + local ok, val = schema:check(5) + assert_true(ok) + assert_equal(val, 5) + + -- Transform mode: applies transform + ok, val = schema:transform(5) + assert_true(ok) + assert_equal(val, 10) + end) + + test("Chained transform using :transform_with", function() + local schema = T.string():transform_with(function(val) + return val:upper() + end) + + local ok, val = schema:transform("hello") + assert_true(ok) + assert_equal(val, "HELLO") + end) + end) + + -- one_of tests + context("one_of type", function() + test("one_of - first variant matches", function() + local schema = T.one_of({ + T.string(), + T.integer() + }) + + local ok, val = schema:check("text") + assert_true(ok) + assert_equal(val, "text") + end) + + test("one_of - second variant matches", function() + local schema = T.one_of({ + T.string(), + T.integer() + }) + + local ok, val = schema:check(42) + assert_true(ok) + assert_equal(val, 42) + end) + + test("one_of - no variant matches", function() + local schema = T.one_of({ + T.string(), + T.integer() + }) + + local ok, err = schema:check(true) + assert_false(ok) + assert_equal(err.kind, "one_of_mismatch") + assert_equal(#err.details.variants, 2) + end) + + test("one_of with named variants", function() + local schema = T.one_of({ + { name = "string_variant", schema = T.string() }, + { name = "number_variant", schema = T.integer() } + }) + + local ok = schema:check("text") + assert_true(ok) + end) + + test("one_of with table variants shows intersection", function() + local schema = T.one_of({ + { + name = "adult", + schema = T.table({ + name = T.string(), + age = T.integer({ min = 18 }) + }) + }, + { + name = "child", + schema = T.table({ + name = T.string(), + age = T.integer({ max = 17 }) + }) + } + }) + + local ok, err = schema:check({ age = 25 }) + assert_false(ok) + assert_equal(err.kind, "one_of_mismatch") + -- Should have intersection showing common fields + assert_not_nil(err.details.intersection) + assert_not_nil(err.details.intersection.required_fields.name) + assert_not_nil(err.details.intersection.required_fields.age) + end) + end) + + -- Registry tests + context("Registry", function() + test("Define and get schema", function() + local reg = Registry.global() + + local schema = reg:define("test.simple", T.string()) + assert_not_nil(schema) + + local retrieved = reg:get("test.simple") + assert_not_nil(retrieved) + end) + + test("Reference resolution", function() + -- Use global registry but with unique schema ID + local reg = Registry.global() + local unique_id = "test.ref_user_" .. tostring(os.time()) + + local user_schema = T.table({ + name = T.string(), + email = T.string() + }) + + reg:define(unique_id, user_schema) + + -- Create a simple test: resolve a ref directly + local ref_schema = T.ref(unique_id) + local resolved = reg:resolve_schema(ref_schema) + + -- Resolved schema should now be the actual table schema + local ok, val = resolved:check({ name = "Alice", email = "alice@example.com" }) + assert_true(ok) + assert_equal(val.name, "Alice") + end) + + test("List registered schemas", function() + local reg = Registry.global() + local ids = reg:list() + assert_not_nil(ids) + assert_equal(type(ids), "table") + end) + end) + + -- Error formatting tests + context("Error formatting", function() + test("Format type mismatch error", function() + local schema = T.string() + local ok, err = schema:check(123) + assert_false(ok) + + local formatted = T.format_error(err) + assert_not_nil(formatted) + assert_true(#formatted > 0) + assert_true(formatted:find("type mismatch") ~= nil) + end) + + test("Format constraint violation error", function() + local schema = T.integer({ min = 0, max = 100 }) + local ok, err = schema:check(150) + assert_false(ok) + + local formatted = T.format_error(err) + assert_not_nil(formatted) + assert_true(formatted:find("constraint violation") ~= nil) + assert_true(formatted:find("max") ~= nil) + end) + + test("Format nested table errors", function() + local schema = T.table({ + name = T.string(), + config = T.table({ + port = T.integer({ min = 1, max = 65535 }) + }) + }) + + local ok, err = schema:check({ + name = "server", + config = { port = 99999 } + }) + assert_false(ok) + + local formatted = T.format_error(err) + assert_not_nil(formatted) + assert_true(formatted:find("config.port") ~= nil) + end) + + test("Format one_of error with intersection", function() + local schema = T.one_of({ + { + name = "config_a", + schema = T.table({ type = T.literal("a"), value_a = T.string() }) + }, + { + name = "config_b", + schema = T.table({ type = T.literal("b"), value_b = T.integer() }) + } + }) + + local ok, err = schema:check({ value_a = "test" }) + assert_false(ok) + + local formatted = T.format_error(err) + assert_not_nil(formatted) + assert_true(formatted:find("alternative") ~= nil) + assert_true(formatted:find("type") ~= nil) + end) + end) + + -- Documentation support + context("Documentation", function() + test("Add documentation to schema", function() + local schema = T.string():doc({ + summary = "User name", + description = "Full name of the user", + examples = {"Alice", "Bob"} + }) + + assert_not_nil(schema.opts.doc) + assert_equal(schema.opts.doc.summary, "User name") + end) + + test("Documentation doesn't affect validation", function() + local schema = T.integer({ min = 0 }):doc({ summary = "Age" }) + + local ok = schema:check(25) + assert_true(ok) + + ok = schema:check(-5) + assert_false(ok) + end) + end) + + -- Utility functions + context("Utility functions", function() + test("Deep clone", function() + local original = { + a = 1, + b = { c = 2, d = { e = 3 } } + } + + local cloned = T.deep_clone(original) + + assert_rspamd_table_eq({expect = original, actual = cloned}) + assert_not_equal(cloned, original) -- different object + assert_not_equal(cloned.b, original.b) -- nested is cloned too + end) + + test("Deep clone handles non-tables", function() + assert_equal(T.deep_clone("string"), "string") + assert_equal(T.deep_clone(42), 42) + assert_equal(T.deep_clone(true), true) + assert_nil(T.deep_clone(nil)) + end) + end) +end)