]> git.ipfire.org Git - thirdparty/freeradius-server.git/commitdiff
Reinstate running open_query for MySQL
authorNick Porter <nick@portercomputing.co.uk>
Wed, 26 Jun 2024 12:17:15 +0000 (13:17 +0100)
committerNick Porter <nick@portercomputing.co.uk>
Wed, 26 Jun 2024 12:17:15 +0000 (13:17 +0100)
src/modules/rlm_sql/drivers/rlm_sql_mysql/rlm_sql_mysql.c

index b92588fc0f0c35c90f27273088bc3fc8f0710b8b..4043b4070f030956dbd9e6236ae01db34b3d3ba2 100644 (file)
@@ -217,6 +217,37 @@ connected:
        connection_signal_connected(c->conn);
 }
 
+static void _sql_connect_query_run(connection_t *conn, UNUSED connection_state_t prev,
+                                  UNUSED connection_state_t state, void *uctx)
+{
+       rlm_sql_t const         *sql = talloc_get_type_abort_const(uctx, rlm_sql_t);
+       rlm_sql_mysql_conn_t    *sql_conn = talloc_get_type_abort(conn->h, rlm_sql_mysql_conn_t);
+       int                     ret;
+       MYSQL_RES               *result;
+
+       DEBUG2("Executing \"%s\" on connection %s", sql->config.connect_query, conn->name);
+
+       ret = mysql_real_query(sql_conn->sock, sql->config.connect_query, strlen(sql->config.connect_query));
+       if (ret != 0) {
+               char const *info;
+               ERROR("Failed running \"open_query\"");
+               info = mysql_info(sql_conn->sock);
+               if (info) ERROR("%s", info);
+               connection_signal_reconnect(conn, CONNECTION_FAILED);
+               return;
+       }
+
+       /*
+        *      These queries should not return any results - but let's be safe
+        */
+       result = mysql_store_result(sql_conn->sock);
+       if (result) mysql_free_result(result);
+       while ((mysql_next_result(sql_conn->sock) == 0) &&
+              (result = mysql_store_result(sql_conn->sock))) {
+               mysql_free_result(result);
+       }
+}
+
 static connection_state_t _sql_connection_init(void **h, connection_t *conn, void *uctx)
 {
        rlm_sql_t const         *sql = talloc_get_type_abort_const(uctx, rlm_sql_t);
@@ -333,6 +364,9 @@ static connection_state_t _sql_connection_init(void **h, connection_t *conn, voi
 
        *h = c;
 
+       if (config->connect_query) connection_add_watch_post(conn, CONNECTION_STATE_CONNECTED,
+                                                            _sql_connect_query_run, true, sql);
+
        return CONNECTION_STATE_CONNECTING;
 }