From 2552801d2c9f6b906cb8f13f2f5061de4383476b Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Fri, 13 Sep 2024 14:34:33 -0400 Subject: [PATCH] Merge url query args to opts in mariadbconnector like mysqldb Fixed issue in mariadbconnector dialect where query string arguments that weren't checked integer or boolean arguments would be ignored, such as string arguments like ``unix_socket``, etc. As part of this change, the argument parsing for particular elements such as ``client_flags``, ``compress``, ``local_infile`` has been made more consistent across all MySQL / MariaDB dialect which accept each argument. Pull request courtesy Tobias Alex-Petersen. Fixes: #11870 Closes: #11869 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11869 Pull-request-sha: 8fdcabc83b548e3fc19aa1625035d43ebc0e1875 Change-Id: I3a11a0e65e118c94928027478409488b0d5e94f8 (cherry picked from commit 5e16d25cc7c32e6cfaea44ceec5a2730d766952c) --- doc/build/changelog/unreleased_20/11870.rst | 12 ++++++ .../dialects/mysql/mariadbconnector.py | 2 + .../dialects/mysql/mysqlconnector.py | 1 + lib/sqlalchemy/dialects/mysql/mysqldb.py | 2 +- lib/sqlalchemy/dialects/mysql/provision.py | 3 ++ test/dialect/mysql/test_dialect.py | 39 ++++++++++++++----- 6 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11870.rst diff --git a/doc/build/changelog/unreleased_20/11870.rst b/doc/build/changelog/unreleased_20/11870.rst new file mode 100644 index 0000000000..9625a20f8c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11870.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, mysql + :tickets: 11870 + + Fixed issue in mariadbconnector dialect where query string arguments that + weren't checked integer or boolean arguments would be ignored, such as + string arguments like ``unix_socket``, etc. As part of this change, the + argument parsing for particular elements such as ``client_flags``, + ``compress``, ``local_infile`` has been made more consistent across all + MySQL / MariaDB dialect which accept each argument. Pull request courtesy + Tobias Alex-Petersen. + diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index 9bb3fa4d75..1730c1a6f2 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -165,6 +165,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): def create_connect_args(self, url): opts = url.translate_connect_args() + opts.update(url.query) int_params = [ "connect_timeout", @@ -179,6 +180,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): "ssl_verify_cert", "ssl", "pool_reset_connection", + "compress", ] for key in int_params: diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index b1523392d8..8f4b417418 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -96,6 +96,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): util.coerce_kw_type(opts, "allow_local_infile", bool) util.coerce_kw_type(opts, "autocommit", bool) util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "client_flag", int) util.coerce_kw_type(opts, "compress", bool) util.coerce_kw_type(opts, "connection_timeout", int) util.coerce_kw_type(opts, "connect_timeout", int) diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 0c632b66f3..0baf10f705 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -212,7 +212,7 @@ class MySQLDialect_mysqldb(MySQLDialect): util.coerce_kw_type(opts, "read_timeout", int) util.coerce_kw_type(opts, "write_timeout", int) util.coerce_kw_type(opts, "client_flag", int) - util.coerce_kw_type(opts, "local_infile", int) + util.coerce_kw_type(opts, "local_infile", bool) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py index 3f05bcee74..836ffa1df4 100644 --- a/lib/sqlalchemy/dialects/mysql/provision.py +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -40,6 +40,9 @@ def generate_driver_url(url, driver, query_str): drivername="%s+%s" % (backend, driver) ).update_query_string(query_str) + if driver == "mariadbconnector": + new_url = new_url.difference_update_query(["charset"]) + try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index c50755df41..cf74f17ad6 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -257,21 +257,40 @@ class DialectTest(fixtures.TestBase): ("read_timeout", 30), ("write_timeout", 30), ("client_flag", 1234), - ("local_infile", 1234), + ("local_infile", 1), + ("local_infile", True), + ("local_infile", False), ("use_unicode", False), ("charset", "hello"), + ("unix_socket", "somesocket"), + argnames="kwarg, value", ) - def test_normal_arguments_mysqldb(self, kwarg, value): - from sqlalchemy.dialects.mysql import mysqldb + @testing.combinations( + ("mysql+mysqldb", ()), + ("mysql+mariadbconnector", {"use_unicode", "charset"}), + ("mariadb+mariadbconnector", {"use_unicode", "charset"}), + ("mysql+pymysql", ()), + ( + "mysql+mysqlconnector", + {"read_timeout", "write_timeout", "local_infile"}, + ), + argnames="dialect_name,skip", + ) + def test_query_arguments(self, kwarg, value, dialect_name, skip): - dialect = mysqldb.dialect() - connect_args = dialect.create_connect_args( - make_url( - "mysql+mysqldb://scott:tiger@localhost:3306/test" - "?%s=%s" % (kwarg, value) - ) + if kwarg in skip: + return + + url_value = {True: "true", False: "false"}.get(value, value) + + url = make_url( + f"{dialect_name}://scott:tiger@" + f"localhost:3306/test?{kwarg}={url_value}" ) + dialect = url.get_dialect()() + + connect_args = dialect.create_connect_args(url) eq_(connect_args[1][kwarg], value) def test_mysqlconnector_buffered_arg(self): @@ -320,8 +339,10 @@ class DialectTest(fixtures.TestBase): [ "mysql+mysqldb", "mysql+pymysql", + "mysql+mariadbconnector", "mariadb+mysqldb", "mariadb+pymysql", + "mariadb+mariadbconnector", ] ) def test_random_arg(self): -- 2.47.3