]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Mysql ddl compiler fall back to default index args
authorTiansu Yu <tiansu.yu@icloud.com>
Fri, 20 Feb 2026 14:20:40 +0000 (09:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Feb 2026 18:21:47 +0000 (13:21 -0500)
Fixed issue where DDL compilation options were registered to the hard-coded
dialect name ``mysql``. This made it awkward for MySQL-derived dialects
like MariaDB, StarRocks, etc. to work with such options when different sets
of options exist for different platforms. Options are now registered under
the actual dialect name, and a fallback was added to help avoid errors when
an option does not exist for that dialect. Pull request courtesy Tiansu Yu.

Fixes: #13134
Closes: #13138
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13138
Pull-request-sha: 1bc953a2a1be97f82cdbbbc0d8961361716190fa

Change-Id: Ifa700a4e34da4d1923e9473dd8f0d2417dcfded4

doc/build/changelog/unreleased_20/13134.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/testing/fixtures/base.py
test/dialect/mysql/test_compiler.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_20/13134.rst b/doc/build/changelog/unreleased_20/13134.rst
new file mode 100644 (file)
index 0000000..56e1fe8
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 13134
+
+    Fixed issue where DDL compilation options were registered to the hard-coded
+    dialect name ``mysql``. This made it awkward for MySQL-derived dialects
+    like MariaDB, StarRocks, etc. to work with such options when different sets
+    of options exist for different platforms. Options are now registered under
+    the actual dialect name, and a fallback was added to help avoid errors when
+    an option does not exist for that dialect.
+
+    To maintain backwards compatibility, when using the MariaDB dialect with
+    the options ``mysql_with_parser`` or ``mysql_using`` without also specifying
+    the corresponding ``mariadb_`` prefixed options, a deprecation warning will
+    be emitted. The ``mysql_`` prefixed options will continue to work during
+    the deprecation period. Users should update their code to additionally
+    specify ``mariadb_with_parser`` and ``mariadb_using`` when using the
+    ``mariadb://`` dialect, or specify both options to support both dialects.
+
+    Pull request courtesy Tiansu Yu.
index 941aee0c490aba2708706e3193986b9041f86f6e..75ec79baac18f44caefb9e019ff3ef7f4da24165 100644 (file)
@@ -2215,7 +2215,7 @@ class MySQLDDLCompiler(
         if index.unique:
             text += "UNIQUE "
 
-        index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None)
+        index_prefix = index.get_dialect_option(self.dialect, "prefix")
         if index_prefix:
             text += index_prefix + " "
 
@@ -2224,7 +2224,7 @@ class MySQLDDLCompiler(
             text += "IF NOT EXISTS "
         text += "%s ON %s " % (name, table)
 
-        length = index.dialect_options[self.dialect.name]["length"]
+        length = index.get_dialect_option(self.dialect, "length")
         if length is not None:
             if isinstance(length, dict):
                 # length value can be a (column_name --> integer value)
@@ -2252,11 +2252,15 @@ class MySQLDDLCompiler(
             columns_str = ", ".join(columns)
         text += "(%s)" % columns_str
 
-        parser = index.dialect_options["mysql"]["with_parser"]
+        parser = index.get_dialect_option(
+            self.dialect, "with_parser", deprecated_fallback="mysql"
+        )
         if parser is not None:
             text += " WITH PARSER %s" % (parser,)
 
-        using = index.dialect_options["mysql"]["using"]
+        using = index.get_dialect_option(
+            self.dialect, "using", deprecated_fallback="mysql"
+        )
         if using is not None:
             text += " USING %s" % (preparer.quote(using))
 
@@ -2266,7 +2270,9 @@ class MySQLDDLCompiler(
         self, constraint: sa_schema.PrimaryKeyConstraint, **kw: Any
     ) -> str:
         text = super().visit_primary_key_constraint(constraint)
-        using = constraint.dialect_options["mysql"]["using"]
+        using = constraint.get_dialect_option(
+            self.dialect, "using", deprecated_fallback="mysql"
+        )
         if using:
             text += " USING %s" % (self.preparer.quote(using))
         return text
index 47fefe239629c074f688af99a46fcefe1b138a7f..9652616cb92e9256ec3190b7347523d6699469df 100644 (file)
@@ -60,6 +60,7 @@ from .. import util
 from ..util import EMPTY_DICT
 from ..util import HasMemoized as HasMemoized
 from ..util import hybridmethod
+from ..util import warn_deprecated
 from ..util.typing import Self
 from ..util.typing import TypeVarTuple
 from ..util.typing import Unpack
@@ -495,6 +496,79 @@ class DialectKWArgs:
         ("dialect_options", InternalTraversal.dp_dialect_options)
     ]
 
+    def get_dialect_option(
+        self,
+        dialect: Dialect,
+        argument_name: str,
+        *,
+        else_: Any = None,
+        deprecated_fallback: Optional[str] = None,
+    ) -> Any:
+        r"""Return the value of a dialect-specific option, or *else_* if
+        this dialect does not register the given argument.
+
+        This is useful for DDL compilers that may be inherited by
+        third-party dialects whose ``construct_arguments`` do not
+        include the same set of keys as the parent dialect.
+
+        :param dialect: The dialect for which to retrieve the option.
+        :param argument_name: The name of the argument to retrieve.
+        :param else\_: The value to return if the argument is not present.
+        :param deprecated_fallback: Optional dialect name to fall back to
+         if the argument is not present for the current dialect. If the
+         argument is present for the fallback dialect but not the current
+         dialect, a deprecation warning will be emitted.
+
+        """
+
+        registry = DialectKWArgs._kw_registry[dialect.name]
+        if registry is None:
+            return else_
+
+        if argument_name in registry.get(self.__class__, {}):
+            if (
+                deprecated_fallback is None
+                or dialect.name == deprecated_fallback
+            ):
+                return self.dialect_options[dialect.name][argument_name]
+
+            # deprecated_fallback is present; need to look in two places
+
+            # Current dialect has this option registered.
+            # Check if user explicitly set it.
+            if (
+                dialect.name in self.dialect_options
+                and argument_name
+                in self.dialect_options[dialect.name]._non_defaults
+            ):
+                # User explicitly set this dialect's option - use it
+                return self.dialect_options[dialect.name][argument_name]
+
+            # User didn't set current dialect's option.
+            # Check for deprecated fallback.
+            elif (
+                deprecated_fallback in self.dialect_options
+                and argument_name
+                in self.dialect_options[deprecated_fallback]._non_defaults
+            ):
+                # User set fallback option but not current dialect's option
+                warn_deprecated(
+                    f"Using '{deprecated_fallback}_{argument_name}' "
+                    f"with the '{dialect.name}' dialect is deprecated; "
+                    f"please additionally specify "
+                    f"'{dialect.name}_{argument_name}'.",
+                    version="2.1",
+                )
+                return self.dialect_options[deprecated_fallback][argument_name]
+
+            # Return default value
+            return self.dialect_options[dialect.name][argument_name]
+        else:
+            # Current dialect doesn't have the option registered at all.
+            # Don't warn - if a third-party dialect doesn't support an
+            # option, that's their choice, not a deprecation case.
+            return else_
+
     @classmethod
     def argument_for(
         cls, dialect_name: str, argument_name: str, default: Any
index 1e770cc37c49081ddd1fa3174143fc12026e282e..4a31e5a9c2f2909f8840a0a75a3f216513642b96 100644 (file)
@@ -218,6 +218,24 @@ class TestBase:
         else:
             drop_all_tables_from_metadata(metadata, config.db)
 
+    @config.fixture()
+    def thirdparty_dialect(self):
+        from ...dialects import registry
+
+        name = None
+
+        def go(dialect_cls):
+            nonlocal name
+            name = dialect_cls.name
+            assert name, "name is required"
+            registry.impls[name] = dialect_cls
+            return dialect_cls
+
+        yield go
+
+        assert name is not None
+        del registry.impls[name]
+
     @config.fixture(
         params=[
             (rollback, second_operation, begin_nested)
index 3cfa35c1dd131f1dabff13196f8682933a2770e8..c32d3f601dd46a41b2633ad9de22e96ac09c46bc 100644 (file)
@@ -1,3 +1,4 @@
+import contextlib
 import random
 
 from sqlalchemy import BLOB
@@ -72,6 +73,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import eq_ignore_whitespace
+from sqlalchemy.testing import expect_deprecated
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
@@ -186,16 +188,45 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             schema.CreateIndex(idx), "CREATE INDEX test_idx1 ON testtbl (data)"
         )
 
-    def test_create_index_with_prefix(self):
+    def test_create_index_delegates_to_dialect_option(
+        self, thirdparty_dialect
+    ):
+        """test the original 3rd party dialect issue in #13134, which is
+        that visit_create_index() doesn't assume kw like "length" are
+        present"""
+
+        @thirdparty_dialect
+        class ThirdPartyDialect(mysql.MySQLDialect):
+            name = "thirdparty"
+            construct_arguments = []
+
+        m = MetaData()
+        tbl = Table("testtbl", m, Column("data", String(255)))
+        idx = Index("test_idx1", tbl.c.data)
+
+        self.assert_compile(
+            schema.CreateIndex(idx),
+            "CREATE INDEX test_idx1 ON testtbl (data)",
+            dialect=ThirdPartyDialect(),
+        )
+
+    @testing.combinations("mysql", "mariadb", argnames="dialect_name")
+    def test_create_index_with_prefix(self, dialect_name):
         m = MetaData()
         tbl = Table("testtbl", m, Column("data", String(255)))
         idx = Index(
-            "test_idx1", tbl.c.data, mysql_length=10, mysql_prefix="FULLTEXT"
+            "test_idx1",
+            tbl.c.data,
+            **{
+                f"{dialect_name}_length": 10,
+                f"{dialect_name}_prefix": "FULLTEXT",
+            },
         )
 
         self.assert_compile(
             schema.CreateIndex(idx),
             "CREATE FULLTEXT INDEX test_idx1 ON testtbl (data(10))",
+            dialect=dialect_name,
         )
 
     def test_create_index_with_text(self):
@@ -237,36 +268,43 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             "CREATE INDEX test_idx1 ON testtbl (created_at desc)",
         )
 
-    def test_create_index_with_parser(self):
+    @testing.combinations("mysql", "mariadb", argnames="dialect_name")
+    def test_create_index_with_parser(self, dialect_name):
         m = MetaData()
         tbl = Table("testtbl", m, Column("data", String(255)))
         idx = Index(
             "test_idx1",
             tbl.c.data,
-            mysql_length=10,
-            mysql_prefix="FULLTEXT",
-            mysql_with_parser="ngram",
+            **{
+                f"{dialect_name}_length": 10,
+                f"{dialect_name}_prefix": "FULLTEXT",
+                f"{dialect_name}_with_parser": "ngram",
+            },
         )
 
         self.assert_compile(
             schema.CreateIndex(idx),
             "CREATE FULLTEXT INDEX test_idx1 "
             "ON testtbl (data(10)) WITH PARSER ngram",
+            dialect=dialect_name,
         )
 
-    def test_create_index_with_length(self):
+    @testing.combinations("mysql", "mariadb", argnames="dialect_name")
+    def test_create_index_with_length(self, dialect_name):
         m = MetaData()
         tbl = Table("testtbl", m, Column("data", String(255)))
-        idx1 = Index("test_idx1", tbl.c.data, mysql_length=10)
-        idx2 = Index("test_idx2", tbl.c.data, mysql_length=5)
+        idx1 = Index("test_idx1", tbl.c.data, **{f"{dialect_name}_length": 10})
+        idx2 = Index("test_idx2", tbl.c.data, **{f"{dialect_name}_length": 5})
 
         self.assert_compile(
             schema.CreateIndex(idx1),
             "CREATE INDEX test_idx1 ON testtbl (data(10))",
+            dialect=dialect_name,
         )
         self.assert_compile(
             schema.CreateIndex(idx2),
             "CREATE INDEX test_idx2 ON testtbl (data(5))",
+            dialect=dialect_name,
         )
 
     def test_drop_constraint_mysql(self):
@@ -374,19 +412,26 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             "CREATE INDEX test_idx3 ON testtbl (a(30), b(30))",
         )
 
-    def test_create_index_with_using(self):
+    @testing.combinations("mysql", "mariadb", argnames="dialect_name")
+    def test_create_index_with_using(self, dialect_name):
         m = MetaData()
         tbl = Table("testtbl", m, Column("data", String(255)))
-        idx1 = Index("test_idx1", tbl.c.data, mysql_using="btree")
-        idx2 = Index("test_idx2", tbl.c.data, mysql_using="hash")
+        idx1 = Index(
+            "test_idx1", tbl.c.data, **{f"{dialect_name}_using": "btree"}
+        )
+        idx2 = Index(
+            "test_idx2", tbl.c.data, **{f"{dialect_name}_using": "hash"}
+        )
 
         self.assert_compile(
             schema.CreateIndex(idx1),
             "CREATE INDEX test_idx1 ON testtbl (data) USING btree",
+            dialect=dialect_name,
         )
         self.assert_compile(
             schema.CreateIndex(idx2),
             "CREATE INDEX test_idx2 ON testtbl (data) USING hash",
+            dialect=dialect_name,
         )
 
     def test_create_pk_plain(self):
@@ -404,21 +449,86 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             "PRIMARY KEY (data))",
         )
 
-    def test_create_pk_with_using(self):
+    @testing.combinations("mysql", "mariadb", argnames="dialect_name")
+    def test_create_pk_with_using(self, dialect_name):
         m = MetaData()
         tbl = Table(
             "testtbl",
             m,
             Column("data", String(255)),
-            PrimaryKeyConstraint("data", mysql_using="btree"),
+            PrimaryKeyConstraint("data", **{f"{dialect_name}_using": "btree"}),
         )
 
         self.assert_compile(
             schema.CreateTable(tbl),
             "CREATE TABLE testtbl (data VARCHAR(255) NOT NULL, "
             "PRIMARY KEY (data) USING btree)",
+            dialect=dialect_name,
+        )
+
+    @testing.combinations(
+        ("with_parser", "ngram", "WITH PARSER ngram"),
+        ("using", "btree", "USING btree"),
+        argnames="paramname, value, expected",
+    )
+    @testing.variation("use_deprecated", [True, False])
+    def test_create_index_mysql_option_mariadb_deprecated(
+        self, paramname, value, expected, use_deprecated
+    ):
+        """Test that mysql_with_parser emits deprecation with mariadb
+        dialect"""
+        m = MetaData()
+        tbl = Table("testtbl", m, Column("data", String(255)))
+        idx = Index(
+            "test_idx1",
+            tbl.c.data,
+            mariadb_length=10,
+            mariadb_prefix="FULLTEXT",
+            **{
+                f"{'mysql' if use_deprecated else 'mariadb'}"
+                f"_{paramname}": value
+            },
         )
 
+        if use_deprecated:
+            expect = expect_deprecated(
+                f"Using 'mysql_{paramname}' with the 'mariadb' dialect is "
+                f"deprecated; please additionally specify "
+                f"'mariadb_{paramname}'.",
+            )
+        else:
+            expect = contextlib.nullcontext()
+
+        with expect:
+            self.assert_compile(
+                schema.CreateIndex(idx),
+                "CREATE FULLTEXT INDEX test_idx1 "
+                f"ON testtbl (data(10)) {expected}",
+                dialect="mariadb",
+            )
+
+    def test_create_pk_with_using_mysql_option_mariadb_deprecated(self):
+        """Test that mysql_using for PK emits deprecation with mariadb
+        dialect"""
+        m = MetaData()
+        tbl = Table(
+            "testtbl",
+            m,
+            Column("data", String(255)),
+            PrimaryKeyConstraint("data", mysql_using="btree"),
+        )
+
+        with expect_deprecated(
+            "Using 'mysql_using' with the 'mariadb' dialect is deprecated; "
+            "please additionally specify 'mariadb_using'."
+        ):
+            self.assert_compile(
+                schema.CreateTable(tbl),
+                "CREATE TABLE testtbl (data VARCHAR(255) NOT NULL, "
+                "PRIMARY KEY (data) USING btree)",
+                dialect="mariadb",
+            )
+
     @testing.combinations(
         (True, True, (10, 2, 2)),
         (True, True, (10, 2, 1)),
index 09d84683cdfeccdefd03f7d76e3f61ce6929b150..a44fd07dfa853c2f21387df08fa063d6e1665271 100644 (file)
@@ -5724,6 +5724,46 @@ class DialectKWArgTest(fixtures.TestBase):
                 5,
             )
 
+    def test_get_dialect_option_participating(self):
+        with self._fixture():
+            idx = Index("a", "b", "c", participating_x=7)
+            dialect = mock.Mock()
+            dialect.name = "participating"
+            eq_(idx.get_dialect_option(dialect, "x"), 7)
+            eq_(idx.get_dialect_option(dialect, "y"), False)
+
+    def test_get_dialect_option_participating_default(self):
+        with self._fixture():
+            idx = Index("a", "b", "c")
+            dialect = mock.Mock()
+            dialect.name = "participating"
+            eq_(idx.get_dialect_option(dialect, "x"), 5)
+            eq_(idx.get_dialect_option(dialect, "z_one"), None)
+
+    def test_get_dialect_option_participating_unregistered_arg(self):
+        with self._fixture():
+            idx = Index("a", "b", "c", participating_x=7)
+            dialect = mock.Mock()
+            dialect.name = "participating"
+            eq_(idx.get_dialect_option(dialect, "nonexistent"), None)
+            eq_(
+                idx.get_dialect_option(
+                    dialect, "nonexistent", else_="fallback"
+                ),
+                "fallback",
+            )
+
+    def test_get_dialect_option_nonparticipating(self):
+        with self._fixture():
+            idx = Index("a", "b", "c")
+            dialect = mock.Mock()
+            dialect.name = "nonparticipating"
+            eq_(idx.get_dialect_option(dialect, "x"), None)
+            eq_(
+                idx.get_dialect_option(dialect, "x", else_="fallback"),
+                "fallback",
+            )
+
 
 class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"