From: Arran Cudbard-Bell Date: Wed, 13 May 2026 01:47:26 +0000 (-0600) Subject: Make SQL escape functions binary safe, and check escaping functions correctly X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ee5ee49afe2bf88e25cd2816cc5277b24fb66371;p=thirdparty%2Ffreeradius-server.git Make SQL escape functions binary safe, and check escaping functions correctly --- diff --git a/src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c b/src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c index a704857c190..172bbd292dc 100644 --- a/src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c +++ b/src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c @@ -848,11 +848,13 @@ static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t return mysql_affected_rows(conn->sock); } -static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg) +static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg) { - size_t inlen; connection_t *conn = talloc_get_type_abort(arg, connection_t); rlm_sql_mysql_conn_t *c; + char *out; + size_t inlen = vb->vb_length; + unsigned long real_len; char const *log_prefix = conn->name; if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) { @@ -862,13 +864,19 @@ static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, cha c = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t); - /* Check for potential buffer overflow */ - inlen = strlen(in); - if ((inlen * 2 + 1) > outlen) return 0; - /* Prevent integer overflow */ - if ((inlen * 2 + 1) <= inlen) return 0; + /* Prevent integer overflow on (inlen * 2 + 1) */ + if (inlen > (SIZE_MAX - 1) / 2) { + ROPTIONAL(RERROR, ERROR, "Input too large to escape"); + return -1; + } + + MEM(out = talloc_array(vb, char, inlen * 2 + 1)); + real_len = mysql_real_escape_string(&c->db, out, vb->vb_strvalue, inlen); - return mysql_real_escape_string(&c->db, out, in, inlen); + if ((size_t)real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1)); + fr_value_box_strdup_shallow_replace(vb, out, real_len); + + return 0; } SQL_TRUNK_CONNECTION_ALLOC diff --git a/src/modules/rlm_sql/drivers/rlm_sql_postgresql/rlm_sql_postgresql.c b/src/modules/rlm_sql/drivers/rlm_sql_postgresql/rlm_sql_postgresql.c index 6942445270e..05b6145df7d 100644 --- a/src/modules/rlm_sql/drivers/rlm_sql_postgresql/rlm_sql_postgresql.c +++ b/src/modules/rlm_sql/drivers/rlm_sql_postgresql/rlm_sql_postgresql.c @@ -698,11 +698,13 @@ static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t return conn->affected_rows; } -static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg) +static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg) { - size_t inlen, ret; + size_t inlen = vb->vb_length; + size_t real_len; connection_t *conn = talloc_get_type_abort(arg, connection_t); rlm_sql_postgres_conn_t *c; + char *out; int err; if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) { @@ -712,19 +714,24 @@ static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, cha c = talloc_get_type_abort(conn->h, rlm_sql_postgres_conn_t); - /* Check for potential buffer overflow */ - inlen = strlen(in); - if ((inlen * 2 + 1) > outlen) return 0; - /* Prevent integer overflow */ - if ((inlen * 2 + 1) <= inlen) return 0; + /* Prevent integer overflow on (inlen * 2 + 1) */ + if (inlen > (SIZE_MAX - 1) / 2) { + ROPTIONAL(RERROR, ERROR, "Input too large to escape"); + return -1; + } - ret = PQescapeStringConn(c->db, out, in, inlen, &err); + MEM(out = talloc_array(vb, char, inlen * 2 + 1)); + real_len = PQescapeStringConn(c->db, out, vb->vb_strvalue, inlen, &err); if (err) { - ROPTIONAL(REDEBUG, ERROR, "Error escaping string \"%s\": %s", in, PQerrorMessage(c->db)); - return 0; + ROPTIONAL(REDEBUG, ERROR, "Error escaping string: %s", PQerrorMessage(c->db)); + talloc_free(out); + return -1; } - return ret; + if (real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1)); + fr_value_box_strdup_shallow_replace(vb, out, real_len); + + return 0; } static int mod_instantiate(module_inst_ctx_t const *mctx) diff --git a/src/modules/rlm_sql/rlm_sql.c b/src/modules/rlm_sql/rlm_sql.c index 40caf5014ea..8ee19d3fde4 100644 --- a/src/modules/rlm_sql/rlm_sql.c +++ b/src/modules/rlm_sql/rlm_sql.c @@ -322,20 +322,11 @@ static sql_fall_through_t fall_through(map_list_t *maps) return FALL_THROUGH_DEFAULT; } -/* - * Yucky prototype. - */ -static ssize_t sql_escape_func(request_t *, char *out, size_t outlen, char const *in, void *arg); - /** Escape a tainted VB used as an xlat argument * */ static int CC_HINT(nonnull(2,3)) sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx) { - fr_sbuff_t sbuff; - fr_sbuff_uctx_talloc_t sbuff_ctx; - - ssize_t len; void *arg = NULL; rlm_sql_escape_uctx_t *ctx = uctx; rlm_sql_t const *inst = talloc_get_type_abort_const(ctx->sql, rlm_sql_t); @@ -399,23 +390,14 @@ check_escape_arg: } /* - * Escaping functions work on strings - ensure the box is a string + * The escape function works with the bytes already in the box + * (using vb_length, not strlen()) so binary data containing NUL + * bytes survives the escape. Cast non-string boxes to a string + * first - the cast preserves the underlying bytes. */ if ((vb->type != FR_TYPE_STRING) && (fr_value_box_cast_in_place(vb, vb, FR_TYPE_STRING, NULL) < 0)) goto error; - /* - * Maximum escaped length is 3 * original - if every character needs escaping - */ - if (!fr_sbuff_init_talloc(vb, &sbuff, &sbuff_ctx, vb->vb_length * 3, vb->vb_length * 3)) { - fr_strerror_printf_push("Failed to allocate buffer for escaped sql argument"); - return -1; - } - - len = inst->sql_escape_func(request, fr_sbuff_buff(&sbuff), vb->vb_length * 3 + 1, vb->vb_strvalue, arg); - if (len < 0) goto error; - - fr_sbuff_trim_talloc(&sbuff, len); - fr_value_box_strdup_shallow_replace(vb, fr_sbuff_buff(&sbuff), len); + if (inst->sql_escape_func(request, vb, arg) < 0) goto error; /* * Different databases have slightly different ideas as @@ -932,29 +914,37 @@ static unlang_action_t mod_map_proc(unlang_result_t *p_result, map_ctx_t const * query_ctx); } -/** xlat escape function for drivers which do not provide their own +/** Default escape function for drivers which do not provide their own * + * Multi-byte UTF-8 characters pass through unchanged, `\n` / `\r` / `\t` become + * their backslash-escaped equivalents, and anything else outside + * `allowed_chars` is replaced with `=XX` (mime-style hex). */ -static ssize_t sql_escape_func(UNUSED request_t *request, char *out, size_t outlen, char const *in, void *arg) +static int sql_escape_func(UNUSED request_t *request, fr_value_box_t *vb, void *arg) { rlm_sql_t const *inst = talloc_get_type_abort_const(arg, rlm_sql_t); + char const *in = vb->vb_strvalue; + char const *end = in + vb->vb_length; + char *out, *p; + size_t outlen = vb->vb_length * 3 + 1; size_t len = 0; - while (*in) { + /* + * Worst case: every input byte expands to `=XX` (3 chars), plus a NUL. + */ + MEM(p = out = talloc_array(vb, char, outlen)); + + while (in < end) { size_t utf8_len; /* * Allow all multi-byte UTF8 characters. */ - utf8_len = fr_utf8_char((uint8_t const *) in, -1); + utf8_len = fr_utf8_char((uint8_t const *) in, end - in); if (utf8_len > 1) { - if (outlen <= utf8_len) break; - - memcpy(out, in, utf8_len); + memcpy(p, in, utf8_len); in += utf8_len; - out += utf8_len; - - outlen -= utf8_len; + p += utf8_len; len += utf8_len; continue; } @@ -966,69 +956,55 @@ static ssize_t sql_escape_func(UNUSED request_t *request, char *out, size_t outl */ switch (*in) { case '\n': - if (outlen <= 2) break; - out[0] = '\\'; - out[1] = 'n'; + p[0] = '\\'; + p[1] = 'n'; goto next; case '\r': - if (outlen <= 2) break; - out[0] = '\\'; - out[1] = 'r'; + p[0] = '\\'; + p[1] = 'r'; goto next; case '\t': - if (outlen <= 2) break; - out[0] = '\\'; - out[1] = 't'; + p[0] = '\\'; + p[1] = 't'; next: in++; - out += 2; - outlen -= 2; + p += 2; len += 2; continue; } /* - * Non-printable characters get replaced with their - * mime-encoded equivalents. + * Non-printable characters (including embedded NULs) get + * replaced with their mime-encoded equivalents. */ - if ((*in < 32) || + if (((unsigned char)*in < 32) || strchr(inst->config.allowed_chars, *in) == NULL) { - /* - * Only 3 or less bytes available. - */ - if (outlen <= 3) { - break; - } - - snprintf(out, outlen, "=%02X", (unsigned char) in[0]); + snprintf(p, 4, "=%02X", (unsigned char) in[0]); in++; - out += 3; - outlen -= 3; + p += 3; len += 3; continue; } - /* - * Only one byte left. - */ - if (outlen <= 1) { - break; - } - /* * Allowed character. */ - *out = *in; - out++; - in++; - outlen--; + *p++ = *in++; len++; } - *out = '\0'; - return len; + *p = '\0'; + + /* + * Shrink the buffer to fit the actual escaped length, then + * hand it to the box. + */ + if (len + 1 < outlen) MEM(out = talloc_realloc(vb, out, char, len + 1)); + fr_value_box_strdup_shallow_replace(vb, out, len); + + return 0; } /* diff --git a/src/modules/rlm_sql/rlm_sql.h b/src/modules/rlm_sql/rlm_sql.h index f60a9b1a6de..c8f0f3e4f42 100644 --- a/src/modules/rlm_sql/rlm_sql.h +++ b/src/modules/rlm_sql/rlm_sql.h @@ -217,7 +217,7 @@ typedef struct { sql_rcode_t (*sql_finish_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config); sql_rcode_t (*sql_finish_select_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config); - xlat_escape_legacy_t sql_escape_func; + xlat_escape_func_t sql_escape_func; void *(*sql_escape_arg_alloc)(TALLOC_CTX *ctx, fr_event_list_t *el, void *uctx); void (*sql_escape_arg_free)(void *uctx); @@ -234,7 +234,7 @@ struct sql_inst { module_instance_t *driver_submodule; //!< Driver's submodule. rlm_sql_driver_t const *driver; //!< Driver's exported interface. - xlat_escape_legacy_t sql_escape_func; + xlat_escape_func_t sql_escape_func; fr_value_box_escape_t box_escape; void *sql_escape_arg; //!< Instance specific argument to be passed to escape function. diff --git a/src/tests/modules/sql/escape.attrs b/src/tests/modules/sql/escape.attrs new file mode 100644 index 00000000000..8e58395db26 --- /dev/null +++ b/src/tests/modules/sql/escape.attrs @@ -0,0 +1,11 @@ +# +# Input packet +# +Packet-Type = Access-Request +User-Name = "escape_generic" +NAS-IP-Address = "1.2.3.4" + +# +# Expected answer +# +Packet-Type == Access-Accept diff --git a/src/tests/modules/sql/escape.unlang b/src/tests/modules/sql/escape.unlang new file mode 100644 index 00000000000..1bb14bd22ae --- /dev/null +++ b/src/tests/modules/sql/escape.unlang @@ -0,0 +1,58 @@ +# +# Generic SQL escape function: +# - `\n`, `\r`, `\t` become backslash + letter +# - anything < 32 or not in `safe_characters` is mime-encoded as `=XX` +# - multi-byte UTF-8 is passed through +# +# Drivers without their own escape (sqlite, oracle, etc.) all use this path. +# +# String literals in unlang are marked "safe for any escape" so we wrap each +# input in %taint(...) to force the escape to actually run. +# + +# +# Characters in the default safe set pass through unchanged. +# +if (%sql.escape(%taint("Hello world.test")) != "Hello world.test") { + test_fail +} + +# +# Single quote is not in the safe set -> =27 +# +if (%sql.escape(%taint("it's")) != "it=27s") { + test_fail +} + +# +# Backslash -> =5C +# +if (%sql.escape(%taint("back\\slash")) != "back=5Cslash") { + test_fail +} + +# +# Newline -> two-byte `\n` literal (backslash + the letter n), +# NOT mime-encoded as =0A. +# +if (%sql.escape(%taint("a\nb")) != "a\\nb") { + test_fail +} + +# +# Regression: embedded NUL byte must NOT truncate the input. +# Before the byte-length fix the escape stopped at the first NUL +# and returned just "a". +# +if (%sql.escape(%taint("a\000b")) != "a=00b") { + test_fail +} + +# +# Multi-byte UTF-8 is passed through unchanged. +# +if (%sql.escape(%taint("café")) != "café") { + test_fail +} + +test_pass diff --git a/src/tests/modules/sql_mysql/escape.attrs b/src/tests/modules/sql_mysql/escape.attrs new file mode 100644 index 00000000000..37d6f0b455b --- /dev/null +++ b/src/tests/modules/sql_mysql/escape.attrs @@ -0,0 +1,11 @@ +# +# Input packet +# +Packet-Type = Access-Request +User-Name = "escape_mysql" +NAS-IP-Address = "1.2.3.4" + +# +# Expected answer +# +Packet-Type == Access-Accept diff --git a/src/tests/modules/sql_mysql/escape.unlang b/src/tests/modules/sql_mysql/escape.unlang new file mode 100644 index 00000000000..bd38231fc22 --- /dev/null +++ b/src/tests/modules/sql_mysql/escape.unlang @@ -0,0 +1,62 @@ +# +# MySQL/MariaDB escape function (via mysql_real_escape_string): +# ' -> \' +# " -> \" +# \ -> \\ +# \n -> backslash + n (literal two bytes) +# \r -> backslash + r +# \0 -> backslash + 0 +# +# Each escape sequence is exactly two ASCII bytes; everything else +# passes through unchanged. Behaviour is character-set sensitive +# (the connection's charset is set up by the driver). +# +# String literals in unlang are marked "safe for any escape" so we wrap +# each input in %taint(...) to force the escape to actually run. +# + +# +# Plain ASCII passes through unchanged. +# +if (%sql.escape(%taint("Hello world")) != "Hello world") { + test_fail +} + +# +# Single quote -> \' +# +if (%sql.escape(%taint("it's")) != "it\\'s") { + test_fail +} + +# +# Double quote -> \" +# +if (%sql.escape(%taint("say \"hi\"")) != "say \\\"hi\\\"") { + test_fail +} + +# +# Backslash -> \\ (input has one backslash, output has two) +# +if (%sql.escape(%taint("back\\slash")) != "back\\\\slash") { + test_fail +} + +# +# Newline -> two-byte `\n` literal. +# +if (%sql.escape(%taint("a\nb")) != "a\\nb") { + test_fail +} + +# +# Regression: embedded NUL byte must round-trip as `\0`. +# Before the byte-length fix the driver called strlen() on the input +# and the NUL truncated the rest of the value. +# +if (%sql.escape(%taint("a\000b")) != "a\\0b") { + test_fail +} + +test_pass diff --git a/src/tests/modules/sql_postgresql/escape.attrs b/src/tests/modules/sql_postgresql/escape.attrs new file mode 100644 index 00000000000..97fa5790285 --- /dev/null +++ b/src/tests/modules/sql_postgresql/escape.attrs @@ -0,0 +1,11 @@ +# +# Input packet +# +Packet-Type = Access-Request +User-Name = "escape_postgresql" +NAS-IP-Address = "1.2.3.4" + +# +# Expected answer +# +Packet-Type == Access-Accept diff --git a/src/tests/modules/sql_postgresql/escape.unlang b/src/tests/modules/sql_postgresql/escape.unlang new file mode 100644 index 00000000000..fa9862f7283 --- /dev/null +++ b/src/tests/modules/sql_postgresql/escape.unlang @@ -0,0 +1,46 @@ +# +# PostgreSQL escape function (via PQescapeStringConn). +# +# With standard_conforming_strings = on (the default since 9.1) +# PQescapeStringConn only doubles single quotes. Backslashes and +# other byte values pass through unchanged - the SQL string literal +# grammar treats them as literal characters. +# +# String literals in unlang are marked "safe for any escape" so we wrap +# each input in %taint(...) to force the escape to actually run. +# + +# +# Plain ASCII passes through unchanged. +# +if (%sql.escape(%taint("Hello world")) != "Hello world") { + test_fail +} + +# +# Single quote is doubled. +# +if (%sql.escape(%taint("it's")) != "it''s") { + test_fail +} + +if (%sql.escape(%taint("''")) != "''''") { + test_fail +} + +# +# Backslash passes through (assumes standard_conforming_strings = on +# on the test server, which is the default in modern PostgreSQL). +# +if (%sql.escape(%taint("back\\slash")) != "back\\slash") { + test_fail +} + +# +# Multi-byte UTF-8 is passed through unchanged. +# +if (%sql.escape(%taint("café")) != "café") { + test_fail +} + +test_pass diff --git a/src/tests/modules/sql_sqlite/escape.attrs b/src/tests/modules/sql_sqlite/escape.attrs new file mode 120000 index 00000000000..dc1637f2b91 --- /dev/null +++ b/src/tests/modules/sql_sqlite/escape.attrs @@ -0,0 +1 @@ +../sql/escape.attrs \ No newline at end of file diff --git a/src/tests/modules/sql_sqlite/escape.unlang b/src/tests/modules/sql_sqlite/escape.unlang new file mode 120000 index 00000000000..748679ba2af --- /dev/null +++ b/src/tests/modules/sql_sqlite/escape.unlang @@ -0,0 +1 @@ +../sql/escape.unlang \ No newline at end of file