--- /dev/null
+.. change::
+ :tags: bug, mariadb
+ :tickets: 13076
+
+ Fixes to the MySQL/MariaDB dialect so that mariadb-specific features such
+ as the :class:`.mariadb.INET4` and :class:`.mariadb.INET6` datatype may be
+ used with an :class:`.Engine` that uses a ``mysql://`` URL, if the backend
+ database is actually a mariadb database. Previously, support for MariaDB
+ features when ``mysql://`` URLs were used instead of ``mariadb://`` URLs
+ was ad-hoc; with this issue resolution, the full set of schema / compiler /
+ type features are now available regardless of how the URL was presented.
--- /dev/null
+# dialects/mysql/_mariadb_shim.py
+# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
+
+from .reserved_words import RESERVED_WORDS_MARIADB
+from ... import exc
+from ... import schema as sa_schema
+from ... import util
+from ...engine import cursor as _cursor
+from ...engine import default
+from ...engine.default import DefaultDialect
+from ...engine.interfaces import TypeCompiler
+from ...sql import elements
+from ...sql import sqltypes
+from ...sql.compiler import DDLCompiler
+from ...sql.compiler import IdentifierPreparer
+from ...sql.compiler import SQLCompiler
+from ...sql.schema import SchemaConst
+from ...sql.sqltypes import _UUID_RETURN
+from ...sql.sqltypes import UUID
+from ...sql.sqltypes import Uuid
+
+if TYPE_CHECKING:
+ from .base import MySQLIdentifierPreparer
+ from .mariadb import INET4
+ from .mariadb import INET6
+ from ...engine import URL
+ from ...engine.base import Connection
+ from ...sql import ddl
+ from ...sql.schema import IdentityOptions
+ from ...sql.schema import Sequence as Sequence_SchemaItem
+ from ...sql.type_api import _BindProcessorType
+
+
+class _MariaDBUUID(UUID[_UUID_RETURN]):
+ def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
+ self.as_uuid = as_uuid
+
+ # the _MariaDBUUID internal type is only invoked for a Uuid() with
+ # native_uuid=True. for non-native uuid type, the plain Uuid
+ # returns itself due to the workings of the Emulated superclass.
+ assert native_uuid
+
+ # for internal type, force string conversion for result_processor() as
+ # current drivers are returning a string, not a Python UUID object
+ self.native_uuid = False
+
+ @property
+ def native(self) -> bool: # type: ignore[override]
+ # override to return True, this is a native type, just turning
+ # off native_uuid for internal data handling
+ return True
+
+ def bind_processor(self, dialect: MariaDBShim) -> Optional[_BindProcessorType[_UUID_RETURN]]: # type: ignore[override] # noqa: E501
+ if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
+ return super().bind_processor(dialect) # type: ignore[return-value] # noqa: E501
+ else:
+ return None
+
+
+class MariaDBTypeCompilerShim(TypeCompiler):
+ def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
+ return "INET4"
+
+ def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
+ return "INET6"
+
+
+class MariadbExecutionContextShim(default.DefaultExecutionContext):
+ def post_exec(self) -> None:
+ if (
+ self.isdelete
+ and cast(SQLCompiler, self.compiled).effective_returning
+ and not self.cursor.description
+ ):
+ # All MySQL/mariadb drivers appear to not include
+ # cursor.description for DELETE..RETURNING with no rows if the
+ # WHERE criteria is a straight "false" condition such as our EMPTY
+ # IN condition. manufacture an empty result in this case (issue
+ # #10505)
+ #
+ # taken from cx_Oracle implementation
+ self.cursor_fetch_strategy = (
+ _cursor.FullyBufferedCursorFetchStrategy(
+ self.cursor,
+ [
+ (entry.keyname, None) # type: ignore[misc]
+ for entry in cast(
+ SQLCompiler, self.compiled
+ )._result_columns
+ ],
+ [],
+ )
+ )
+
+ def fire_sequence(
+ self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
+ ) -> int:
+ return self._execute_scalar( # type: ignore[no-any-return]
+ (
+ "select nextval(%s)"
+ % self.identifier_preparer.format_sequence(seq)
+ ),
+ type_,
+ )
+
+
+class MariaDBIdentifierPreparerShim(IdentifierPreparer):
+ def _set_mariadb(self) -> None:
+ self.reserved_words = RESERVED_WORDS_MARIADB
+
+
+class MariaDBSQLCompilerShim(SQLCompiler):
+ def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
+ return "nextval(%s)" % self.preparer.format_sequence(sequence)
+
+ def _mariadb_regexp_flags(
+ self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+ ) -> str:
+ return "CONCAT('(?', %s, ')', %s)" % (
+ self.render_literal_value(flags, sqltypes.STRINGTYPE),
+ self.process(pattern, **kw),
+ )
+
+ def _mariadb_regexp_match(
+ self,
+ op_string: str,
+ binary: elements.BinaryExpression[Any],
+ operator: Any,
+ **kw: Any,
+ ) -> str:
+ flags = binary.modifiers["flags"]
+ return "%s%s%s" % (
+ self.process(binary.left, **kw),
+ op_string,
+ self._mariadb_regexp_flags(flags, binary.right),
+ )
+
+ def _mariadb_regexp_replace_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
+ flags = binary.modifiers["flags"]
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self._mariadb_regexp_flags(flags, binary.right.clauses[0]),
+ self.process(binary.right.clauses[1], **kw),
+ )
+
+ def _mariadb_visit_drop_check_constraint(
+ self, drop: ddl.DropConstraint, **kw: Any
+ ) -> str:
+ constraint = drop.element
+ qual = "CONSTRAINT "
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
+
+
+class MariaDBDDLCompilerShim(DDLCompiler):
+ dialect: MariaDBShim
+
+ def _mariadb_get_column_specification(
+ self, column: sa_schema.Column[Any], **kw: Any
+ ) -> str:
+
+ if (
+ column.computed is not None
+ and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED
+ ):
+ kw["_force_column_to_nullable"] = True
+
+ return self._mysql_get_column_specification(column, **kw)
+
+ def _mysql_get_column_specification(
+ self,
+ column: sa_schema.Column[Any],
+ *,
+ _force_column_to_nullable: bool = False,
+ **kw: Any,
+ ) -> str:
+ raise NotImplementedError()
+
+ def get_identity_options(self, identity_options: IdentityOptions) -> str:
+ text = super().get_identity_options(identity_options)
+ text = text.replace("NO CYCLE", "NOCYCLE")
+ return text
+
+ def _mariadb_visit_drop_check_constraint(
+ self, drop: ddl.DropConstraint, **kw: Any
+ ) -> str:
+ constraint = drop.element
+ qual = "CONSTRAINT "
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
+
+
+class MariaDBShim(DefaultDialect):
+ server_version_info: tuple[int, ...]
+ is_mariadb: bool
+ _allows_uuid_binds = False
+
+ identifier_preparer: MySQLIdentifierPreparer
+ preparer: Type[MySQLIdentifierPreparer]
+
+ def _set_mariadb(
+ self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
+ ) -> None:
+ if is_mariadb is None:
+ return
+
+ if not is_mariadb and self.is_mariadb:
+ raise exc.InvalidRequestError(
+ "MySQL version %s is not a MariaDB variant."
+ % (".".join(map(str, server_version_info)),)
+ )
+ if is_mariadb:
+ assert isinstance(self.colspecs, dict)
+ self.colspecs = util.update_copy(
+ self.colspecs, {Uuid: _MariaDBUUID}
+ )
+
+ self.identifier_preparer = self.preparer(self)
+ self.identifier_preparer._set_mariadb()
+
+ # this will be updated on first connect in initialize()
+ # if using older mariadb version
+ self.delete_returning = True
+ self.insert_returning = True
+
+ self.is_mariadb = is_mariadb
+
+ @property
+ def _mariadb_normalized_version_info(self) -> tuple[int, ...]:
+ return self.server_version_info
+
+ @property
+ def _is_mariadb(self) -> bool:
+ return self.is_mariadb
+
+ @classmethod
+ def _is_mariadb_from_url(cls, url: URL) -> bool:
+ dbapi = cls.import_dbapi()
+ dialect = cls(dbapi=dbapi)
+
+ cargs, cparams = dialect.create_connect_args(url)
+ conn = dialect.connect(*cargs, **cparams)
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
+ val = cursor.fetchone()[0] # type: ignore[index]
+ except:
+ raise
+ else:
+ return bool(val)
+ finally:
+ conn.close()
+
+ def _initialize_mariadb(self, connection: Connection) -> None:
+ assert self.is_mariadb
+
+ self.supports_sequences = self.server_version_info >= (10, 3)
+
+ self.delete_returning = self.server_version_info >= (10, 0, 5)
+
+ self.insert_returning = self.server_version_info >= (10, 5)
+
+ self._warn_for_known_db_issues()
+
+ self.supports_native_uuid = (
+ self.server_version_info is not None
+ and self.server_version_info >= (10, 7)
+ )
+ self._allows_uuid_binds = True
+
+ # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/
+ self._support_default_function = self.server_version_info >= (10, 2, 1)
+
+ # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
+ self._support_float_cast = self.server_version_info >= (10, 4, 5)
+
+ def _warn_for_known_db_issues(self) -> None:
+ if self.is_mariadb:
+ mdb_version = self.server_version_info
+ assert mdb_version is not None
+ if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
+ util.warn(
+ "MariaDB %r before 10.2.9 has known issues regarding "
+ "CHECK constraints, which impact handling of NULL values "
+ "with SQLAlchemy's boolean datatype (MDEV-13596). An "
+ "additional issue prevents proper migrations of columns "
+ "with CHECK constraints (MDEV-11114). Please upgrade to "
+ "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
+ "series, to avoid these issues." % (mdb_version,)
+ )
from typing import overload
from typing import Sequence
from typing import TYPE_CHECKING
+from typing import TypeVar
from typing import Union
+from . import _mariadb_shim
from . import reflection as _reflection
from .enumerated import ENUM
from .enumerated import SET
from .json import JSON
from .json import JSONIndexType
from .json import JSONPathType
-from .reserved_words import RESERVED_WORDS_MARIADB
from .reserved_words import RESERVED_WORDS_MYSQL
from .types import _FloatType
from .types import _IntegerType
from ... import schema as sa_schema
from ... import sql
from ... import util
-from ...engine import cursor as _cursor
from ...engine import default
from ...engine import reflection
from ...engine.reflection import ReflectionDefaults
from ...sql import util as sql_util
from ...sql import visitors
from ...sql.compiler import InsertmanyvaluesSentinelOpts
-from ...sql.compiler import SQLCompiler
-from ...sql.schema import SchemaConst
from ...types import BINARY
from ...types import BLOB
from ...types import BOOLEAN
from ...engine.interfaces import ReflectedUniqueConstraint
from ...engine.result import _Ts
from ...engine.row import Row
- from ...engine.url import URL
from ...schema import Table
from ...sql import ddl
from ...sql import selectable
from ...sql.functions import random
from ...sql.functions import rollup
from ...sql.functions import sysdate
- from ...sql.schema import IdentityOptions
- from ...sql.schema import Sequence as Sequence_SchemaItem
from ...sql.type_api import TypeEngine
from ...sql.visitors import ExternallyTraversible
from ...util.typing import TupleAny
from ...util.typing import Unpack
-
+_T = TypeVar("_T", bound=Any)
SET_RE = re.compile(
r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE
)
}
-class MySQLExecutionContext(default.DefaultExecutionContext):
- def post_exec(self) -> None:
- if (
- self.isdelete
- and cast(SQLCompiler, self.compiled).effective_returning
- and not self.cursor.description
- ):
- # All MySQL/mariadb drivers appear to not include
- # cursor.description for DELETE..RETURNING with no rows if the
- # WHERE criteria is a straight "false" condition such as our EMPTY
- # IN condition. manufacture an empty result in this case (issue
- # #10505)
- #
- # taken from cx_Oracle implementation
- self.cursor_fetch_strategy = (
- _cursor.FullyBufferedCursorFetchStrategy(
- self.cursor,
- [
- (entry.keyname, None) # type: ignore[misc]
- for entry in cast(
- SQLCompiler, self.compiled
- )._result_columns
- ],
- [],
- )
- )
-
+class MySQLExecutionContext(
+ _mariadb_shim.MariadbExecutionContextShim, default.DefaultExecutionContext
+):
def create_server_side_cursor(self) -> DBAPICursor:
if self.dialect.supports_server_side_cursors:
return self._dbapi_connection.cursor(
else:
raise NotImplementedError()
- def fire_sequence(
- self, seq: Sequence_SchemaItem, type_: sqltypes.Integer
- ) -> int:
- return self._execute_scalar( # type: ignore[no-any-return]
- (
- "select nextval(%s)"
- % self.identifier_preparer.format_sequence(seq)
- ),
- type_,
- )
-
-class MySQLCompiler(compiler.SQLCompiler):
+class MySQLCompiler(
+ _mariadb_shim.MariaDBSQLCompilerShim, compiler.SQLCompiler
+):
dialect: MySQLDialect
render_table_with_column_in_update_from = True
"""Overridden from base SQLCompiler value"""
return (
f"group_concat({expr._compiler_dispatch(self, **kw)} "
f"ORDER BY {order_by._compiler_dispatch(self, **kw)} "
- f"SEPARATOR "
+ "SEPARATOR "
f"{delimiter._compiler_dispatch(self, **literal_exec)})"
)
else:
return (
f"group_concat({expr._compiler_dispatch(self, **kw)} "
- f"SEPARATOR "
+ "SEPARATOR "
f"{delimiter._compiler_dispatch(self, **literal_exec)})"
)
- def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str:
- return "nextval(%s)" % self.preparer.format_sequence(sequence)
-
def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str:
return "SYSDATE()"
"any column keys in table '%s': %s"
% (
self.statement.table.name, # type: ignore[union-attr]
- (", ".join("'%s'" % c for c in non_matching)),
+ ", ".join("'%s'" % c for c in non_matching),
)
)
def visit_concat_op_expression_clauselist(
self, clauselist: elements.ClauseList, operator: Any, **kw: Any
) -> str:
- return "concat(%s)" % (
- ", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
+ return "concat(%s)" % ", ".join(
+ self.process(elem, **kw) for elem in clauselist.clauses
)
def visit_concat_op_binary(
self, element_types: list[TypeEngine[Any]], **kw: Any
) -> str:
return (
- "SELECT %(outer)s FROM (SELECT %(inner)s) "
- "as _empty_set WHERE 1!=1"
+ "SELECT %(outer)s FROM (SELECT %(inner)s) as _empty_set WHERE 1!=1"
% {
"inner": ", ".join(
"1 AS _in_%s" % idx
self.process(binary.right),
)
- def _mariadb_regexp_flags(
- self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any
+ def _mysql_regexp_match(
+ self,
+ op_string: str,
+ binary: elements.BinaryExpression[Any],
+ operator: Any,
+ **kw: Any,
) -> str:
- return "CONCAT('(?', %s, ')', %s)" % (
+ flags = binary.modifiers["flags"]
+
+ text = "REGEXP_LIKE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
self.render_literal_value(flags, sqltypes.STRINGTYPE),
- self.process(pattern, **kw),
)
+ if op_string == " NOT REGEXP ":
+ return "NOT %s" % text
+ else:
+ return text
def _regexp_match(
self,
flags = binary.modifiers["flags"]
if flags is None:
return self._generate_generic_binary(binary, op_string, **kw)
- elif self.dialect.is_mariadb:
- return "%s%s%s" % (
- self.process(binary.left, **kw),
- op_string,
- self._mariadb_regexp_flags(flags, binary.right),
- )
else:
- text = "REGEXP_LIKE(%s, %s, %s)" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw),
- self.render_literal_value(flags, sqltypes.STRINGTYPE),
+ return self.dialect._dispatch_for_vendor(
+ self._mysql_regexp_match,
+ self._mariadb_regexp_match,
+ op_string,
+ binary,
+ operator,
+ **kw,
)
- if op_string == " NOT REGEXP ":
- return "NOT %s" % text
- else:
- return text
def visit_regexp_match_op_binary(
self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)
- elif self.dialect.is_mariadb:
- return "REGEXP_REPLACE(%s, %s, %s)" % (
- self.process(binary.left, **kw),
- self._mariadb_regexp_flags(flags, binary.right.clauses[0]),
- self.process(binary.right.clauses[1], **kw),
- )
else:
- return "REGEXP_REPLACE(%s, %s, %s)" % (
- self.process(binary.left, **kw),
- self.process(binary.right, **kw),
- self.render_literal_value(flags, sqltypes.STRINGTYPE),
+ return self.dialect._dispatch_for_vendor(
+ self._mysql_regexp_replace_op_binary,
+ self._mariadb_regexp_replace_op_binary,
+ binary,
+ operator,
+ **kw,
)
+ def _mysql_regexp_replace_op_binary(
+ self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any
+ ) -> str:
+ flags = binary.modifiers["flags"]
-class MySQLDDLCompiler(compiler.DDLCompiler):
+ return "REGEXP_REPLACE(%s, %s, %s)" % (
+ self.process(binary.left, **kw),
+ self.process(binary.right, **kw),
+ self.render_literal_value(flags, sqltypes.STRINGTYPE),
+ )
+
+
+class MySQLDDLCompiler(
+ _mariadb_shim.MariaDBDDLCompilerShim, compiler.DDLCompiler
+):
dialect: MySQLDialect
def get_column_specification(
self, column: sa_schema.Column[Any], **kw: Any
) -> str:
"""Builds column DDL."""
- if (
- self.dialect.is_mariadb is True
- and column.computed is not None
- and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED
- ):
- column.nullable = True
+
+ return self.dialect._dispatch_for_vendor(
+ self._mysql_get_column_specification,
+ self._mariadb_get_column_specification,
+ column,
+ **kw,
+ )
+
+ def _mysql_get_column_specification(
+ self,
+ column: sa_schema.Column[Any],
+ *,
+ _force_column_to_nullable: bool = False,
+ **kw: Any,
+ ) -> str:
+
colspec = [
self.preparer.format_column(column),
self.dialect.type_compiler_instance.process(
sqltypes.TIMESTAMP,
)
- if not column.nullable:
+ if not column.nullable and not _force_column_to_nullable:
colspec.append("NOT NULL")
# see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa
qual = "INDEX "
const = self.preparer.format_constraint(constraint)
elif isinstance(constraint, sa_schema.CheckConstraint):
- if self.dialect.is_mariadb:
- qual = "CONSTRAINT "
- else:
- qual = "CHECK "
- const = self.preparer.format_constraint(constraint)
+ return self.dialect._dispatch_for_vendor(
+ self._mysql_visit_drop_check_constraint,
+ self._mariadb_visit_drop_check_constraint,
+ drop,
+ **kw,
+ )
else:
qual = ""
const = self.preparer.format_constraint(constraint)
const,
)
+ def _mysql_visit_drop_check_constraint(
+ self, drop: ddl.DropConstraint, **kw: Any
+ ) -> str:
+ constraint = drop.element
+ qual = "CHECK "
+ const = self.preparer.format_constraint(constraint)
+ return "ALTER TABLE %s DROP %s%s" % (
+ self.preparer.format_table(constraint.table),
+ qual,
+ const,
+ )
+
def define_constraint_match(
self, constraint: sa_schema.ForeignKeyConstraint
) -> str:
self.get_column_specification(create.element),
)
- def get_identity_options(self, identity_options: IdentityOptions) -> str:
- """mariadb-specific sequence option; this will move to a
- mariadb-specific module in 2.1
- """
- text = super().get_identity_options(identity_options)
- text = text.replace("NO CYCLE", "NOCYCLE")
- return text
-
-
-class MySQLTypeCompiler(compiler.GenericTypeCompiler):
+class MySQLTypeCompiler(
+ _mariadb_shim.MariaDBTypeCompilerShim, compiler.GenericTypeCompiler
+):
def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str:
"Extend a numeric-type declaration with MySQL specific extensions."
return "BOOL"
-class MySQLIdentifierPreparer(compiler.IdentifierPreparer):
+class MySQLIdentifierPreparer(
+ _mariadb_shim.MariaDBIdentifierPreparerShim, compiler.IdentifierPreparer
+):
reserved_words = RESERVED_WORDS_MYSQL
def __init__(
return tuple([self.quote_identifier(i) for i in ids if i is not None])
-class MariaDBIdentifierPreparer(MySQLIdentifierPreparer):
- reserved_words = RESERVED_WORDS_MARIADB
-
-
-class MySQLDialect(default.DefaultDialect):
+class MySQLDialect(_mariadb_shim.MariaDBShim, default.DefaultDialect):
"""Details of the MySQL dialect.
Not used directly in application code.
"""
name = "mysql"
+
+ is_mariadb = False
+
supports_statement_cache = True
supports_alter = True
ischema_names = ischema_names
preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer
- is_mariadb: bool = False
- _mariadb_normalized_version_info = None
-
# default SQL compilation settings -
# these are modified upon initialize(),
# i.e. first connect
_backslash_escapes = True
_server_ansiquotes = False
+ _support_default_function = True
+ _support_float_cast = False
server_version_info: tuple[int, ...]
identifier_preparer: MySQLIdentifierPreparer
val = val.decode()
return val.upper().replace("-", " ") # type: ignore[no-any-return]
- @classmethod
- def _is_mariadb_from_url(cls, url: URL) -> bool:
- dbapi = cls.import_dbapi()
- dialect = cls(dbapi=dbapi)
-
- cargs, cparams = dialect.create_connect_args(url)
- conn = dialect.connect(*cargs, **cparams)
- try:
- cursor = conn.cursor()
- cursor.execute("SELECT VERSION() LIKE '%MariaDB%'")
- val = cursor.fetchone()[0] # type: ignore[index]
- except:
- raise
- else:
- return bool(val)
- finally:
- conn.close()
-
def _get_server_version_info(
self, connection: Connection
) -> tuple[int, ...]:
r = re.compile(r"[.\-+]")
tokens = r.split(val)
+
+ _mariadb_normalized_version_info = None
for token in tokens:
parsed_token = re.match(
r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token
if not parsed_token:
continue
elif parsed_token.group(2):
- self._mariadb_normalized_version_info = tuple(version[-3:])
+ _mariadb_normalized_version_info = tuple(version[-3:])
is_mariadb = True
else:
digit = int(parsed_token.group(1))
version.append(digit)
- server_version_info = tuple(version)
+ if _mariadb_normalized_version_info:
+ server_version_info = _mariadb_normalized_version_info
+ else:
+ server_version_info = tuple(version)
self._set_mariadb(
bool(server_version_info and is_mariadb), server_version_info
)
- if not is_mariadb:
- self._mariadb_normalized_version_info = server_version_info
-
if server_version_info < (5, 0, 2):
raise NotImplementedError(
"the MySQL/MariaDB dialect supports server "
self.server_version_info = server_version_info
return server_version_info
- def _set_mariadb(
- self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...]
- ) -> None:
- if is_mariadb is None:
- return
-
- if not is_mariadb and self.is_mariadb:
- raise exc.InvalidRequestError(
- "MySQL version %s is not a MariaDB variant."
- % (".".join(map(str, server_version_info)),)
- )
- if is_mariadb:
-
- if not issubclass(self.preparer, MariaDBIdentifierPreparer):
- self.preparer = MariaDBIdentifierPreparer
- # this would have been set by the default dialect already,
- # so set it again
- self.identifier_preparer = self.preparer(self)
-
- # this will be updated on first connect in initialize()
- # if using older mariadb version
- self.delete_returning = True
- self.insert_returning = True
-
- self.is_mariadb = is_mariadb
-
def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid))
)
]
+ def _dispatch_for_vendor(
+ self,
+ mysql_callable: Callable[..., _T],
+ mariadb_callable: Callable[..., _T],
+ *arg: Any,
+ **kw: Any,
+ ) -> _T:
+ if not self.is_mariadb:
+ return mysql_callable(*arg, **kw)
+ else:
+ return mariadb_callable(*arg, **kw)
+
def initialize(self, connection: Connection) -> None:
# this is driver-based, does not need server version info
# and is fairly critical for even basic SQL operations
self, server_ansiquotes=self._server_ansiquotes
)
- self.supports_sequences = (
- self.is_mariadb and self.server_version_info >= (10, 3)
+ self._dispatch_for_vendor(
+ self._initialize_mysql, self._initialize_mariadb, connection
)
- self.supports_for_update_of = (
- self._is_mysql and self.server_version_info >= (8,)
- )
+ def _initialize_mysql(self, connection: Connection) -> None:
+ assert not self.is_mariadb
- self.use_mysql_for_share = (
- self._is_mysql and self.server_version_info >= (8, 0, 1)
- )
+ self.supports_for_update_of = self.server_version_info >= (8,)
- self._needs_correct_for_88718_96365 = (
- not self.is_mariadb and self.server_version_info >= (8,)
- )
-
- self.delete_returning = (
- self.is_mariadb and self.server_version_info >= (10, 0, 5)
- )
+ self.use_mysql_for_share = self.server_version_info >= (8, 0, 1)
- self.insert_returning = (
- self.is_mariadb and self.server_version_info >= (10, 5)
- )
+ self._needs_correct_for_88718_96365 = self.server_version_info >= (8,)
self._requires_alias_for_on_duplicate_key = (
- self._is_mysql and self.server_version_info >= (8, 0, 20)
+ self.server_version_info >= (8, 0, 20)
)
- self._warn_for_known_db_issues()
+ # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa
+ self._support_default_function = self.server_version_info >= (8, 0, 13)
- def _warn_for_known_db_issues(self) -> None:
- if self.is_mariadb:
- mdb_version = self._mariadb_normalized_version_info
- assert mdb_version is not None
- if mdb_version > (10, 2) and mdb_version < (10, 2, 9):
- util.warn(
- "MariaDB %r before 10.2.9 has known issues regarding "
- "CHECK constraints, which impact handling of NULL values "
- "with SQLAlchemy's boolean datatype (MDEV-13596). An "
- "additional issue prevents proper migrations of columns "
- "with CHECK constraints (MDEV-11114). Please upgrade to "
- "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 "
- "series, to avoid these issues." % (mdb_version,)
- )
-
- @property
- def _support_float_cast(self) -> bool:
- if not self.server_version_info:
- return False
- elif self.is_mariadb:
- # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/
- return self.server_version_info >= (10, 4, 5)
- else:
- # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
- return self.server_version_info >= (8, 0, 17)
-
- @property
- def _support_default_function(self) -> bool:
- if not self.server_version_info:
- return False
- elif self.is_mariadb:
- # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/
- return self.server_version_info >= (10, 2, 1)
- else:
- # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa
- return self.server_version_info >= (8, 0, 13)
-
- @property
- def _is_mariadb(self) -> bool:
- return self.is_mariadb
+ # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa
+ self._support_float_cast = self.server_version_info >= (8, 0, 17)
@property
def _is_mysql(self) -> bool:
return not self.is_mariadb
- @property
- def _is_mariadb_102(self) -> bool:
- return (
- self.is_mariadb
- and self._mariadb_normalized_version_info # type:ignore[operator]
- > (
- 10,
- 2,
- )
- )
-
@reflection.cache
def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]:
rp = connection.exec_driver_sql("SHOW schemas")
}
fkeys.append(fkey_d)
- if self._needs_correct_for_88718_96365:
+ if self._is_mysql and self._needs_correct_for_88718_96365:
self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
return fkeys if fkeys else ReflectionDefaults.foreign_keys()
elif code == 1356:
raise exc.UnreflectableTableError(
- "Table or view named %s could not be "
- "reflected: %s" % (full_name, e)
+ "Table or view named %s could not be reflected: %s"
+ % (full_name, e)
) from e
else:
from __future__ import annotations
from typing import Any
-from typing import Optional
-from typing import TYPE_CHECKING
-from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
-from .base import MySQLIdentifierPreparer
-from .base import MySQLTypeCompiler
-from ... import util
from ...sql import sqltypes
-from ...sql.sqltypes import _UUID_RETURN
-from ...sql.sqltypes import UUID
-from ...sql.sqltypes import Uuid
-
-if TYPE_CHECKING:
- from ...engine.base import Connection
- from ...sql.type_api import _BindProcessorType
class INET4(sqltypes.TypeEngine[str]):
__visit_name__ = "INET6"
-class _MariaDBUUID(UUID[_UUID_RETURN]):
- def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
- self.as_uuid = as_uuid
-
- # the _MariaDBUUID internal type is only invoked for a Uuid() with
- # native_uuid=True. for non-native uuid type, the plain Uuid
- # returns itself due to the workings of the Emulated superclass.
- assert native_uuid
-
- # for internal type, force string conversion for result_processor() as
- # current drivers are returning a string, not a Python UUID object
- self.native_uuid = False
-
- @property
- def native(self) -> bool: # type: ignore[override]
- # override to return True, this is a native type, just turning
- # off native_uuid for internal data handling
- return True
-
- def bind_processor(self, dialect: MariaDBDialect) -> Optional[_BindProcessorType[_UUID_RETURN]]: # type: ignore[override] # noqa: E501
- if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
- return super().bind_processor(dialect) # type: ignore[return-value] # noqa: E501
- else:
- return None
-
-
-class MariaDBTypeCompiler(MySQLTypeCompiler):
- def visit_INET4(self, type_: INET4, **kwargs: Any) -> str:
- return "INET4"
-
- def visit_INET6(self, type_: INET6, **kwargs: Any) -> str:
- return "INET6"
-
-
class MariaDBDialect(MySQLDialect):
is_mariadb = True
supports_statement_cache = True
_allows_uuid_binds = True
name = "mariadb"
- preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer
- type_compiler_cls = MariaDBTypeCompiler
- colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID})
-
- def initialize(self, connection: Connection) -> None:
- super().initialize(connection)
-
- self.supports_native_uuid = (
- self.server_version_info is not None
- and self.server_version_info >= (10, 7)
- )
+ def __init__(self, **kw: Any) -> None:
+ kw["is_mariadb"] = True
+ super().__init__(**kw)
def loader(driver: str) -> type[MariaDBDialect]:
from typing import TYPE_CHECKING
from typing import Union
-from .base import MariaDBIdentifierPreparer
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
pass
-class MariaDBIdentifierPreparer_mysqlconnector(
- IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer
-):
- pass
-
-
class _myconnpyBIT(BIT):
def result_processor(self, dialect: Any, coltype: Any) -> None:
"""MySQL-connector already converts mysql bits, so."""
):
supports_statement_cache = True
_allows_uuid_binds = False
- preparer = MariaDBIdentifierPreparer_mysqlconnector
+ preparer = MySQLIdentifierPreparer_mysqlconnector
dialect = MySQLDialect_mysqlconnector
from ..util.typing import Unpack
if typing.TYPE_CHECKING:
- from types import ModuleType
-
from .base import Engine
from .cursor import ResultFetchStrategy
from .interfaces import _CoreMultiExecuteParams
self,
paramstyle: Optional[_ParamStyle] = None,
isolation_level: Optional[IsolationLevel] = None,
- dbapi: Optional[ModuleType] = None,
+ dbapi: Optional[DBAPIModule] = None,
implicit_returning: Literal[True] = True,
supports_native_boolean: Optional[bool] = None,
max_identifier_length: Optional[int] = None,
" CASCADE" if drop.cascade else "",
)
- def get_column_specification(self, column, **kwargs):
+ def get_column_specification(
+ self, column: Column[Any], **kwargs: Any
+ ) -> str:
colspec = (
self.preparer.format_column(column)
+ " "
colspec += " NOT NULL"
return colspec
- def create_table_suffix(self, table):
+ def create_table_suffix(self, table: Table) -> str:
return ""
- def post_create_table(self, table):
+ def post_create_table(self, table: Table) -> str:
return ""
def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
):
dialect = mysql.dialect(is_mariadb=is_mariadb)
dialect.server_version_info = version
+ if is_mariadb:
+ with testing.expect_warnings(".*"):
+ dialect._initialize_mariadb(None)
+ else:
+ dialect._initialize_mysql(None)
m = MetaData()
tbl = Table(
if maria_db:
dialect.is_mariadb = maria_db
dialect.server_version_info = (10, 4, 5)
+ dialect._initialize_mariadb(None)
else:
dialect.server_version_info = (8, 0, 17)
+ dialect._initialize_mysql(None)
t = sql.table("t", sql.column("col"))
self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect)
)
@testing.combinations(
- ((10, 2, 7), "10.2.7-MariaDB", (10, 2, 7), True),
+ ((10, 2, 7), "10.2.7-MariaDB", True),
(
(10, 2, 7),
"5.6.15.10.2.7-MariaDB",
- (5, 6, 15, 10, 2, 7),
True,
),
- ((5, 0, 51, 24), "5.0.51a.24+lenny5", (5, 0, 51, 24), False),
- ((10, 2, 10), "10.2.10-MariaDB", (10, 2, 10), True),
- ((5, 7, 20), "5.7.20", (5, 7, 20), False),
- ((5, 6, 15), "5.6.15", (5, 6, 15), False),
+ ((5, 0, 51, 24), "5.0.51a.24+lenny5", False),
+ ((10, 2, 10), "10.2.10-MariaDB", True),
+ ((5, 7, 20), "5.7.20", False),
+ ((5, 6, 15), "5.6.15", False),
(
(10, 2, 6),
"10.2.6.MariaDB.10.2.6+maria~stretch-log",
- (10, 2, 6, 10, 2, 6),
True,
),
(
(10, 1, 9),
"10.1.9-MariaDBV1.0R050D002-20170809-1522",
- (10, 1, 9, 20170809, 1522),
True,
),
)
def test_mariadb_normalized_version(
- self, expected, raw_version, version, is_mariadb
+ self, expected, raw_version, is_mariadb
):
dialect = mysql.dialect()
- eq_(dialect._parse_server_version(raw_version), version)
- dialect.server_version_info = version
- eq_(dialect._mariadb_normalized_version_info, expected)
+ eq_(dialect._parse_server_version(raw_version), expected)
+ eq_(dialect.server_version_info, expected)
assert dialect._is_mariadb is is_mariadb
@testing.combinations(
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal_column
-from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy import schema
from sqlalchemy import select
("unset"),
argnames="nullable",
)
- def test_column_computed_for_nullable(self, connection, nullable):
+ @testing.variation("setnull", ["constructor", "assign"])
+ def test_column_computed_for_nullable(
+ self, metadata, connection, nullable, setnull: testing.Variation
+ ):
"""test #10056
we want to make sure that nullable is always set to True for computed
ref: https://mariadb.com/kb/en/generated-columns/#statement-support
"""
- m = MetaData()
- kwargs = {"nullable": nullable} if nullable != "unset" else {}
+
+ if setnull.assign:
+ kwargs = {}
+ elif setnull.constructor:
+ kwargs = {"nullable": nullable} if nullable != "unset" else {}
+ else:
+ setnull.fail()
+
t = Table(
"t",
- m,
+ metadata,
Column("x", Integer),
Column("y", Integer, Computed("x + 2"), **kwargs),
)
- if connection.engine.dialect.name == "mariadb" and nullable in (
- False,
- None,
+
+ # the "assign" path here is to exercise the actual bug shown at
+ # #10056; user defined nullable is not passed, but MappedColumn
+ # sets it after the fact. Previously this path was not exercised
+ if setnull.assign and nullable != "unset":
+ t.c.y.nullable = nullable
+
+ if (
+ connection.engine.dialect.name == "mariadb"
+ and setnull.constructor
+ and nullable in (False, None)
):
+ # if nullable was passed explicitly as False or None,
+ # then we dont touch .nullable
assert_raises(
exc.ProgrammingError,
connection.execute,
schema.CreateTable(t),
)
- # If assertion happens table won't be created so
- # return from test
- return
- # Create and then drop table
- connection.execute(schema.CreateTable(t))
- connection.execute(schema.DropTable(t))
+ else:
+ # assert table can be created
+ connection.execute(schema.CreateTable(t))
class LimitORMTest(fixtures.MappedTest):
else:
default = er["default"]
- if default is not None and connection.dialect._is_mariadb_102:
+ if (
+ default is not None
+ and connection.dialect.is_mariadb
+ and connection.dialect.server_version_info > (10, 2)
+ ):
default = default.replace(
"CURRENT_TIMESTAMP", "current_timestamp()"
)
self.assert_compile(type_, sql_text)
-class MariaDBUUIDTest(fixtures.TestBase, AssertsCompiledSQL):
- __only_on__ = "mysql", "mariadb"
- __backend__ = True
-
- def test_requirements(self):
- if testing.against("mariadb>=10.7"):
- assert testing.requires.uuid_data_type.enabled
- else:
- assert not testing.requires.uuid_data_type.enabled
-
- def test_compile_generic(self):
- if testing.against("mariadb>=10.7"):
- self.assert_compile(sqltypes.Uuid(), "UUID")
- else:
- self.assert_compile(sqltypes.Uuid(), "CHAR(32)")
-
- def test_compile_upper(self):
- self.assert_compile(sqltypes.UUID(), "UUID")
-
-
class UUIDTest(fixtures.TestBase, AssertsCompiledSQL):
@testing.combinations(
(sqltypes.Uuid(), (10, 6, 5), "CHAR(32)"),
(sqltypes.Uuid(native_uuid=False), (10, 7, 0), "CHAR(32)"),
(sqltypes.UUID(), (10, 6, 5), "UUID"),
(sqltypes.UUID(), (10, 7, 0), "UUID"),
+ argnames="type_, version, res",
)
- def test_mariadb_uuid_combinations(self, type_, version, res):
- dialect = mariadb.MariaDBDialect()
+ @testing.variation("use_mariadb", [True, False])
+ def test_mariadb_uuid_combinations(self, type_, version, res, use_mariadb):
+ if use_mariadb:
+ dialect = mariadb.MariaDBDialect()
+ else:
+ dialect = mysql.MySQLDialect(is_mariadb=True)
+ dialect._set_mariadb(True, "10.2.0")
+
dialect.server_version_info = version
- dialect.supports_native_uuid = version >= (10, 7)
+ dialect._initialize_mariadb(None)
self.assert_compile(type_, res, dialect=dialect)
@testing.combinations(
self.assert_compile(type_, "CHAR(32)", dialect=dialect)
-class INETMariadbTest(fixtures.TestBase, AssertsCompiledSQL):
- __dialect__ = mariadb.MariaDBDialect()
+class INETTest(fixtures.TestBase, AssertsCompiledSQL):
+ __only_on__ = ("mariadb",)
+ __backend__ = True
@testing.combinations(
(mariadb.INET4(), "INET4"),
(mariadb.INET6(), "INET6"),
+ argnames="type_, res",
)
- def test_mariadb_inet6(self, type_, res):
- self.assert_compile(type_, res)
+ @testing.variation("use_mariadb", [True, False])
+ def test_mariadb_inet6(self, type_, res, use_mariadb):
+ self.assert_compile(
+ type_, res, dialect="mariadb" if use_mariadb else "mysql"
+ )
+
+ @testing.fixture
+ def inet_table(self, metadata):
+ inet_table = Table(
+ "inet_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("ip4", mariadb.INET4()),
+ Column("ip6", mariadb.INET6()),
+ )
+ return inet_table
+
+ @testing.fixture
+ def mysql_url_mariadb(self, testing_engine):
+ url = testing.db.url
+
+ url = url.set(drivername=f"mysql+{url.get_driver_name()}")
+
+ return testing_engine(url)
+
+ def test_inet(self, inet_table):
+ self._roundtrip(inet_table, testing.db)
+
+ def test_inet_w_mysql_url(self, mysql_url_mariadb, inet_table):
+ db = mysql_url_mariadb
+
+ self._roundtrip(inet_table, db)
+
+ def _roundtrip(self, inet_table, db):
+
+ with db.begin() as conn:
+ inet_table.create(conn)
+
+ conn.execute(
+ inet_table.insert(),
+ {
+ "ip4": "192.168.1.1",
+ "ip6": "2001:db8:85a3:0:0:8a2e:370:7334",
+ },
+ )
+
+ eq_(
+ conn.execute(
+ select(inet_table.c.ip4, inet_table.c.ip6)
+ ).first(),
+ ("192.168.1.1", "2001:db8:85a3::8a2e:370:7334"),
+ )
class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults):
and (
(
config.db.dialect._is_mariadb
- and config.db.dialect._mariadb_normalized_version_info
- >= (10, 2)
+ and config.db.dialect.server_version_info >= (10, 2)
)
),
"mariadb>10.2",
not config.db.dialect._is_mariadb
and against(config, "mysql >= 5.7")
)
- or (
- config.db.dialect._mariadb_normalized_version_info
- >= (10, 2, 7)
- )
+ or (config.db.dialect.server_version_info >= (10, 2, 7))
),
"mariadb>=10.2.7",
"postgresql >= 9.3",
return (
against(config, ["mysql", "mariadb"])
and config.db.dialect._is_mariadb
- and config.db.dialect._mariadb_normalized_version_info >= (10, 2)
+ and config.db.dialect.server_version_info >= (10, 2)
)
def _mariadb_105(self, config):
return (
against(config, ["mysql", "mariadb"])
and config.db.dialect._is_mariadb
- and config.db.dialect._mariadb_normalized_version_info >= (10, 5)
+ and config.db.dialect.server_version_info >= (10, 5)
)
def _mysql_and_check_constraints_exist(self, config):
# 2. it enforces check constraints
if exclusions.against(config, ["mysql", "mariadb"]):
if config.db.dialect._is_mariadb:
- norm_version_info = (
- config.db.dialect._mariadb_normalized_version_info
- )
+ norm_version_info = config.db.dialect.server_version_info
return norm_version_info >= (10, 2)
else:
norm_version_info = config.db.dialect.server_version_info