]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
separate out mariadb/mysql implementations but remain monolithic
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Jan 2026 16:06:50 +0000 (11:06 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Jan 2026 21:17:32 +0000 (16:17 -0500)
Fixes to the MySQL/MariaDB dialect so that mariadb-specific features such
as the :class:`.mariadb.INET4` and :class:`.mariadb.INET6` datatype may be
used with an :class:`.Engine` that uses a ``mysql://`` URL, if the backend
database is actually a mariadb database.   Previously, support for MariaDB
features when ``mysql://`` URLs were used instead of ``mariadb://`` URLs
was ad-hoc; with this issue resolution, the full set of schema / compiler /
type features are now available regardless of how the URL was presented.

After much discussion it seems premature to formally separate the
mysql/mariadb dialects to be mutually exclusive, however we'd like
to standardize on how dialect-exclusive behaviors are architected.
For now, use an approach of MariaDB "shim" classes which provide
MariaDB behaviors into all relevant MySQL classes up front.  Where
behaviors are mutually exclusive, support `_set_mariadb()` methods
that enable the mariadb version of things.

this approach may or may not have resilience against future
divergences in mysql/mariadb , but at least starts to separate the
source code for the two databases and hopefully provides a clearer
separation of concerns.

Fixes: #13076
Change-Id: I42cffc0563a2c65c38e854e9cc2181353b230c44

13 files changed:
doc/build/changelog/unreleased_21/13076.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/_mariadb_shim.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mariadb.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_dialect.py
test/dialect/mysql/test_query.py
test/dialect/mysql/test_reflection.py
test/dialect/mysql/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_21/13076.rst b/doc/build/changelog/unreleased_21/13076.rst
new file mode 100644 (file)
index 0000000..9a06919
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, mariadb
+    :tickets: 13076
+
+    Fixes to the MySQL/MariaDB dialect so that mariadb-specific features such
+    as the :class:`.mariadb.INET4` and :class:`.mariadb.INET6` datatype may be
+    used with an :class:`.Engine` that uses a ``mysql://`` URL, if the backend
+    database is actually a mariadb database.   Previously, support for MariaDB
+    features when ``mysql://`` URLs were used instead of ``mariadb://`` URLs
+    was ad-hoc; with this issue resolution, the full set of schema / compiler /
+    type features are now available regardless of how the URL was presented.
diff --git a/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py b/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py
new file mode 100644 (file)
index 0000000..fdf210c
--- /dev/null
@@ -0,0 +1,312 @@
+# dialects/mysql/_mariadb_shim.py
+# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
+
+from .reserved_words import RESERVED_WORDS_MARIADB
+from ... import exc
+from ... import schema as sa_schema
+from ... import util
+from ...engine import cursor as _cursor
+from ...engine import default
+from ...engine.default import DefaultDialect
+from ...engine.interfaces import TypeCompiler
+from ...sql import elements
+from ...sql import sqltypes
+from ...sql.compiler import DDLCompiler
+from ...sql.compiler import IdentifierPreparer
+from ...sql.compiler import SQLCompiler
+from ...sql.schema import SchemaConst
+from ...sql.sqltypes import _UUID_RETURN
+from ...sql.sqltypes import UUID
+from ...sql.sqltypes import Uuid
+
+if TYPE_CHECKING:
+    from .base import MySQLIdentifierPreparer
+    from .mariadb import INET4
+    from .mariadb import INET6
+    from ...engine import URL
+    from ...engine.base import Connection
+    from ...sql import ddl
+    from ...sql.schema import IdentityOptions
+    from ...sql.schema import Sequence as Sequence_SchemaItem
+    from ...sql.type_api import _BindProcessorType
+
+
+class _MariaDBUUID(UUID[_UUID_RETURN]):
+    def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
+        self.as_uuid = as_uuid
+
+        # the _MariaDBUUID internal type is only invoked for a Uuid() with
+        # native_uuid=True.   for non-native uuid type, the plain Uuid
+        # returns itself due to the workings of the Emulated superclass.
+        assert native_uuid
+
+        # for internal type, force string conversion for result_processor() as
+        # current drivers are returning a string, not a Python UUID object
+        self.native_uuid = False
+
+    @property
+    def native(self) -> bool:  # type: ignore[override]
+        # override to return True, this is a native type, just turning
+        # off native_uuid for internal data handling
+        return True
+
+    def bind_processor(self, dialect: MariaDBShim) -> Optional[_BindProcessorType[_UUID_RETURN]]:  # type: ignore[override] # noqa: E501
+        if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
+            return super().bind_processor(dialect)  # type: ignore[return-value] # noqa: E501
+        else:
+            return None
+
+
+class MariaDBTypeCompilerShim(TypeCompiler):
+    def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
+        return "INET4"
+
+    def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
+        return "INET6"
+
+
+class MariadbExecutionContextShim(default.DefaultExecutionContext):
+    def post_exec(self) -> None:
+        if (
+            self.isdelete
+            and cast(SQLCompiler, self.compiled).effective_returning
+            and not self.cursor.description
+        ):
+            # All MySQL/mariadb drivers appear to not include
+            # cursor.description for DELETE..RETURNING with no rows if the
+            # WHERE criteria is a straight "false" condition such as our EMPTY
+            # IN condition. manufacture an empty result in this case (issue
+            # #10505)
+            #
+            # taken from cx_Oracle implementation
+            self.cursor_fetch_strategy = (
+                _cursor.FullyBufferedCursorFetchStrategy(
+                    self.cursor,
+                    [
+                        (entry.keyname, None)  # type: ignore[misc]
+                        for entry in cast(
+                            SQLCompiler, self.compiled
+                        )._result_columns
+                    ],
+                    [],
+                )
+            )
+
+    def fire_sequence(
+        self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
+    ) -> int:
+        return self._execute_scalar(  # type: ignore[no-any-return]
+            (
+                "select nextval(%s)"
+                % self.identifier_preparer.format_sequence(seq)
+            ),
+            type_,
+        )
+
+
+class MariaDBIdentifierPreparerShim(IdentifierPreparer):
+    def _set_mariadb(self) -> None:
+        self.reserved_words = RESERVED_WORDS_MARIADB
+
+
+class MariaDBSQLCompilerShim(SQLCompiler):
+    def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
+        return "nextval(%s)" % self.preparer.format_sequence(sequence)
+
+    def _mariadb_regexp_flags(
+        self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+    ) -> str:
+        return "CONCAT('(?', %s, ')', %s)" % (
+            self.render_literal_value(flags, sqltypes.STRINGTYPE),
+            self.process(pattern, **kw),
+        )
+
+    def _mariadb_regexp_match(
+        self,
+        op_string: str,
+        binary: elements.BinaryExpression[Any],
+        operator: Any,
+        **kw: Any,
+    ) -> str:
+        flags = binary.modifiers["flags"]
+        return "%s%s%s" % (
+            self.process(binary.left, **kw),
+            op_string,
+            self._mariadb_regexp_flags(flags, binary.right),
+        )
+
+    def _mariadb_regexp_replace_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
+        flags = binary.modifiers["flags"]
+        return "REGEXP_REPLACE(%s, %s, %s)" % (
+            self.process(binary.left, **kw),
+            self._mariadb_regexp_flags(flags, binary.right.clauses[0]),
+            self.process(binary.right.clauses[1], **kw),
+        )
+
+    def _mariadb_visit_drop_check_constraint(
+        self, drop: ddl.DropConstraint, **kw: Any
+    ) -> str:
+        constraint = drop.element
+        qual = "CONSTRAINT "
+        const = self.preparer.format_constraint(constraint)
+        return "ALTER TABLE %s DROP %s%s" % (
+            self.preparer.format_table(constraint.table),
+            qual,
+            const,
+        )
+
+
+class MariaDBDDLCompilerShim(DDLCompiler):
+    dialect: MariaDBShim
+
+    def _mariadb_get_column_specification(
+        self, column: sa_schema.Column[Any], **kw: Any
+    ) -> str:
+
+        if (
+            column.computed is not None
+            and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED
+        ):
+            kw["_force_column_to_nullable"] = True
+
+        return self._mysql_get_column_specification(column, **kw)
+
+    def _mysql_get_column_specification(
+        self,
+        column: sa_schema.Column[Any],
+        *,
+        _force_column_to_nullable: bool = False,
+        **kw: Any,
+    ) -> str:
+        raise NotImplementedError()
+
+    def get_identity_options(self, identity_options: IdentityOptions) -> str:
+        text = super().get_identity_options(identity_options)
+        text = text.replace("NO CYCLE", "NOCYCLE")
+        return text
+
+    def _mariadb_visit_drop_check_constraint(
+        self, drop: ddl.DropConstraint, **kw: Any
+    ) -> str:
+        constraint = drop.element
+        qual = "CONSTRAINT "
+        const = self.preparer.format_constraint(constraint)
+        return "ALTER TABLE %s DROP %s%s" % (
+            self.preparer.format_table(constraint.table),
+            qual,
+            const,
+        )
+
+
+class MariaDBShim(DefaultDialect):
+    server_version_info: tuple[int, ...]
+    is_mariadb: bool
+    _allows_uuid_binds = False
+
+    identifier_preparer: MySQLIdentifierPreparer
+    preparer: Type[MySQLIdentifierPreparer]
+
+    def _set_mariadb(
+        self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
+    ) -> None:
+        if is_mariadb is None:
+            return
+
+        if not is_mariadb and self.is_mariadb:
+            raise exc.InvalidRequestError(
+                "MySQL version %s is not a MariaDB variant."
+                % (".".join(map(str, server_version_info)),)
+            )
+        if is_mariadb:
+            assert isinstance(self.colspecs, dict)
+            self.colspecs = util.update_copy(
+                self.colspecs, {Uuid: _MariaDBUUID}
+            )
+
+            self.identifier_preparer = self.preparer(self)
+            self.identifier_preparer._set_mariadb()
+
+            # this will be updated on first connect in initialize()
+            # if using older mariadb version
+            self.delete_returning = True
+            self.insert_returning = True
+
+        self.is_mariadb = is_mariadb
+
+    @property
+    def _mariadb_normalized_version_info(self) -> tuple[int, ...]:
+        return self.server_version_info
+
+    @property
+    def _is_mariadb(self) -> bool:
+        return self.is_mariadb
+
+    @classmethod
+    def _is_mariadb_from_url(cls, url: URL) -> bool:
+        dbapi = cls.import_dbapi()
+        dialect = cls(dbapi=dbapi)
+
+        cargs, cparams = dialect.create_connect_args(url)
+        conn = dialect.connect(*cargs, **cparams)
+        try:
+            cursor = conn.cursor()
+            cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
+            val = cursor.fetchone()[0]  # type: ignore[index]
+        except:
+            raise
+        else:
+            return bool(val)
+        finally:
+            conn.close()
+
+    def _initialize_mariadb(self, connection: Connection) -> None:
+        assert self.is_mariadb
+
+        self.supports_sequences = self.server_version_info >= (10, 3)
+
+        self.delete_returning = self.server_version_info >= (10, 0, 5)
+
+        self.insert_returning = self.server_version_info >= (10, 5)
+
+        self._warn_for_known_db_issues()
+
+        self.supports_native_uuid = (
+            self.server_version_info is not None
+            and self.server_version_info >= (10, 7)
+        )
+        self._allows_uuid_binds = True
+
+        # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/
+        self._support_default_function = self.server_version_info >= (10, 2, 1)
+
+        # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+        self._support_float_cast = self.server_version_info >= (10, 4, 5)
+
+    def _warn_for_known_db_issues(self) -> None:
+        if self.is_mariadb:
+            mdb_version = self.server_version_info
+            assert mdb_version is not None
+            if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
+                util.warn(
+                    "MariaDB %r before 10.2.9 has known issues regarding "
+                    "CHECK constraints, which impact handling of NULL values "
+                    "with SQLAlchemy's boolean datatype (MDEV-13596). An "
+                    "additional issue prevents proper migrations of columns "
+                    "with CHECK constraints (MDEV-11114).  Please upgrade to "
+                    "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
+                    "series, to avoid these issues." % (mdb_version,)
+                )
index be6343aec12841e6882e424ce6e0545088583b75..941aee0c490aba2708706e3193986b9041f86f6e 100644 (file)
@@ -1075,15 +1075,16 @@ from typing import Optional
 from typing import overload
 from typing import Sequence
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
+from . import _mariadb_shim
 from . import reflection as _reflection
 from .enumerated import ENUM
 from .enumerated import SET
 from .json import JSON
 from .json import JSONIndexType
 from .json import JSONPathType
-from .reserved_words import RESERVED_WORDS_MARIADB
 from .reserved_words import RESERVED_WORDS_MYSQL
 from .types import _FloatType
 from .types import _IntegerType
@@ -1122,7 +1123,6 @@ from ... import literal_column
 from ... import schema as sa_schema
 from ... import sql
 from ... import util
-from ...engine import cursor as _cursor
 from ...engine import default
 from ...engine import reflection
 from ...engine.reflection import ReflectionDefaults
@@ -1136,8 +1136,6 @@ from ...sql import sqltypes
 from ...sql import util as sql_util
 from ...sql import visitors
 from ...sql.compiler import InsertmanyvaluesSentinelOpts
-from ...sql.compiler import SQLCompiler
-from ...sql.schema import SchemaConst
 from ...types import BINARY
 from ...types import BLOB
 from ...types import BOOLEAN
@@ -1168,7 +1166,6 @@ if TYPE_CHECKING:
     from ...engine.interfaces import ReflectedUniqueConstraint
     from ...engine.result import _Ts
     from ...engine.row import Row
-    from ...engine.url import URL
     from ...schema import Table
     from ...sql import ddl
     from ...sql import selectable
@@ -1180,14 +1177,12 @@ if TYPE_CHECKING:
     from ...sql.functions import random
     from ...sql.functions import rollup
     from ...sql.functions import sysdate
-    from ...sql.schema import IdentityOptions
-    from ...sql.schema import Sequence as Sequence_SchemaItem
     from ...sql.type_api import TypeEngine
     from ...sql.visitors import ExternallyTraversible
     from ...util.typing import TupleAny
     from ...util.typing import Unpack
 
-
+_T = TypeVar("_T", bound=Any)
 SET_RE = re.compile(
     r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
 )
@@ -1281,33 +1276,9 @@ ischema_names = {
 }
 
 
-class MySQLExecutionContext(default.DefaultExecutionContext):
-    def post_exec(self) -> None:
-        if (
-            self.isdelete
-            and cast(SQLCompiler, self.compiled).effective_returning
-            and not self.cursor.description
-        ):
-            # All MySQL/mariadb drivers appear to not include
-            # cursor.description for DELETE..RETURNING with no rows if the
-            # WHERE criteria is a straight "false" condition such as our EMPTY
-            # IN condition. manufacture an empty result in this case (issue
-            # #10505)
-            #
-            # taken from cx_Oracle implementation
-            self.cursor_fetch_strategy = (
-                _cursor.FullyBufferedCursorFetchStrategy(
-                    self.cursor,
-                    [
-                        (entry.keyname, None)  # type: ignore[misc]
-                        for entry in cast(
-                            SQLCompiler, self.compiled
-                        )._result_columns
-                    ],
-                    [],
-                )
-            )
-
+class MySQLExecutionContext(
+    _mariadb_shim.MariadbExecutionContextShim, default.DefaultExecutionContext
+):
     def create_server_side_cursor(self) -> DBAPICursor:
         if self.dialect.supports_server_side_cursors:
             return self._dbapi_connection.cursor(
@@ -1316,19 +1287,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
         else:
             raise NotImplementedError()
 
-    def fire_sequence(
-        self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
-    ) -> int:
-        return self._execute_scalar(  # type: ignore[no-any-return]
-            (
-                "select nextval(%s)"
-                % self.identifier_preparer.format_sequence(seq)
-            ),
-            type_,
-        )
-
 
-class MySQLCompiler(compiler.SQLCompiler):
+class MySQLCompiler(
+    _mariadb_shim.MariaDBSQLCompilerShim, compiler.SQLCompiler
+):
     dialect: MySQLDialect
     render_table_with_column_in_update_from = True
     """Overridden from base SQLCompiler value"""
@@ -1373,19 +1335,16 @@ class MySQLCompiler(compiler.SQLCompiler):
             return (
                 f"group_concat({expr._compiler_dispatch(self, **kw)} "
                 f"ORDER BY {order_by._compiler_dispatch(self, **kw)} "
-                f"SEPARATOR "
+                "SEPARATOR "
                 f"{delimiter._compiler_dispatch(self, **literal_exec)})"
             )
         else:
             return (
                 f"group_concat({expr._compiler_dispatch(self, **kw)} "
-                f"SEPARATOR "
+                "SEPARATOR "
                 f"{delimiter._compiler_dispatch(self, **literal_exec)})"
             )
 
-    def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
-        return "nextval(%s)" % self.preparer.format_sequence(sequence)
-
     def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str:
         return "SYSDATE()"
 
@@ -1554,7 +1513,7 @@ class MySQLCompiler(compiler.SQLCompiler):
                 "any column keys in table '%s': %s"
                 % (
                     self.statement.table.name,  # type: ignore[union-attr]
-                    (", ".join("'%s'" % c for c in non_matching)),
+                    ", ".join("'%s'" % c for c in non_matching),
                 )
             )
 
@@ -1569,8 +1528,8 @@ class MySQLCompiler(compiler.SQLCompiler):
     def visit_concat_op_expression_clauselist(
         self, clauselist: elements.ClauseList, operator: Any, **kw: Any
     ) -> str:
-        return "concat(%s)" % (
-            ", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
+        return "concat(%s)" % ", ".join(
+            self.process(elem, **kw) for elem in clauselist.clauses
         )
 
     def visit_concat_op_binary(
@@ -1942,8 +1901,7 @@ class MySQLCompiler(compiler.SQLCompiler):
         self, element_types: list[TypeEngine[Any]], **kw: Any
     ) -> str:
         return (
-            "SELECT %(outer)s FROM (SELECT %(inner)s) "
-            "as _empty_set WHERE 1!=1"
+            "SELECT %(outer)s FROM (SELECT %(inner)s) as _empty_set WHERE 1!=1"
             % {
                 "inner": ", ".join(
                     "1 AS _in_%s" % idx
@@ -1971,13 +1929,24 @@ class MySQLCompiler(compiler.SQLCompiler):
             self.process(binary.right),
         )
 
-    def _mariadb_regexp_flags(
-        self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+    def _mysql_regexp_match(
+        self,
+        op_string: str,
+        binary: elements.BinaryExpression[Any],
+        operator: Any,
+        **kw: Any,
     ) -> str:
-        return "CONCAT('(?', %s, ')', %s)" % (
+        flags = binary.modifiers["flags"]
+
+        text = "REGEXP_LIKE(%s, %s, %s)" % (
+            self.process(binary.left, **kw),
+            self.process(binary.right, **kw),
             self.render_literal_value(flags, sqltypes.STRINGTYPE),
-            self.process(pattern, **kw),
         )
+        if op_string == " NOT REGEXP ":
+            return "NOT %s" % text
+        else:
+            return text
 
     def _regexp_match(
         self,
@@ -1990,22 +1959,15 @@ class MySQLCompiler(compiler.SQLCompiler):
         flags = binary.modifiers["flags"]
         if flags is None:
             return self._generate_generic_binary(binary, op_string, **kw)
-        elif self.dialect.is_mariadb:
-            return "%s%s%s" % (
-                self.process(binary.left, **kw),
-                op_string,
-                self._mariadb_regexp_flags(flags, binary.right),
-            )
         else:
-            text = "REGEXP_LIKE(%s, %s, %s)" % (
-                self.process(binary.left, **kw),
-                self.process(binary.right, **kw),
-                self.render_literal_value(flags, sqltypes.STRINGTYPE),
+            return self.dialect._dispatch_for_vendor(
+                self._mysql_regexp_match,
+                self._mariadb_regexp_match,
+                op_string,
+                binary,
+                operator,
+                **kw,
             )
-            if op_string == " NOT REGEXP ":
-                return "NOT %s" % text
-            else:
-                return text
 
     def visit_regexp_match_op_binary(
         self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
@@ -2027,33 +1989,52 @@ class MySQLCompiler(compiler.SQLCompiler):
                 self.process(binary.left, **kw),
                 self.process(binary.right, **kw),
             )
-        elif self.dialect.is_mariadb:
-            return "REGEXP_REPLACE(%s, %s, %s)" % (
-                self.process(binary.left, **kw),
-                self._mariadb_regexp_flags(flags, binary.right.clauses[0]),
-                self.process(binary.right.clauses[1], **kw),
-            )
         else:
-            return "REGEXP_REPLACE(%s, %s, %s)" % (
-                self.process(binary.left, **kw),
-                self.process(binary.right, **kw),
-                self.render_literal_value(flags, sqltypes.STRINGTYPE),
+            return self.dialect._dispatch_for_vendor(
+                self._mysql_regexp_replace_op_binary,
+                self._mariadb_regexp_replace_op_binary,
+                binary,
+                operator,
+                **kw,
             )
 
+    def _mysql_regexp_replace_op_binary(
+        self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+    ) -> str:
+        flags = binary.modifiers["flags"]
 
-class MySQLDDLCompiler(compiler.DDLCompiler):
+        return "REGEXP_REPLACE(%s, %s, %s)" % (
+            self.process(binary.left, **kw),
+            self.process(binary.right, **kw),
+            self.render_literal_value(flags, sqltypes.STRINGTYPE),
+        )
+
+
+class MySQLDDLCompiler(
+    _mariadb_shim.MariaDBDDLCompilerShim, compiler.DDLCompiler
+):
     dialect: MySQLDialect
 
     def get_column_specification(
         self, column: sa_schema.Column[Any], **kw: Any
     ) -> str:
         """Builds column DDL."""
-        if (
-            self.dialect.is_mariadb is True
-            and column.computed is not None
-            and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED
-        ):
-            column.nullable = True
+
+        return self.dialect._dispatch_for_vendor(
+            self._mysql_get_column_specification,
+            self._mariadb_get_column_specification,
+            column,
+            **kw,
+        )
+
+    def _mysql_get_column_specification(
+        self,
+        column: sa_schema.Column[Any],
+        *,
+        _force_column_to_nullable: bool = False,
+        **kw: Any,
+    ) -> str:
+
         colspec = [
             self.preparer.format_column(column),
             self.dialect.type_compiler_instance.process(
@@ -2069,7 +2050,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             sqltypes.TIMESTAMP,
         )
 
-        if not column.nullable:
+        if not column.nullable and not _force_column_to_nullable:
             colspec.append("NOT NULL")
 
         # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null  # noqa
@@ -2315,11 +2296,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             qual = "INDEX "
             const = self.preparer.format_constraint(constraint)
         elif isinstance(constraint, sa_schema.CheckConstraint):
-            if self.dialect.is_mariadb:
-                qual = "CONSTRAINT "
-            else:
-                qual = "CHECK "
-            const = self.preparer.format_constraint(constraint)
+            return self.dialect._dispatch_for_vendor(
+                self._mysql_visit_drop_check_constraint,
+                self._mariadb_visit_drop_check_constraint,
+                drop,
+                **kw,
+            )
         else:
             qual = ""
             const = self.preparer.format_constraint(constraint)
@@ -2329,6 +2311,18 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             const,
         )
 
+    def _mysql_visit_drop_check_constraint(
+        self, drop: ddl.DropConstraint, **kw: Any
+    ) -> str:
+        constraint = drop.element
+        qual = "CHECK "
+        const = self.preparer.format_constraint(constraint)
+        return "ALTER TABLE %s DROP %s%s" % (
+            self.preparer.format_table(constraint.table),
+            qual,
+            const,
+        )
+
     def define_constraint_match(
         self, constraint: sa_schema.ForeignKeyConstraint
     ) -> str:
@@ -2365,17 +2359,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             self.get_column_specification(create.element),
         )
 
-    def get_identity_options(self, identity_options: IdentityOptions) -> str:
-        """mariadb-specific sequence option; this will move to a
-        mariadb-specific module in 2.1
 
-        """
-        text = super().get_identity_options(identity_options)
-        text = text.replace("NO CYCLE", "NOCYCLE")
-        return text
-
-
-class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+class MySQLTypeCompiler(
+    _mariadb_shim.MariaDBTypeCompilerShim, compiler.GenericTypeCompiler
+):
     def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str:
         "Extend a numeric-type declaration with MySQL specific extensions."
 
@@ -2687,7 +2674,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         return "BOOL"
 
 
-class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
+class MySQLIdentifierPreparer(
+    _mariadb_shim.MariaDBIdentifierPreparerShim, compiler.IdentifierPreparer
+):
     reserved_words = RESERVED_WORDS_MYSQL
 
     def __init__(
@@ -2709,16 +2698,15 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
         return tuple([self.quote_identifier(i) for i in ids if i is not None])
 
 
-class MariaDBIdentifierPreparer(MySQLIdentifierPreparer):
-    reserved_words = RESERVED_WORDS_MARIADB
-
-
-class MySQLDialect(default.DefaultDialect):
+class MySQLDialect(_mariadb_shim.MariaDBShim, default.DefaultDialect):
     """Details of the MySQL dialect.
     Not used directly in application code.
     """
 
     name = "mysql"
+
+    is_mariadb = False
+
     supports_statement_cache = True
 
     supports_alter = True
@@ -2784,14 +2772,13 @@ class MySQLDialect(default.DefaultDialect):
     ischema_names = ischema_names
     preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer
 
-    is_mariadb: bool = False
-    _mariadb_normalized_version_info = None
-
     # default SQL compilation settings -
     # these are modified upon initialize(),
     # i.e. first connect
     _backslash_escapes = True
     _server_ansiquotes = False
+    _support_default_function = True
+    _support_float_cast = False
 
     server_version_info: tuple[int, ...]
     identifier_preparer: MySQLIdentifierPreparer
@@ -2864,24 +2851,6 @@ class MySQLDialect(default.DefaultDialect):
             val = val.decode()
         return val.upper().replace("-", " ")  # type: ignore[no-any-return]
 
-    @classmethod
-    def _is_mariadb_from_url(cls, url: URL) -> bool:
-        dbapi = cls.import_dbapi()
-        dialect = cls(dbapi=dbapi)
-
-        cargs, cparams = dialect.create_connect_args(url)
-        conn = dialect.connect(*cargs, **cparams)
-        try:
-            cursor = conn.cursor()
-            cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
-            val = cursor.fetchone()[0]  # type: ignore[index]
-        except:
-            raise
-        else:
-            return bool(val)
-        finally:
-            conn.close()
-
     def _get_server_version_info(
         self, connection: Connection
     ) -> tuple[int, ...]:
@@ -2905,6 +2874,8 @@ class MySQLDialect(default.DefaultDialect):
 
         r = re.compile(r"[.\-+]")
         tokens = r.split(val)
+
+        _mariadb_normalized_version_info = None
         for token in tokens:
             parsed_token = re.match(
                 r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token
@@ -2912,21 +2883,21 @@ class MySQLDialect(default.DefaultDialect):
             if not parsed_token:
                 continue
             elif parsed_token.group(2):
-                self._mariadb_normalized_version_info = tuple(version[-3:])
+                _mariadb_normalized_version_info = tuple(version[-3:])
                 is_mariadb = True
             else:
                 digit = int(parsed_token.group(1))
                 version.append(digit)
 
-        server_version_info = tuple(version)
+        if _mariadb_normalized_version_info:
+            server_version_info = _mariadb_normalized_version_info
+        else:
+            server_version_info = tuple(version)
 
         self._set_mariadb(
             bool(server_version_info and is_mariadb), server_version_info
         )
 
-        if not is_mariadb:
-            self._mariadb_normalized_version_info = server_version_info
-
         if server_version_info < (5, 0, 2):
             raise NotImplementedError(
                 "the MySQL/MariaDB dialect supports server "
@@ -2937,32 +2908,6 @@ class MySQLDialect(default.DefaultDialect):
         self.server_version_info = server_version_info
         return server_version_info
 
-    def _set_mariadb(
-        self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
-    ) -> None:
-        if is_mariadb is None:
-            return
-
-        if not is_mariadb and self.is_mariadb:
-            raise exc.InvalidRequestError(
-                "MySQL version %s is not a MariaDB variant."
-                % (".".join(map(str, server_version_info)),)
-            )
-        if is_mariadb:
-
-            if not issubclass(self.preparer, MariaDBIdentifierPreparer):
-                self.preparer = MariaDBIdentifierPreparer
-                # this would have been set by the default dialect already,
-                # so set it again
-                self.identifier_preparer = self.preparer(self)
-
-            # this will be updated on first connect in initialize()
-            # if using older mariadb version
-            self.delete_returning = True
-            self.insert_returning = True
-
-        self.is_mariadb = is_mariadb
-
     def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
         connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
 
@@ -3179,6 +3124,18 @@ class MySQLDialect(default.DefaultDialect):
             )
         ]
 
+    def _dispatch_for_vendor(
+        self,
+        mysql_callable: Callable[..., _T],
+        mariadb_callable: Callable[..., _T],
+        *arg: Any,
+        **kw: Any,
+    ) -> _T:
+        if not self.is_mariadb:
+            return mysql_callable(*arg, **kw)
+        else:
+            return mariadb_callable(*arg, **kw)
+
     def initialize(self, connection: Connection) -> None:
         # this is driver-based, does not need server version info
         # and is fairly critical for even basic SQL operations
@@ -3202,92 +3159,33 @@ class MySQLDialect(default.DefaultDialect):
                 self, server_ansiquotes=self._server_ansiquotes
             )
 
-        self.supports_sequences = (
-            self.is_mariadb and self.server_version_info >= (10, 3)
+        self._dispatch_for_vendor(
+            self._initialize_mysql, self._initialize_mariadb, connection
         )
 
-        self.supports_for_update_of = (
-            self._is_mysql and self.server_version_info >= (8,)
-        )
+    def _initialize_mysql(self, connection: Connection) -> None:
+        assert not self.is_mariadb
 
-        self.use_mysql_for_share = (
-            self._is_mysql and self.server_version_info >= (8, 0, 1)
-        )
+        self.supports_for_update_of = self.server_version_info >= (8,)
 
-        self._needs_correct_for_88718_96365 = (
-            not self.is_mariadb and self.server_version_info >= (8,)
-        )
-
-        self.delete_returning = (
-            self.is_mariadb and self.server_version_info >= (10, 0, 5)
-        )
+        self.use_mysql_for_share = self.server_version_info >= (8, 0, 1)
 
-        self.insert_returning = (
-            self.is_mariadb and self.server_version_info >= (10, 5)
-        )
+        self._needs_correct_for_88718_96365 = self.server_version_info >= (8,)
 
         self._requires_alias_for_on_duplicate_key = (
-            self._is_mysql and self.server_version_info >= (8, 0, 20)
+            self.server_version_info >= (8, 0, 20)
         )
 
-        self._warn_for_known_db_issues()
+        # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa
+        self._support_default_function = self.server_version_info >= (8, 0, 13)
 
-    def _warn_for_known_db_issues(self) -> None:
-        if self.is_mariadb:
-            mdb_version = self._mariadb_normalized_version_info
-            assert mdb_version is not None
-            if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
-                util.warn(
-                    "MariaDB %r before 10.2.9 has known issues regarding "
-                    "CHECK constraints, which impact handling of NULL values "
-                    "with SQLAlchemy's boolean datatype (MDEV-13596). An "
-                    "additional issue prevents proper migrations of columns "
-                    "with CHECK constraints (MDEV-11114).  Please upgrade to "
-                    "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
-                    "series, to avoid these issues." % (mdb_version,)
-                )
-
-    @property
-    def _support_float_cast(self) -> bool:
-        if not self.server_version_info:
-            return False
-        elif self.is_mariadb:
-            # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
-            return self.server_version_info >= (10, 4, 5)
-        else:
-            # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature  # noqa
-            return self.server_version_info >= (8, 0, 17)
-
-    @property
-    def _support_default_function(self) -> bool:
-        if not self.server_version_info:
-            return False
-        elif self.is_mariadb:
-            # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/
-            return self.server_version_info >= (10, 2, 1)
-        else:
-            # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa
-            return self.server_version_info >= (8, 0, 13)
-
-    @property
-    def _is_mariadb(self) -> bool:
-        return self.is_mariadb
+        # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature  # noqa
+        self._support_float_cast = self.server_version_info >= (8, 0, 17)
 
     @property
     def _is_mysql(self) -> bool:
         return not self.is_mariadb
 
-    @property
-    def _is_mariadb_102(self) -> bool:
-        return (
-            self.is_mariadb
-            and self._mariadb_normalized_version_info  # type:ignore[operator]
-            > (
-                10,
-                2,
-            )
-        )
-
     @reflection.cache
     def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]:
         rp = connection.exec_driver_sql("SHOW schemas")
@@ -3427,7 +3325,7 @@ class MySQLDialect(default.DefaultDialect):
             }
             fkeys.append(fkey_d)
 
-        if self._needs_correct_for_88718_96365:
+        if self._is_mysql and self._needs_correct_for_88718_96365:
             self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
 
         return fkeys if fkeys else ReflectionDefaults.foreign_keys()
@@ -3904,8 +3802,8 @@ class MySQLDialect(default.DefaultDialect):
 
                 elif code == 1356:
                     raise exc.UnreflectableTableError(
-                        "Table or view named %s could not be "
-                        "reflected: %s" % (full_name, e)
+                        "Table or view named %s could not be reflected: %s"
+                        % (full_name, e)
                     ) from e
 
                 else:
index faad28af10c6b97bcce5a670f865e1c7e065616e..cdb1e2ffeed477a5221294aa02648ef11e4ded7c 100644 (file)
@@ -8,22 +8,9 @@
 from __future__ import annotations
 
 from typing import Any
-from typing import Optional
-from typing import TYPE_CHECKING
 
-from .base import MariaDBIdentifierPreparer
 from .base import MySQLDialect
-from .base import MySQLIdentifierPreparer
-from .base import MySQLTypeCompiler
-from ... import util
 from ...sql import sqltypes
-from ...sql.sqltypes import _UUID_RETURN
-from ...sql.sqltypes import UUID
-from ...sql.sqltypes import Uuid
-
-if TYPE_CHECKING:
-    from ...engine.base import Connection
-    from ...sql.type_api import _BindProcessorType
 
 
 class INET4(sqltypes.TypeEngine[str]):
@@ -44,40 +31,6 @@ class INET6(sqltypes.TypeEngine[str]):
     __visit_name__ = "INET6"
 
 
-class _MariaDBUUID(UUID[_UUID_RETURN]):
-    def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
-        self.as_uuid = as_uuid
-
-        # the _MariaDBUUID internal type is only invoked for a Uuid() with
-        # native_uuid=True.   for non-native uuid type, the plain Uuid
-        # returns itself due to the workings of the Emulated superclass.
-        assert native_uuid
-
-        # for internal type, force string conversion for result_processor() as
-        # current drivers are returning a string, not a Python UUID object
-        self.native_uuid = False
-
-    @property
-    def native(self) -> bool:  # type: ignore[override]
-        # override to return True, this is a native type, just turning
-        # off native_uuid for internal data handling
-        return True
-
-    def bind_processor(self, dialect: MariaDBDialect) -> Optional[_BindProcessorType[_UUID_RETURN]]:  # type: ignore[override] # noqa: E501
-        if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
-            return super().bind_processor(dialect)  # type: ignore[return-value] # noqa: E501
-        else:
-            return None
-
-
-class MariaDBTypeCompiler(MySQLTypeCompiler):
-    def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
-        return "INET4"
-
-    def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
-        return "INET6"
-
-
 class MariaDBDialect(MySQLDialect):
     is_mariadb = True
     supports_statement_cache = True
@@ -86,18 +39,10 @@ class MariaDBDialect(MySQLDialect):
     _allows_uuid_binds = True
 
     name = "mariadb"
-    preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
-    type_compiler_cls = MariaDBTypeCompiler
 
-    colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID})
-
-    def initialize(self, connection: Connection) -> None:
-        super().initialize(connection)
-
-        self.supports_native_uuid = (
-            self.server_version_info is not None
-            and self.server_version_info >= (10, 7)
-        )
+    def __init__(self, **kw: Any) -> None:
+        kw["is_mariadb"] = True
+        super().__init__(**kw)
 
 
 def loader(driver: str) -> type[MariaDBDialect]:
index 6159293c07886ba2b6c0f74f864177757a63f64e..6ace9bb3030774e29b6f5596e7586e0318eb7048 100644 (file)
@@ -55,7 +55,6 @@ from typing import Sequence
 from typing import TYPE_CHECKING
 from typing import Union
 
-from .base import MariaDBIdentifierPreparer
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
@@ -123,12 +122,6 @@ class MySQLIdentifierPreparer_mysqlconnector(
     pass
 
 
-class MariaDBIdentifierPreparer_mysqlconnector(
-    IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
-):
-    pass
-
-
 class _myconnpyBIT(BIT):
     def result_processor(self, dialect: Any, coltype: Any) -> None:
         """MySQL-connector already converts mysql bits, so."""
@@ -296,7 +289,7 @@ class MariaDBDialect_mysqlconnector(
 ):
     supports_statement_cache = True
     _allows_uuid_binds = False
-    preparer = MariaDBIdentifierPreparer_mysqlconnector
+    preparer = MySQLIdentifierPreparer_mysqlconnector
 
 
 dialect = MySQLDialect_mysqlconnector
index bbbcb34859f2ee8ac20b335ab0a2a1b190737746..22253d6a612fde0ad0828d3cdd1f551bf5d8d2f1 100644 (file)
@@ -72,8 +72,6 @@ from ..util.typing import TupleAny
 from ..util.typing import Unpack
 
 if typing.TYPE_CHECKING:
-    from types import ModuleType
-
     from .base import Engine
     from .cursor import ResultFetchStrategy
     from .interfaces import _CoreMultiExecuteParams
@@ -302,7 +300,7 @@ class DefaultDialect(Dialect):
         self,
         paramstyle: Optional[_ParamStyle] = None,
         isolation_level: Optional[IsolationLevel] = None,
-        dbapi: Optional[ModuleType] = None,
+        dbapi: Optional[DBAPIModule] = None,
         implicit_returning: Literal[True] = True,
         supports_native_boolean: Optional[bool] = None,
         max_identifier_length: Optional[int] = None,
index 911026feaa6d3857547962e582f1069d0920740f..31c8ac8b8a4498fbec9ad1d006e25296128df471 100644 (file)
@@ -7314,7 +7314,9 @@ class DDLCompiler(Compiled):
             " CASCADE" if drop.cascade else "",
         )
 
-    def get_column_specification(self, column, **kwargs):
+    def get_column_specification(
+        self, column: Column[Any], **kwargs: Any
+    ) -> str:
         colspec = (
             self.preparer.format_column(column)
             + " "
@@ -7341,10 +7343,10 @@ class DDLCompiler(Compiled):
             colspec += " NOT NULL"
         return colspec
 
-    def create_table_suffix(self, table):
+    def create_table_suffix(self, table: Table) -> str:
         return ""
 
-    def post_create_table(self, table):
+    def post_create_table(self, table: Table) -> str:
         return ""
 
     def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
index 265d08c11150c48019dd37e96e8b5db8ec323f1e..3cfa35c1dd131f1dabff13196f8682933a2770e8 100644 (file)
@@ -433,6 +433,11 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
     ):
         dialect = mysql.dialect(is_mariadb=is_mariadb)
         dialect.server_version_info = version
+        if is_mariadb:
+            with testing.expect_warnings(".*"):
+                dialect._initialize_mariadb(None)
+        else:
+            dialect._initialize_mysql(None)
 
         m = MetaData()
         tbl = Table(
@@ -1143,8 +1148,10 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL, CacheKeyFixture):
         if maria_db:
             dialect.is_mariadb = maria_db
             dialect.server_version_info = (10, 4, 5)
+            dialect._initialize_mariadb(None)
         else:
             dialect.server_version_info = (8, 0, 17)
+            dialect._initialize_mysql(None)
         t = sql.table("t", sql.column("col"))
         self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
 
index 0dcc079a22bcd7e218badb48ae18c201def26054..41bf091c5a1fcf4e26af28c4408411331b5b73c5 100644 (file)
@@ -428,37 +428,33 @@ class ParseVersionTest(fixtures.TestBase):
         )
 
     @testing.combinations(
-        ((10, 2, 7), "10.2.7-MariaDB", (10, 2, 7), True),
+        ((10, 2, 7), "10.2.7-MariaDB", True),
         (
             (10, 2, 7),
             "5.6.15.10.2.7-MariaDB",
-            (5, 6, 15, 10, 2, 7),
             True,
         ),
-        ((5, 0, 51, 24), "5.0.51a.24+lenny5", (5, 0, 51, 24), False),
-        ((10, 2, 10), "10.2.10-MariaDB", (10, 2, 10), True),
-        ((5, 7, 20), "5.7.20", (5, 7, 20), False),
-        ((5, 6, 15), "5.6.15", (5, 6, 15), False),
+        ((5, 0, 51, 24), "5.0.51a.24+lenny5", False),
+        ((10, 2, 10), "10.2.10-MariaDB", True),
+        ((5, 7, 20), "5.7.20", False),
+        ((5, 6, 15), "5.6.15", False),
         (
             (10, 2, 6),
             "10.2.6.MariaDB.10.2.6+maria~stretch-log",
-            (10, 2, 6, 10, 2, 6),
             True,
         ),
         (
             (10, 1, 9),
             "10.1.9-MariaDBV1.0R050D002-20170809-1522",
-            (10, 1, 9, 20170809, 1522),
             True,
         ),
     )
     def test_mariadb_normalized_version(
-        self, expected, raw_version, version, is_mariadb
+        self, expected, raw_version, is_mariadb
     ):
         dialect = mysql.dialect()
-        eq_(dialect._parse_server_version(raw_version), version)
-        dialect.server_version_info = version
-        eq_(dialect._mariadb_normalized_version_info, expected)
+        eq_(dialect._parse_server_version(raw_version), expected)
+        eq_(dialect.server_version_info, expected)
         assert dialect._is_mariadb is is_mariadb
 
     @testing.combinations(
index 06626b579323ea389586b25c44a0ee61d5fdd34d..19c7ec6214e91cf718eff5f01be2b9fb642fa8c5 100644 (file)
@@ -14,7 +14,6 @@ from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
-from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import schema
 from sqlalchemy import select
@@ -408,7 +407,10 @@ class ComputedTest(fixtures.TestBase):
         ("unset"),
         argnames="nullable",
     )
-    def test_column_computed_for_nullable(self, connection, nullable):
+    @testing.variation("setnull", ["constructor", "assign"])
+    def test_column_computed_for_nullable(
+        self, metadata, connection, nullable, setnull: testing.Variation
+    ):
         """test #10056
 
         we want to make sure that nullable is always set to True for computed
@@ -416,29 +418,42 @@ class ComputedTest(fixtures.TestBase):
         ref: https://mariadb.com/kb/en/generated-columns/#statement-support
 
         """
-        m = MetaData()
-        kwargs = {"nullable": nullable} if nullable != "unset" else {}
+
+        if setnull.assign:
+            kwargs = {}
+        elif setnull.constructor:
+            kwargs = {"nullable": nullable} if nullable != "unset" else {}
+        else:
+            setnull.fail()
+
         t = Table(
             "t",
-            m,
+            metadata,
             Column("x", Integer),
             Column("y", Integer, Computed("x + 2"), **kwargs),
         )
-        if connection.engine.dialect.name == "mariadb" and nullable in (
-            False,
-            None,
+
+        # the "assign" path here is to exercise the actual bug shown at
+        # #10056; user defined nullable is not passed, but MappedColumn
+        # sets it after the fact.   Previously this path was not exercised
+        if setnull.assign and nullable != "unset":
+            t.c.y.nullable = nullable
+
+        if (
+            connection.engine.dialect.name == "mariadb"
+            and setnull.constructor
+            and nullable in (False, None)
         ):
+            # if nullable was passed explicitly as False or None,
+            # then we dont touch .nullable
             assert_raises(
                 exc.ProgrammingError,
                 connection.execute,
                 schema.CreateTable(t),
             )
-            # If assertion happens table won't be created so
-            #  return from test
-            return
-        # Create and then drop table
-        connection.execute(schema.CreateTable(t))
-        connection.execute(schema.DropTable(t))
+        else:
+            # assert table can be created
+            connection.execute(schema.CreateTable(t))
 
 
 class LimitORMTest(fixtures.MappedTest):
index f3210ad61522ce829a1eb6f347dfe1d7def5cced..e42ad32ff97a4d7bd3e4367295f6d5f0d47e9222 100644 (file)
@@ -872,7 +872,11 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
             else:
                 default = er["default"]
 
-            if default is not None and connection.dialect._is_mariadb_102:
+            if (
+                default is not None
+                and connection.dialect.is_mariadb
+                and connection.dialect.server_version_info > (10, 2)
+            ):
                 default = default.replace(
                     "CURRENT_TIMESTAMP", "current_timestamp()"
                 )
index 284370c4ac76d632463eaca1f28b1f654dae514d..9ce8e3b33c6182f6ba7b12123c3aab50cee960a2 100644 (file)
@@ -475,26 +475,6 @@ class TypeCompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(type_, sql_text)
 
 
-class MariaDBUUIDTest(fixtures.TestBase, AssertsCompiledSQL):
-    __only_on__ = "mysql", "mariadb"
-    __backend__ = True
-
-    def test_requirements(self):
-        if testing.against("mariadb>=10.7"):
-            assert testing.requires.uuid_data_type.enabled
-        else:
-            assert not testing.requires.uuid_data_type.enabled
-
-    def test_compile_generic(self):
-        if testing.against("mariadb>=10.7"):
-            self.assert_compile(sqltypes.Uuid(), "UUID")
-        else:
-            self.assert_compile(sqltypes.Uuid(), "CHAR(32)")
-
-    def test_compile_upper(self):
-        self.assert_compile(sqltypes.UUID(), "UUID")
-
-
 class UUIDTest(fixtures.TestBase, AssertsCompiledSQL):
     @testing.combinations(
         (sqltypes.Uuid(), (10, 6, 5), "CHAR(32)"),
@@ -503,11 +483,18 @@ class UUIDTest(fixtures.TestBase, AssertsCompiledSQL):
         (sqltypes.Uuid(native_uuid=False), (10, 7, 0), "CHAR(32)"),
         (sqltypes.UUID(), (10, 6, 5), "UUID"),
         (sqltypes.UUID(), (10, 7, 0), "UUID"),
+        argnames="type_, version, res",
     )
-    def test_mariadb_uuid_combinations(self, type_, version, res):
-        dialect = mariadb.MariaDBDialect()
+    @testing.variation("use_mariadb", [True, False])
+    def test_mariadb_uuid_combinations(self, type_, version, res, use_mariadb):
+        if use_mariadb:
+            dialect = mariadb.MariaDBDialect()
+        else:
+            dialect = mysql.MySQLDialect(is_mariadb=True)
+            dialect._set_mariadb(True, "10.2.0")
+
         dialect.server_version_info = version
-        dialect.supports_native_uuid = version >= (10, 7)
+        dialect._initialize_mariadb(None)
         self.assert_compile(type_, res, dialect=dialect)
 
     @testing.combinations(
@@ -519,15 +506,67 @@ class UUIDTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(type_, "CHAR(32)", dialect=dialect)
 
 
-class INETMariadbTest(fixtures.TestBase, AssertsCompiledSQL):
-    __dialect__ = mariadb.MariaDBDialect()
+class INETTest(fixtures.TestBase, AssertsCompiledSQL):
+    __only_on__ = ("mariadb",)
+    __backend__ = True
 
     @testing.combinations(
         (mariadb.INET4(), "INET4"),
         (mariadb.INET6(), "INET6"),
+        argnames="type_, res",
     )
-    def test_mariadb_inet6(self, type_, res):
-        self.assert_compile(type_, res)
+    @testing.variation("use_mariadb", [True, False])
+    def test_mariadb_inet6(self, type_, res, use_mariadb):
+        self.assert_compile(
+            type_, res, dialect="mariadb" if use_mariadb else "mysql"
+        )
+
+    @testing.fixture
+    def inet_table(self, metadata):
+        inet_table = Table(
+            "inet_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("ip4", mariadb.INET4()),
+            Column("ip6", mariadb.INET6()),
+        )
+        return inet_table
+
+    @testing.fixture
+    def mysql_url_mariadb(self, testing_engine):
+        url = testing.db.url
+
+        url = url.set(drivername=f"mysql+{url.get_driver_name()}")
+
+        return testing_engine(url)
+
+    def test_inet(self, inet_table):
+        self._roundtrip(inet_table, testing.db)
+
+    def test_inet_w_mysql_url(self, mysql_url_mariadb, inet_table):
+        db = mysql_url_mariadb
+
+        self._roundtrip(inet_table, db)
+
+    def _roundtrip(self, inet_table, db):
+
+        with db.begin() as conn:
+            inet_table.create(conn)
+
+            conn.execute(
+                inet_table.insert(),
+                {
+                    "ip4": "192.168.1.1",
+                    "ip6": "2001:db8:85a3:0:0:8a2e:370:7334",
+                },
+            )
+
+            eq_(
+                conn.execute(
+                    select(inet_table.c.ip4, inet_table.c.ip6)
+                ).first(),
+                ("192.168.1.1", "2001:db8:85a3::8a2e:370:7334"),
+            )
 
 
 class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults):
index cdc5e4f869e0317e3d5ff8a5f5ecd776170f325b..164c407d5e9ae7ca3aaace5552b7e428c294fa24 100644 (file)
@@ -796,8 +796,7 @@ class DefaultRequirements(SuiteRequirements):
                 and (
                     (
                         config.db.dialect._is_mariadb
-                        and config.db.dialect._mariadb_normalized_version_info
-                        >= (10, 2)
+                        and config.db.dialect.server_version_info >= (10, 2)
                     )
                 ),
                 "mariadb>10.2",
@@ -1238,10 +1237,7 @@ class DefaultRequirements(SuiteRequirements):
                         not config.db.dialect._is_mariadb
                         and against(config, "mysql >= 5.7")
                     )
-                    or (
-                        config.db.dialect._mariadb_normalized_version_info
-                        >= (10, 2, 7)
-                    )
+                    or (config.db.dialect.server_version_info >= (10, 2, 7))
                 ),
                 "mariadb>=10.2.7",
                 "postgresql >= 9.3",
@@ -1915,14 +1911,14 @@ class DefaultRequirements(SuiteRequirements):
         return (
             against(config, ["mysql", "mariadb"])
             and config.db.dialect._is_mariadb
-            and config.db.dialect._mariadb_normalized_version_info >= (10, 2)
+            and config.db.dialect.server_version_info >= (10, 2)
         )
 
     def _mariadb_105(self, config):
         return (
             against(config, ["mysql", "mariadb"])
             and config.db.dialect._is_mariadb
-            and config.db.dialect._mariadb_normalized_version_info >= (10, 5)
+            and config.db.dialect.server_version_info >= (10, 5)
         )
 
     def _mysql_and_check_constraints_exist(self, config):
@@ -1930,9 +1926,7 @@ class DefaultRequirements(SuiteRequirements):
         # 2. it enforces check constraints
         if exclusions.against(config, ["mysql", "mariadb"]):
             if config.db.dialect._is_mariadb:
-                norm_version_info = (
-                    config.db.dialect._mariadb_normalized_version_info
-                )
+                norm_version_info = config.db.dialect.server_version_info
                 return norm_version_info >= (10, 2)
             else:
                 norm_version_info = config.db.dialect.server_version_info