]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Make SQL escape functions binary safe, and check escaping functions correctly
authorArran Cudbard-Bell <a.cudbardb@freeradius.org>
Wed, 13 May 2026 01:47:26 +0000 (19:47 -0600)
committerArran Cudbard-Bell <a.cudbardb@freeradius.org>
Wed, 13 May 2026 01:48:00 +0000 (19:48 -0600)
12 files changed:
src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c
src/modules/rlm_sql/drivers/rlm_sql_postgresql/rlm_sql_postgresql.c
src/modules/rlm_sql/rlm_sql.c
src/modules/rlm_sql/rlm_sql.h
src/tests/modules/sql/escape.attrs [new file with mode: 0644]
src/tests/modules/sql/escape.unlang [new file with mode: 0644]
src/tests/modules/sql_mysql/escape.attrs [new file with mode: 0644]
src/tests/modules/sql_mysql/escape.unlang [new file with mode: 0644]
src/tests/modules/sql_postgresql/escape.attrs [new file with mode: 0644]
src/tests/modules/sql_postgresql/escape.unlang [new file with mode: 0644]
src/tests/modules/sql_sqlite/escape.attrs [new symlink]
src/tests/modules/sql_sqlite/escape.unlang [new symlink]

index a704857c190ecbf227341fb884b1f65c9a7c2a74..172bbd292dc452803e2da4979e7cc1edec612f46 100644 (file)
@@ -848,11 +848,13 @@ static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t
        return mysql_affected_rows(conn->sock);
 }
 
-static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg)
 {
-       size_t                  inlen;
        connection_t            *conn = talloc_get_type_abort(arg, connection_t);
        rlm_sql_mysql_conn_t    *c;
+       char                    *out;
+       size_t                  inlen = vb->vb_length;
+       unsigned long           real_len;
        char const              *log_prefix = conn->name;
 
        if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) {
@@ -862,13 +864,19 @@ static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, cha
 
        c = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
 
-       /* Check for potential buffer overflow */
-       inlen = strlen(in);
-       if ((inlen * 2 + 1) > outlen) return 0;
-       /* Prevent integer overflow */
-       if ((inlen * 2 + 1) <= inlen) return 0;
+       /* Prevent integer overflow on (inlen * 2 + 1) */
+       if (inlen > (SIZE_MAX - 1) / 2) {
+               ROPTIONAL(RERROR, ERROR, "Input too large to escape");
+               return -1;
+       }
+
+       MEM(out = talloc_array(vb, char, inlen * 2 + 1));
+       real_len = mysql_real_escape_string(&c->db, out, vb->vb_strvalue, inlen);
 
-       return mysql_real_escape_string(&c->db, out, in, inlen);
+       if ((size_t)real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1));
+       fr_value_box_strdup_shallow_replace(vb, out, real_len);
+
+       return 0;
 }
 
 SQL_TRUNK_CONNECTION_ALLOC
index 6942445270ee7112b79bd6e0de4f6340768df450..05b6145df7d010a96dafaaa42087397bb29d4f2f 100644 (file)
@@ -698,11 +698,13 @@ static int sql_affected_rows(fr_sql_query_t *query_ctx, UNUSED rlm_sql_config_t
        return conn->affected_rows;
 }
 
