return mysql_affected_rows(conn->sock);
}
-static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg)
{
- size_t inlen;
connection_t *conn = talloc_get_type_abort(arg, connection_t);
rlm_sql_mysql_conn_t *c;
+ char *out;
+ size_t inlen = vb->vb_length;
+ unsigned long real_len;
char const *log_prefix = conn->name;
if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) {
c = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
- /* Check for potential buffer overflow */
- inlen = strlen(in);
- if ((inlen * 2 + 1) > outlen) return 0;
- /* Prevent integer overflow */
- if ((inlen * 2 + 1) <= inlen) return 0;
+ /* Prevent integer overflow on (inlen * 2 + 1) */
+ if (inlen > (SIZE_MAX - 1) / 2) {
+ ROPTIONAL(RERROR, ERROR, "Input too large to escape");
+ return -1;
+ }
+
+ MEM(out = talloc_array(vb, char, inlen * 2 + 1));
+ real_len = mysql_real_escape_string(&c->db, out, vb->vb_strvalue, inlen);
- return mysql_real_escape_string(&c->db, out, in, inlen);
+ if ((size_t)real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1));
+ fr_value_box_strdup_shallow_replace(vb, out, real_len);
+
+ return 0;
}
SQL_TRUNK_CONNECTION_ALLOC
return conn->affected_rows;
}
-static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg)
{
- size_t inlen, ret;
+ size_t inlen = vb->vb_length;
+ size_t real_len;
connection_t *conn = talloc_get_type_abort(arg, connection_t);
rlm_sql_postgres_conn_t *c;
+ char *out;
int err;
if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) {
c = talloc_get_type_abort(conn->h, rlm_sql_postgres_conn_t);
- /* Check for potential buffer overflow */
- inlen = strlen(in);
- if ((inlen * 2 + 1) > outlen) return 0;
- /* Prevent integer overflow */
- if ((inlen * 2 + 1) <= inlen) return 0;
+ /* Prevent integer overflow on (inlen * 2 + 1) */
+ if (inlen > (SIZE_MAX - 1) / 2) {
+ ROPTIONAL(RERROR, ERROR, "Input too large to escape");
+ return -1;
+ }
- ret = PQescapeStringConn(c->db, out, in, inlen, &err);
+ MEM(out = talloc_array(vb, char, inlen * 2 + 1));
+ real_len = PQescapeStringConn(c->db, out, vb->vb_strvalue, inlen, &err);
if (err) {
- ROPTIONAL(REDEBUG, ERROR, "Error escaping string \"%s\": %s", in, PQerrorMessage(c->db));
- return 0;
+ ROPTIONAL(REDEBUG, ERROR, "Error escaping string: %s", PQerrorMessage(c->db));
+ talloc_free(out);
+ return -1;
}
- return ret;
+ if (real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1));
+ fr_value_box_strdup_shallow_replace(vb, out, real_len);
+
+ return 0;
}
static int mod_instantiate(module_inst_ctx_t const *mctx)
return FALL_THROUGH_DEFAULT;
}
-/*
- * Yucky prototype.
- */
-static ssize_t sql_escape_func(request_t *, char *out, size_t outlen, char const *in, void *arg);
-
/** Escape a tainted VB used as an xlat argument
*
*/
static int CC_HINT(nonnull(2,3)) sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx)
{
- fr_sbuff_t sbuff;
- fr_sbuff_uctx_talloc_t sbuff_ctx;
-
- ssize_t len;
void *arg = NULL;
rlm_sql_escape_uctx_t *ctx = uctx;
rlm_sql_t const *inst = talloc_get_type_abort_const(ctx->sql, rlm_sql_t);
}
/*
- * Escaping functions work on strings - ensure the box is a string
+ * The escape function works with the bytes already in the box
+ * (using vb_length, not strlen()) so binary data containing NUL
+ * bytes survives the escape. Cast non-string boxes to a string
+ * first - the cast preserves the underlying bytes.
*/
if ((vb->type != FR_TYPE_STRING) && (fr_value_box_cast_in_place(vb, vb, FR_TYPE_STRING, NULL) < 0)) goto error;
- /*
- * Maximum escaped length is 3 * original - if every character needs escaping
- */
- if (!fr_sbuff_init_talloc(vb, &sbuff, &sbuff_ctx, vb->vb_length * 3, vb->vb_length * 3)) {
- fr_strerror_printf_push("Failed to allocate buffer for escaped sql argument");
- return -1;
- }
-
- len = inst->sql_escape_func(request, fr_sbuff_buff(&sbuff), vb->vb_length * 3 + 1, vb->vb_strvalue, arg);
- if (len < 0) goto error;
-
- fr_sbuff_trim_talloc(&sbuff, len);
- fr_value_box_strdup_shallow_replace(vb, fr_sbuff_buff(&sbuff), len);
+ if (inst->sql_escape_func(request, vb, arg) < 0) goto error;
/*
* Different databases have slightly different ideas as
query_ctx);
}
-/** xlat escape function for drivers which do not provide their own
+/** Default escape function for drivers which do not provide their own
*
+ * Multi-byte UTF-8 characters pass through unchanged, `\n` / `\r` / `\t` become
+ * their backslash-escaped equivalents, and anything else outside
+ * `allowed_chars` is replaced with `=XX` (mime-style hex).
*/
-static ssize_t sql_escape_func(UNUSED request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(UNUSED request_t *request, fr_value_box_t *vb, void *arg)
{
rlm_sql_t const *inst = talloc_get_type_abort_const(arg, rlm_sql_t);
+ char const *in = vb->vb_strvalue;
+ char const *end = in + vb->vb_length;
+ char *out, *p;
+ size_t outlen = vb->vb_length * 3 + 1;
size_t len = 0;
- while (*in) {
+ /*
+ * Worst case: every input byte expands to `=XX` (3 chars), plus a NUL.
+ */
+ MEM(p = out = talloc_array(vb, char, outlen));
+
+ while (in < end) {
size_t utf8_len;
/*
* Allow all multi-byte UTF8 characters.
*/
- utf8_len = fr_utf8_char((uint8_t const *) in, -1);
+ utf8_len = fr_utf8_char((uint8_t const *) in, end - in);
if (utf8_len > 1) {
- if (outlen <= utf8_len) break;
-
- memcpy(out, in, utf8_len);
+ memcpy(p, in, utf8_len);
in += utf8_len;
- out += utf8_len;
-
- outlen -= utf8_len;
+ p += utf8_len;
len += utf8_len;
continue;
}
*/
switch (*in) {
case '\n':
- if (outlen <= 2) break;
- out[0] = '\\';
- out[1] = 'n';
+ p[0] = '\\';
+ p[1] = 'n';
goto next;
case '\r':
- if (outlen <= 2) break;
- out[0] = '\\';
- out[1] = 'r';
+ p[0] = '\\';
+ p[1] = 'r';
goto next;
case '\t':
- if (outlen <= 2) break;
- out[0] = '\\';
- out[1] = 't';
+ p[0] = '\\';
+ p[1] = 't';
next:
in++;
- out += 2;
- outlen -= 2;
+ p += 2;
len += 2;
continue;
}
/*
- * Non-printable characters get replaced with their
- * mime-encoded equivalents.
+ * Non-printable characters (including embedded NULs) get
+ * replaced with their mime-encoded equivalents.
*/
- if ((*in < 32) ||
+ if (((unsigned char)*in < 32) ||
strchr(inst->config.allowed_chars, *in) == NULL) {
- /*
- * Only 3 or less bytes available.
- */
- if (outlen <= 3) {
- break;
- }
-
- snprintf(out, outlen, "=%02X", (unsigned char) in[0]);
+ snprintf(p, 4, "=%02X", (unsigned char) in[0]);
in++;
- out += 3;
- outlen -= 3;
+ p += 3;
len += 3;
continue;
}
- /*
- * Only one byte left.
- */
- if (outlen <= 1) {
- break;
- }
-
/*
* Allowed character.
*/
- *out = *in;
- out++;
- in++;
- outlen--;
+ *p++ = *in++;
len++;
}
- *out = '\0';
- return len;
+ *p = '\0';
+
+ /*
+ * Shrink the buffer to fit the actual escaped length, then
+ * hand it to the box.
+ */
+ if (len + 1 < outlen) MEM(out = talloc_realloc(vb, out, char, len + 1));
+ fr_value_box_strdup_shallow_replace(vb, out, len);
+
+ return 0;
}
/*
sql_rcode_t (*sql_finish_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config);
sql_rcode_t (*sql_finish_select_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config);
- xlat_escape_legacy_t sql_escape_func;
+ xlat_escape_func_t sql_escape_func;
void *(*sql_escape_arg_alloc)(TALLOC_CTX *ctx, fr_event_list_t *el, void *uctx);
void (*sql_escape_arg_free)(void *uctx);
module_instance_t *driver_submodule; //!< Driver's submodule.
rlm_sql_driver_t const *driver; //!< Driver's exported interface.
- xlat_escape_legacy_t sql_escape_func;
+ xlat_escape_func_t sql_escape_func;
fr_value_box_escape_t box_escape;
void *sql_escape_arg; //!< Instance specific argument to be passed to escape function.
--- /dev/null
+#
+# Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_generic"
+NAS-IP-Address = "1.2.3.4"
+
+#
+# Expected answer
+#
+Packet-Type == Access-Accept
--- /dev/null
+#
+# Generic SQL escape function:
+# - `\n`, `\r`, `\t` become backslash + letter
+# - anything < 32 or not in `safe_characters` is mime-encoded as `=XX`
+# - multi-byte UTF-8 is passed through
+#
+# Drivers without their own escape (sqlite, oracle, etc.) all use this path.
+#
+# String literals in unlang are marked "safe for any escape" so we wrap each
+# input in %taint(...) to force the escape to actually run.
+#
+
+#
+# Characters in the default safe set pass through unchanged.
+#
+if (%sql.escape(%taint("Hello world.test")) != "Hello world.test") {
+ test_fail
+}
+
+#
+# Single quote is not in the safe set -> =27
+#
+if (%sql.escape(%taint("it's")) != "it=27s") {
+ test_fail
+}
+
+#
+# Backslash -> =5C
+#
+if (%sql.escape(%taint("back\\slash")) != "back=5Cslash") {
+ test_fail
+}
+
+#
+# Newline -> two-byte `\n` literal (backslash + the letter n),
+# NOT mime-encoded as =0A.
+#
+if (%sql.escape(%taint("a\nb")) != "a\\nb") {
+ test_fail
+}
+
+#
+# Regression: embedded NUL byte must NOT truncate the input.
+# Before the byte-length fix the escape stopped at the first NUL
+# and returned just "a".
+#
+if (%sql.escape(%taint("a\000b")) != "a=00b") {
+ test_fail
+}
+
+#
+# Multi-byte UTF-8 is passed through unchanged.
+#
+if (%sql.escape(%taint("café")) != "café") {
+ test_fail
+}
+
+test_pass
--- /dev/null
+#
+# Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_mysql"
+NAS-IP-Address = "1.2.3.4"
+
+#
+# Expected answer
+#
+Packet-Type == Access-Accept
--- /dev/null
+#
+# MySQL/MariaDB escape function (via mysql_real_escape_string):
+# ' -> \'
+# " -> \"
+# \ -> \\
+# \n -> backslash + n (literal two bytes)
+# \r -> backslash + r
+# \0 -> backslash + 0
+#
+# Each escape sequence is exactly two ASCII bytes; everything else
+# passes through unchanged. Behaviour is character-set sensitive
+# (the connection's charset is set up by the driver).
+#
+# String literals in unlang are marked "safe for any escape" so we wrap
+# each input in %taint(...) to force the escape to actually run.
+#
+
+#
+# Plain ASCII passes through unchanged.
+#
+if (%sql.escape(%taint("Hello world")) != "Hello world") {
+ test_fail
+}
+
+#
+# Single quote -> \'
+#
+if (%sql.escape(%taint("it's")) != "it\\'s") {
+ test_fail
+}
+
+#
+# Double quote -> \"
+#
+if (%sql.escape(%taint("say \"hi\"")) != "say \\\"hi\\\"") {
+ test_fail
+}
+
+#
+# Backslash -> \\ (input has one backslash, output has two)
+#
+if (%sql.escape(%taint("back\\slash")) != "back\\\\slash") {
+ test_fail
+}
+
+#
+# Newline -> two-byte `\n` literal.
+#
+if (%sql.escape(%taint("a\nb")) != "a\\nb") {
+ test_fail
+}
+
+#
+# Regression: embedded NUL byte must round-trip as `\0`.
+# Before the byte-length fix the driver called strlen() on the input
+# and the NUL truncated the rest of the value.
+#
+if (%sql.escape(%taint("a\000b")) != "a\\0b") {
+ test_fail
+}
+
+test_pass
--- /dev/null
+#
+# Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_postgresql"
+NAS-IP-Address = "1.2.3.4"
+
+#
+# Expected answer
+#
+Packet-Type == Access-Accept
--- /dev/null
+#
+# PostgreSQL escape function (via PQescapeStringConn).
+#
+# With standard_conforming_strings = on (the default since 9.1)
+# PQescapeStringConn only doubles single quotes. Backslashes and
+# other byte values pass through unchanged - the SQL string literal
+# grammar treats them as literal characters.
+#
+# String literals in unlang are marked "safe for any escape" so we wrap
+# each input in %taint(...) to force the escape to actually run.
+#
+
+#
+# Plain ASCII passes through unchanged.
+#
+if (%sql.escape(%taint("Hello world")) != "Hello world") {
+ test_fail
+}
+
+#
+# Single quote is doubled.
+#
+if (%sql.escape(%taint("it's")) != "it''s") {
+ test_fail
+}
+
+if (%sql.escape(%taint("''")) != "''''") {
+ test_fail
+}
+
+#
+# Backslash passes through (assumes standard_conforming_strings = on
+# on the test server, which is the default in modern PostgreSQL).
+#
+if (%sql.escape(%taint("back\\slash")) != "back\\slash") {
+ test_fail
+}
+
+#
+# Multi-byte UTF-8 is passed through unchanged.
+#
+if (%sql.escape(%taint("café")) != "café") {
+ test_fail
+}
+
+test_pass
--- /dev/null
+../sql/escape.attrs
\ No newline at end of file
--- /dev/null
+../sql/escape.unlang
\ No newline at end of file