]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merge url query args to opts in mariadbconnector like mysqldb
authorTobias Petersen <tobias.petersen@mikrodust.com>
Fri, 13 Sep 2024 18:34:33 +0000 (14:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Sep 2024 17:07:29 +0000 (13:07 -0400)
Fixed issue in mariadbconnector dialect where query string arguments that
weren't checked integer or boolean arguments would be ignored, such as
string arguments like ``unix_socket``, etc.  As part of this change, the
argument parsing for particular elements such as ``client_flags``,
``compress``, ``local_infile`` has been made more consistent across all
MySQL / MariaDB dialect which accept each argument. Pull request courtesy
Tobias Alex-Petersen.

Fixes: #11870
Closes: #11869
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11869
Pull-request-sha: 8fdcabc83b548e3fc19aa1625035d43ebc0e1875

Change-Id: I3a11a0e65e118c94928027478409488b0d5e94f8

doc/build/changelog/unreleased_20/11870.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/mariadbconnector.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/provision.py
test/dialect/mysql/test_dialect.py

diff --git a/doc/build/changelog/unreleased_20/11870.rst b/doc/build/changelog/unreleased_20/11870.rst
new file mode 100644 (file)
index 0000000..9625a20
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 11870
+
+    Fixed issue in mariadbconnector dialect where query string arguments that
+    weren't checked integer or boolean arguments would be ignored, such as
+    string arguments like ``unix_socket``, etc.  As part of this change, the
+    argument parsing for particular elements such as ``client_flags``,
+    ``compress``, ``local_infile`` has been made more consistent across all
+    MySQL / MariaDB dialect which accept each argument. Pull request courtesy
+    Tobias Alex-Petersen.
+
index c33ccd3b9332e9e7e7176e81580dddad313c4c11..361cf6ec4083ccadb2a29a8051dcf0947a736b95 100644 (file)
@@ -166,6 +166,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args()
+        opts.update(url.query)
 
         int_params = [
             "connect_timeout",
@@ -180,6 +181,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
             "ssl_verify_cert",
             "ssl",
             "pool_reset_connection",
+            "compress",
         ]
 
         for key in int_params:
index 8a6c2da8b4f304a45230ee3ccb3d74d73ca7c741..edc63fe38657c023c379135448254d54a102cd8c 100644 (file)
@@ -97,6 +97,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
         util.coerce_kw_type(opts, "allow_local_infile", bool)
         util.coerce_kw_type(opts, "autocommit", bool)
         util.coerce_kw_type(opts, "buffered", bool)
+        util.coerce_kw_type(opts, "client_flag", int)
         util.coerce_kw_type(opts, "compress", bool)
         util.coerce_kw_type(opts, "connection_timeout", int)
         util.coerce_kw_type(opts, "connect_timeout", int)
index 0c632b66f3eff87cea6ae1b36ad9d8dff2f84ddf..0baf10f7056ea133e352b6d6a5fa712eb011bda2 100644 (file)
@@ -212,7 +212,7 @@ class MySQLDialect_mysqldb(MySQLDialect):
         util.coerce_kw_type(opts, "read_timeout", int)
         util.coerce_kw_type(opts, "write_timeout", int)
         util.coerce_kw_type(opts, "client_flag", int)
-        util.coerce_kw_type(opts, "local_infile", int)
+        util.coerce_kw_type(opts, "local_infile", bool)
         # Note: using either of the below will cause all strings to be
         # returned as Unicode, both in raw SQL operations and with column
         # types like String and MSString.
index 3f05bcee74d98ef49d0168b3a0603486d2e67f0f..836ffa1df43133c99d693b673d8b249d6166fd15 100644 (file)
@@ -40,6 +40,9 @@ def generate_driver_url(url, driver, query_str):
         drivername="%s+%s" % (backend, driver)
     ).update_query_string(query_str)
 
+    if driver == "mariadbconnector":
+        new_url = new_url.difference_update_query(["charset"])
+
     try:
         new_url.get_dialect()
     except exc.NoSuchModuleError:
index c50755df41430262460b1b189ded654a846f474a..cf74f17ad669ee2bb8ce73f70d3e56e288e15cc8 100644 (file)
@@ -257,21 +257,40 @@ class DialectTest(fixtures.TestBase):
         ("read_timeout", 30),
         ("write_timeout", 30),
         ("client_flag", 1234),
-        ("local_infile", 1234),
+        ("local_infile", 1),
+        ("local_infile", True),
+        ("local_infile", False),
         ("use_unicode", False),
         ("charset", "hello"),
+        ("unix_socket", "somesocket"),
+        argnames="kwarg, value",
     )
-    def test_normal_arguments_mysqldb(self, kwarg, value):
-        from sqlalchemy.dialects.mysql import mysqldb
+    @testing.combinations(
+        ("mysql+mysqldb", ()),
+        ("mysql+mariadbconnector", {"use_unicode", "charset"}),
+        ("mariadb+mariadbconnector", {"use_unicode", "charset"}),
+        ("mysql+pymysql", ()),
+        (
+            "mysql+mysqlconnector",
+            {"read_timeout", "write_timeout", "local_infile"},
+        ),
+        argnames="dialect_name,skip",
+    )
+    def test_query_arguments(self, kwarg, value, dialect_name, skip):
 
-        dialect = mysqldb.dialect()
-        connect_args = dialect.create_connect_args(
-            make_url(
-                "mysql+mysqldb://scott:tiger@localhost:3306/test"
-                "?%s=%s" % (kwarg, value)
-            )
+        if kwarg in skip:
+            return
+
+        url_value = {True: "true", False: "false"}.get(value, value)
+
+        url = make_url(
+            f"{dialect_name}://scott:tiger@"
+            f"localhost:3306/test?{kwarg}={url_value}"
         )
 
+        dialect = url.get_dialect()()
+
+        connect_args = dialect.create_connect_args(url)
         eq_(connect_args[1][kwarg], value)
 
     def test_mysqlconnector_buffered_arg(self):
@@ -320,8 +339,10 @@ class DialectTest(fixtures.TestBase):
         [
             "mysql+mysqldb",
             "mysql+pymysql",
+            "mysql+mariadbconnector",
             "mariadb+mysqldb",
             "mariadb+pymysql",
+            "mariadb+mariadbconnector",
         ]
     )
     def test_random_arg(self):