From 5e16d25cc7c32e6cfaea44ceec5a2730d766952c Mon Sep 17 00:00:00 2001 From: Tobias Petersen Date: Fri, 13 Sep 2024 14:34:33 -0400 Subject: [PATCH] Merge url query args to opts in mariadbconnector like mysqldb Fixed issue in mariadbconnector dialect where query string arguments that weren't checked integer or boolean arguments would be ignored, such as string arguments like ``unix_socket``, etc. As part of this change, the argument parsing for particular elements such as ``client_flags``, ``compress``, ``local_infile`` has been made more consistent across all MySQL / MariaDB dialect which accept each argument. Pull request courtesy Tobias Alex-Petersen. Fixes: #11870 Closes: #11869 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11869 Pull-request-sha: 8fdcabc83b548e3fc19aa1625035d43ebc0e1875 Change-Id: I3a11a0e65e118c94928027478409488b0d5e94f8 --- doc/build/changelog/unreleased_20/11870.rst | 12 ++++++ .../dialects/mysql/mariadbconnector.py | 2 + .../dialects/mysql/mysqlconnector.py | 1 + lib/sqlalchemy/dialects/mysql/mysqldb.py | 2 +- lib/sqlalchemy/dialects/mysql/provision.py | 3 ++ test/dialect/mysql/test_dialect.py | 39 ++++++++++++++----- 6 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11870.rst diff --git a/doc/build/changelog/unreleased_20/11870.rst b/doc/build/changelog/unreleased_20/11870.rst new file mode 100644 index 0000000000..9625a20f8c --- /dev/null +++ b/doc/build/changelog/unreleased_20/11870.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, mysql + :tickets: 11870 + + Fixed issue in mariadbconnector dialect where query string arguments that + weren't checked integer or boolean arguments would be ignored, such as + string arguments like ``unix_socket``, etc. As part of this change, the + argument parsing for particular elements such as ``client_flags``, + ``compress``, ``local_infile`` has been made more consistent across all + MySQL / MariaDB dialect which accept each argument. Pull request courtesy + Tobias Alex-Petersen. + diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py index c33ccd3b93..361cf6ec40 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -166,6 +166,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): def create_connect_args(self, url): opts = url.translate_connect_args() + opts.update(url.query) int_params = [ "connect_timeout", @@ -180,6 +181,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): "ssl_verify_cert", "ssl", "pool_reset_connection", + "compress", ] for key in int_params: diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 8a6c2da8b4..edc63fe386 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -97,6 +97,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): util.coerce_kw_type(opts, "allow_local_infile", bool) util.coerce_kw_type(opts, "autocommit", bool) util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "client_flag", int) util.coerce_kw_type(opts, "compress", bool) util.coerce_kw_type(opts, "connection_timeout", int) util.coerce_kw_type(opts, "connect_timeout", int) diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index 0c632b66f3..0baf10f705 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -212,7 +212,7 @@ class MySQLDialect_mysqldb(MySQLDialect): util.coerce_kw_type(opts, "read_timeout", int) util.coerce_kw_type(opts, "write_timeout", int) util.coerce_kw_type(opts, "client_flag", int) - util.coerce_kw_type(opts, "local_infile", int) + util.coerce_kw_type(opts, "local_infile", bool) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py index 3f05bcee74..836ffa1df4 100644 --- a/lib/sqlalchemy/dialects/mysql/provision.py +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -40,6 +40,9 @@ def generate_driver_url(url, driver, query_str): drivername="%s+%s" % (backend, driver) ).update_query_string(query_str) + if driver == "mariadbconnector": + new_url = new_url.difference_update_query(["charset"]) + try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index c50755df41..cf74f17ad6 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -257,21 +257,40 @@ class DialectTest(fixtures.TestBase): ("read_timeout", 30), ("write_timeout", 30), ("client_flag", 1234), - ("local_infile", 1234), + ("local_infile", 1), + ("local_infile", True), + ("local_infile", False), ("use_unicode", False), ("charset", "hello"), + ("unix_socket", "somesocket"), + argnames="kwarg, value", ) - def test_normal_arguments_mysqldb(self, kwarg, value): - from sqlalchemy.dialects.mysql import mysqldb + @testing.combinations( + ("mysql+mysqldb", ()), + ("mysql+mariadbconnector", {"use_unicode", "charset"}), + ("mariadb+mariadbconnector", {"use_unicode", "charset"}), + ("mysql+pymysql", ()), + ( + "mysql+mysqlconnector", + {"read_timeout", "write_timeout", "local_infile"}, + ), + argnames="dialect_name,skip", + ) + def test_query_arguments(self, kwarg, value, dialect_name, skip): - dialect = mysqldb.dialect() - connect_args = dialect.create_connect_args( - make_url( - "mysql+mysqldb://scott:tiger@localhost:3306/test" - "?%s=%s" % (kwarg, value) - ) + if kwarg in skip: + return + + url_value = {True: "true", False: "false"}.get(value, value) + + url = make_url( + f"{dialect_name}://scott:tiger@" + f"localhost:3306/test?{kwarg}={url_value}" ) + dialect = url.get_dialect()() + + connect_args = dialect.create_connect_args(url) eq_(connect_args[1][kwarg], value) def test_mysqlconnector_buffered_arg(self): @@ -320,8 +339,10 @@ class DialectTest(fixtures.TestBase): [ "mysql+mysqldb", "mysql+pymysql", + "mysql+mariadbconnector", "mariadb+mysqldb", "mariadb+pymysql", + "mariadb+mariadbconnector", ] ) def test_random_arg(self): -- 2.47.2