-static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(request_t *request, fr_value_box_t *vb, void *arg)
 {
-       size_t                  inlen, ret;
+       size_t                  inlen = vb->vb_length;
+       size_t                  real_len;
        connection_t            *conn = talloc_get_type_abort(arg, connection_t);
        rlm_sql_postgres_conn_t *c;
+       char                    *out;
        int                     err;
 
        if (!((conn->state == CONNECTION_STATE_CONNECTING) || (conn->state == CONNECTION_STATE_CONNECTED))) {
@@ -712,19 +714,24 @@ static ssize_t sql_escape_func(request_t *request, char *out, size_t outlen, cha
 
        c = talloc_get_type_abort(conn->h, rlm_sql_postgres_conn_t);
 
-       /* Check for potential buffer overflow */
-       inlen = strlen(in);
-       if ((inlen * 2 + 1) > outlen) return 0;
-       /* Prevent integer overflow */
-       if ((inlen * 2 + 1) <= inlen) return 0;
+       /* Prevent integer overflow on (inlen * 2 + 1) */
+       if (inlen > (SIZE_MAX - 1) / 2) {
+               ROPTIONAL(RERROR, ERROR, "Input too large to escape");
+               return -1;
+       }
 
-       ret = PQescapeStringConn(c->db, out, in, inlen, &err);
+       MEM(out = talloc_array(vb, char, inlen * 2 + 1));
+       real_len = PQescapeStringConn(c->db, out, vb->vb_strvalue, inlen, &err);
        if (err) {
-               ROPTIONAL(REDEBUG, ERROR, "Error escaping string \"%s\": %s", in, PQerrorMessage(c->db));
-               return 0;
+               ROPTIONAL(REDEBUG, ERROR, "Error escaping string: %s", PQerrorMessage(c->db));
+               talloc_free(out);
+               return -1;
        }
 
-       return ret;
+       if (real_len + 1 < inlen * 2 + 1) MEM(out = talloc_realloc(vb, out, char, real_len + 1));
+       fr_value_box_strdup_shallow_replace(vb, out, real_len);
+
+       return 0;
 }
 
 static int mod_instantiate(module_inst_ctx_t const *mctx)
index 40caf5014eaf19b69fc529410ebcbd9a75d8d4d0..8ee19d3fde4e2a708124c44bc764d82ff6ef9de7 100644 (file)
@@ -322,20 +322,11 @@ static sql_fall_through_t fall_through(map_list_t *maps)
        return  FALL_THROUGH_DEFAULT;
 }
 
-/*
- *     Yucky prototype.
- */
-static ssize_t sql_escape_func(request_t *, char *out, size_t outlen, char const *in, void *arg);
-
 /** Escape a tainted VB used as an xlat argument
  *
  */
 static int CC_HINT(nonnull(2,3)) sql_xlat_escape(request_t *request, fr_value_box_t *vb, void *uctx)
 {
-       fr_sbuff_t                      sbuff;
-       fr_sbuff_uctx_talloc_t          sbuff_ctx;
-
-       ssize_t                         len;
        void                            *arg = NULL;
        rlm_sql_escape_uctx_t           *ctx = uctx;
        rlm_sql_t const                 *inst = talloc_get_type_abort_const(ctx->sql, rlm_sql_t);
@@ -399,23 +390,14 @@ check_escape_arg:
        }
 
        /*
-        *      Escaping functions work on strings - ensure the box is a string
+        *      The escape function works with the bytes already in the box
+        *      (using vb_length, not strlen()) so binary data containing NUL
+        *      bytes survives the escape.  Cast non-string boxes to a string
+        *      first - the cast preserves the underlying bytes.
         */
        if ((vb->type != FR_TYPE_STRING) && (fr_value_box_cast_in_place(vb, vb, FR_TYPE_STRING, NULL) < 0)) goto error;
 
-       /*
-        *      Maximum escaped length is 3 * original - if every character needs escaping
-        */
-       if (!fr_sbuff_init_talloc(vb, &sbuff, &sbuff_ctx, vb->vb_length * 3, vb->vb_length * 3)) {
-               fr_strerror_printf_push("Failed to allocate buffer for escaped sql argument");
-               return -1;
-       }
-
-       len = inst->sql_escape_func(request, fr_sbuff_buff(&sbuff), vb->vb_length * 3 + 1, vb->vb_strvalue, arg);
-       if (len < 0) goto error;
-
-       fr_sbuff_trim_talloc(&sbuff, len);
-       fr_value_box_strdup_shallow_replace(vb, fr_sbuff_buff(&sbuff), len);
+       if (inst->sql_escape_func(request, vb, arg) < 0) goto error;
 
        /*
         *      Different databases have slightly different ideas as
@@ -932,29 +914,37 @@ static unlang_action_t mod_map_proc(unlang_result_t *p_result, map_ctx_t const *
                                                query_ctx);
 }
 
-/** xlat escape function for drivers which do not provide their own
+/** Default escape function for drivers which do not provide their own
  *
+ * Multi-byte UTF-8 characters pass through unchanged, `\n` / `\r` / `\t` become
+ * their backslash-escaped equivalents, and anything else outside
+ * `allowed_chars` is replaced with `=XX` (mime-style hex).
  */
-static ssize_t sql_escape_func(UNUSED request_t *request, char *out, size_t outlen, char const *in, void *arg)
+static int sql_escape_func(UNUSED request_t *request, fr_value_box_t *vb, void *arg)
 {
        rlm_sql_t const         *inst = talloc_get_type_abort_const(arg, rlm_sql_t);
+       char const              *in = vb->vb_strvalue;
+       char const              *end = in + vb->vb_length;
+       char                    *out, *p;
+       size_t                  outlen = vb->vb_length * 3 + 1;
        size_t                  len = 0;
 
-       while (*in) {
+       /*
+        *      Worst case: every input byte expands to `=XX` (3 chars), plus a NUL.
+        */
+       MEM(p = out = talloc_array(vb, char, outlen));
+
+       while (in < end) {
                size_t utf8_len;
 
                /*
                 *      Allow all multi-byte UTF8 characters.
                 */
-               utf8_len = fr_utf8_char((uint8_t const *) in, -1);
+               utf8_len = fr_utf8_char((uint8_t const *) in, end - in);
                if (utf8_len > 1) {
-                       if (outlen <= utf8_len) break;
-
-                       memcpy(out, in, utf8_len);
+                       memcpy(p, in, utf8_len);
                        in += utf8_len;
-                       out += utf8_len;
-
-                       outlen -= utf8_len;
+                       p += utf8_len;
                        len += utf8_len;
                        continue;
                }
@@ -966,69 +956,55 @@ static ssize_t sql_escape_func(UNUSED request_t *request, char *out, size_t outl
                 */
                switch (*in) {
                case '\n':
-                       if (outlen <= 2) break;
-                       out[0] = '\\';
-                       out[1] = 'n';
+                       p[0] = '\\';
+                       p[1] = 'n';
                        goto next;
 
                case '\r':
-                       if (outlen <= 2) break;
-                       out[0] = '\\';
-                       out[1] = 'r';
+                       p[0] = '\\';
+                       p[1] = 'r';
                        goto next;
 
                case '\t':
-                       if (outlen <= 2) break;
-                       out[0] = '\\';
-                       out[1] = 't';
+                       p[0] = '\\';
+                       p[1] = 't';
 
                next:
                        in++;
-                       out += 2;
-                       outlen -= 2;
+                       p += 2;
                        len += 2;
                        continue;
                }
 
                /*
-                *      Non-printable characters get replaced with their
-                *      mime-encoded equivalents.
+                *      Non-printable characters (including embedded NULs) get
+                *      replaced with their mime-encoded equivalents.
                 */
-               if ((*in < 32) ||
+               if (((unsigned char)*in < 32) ||
                    strchr(inst->config.allowed_chars, *in) == NULL) {
-                       /*
-                        *      Only 3 or less bytes available.
-                        */
-                       if (outlen <= 3) {
-                               break;
-                       }
-
-                       snprintf(out, outlen, "=%02X", (unsigned char) in[0]);
+                       snprintf(p, 4, "=%02X", (unsigned char) in[0]);
                        in++;
-                       out += 3;
-                       outlen -= 3;
+                       p += 3;
                        len += 3;
                        continue;
                }
 
-               /*
-                *      Only one byte left.
-                */
-               if (outlen <= 1) {
-                       break;
-               }
-
                /*
                 *      Allowed character.
                 */
-               *out = *in;
-               out++;
-               in++;
-               outlen--;
+               *p++ = *in++;
                len++;
        }
-       *out = '\0';
-       return len;
+       *p = '\0';
+
+       /*
+        *      Shrink the buffer to fit the actual escaped length, then
+        *      hand it to the box.
+        */
+       if (len + 1 < outlen) MEM(out = talloc_realloc(vb, out, char, len + 1));
+       fr_value_box_strdup_shallow_replace(vb, out, len);
+
+       return 0;
 }
 
 /*
index f60a9b1a6de7dcbc41f15941c42a17d67f348be0..c8f0f3e4f42f173f0133312be07843cee3adbde6 100644 (file)
@@ -217,7 +217,7 @@ typedef struct {
        sql_rcode_t     (*sql_finish_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config);
        sql_rcode_t     (*sql_finish_select_query)(fr_sql_query_t *query_ctx, rlm_sql_config_t const *config);
 
-       xlat_escape_legacy_t    sql_escape_func;
+       xlat_escape_func_t      sql_escape_func;
        void            *(*sql_escape_arg_alloc)(TALLOC_CTX *ctx, fr_event_list_t *el, void *uctx);
        void            (*sql_escape_arg_free)(void *uctx);
 
@@ -234,7 +234,7 @@ struct sql_inst {
        module_instance_t               *driver_submodule;      //!< Driver's submodule.
        rlm_sql_driver_t const          *driver;                //!< Driver's exported interface.
 
-       xlat_escape_legacy_t            sql_escape_func;
+       xlat_escape_func_t              sql_escape_func;
        fr_value_box_escape_t           box_escape;
        void                            *sql_escape_arg;        //!< Instance specific argument to be passed to escape function.
 
diff --git a/src/tests/modules/sql/escape.attrs b/src/tests/modules/sql/escape.attrs
new file mode 100644 (file)
index 0000000..8e58395
--- /dev/null
@@ -0,0 +1,11 @@
+#
+#  Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_generic"
+NAS-IP-Address = "1.2.3.4"
+
+#
+#  Expected answer
+#
+Packet-Type == Access-Accept
diff --git a/src/tests/modules/sql/escape.unlang b/src/tests/modules/sql/escape.unlang
new file mode 100644 (file)
index 0000000..1bb14bd
--- /dev/null
@@ -0,0 +1,58 @@
+#
+#  Generic SQL escape function:
+#    - `\n`, `\r`, `\t` become backslash + letter
+#    - anything < 32 or not in `safe_characters` is mime-encoded as `=XX`
+#    - multi-byte UTF-8 is passed through
+#
+#  Drivers without their own escape (sqlite, oracle, etc.) all use this path.
+#
+#  String literals in unlang are marked "safe for any escape" so we wrap each
+#  input in %taint(...) to force the escape to actually run.
+#
+
+#
+#  Characters in the default safe set pass through unchanged.
+#
+if (%sql.escape(%taint("Hello world.test")) != "Hello world.test") {
+       test_fail
+}
+
+#
+#  Single quote is not in the safe set -> =27
+#
+if (%sql.escape(%taint("it's")) != "it=27s") {
+       test_fail
+}
+
+#
+#  Backslash -> =5C
+#
+if (%sql.escape(%taint("back\\slash")) != "back=5Cslash") {
+       test_fail
+}
+
+#
+#  Newline -> two-byte `\n` literal (backslash + the letter n),
+#  NOT mime-encoded as =0A.
+#
+if (%sql.escape(%taint("a\nb")) != "a\\nb") {
+       test_fail
+}
+
+#
+#  Regression: embedded NUL byte must NOT truncate the input.
+#  Before the byte-length fix the escape stopped at the first NUL
+#  and returned just "a".
+#
+if (%sql.escape(%taint("a\000b")) != "a=00b") {
+       test_fail
+}
+
+#
+#  Multi-byte UTF-8 is passed through unchanged.
+#
+if (%sql.escape(%taint("café")) != "café") {
+       test_fail
+}
+
+test_pass
diff --git a/src/tests/modules/sql_mysql/escape.attrs b/src/tests/modules/sql_mysql/escape.attrs
new file mode 100644 (file)
index 0000000..37d6f0b
--- /dev/null
@@ -0,0 +1,11 @@
+#
+#  Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_mysql"
+NAS-IP-Address = "1.2.3.4"
+
+#
+#  Expected answer
+#
+Packet-Type == Access-Accept
diff --git a/src/tests/modules/sql_mysql/escape.unlang b/src/tests/modules/sql_mysql/escape.unlang
new file mode 100644 (file)
index 0000000..bd38231
--- /dev/null
@@ -0,0 +1,62 @@
+#
+#  MySQL/MariaDB escape function (via mysql_real_escape_string):
+#    '  -> \'
+#    "  -> \"
+#    \  -> \\
+#    \n -> backslash + n  (literal two bytes)
+#    \r -> backslash + r
+#    \0 -> backslash + 0
+#
+#  Each escape sequence is exactly two ASCII bytes; everything else
+#  passes through unchanged.  Behaviour is character-set sensitive
+#  (the connection's charset is set up by the driver).
+#
+#  String literals in unlang are marked "safe for any escape" so we wrap
+#  each input in %taint(...) to force the escape to actually run.
+#
+
+#
+#  Plain ASCII passes through unchanged.
+#
+if (%sql.escape(%taint("Hello world")) != "Hello world") {
+       test_fail
+}
+
+#
+#  Single quote -> \'
+#
+if (%sql.escape(%taint("it's")) != "it\\'s") {
+       test_fail
+}
+
+#
+#  Double quote -> \"
+#
+if (%sql.escape(%taint("say \"hi\"")) != "say \\\"hi\\\"") {
+       test_fail
+}
+
+#
+#  Backslash -> \\  (input has one backslash, output has two)
+#
+if (%sql.escape(%taint("back\\slash")) != "back\\\\slash") {
+       test_fail
+}
+
+#
+#  Newline -> two-byte `\n` literal.
+#
+if (%sql.escape(%taint("a\nb")) != "a\\nb") {
+       test_fail
+}
+
+#
+#  Regression: embedded NUL byte must round-trip as `\0`.
+#  Before the byte-length fix the driver called strlen() on the input
+#  and the NUL truncated the rest of the value.
+#
+if (%sql.escape(%taint("a\000b")) != "a\\0b") {
+       test_fail
+}
+
+test_pass
diff --git a/src/tests/modules/sql_postgresql/escape.attrs b/src/tests/modules/sql_postgresql/escape.attrs
new file mode 100644 (file)
index 0000000..97fa579
--- /dev/null
@@ -0,0 +1,11 @@
+#
+#  Input packet
+#
+Packet-Type = Access-Request
+User-Name = "escape_postgresql"
+NAS-IP-Address = "1.2.3.4"
+
+#
+#  Expected answer
+#
+Packet-Type == Access-Accept
diff --git a/src/tests/modules/sql_postgresql/escape.unlang b/src/tests/modules/sql_postgresql/escape.unlang
new file mode 100644 (file)
index 0000000..fa9862f
--- /dev/null
@@ -0,0 +1,46 @@
+#
+#  PostgreSQL escape function (via PQescapeStringConn).
+#
+#  With standard_conforming_strings = on (the default since 9.1)
+#  PQescapeStringConn only doubles single quotes.  Backslashes and
+#  other byte values pass through unchanged - the SQL string literal
+#  grammar treats them as literal characters.
+#
+#  String literals in unlang are marked "safe for any escape" so we wrap
+#  each input in %taint(...) to force the escape to actually run.
+#
+
+#
+#  Plain ASCII passes through unchanged.
+#
+if (%sql.escape(%taint("Hello world")) != "Hello world") {
+       test_fail
+}
+
+#
+#  Single quote is doubled.
+#
+if (%sql.escape(%taint("it's")) != "it''s") {
+       test_fail
+}
+
+if (%sql.escape(%taint("''")) != "''''") {
+       test_fail
+}
+
+#
+#  Backslash passes through (assumes standard_conforming_strings = on
+#  on the test server, which is the default in modern PostgreSQL).
+#
+if (%sql.escape(%taint("back\\slash")) != "back\\slash") {
+       test_fail
+}
+
+#
+#  Multi-byte UTF-8 is passed through unchanged.
+#
+if (%sql.escape(%taint("café")) != "café") {
+       test_fail
+}
+
+test_pass
diff --git a/src/tests/modules/sql_sqlite/escape.attrs b/src/tests/modules/sql_sqlite/escape.attrs
new file mode 120000 (symlink)
index 0000000..dc1637f
--- /dev/null
@@ -0,0 +1 @@
+../sql/escape.attrs
\ No newline at end of file
diff --git a/src/tests/modules/sql_sqlite/escape.unlang b/src/tests/modules/sql_sqlite/escape.unlang
new file mode 120000 (symlink)
index 0000000..748679b
--- /dev/null
@@ -0,0 +1 @@
+../sql/escape.unlang
\ No newline at end of file