From: Mike Bayer Date: Wed, 14 Jan 2026 16:06:50 +0000 (-0500) Subject: separate out mariadb/mysql implementations but remain monolithic X-Git-Tag: rel_2_1_0b1~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=80ced341154d02e0ff56180420905164abd3fba0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git separate out mariadb/mysql implementations but remain monolithic Fixes to the MySQL/MariaDB dialect so that mariadb-specific features such as the :class:`.mariadb.INET4` and :class:`.mariadb.INET6` datatype may be used with an :class:`.Engine` that uses a ``mysql://`` URL, if the backend database is actually a mariadb database. Previously, support for MariaDB features when ``mysql://`` URLs were used instead of ``mariadb://`` URLs was ad-hoc; with this issue resolution, the full set of schema / compiler / type features are now available regardless of how the URL was presented. After much discussion it seems premature to formally separate the mysql/mariadb dialects to be mutually exclusive, however we'd like to standardize on how dialect-exclusive behaviors are architected. For now, use an approach of MariaDB "shim" classes which provide MariaDB behaviors into all relevant MySQL classes up front. Where behaviors are mutually exclusive, support `_set_mariadb()` methods that enable the mariadb version of things. this approach may or may not have resilience against future divergences in mysql/mariadb , but at least starts to separate the source code for the two databases and hopefully provides a clearer separation of concerns. Fixes: #13076 Change-Id: I42cffc0563a2c65c38e854e9cc2181353b230c44 --- diff --git a/doc/build/changelog/unreleased_21/13076.rst b/doc/build/changelog/unreleased_21/13076.rst new file mode 100644 index 0000000000..9a06919fbd --- /dev/null +++ b/doc/build/changelog/unreleased_21/13076.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, mariadb + :tickets: 13076 + + Fixes to the MySQL/MariaDB dialect so that mariadb-specific features such + as the :class:`.mariadb.INET4` and :class:`.mariadb.INET6` datatype may be + used with an :class:`.Engine` that uses a ``mysql://`` URL, if the backend + database is actually a mariadb database. Previously, support for MariaDB + features when ``mysql://`` URLs were used instead of ``mariadb://`` URLs + was ad-hoc; with this issue resolution, the full set of schema / compiler / + type features are now available regardless of how the URL was presented. diff --git a/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py b/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py new file mode 100644 index 0000000000..fdf210c1fc --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py @@ -0,0 +1,312 @@ +# dialects/mysql/_mariadb_shim.py +# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING + +from .reserved_words import RESERVED_WORDS_MARIADB +from ... import exc +from ... import schema as sa_schema +from ... import util +from ...engine import cursor as _cursor +from ...engine import default +from ...engine.default import DefaultDialect +from ...engine.interfaces import TypeCompiler +from ...sql import elements +from ...sql import sqltypes +from ...sql.compiler import DDLCompiler +from ...sql.compiler import IdentifierPreparer +from ...sql.compiler import SQLCompiler +from ...sql.schema import SchemaConst +from ...sql.sqltypes import _UUID_RETURN +from ...sql.sqltypes import UUID +from ...sql.sqltypes import Uuid + +if TYPE_CHECKING: + from .base import MySQLIdentifierPreparer + from .mariadb import INET4 + from .mariadb import INET6 + from ...engine import URL + from ...engine.base import Connection + from ...sql import ddl + from ...sql.schema import IdentityOptions + from ...sql.schema import Sequence as Sequence_SchemaItem + from ...sql.type_api import _BindProcessorType + + +class _MariaDBUUID(UUID[_UUID_RETURN]): + def __init__(self, as_uuid: bool = True, native_uuid: bool = True): + self.as_uuid = as_uuid + + # the _MariaDBUUID internal type is only invoked for a Uuid() with + # native_uuid=True. for non-native uuid type, the plain Uuid + # returns itself due to the workings of the Emulated superclass. + assert native_uuid + + # for internal type, force string conversion for result_processor() as + # current drivers are returning a string, not a Python UUID object + self.native_uuid = False + + @property + def native(self) -> bool: # type: ignore[override] + # override to return True, this is a native type, just turning + # off native_uuid for internal data handling + return True + + def bind_processor(self, dialect: MariaDBShim) -> Optional[_BindProcessorType[_UUID_RETURN]]: # type: ignore[override] # noqa: E501 + if not dialect.supports_native_uuid or not dialect._allows_uuid_binds: + return super().bind_processor(dialect) # type: ignore[return-value] # noqa: E501 + else: + return None + + +class MariaDBTypeCompilerShim(TypeCompiler): + def visit_INET4(self, type_: INET4, **kwargs: Any) -> str: + return "INET4" + + def visit_INET6(self, type_: INET6, **kwargs: Any) -> str: + return "INET6" + + +class MariadbExecutionContextShim(default.DefaultExecutionContext): + def post_exec(self) -> None: + if ( + self.isdelete + and cast(SQLCompiler, self.compiled).effective_returning + and not self.cursor.description + ): + # All MySQL/mariadb drivers appear to not include + # cursor.description for DELETE..RETURNING with no rows if the + # WHERE criteria is a straight "false" condition such as our EMPTY + # IN condition. manufacture an empty result in this case (issue + # #10505) + # + # taken from cx_Oracle implementation + self.cursor_fetch_strategy = ( + _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (entry.keyname, None) # type: ignore[misc] + for entry in cast( + SQLCompiler, self.compiled + )._result_columns + ], + [], + ) + ) + + def fire_sequence( + self, seq: Sequence_SchemaItem, type_: sqltypes.Integer + ) -> int: + return self._execute_scalar( # type: ignore[no-any-return] + ( + "select nextval(%s)" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + +class MariaDBIdentifierPreparerShim(IdentifierPreparer): + def _set_mariadb(self) -> None: + self.reserved_words = RESERVED_WORDS_MARIADB + + +class MariaDBSQLCompilerShim(SQLCompiler): + def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str: + return "nextval(%s)" % self.preparer.format_sequence(sequence) + + def _mariadb_regexp_flags( + self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any + ) -> str: + return "CONCAT('(?', %s, ')', %s)" % ( + self.render_literal_value(flags, sqltypes.STRINGTYPE), + self.process(pattern, **kw), + ) + + def _mariadb_regexp_match( + self, + op_string: str, + binary: elements.BinaryExpression[Any], + operator: Any, + **kw: Any, + ) -> str: + flags = binary.modifiers["flags"] + return "%s%s%s" % ( + self.process(binary.left, **kw), + op_string, + self._mariadb_regexp_flags(flags, binary.right), + ) + + def _mariadb_regexp_replace_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: + flags = binary.modifiers["flags"] + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self._mariadb_regexp_flags(flags, binary.right.clauses[0]), + self.process(binary.right.clauses[1], **kw), + ) + + def _mariadb_visit_drop_check_constraint( + self, drop: ddl.DropConstraint, **kw: Any + ) -> str: + constraint = drop.element + qual = "CONSTRAINT " + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + + +class MariaDBDDLCompilerShim(DDLCompiler): + dialect: MariaDBShim + + def _mariadb_get_column_specification( + self, column: sa_schema.Column[Any], **kw: Any + ) -> str: + + if ( + column.computed is not None + and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED + ): + kw["_force_column_to_nullable"] = True + + return self._mysql_get_column_specification(column, **kw) + + def _mysql_get_column_specification( + self, + column: sa_schema.Column[Any], + *, + _force_column_to_nullable: bool = False, + **kw: Any, + ) -> str: + raise NotImplementedError() + + def get_identity_options(self, identity_options: IdentityOptions) -> str: + text = super().get_identity_options(identity_options) + text = text.replace("NO CYCLE", "NOCYCLE") + return text + + def _mariadb_visit_drop_check_constraint( + self, drop: ddl.DropConstraint, **kw: Any + ) -> str: + constraint = drop.element + qual = "CONSTRAINT " + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + + +class MariaDBShim(DefaultDialect): + server_version_info: tuple[int, ...] + is_mariadb: bool + _allows_uuid_binds = False + + identifier_preparer: MySQLIdentifierPreparer + preparer: Type[MySQLIdentifierPreparer] + + def _set_mariadb( + self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...] + ) -> None: + if is_mariadb is None: + return + + if not is_mariadb and self.is_mariadb: + raise exc.InvalidRequestError( + "MySQL version %s is not a MariaDB variant." + % (".".join(map(str, server_version_info)),) + ) + if is_mariadb: + assert isinstance(self.colspecs, dict) + self.colspecs = util.update_copy( + self.colspecs, {Uuid: _MariaDBUUID} + ) + + self.identifier_preparer = self.preparer(self) + self.identifier_preparer._set_mariadb() + + # this will be updated on first connect in initialize() + # if using older mariadb version + self.delete_returning = True + self.insert_returning = True + + self.is_mariadb = is_mariadb + + @property + def _mariadb_normalized_version_info(self) -> tuple[int, ...]: + return self.server_version_info + + @property + def _is_mariadb(self) -> bool: + return self.is_mariadb + + @classmethod + def _is_mariadb_from_url(cls, url: URL) -> bool: + dbapi = cls.import_dbapi() + dialect = cls(dbapi=dbapi) + + cargs, cparams = dialect.create_connect_args(url) + conn = dialect.connect(*cargs, **cparams) + try: + cursor = conn.cursor() + cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") + val = cursor.fetchone()[0] # type: ignore[index] + except: + raise + else: + return bool(val) + finally: + conn.close() + + def _initialize_mariadb(self, connection: Connection) -> None: + assert self.is_mariadb + + self.supports_sequences = self.server_version_info >= (10, 3) + + self.delete_returning = self.server_version_info >= (10, 0, 5) + + self.insert_returning = self.server_version_info >= (10, 5) + + self._warn_for_known_db_issues() + + self.supports_native_uuid = ( + self.server_version_info is not None + and self.server_version_info >= (10, 7) + ) + self._allows_uuid_binds = True + + # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/ + self._support_default_function = self.server_version_info >= (10, 2, 1) + + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + self._support_float_cast = self.server_version_info >= (10, 4, 5) + + def _warn_for_known_db_issues(self) -> None: + if self.is_mariadb: + mdb_version = self.server_version_info + assert mdb_version is not None + if mdb_version > (10, 2) and mdb_version < (10, 2, 9): + util.warn( + "MariaDB %r before 10.2.9 has known issues regarding " + "CHECK constraints, which impact handling of NULL values " + "with SQLAlchemy's boolean datatype (MDEV-13596). An " + "additional issue prevents proper migrations of columns " + "with CHECK constraints (MDEV-11114). Please upgrade to " + "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " + "series, to avoid these issues." % (mdb_version,) + ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index be6343aec1..941aee0c49 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1075,15 +1075,16 @@ from typing import Optional from typing import overload from typing import Sequence from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union +from . import _mariadb_shim from . import reflection as _reflection from .enumerated import ENUM from .enumerated import SET from .json import JSON from .json import JSONIndexType from .json import JSONPathType -from .reserved_words import RESERVED_WORDS_MARIADB from .reserved_words import RESERVED_WORDS_MYSQL from .types import _FloatType from .types import _IntegerType @@ -1122,7 +1123,6 @@ from ... import literal_column from ... import schema as sa_schema from ... import sql from ... import util -from ...engine import cursor as _cursor from ...engine import default from ...engine import reflection from ...engine.reflection import ReflectionDefaults @@ -1136,8 +1136,6 @@ from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors from ...sql.compiler import InsertmanyvaluesSentinelOpts -from ...sql.compiler import SQLCompiler -from ...sql.schema import SchemaConst from ...types import BINARY from ...types import BLOB from ...types import BOOLEAN @@ -1168,7 +1166,6 @@ if TYPE_CHECKING: from ...engine.interfaces import ReflectedUniqueConstraint from ...engine.result import _Ts from ...engine.row import Row - from ...engine.url import URL from ...schema import Table from ...sql import ddl from ...sql import selectable @@ -1180,14 +1177,12 @@ if TYPE_CHECKING: from ...sql.functions import random from ...sql.functions import rollup from ...sql.functions import sysdate - from ...sql.schema import IdentityOptions - from ...sql.schema import Sequence as Sequence_SchemaItem from ...sql.type_api import TypeEngine from ...sql.visitors import ExternallyTraversible from ...util.typing import TupleAny from ...util.typing import Unpack - +_T = TypeVar("_T", bound=Any) SET_RE = re.compile( r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE ) @@ -1281,33 +1276,9 @@ ischema_names = { } -class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self) -> None: - if ( - self.isdelete - and cast(SQLCompiler, self.compiled).effective_returning - and not self.cursor.description - ): - # All MySQL/mariadb drivers appear to not include - # cursor.description for DELETE..RETURNING with no rows if the - # WHERE criteria is a straight "false" condition such as our EMPTY - # IN condition. manufacture an empty result in this case (issue - # #10505) - # - # taken from cx_Oracle implementation - self.cursor_fetch_strategy = ( - _cursor.FullyBufferedCursorFetchStrategy( - self.cursor, - [ - (entry.keyname, None) # type: ignore[misc] - for entry in cast( - SQLCompiler, self.compiled - )._result_columns - ], - [], - ) - ) - +class MySQLExecutionContext( + _mariadb_shim.MariadbExecutionContextShim, default.DefaultExecutionContext +): def create_server_side_cursor(self) -> DBAPICursor: if self.dialect.supports_server_side_cursors: return self._dbapi_connection.cursor( @@ -1316,19 +1287,10 @@ class MySQLExecutionContext(default.DefaultExecutionContext): else: raise NotImplementedError() - def fire_sequence( - self, seq: Sequence_SchemaItem, type_: sqltypes.Integer - ) -> int: - return self._execute_scalar( # type: ignore[no-any-return] - ( - "select nextval(%s)" - % self.identifier_preparer.format_sequence(seq) - ), - type_, - ) - -class MySQLCompiler(compiler.SQLCompiler): +class MySQLCompiler( + _mariadb_shim.MariaDBSQLCompilerShim, compiler.SQLCompiler +): dialect: MySQLDialect render_table_with_column_in_update_from = True """Overridden from base SQLCompiler value""" @@ -1373,19 +1335,16 @@ class MySQLCompiler(compiler.SQLCompiler): return ( f"group_concat({expr._compiler_dispatch(self, **kw)} " f"ORDER BY {order_by._compiler_dispatch(self, **kw)} " - f"SEPARATOR " + "SEPARATOR " f"{delimiter._compiler_dispatch(self, **literal_exec)})" ) else: return ( f"group_concat({expr._compiler_dispatch(self, **kw)} " - f"SEPARATOR " + "SEPARATOR " f"{delimiter._compiler_dispatch(self, **literal_exec)})" ) - def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str: - return "nextval(%s)" % self.preparer.format_sequence(sequence) - def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str: return "SYSDATE()" @@ -1554,7 +1513,7 @@ class MySQLCompiler(compiler.SQLCompiler): "any column keys in table '%s': %s" % ( self.statement.table.name, # type: ignore[union-attr] - (", ".join("'%s'" % c for c in non_matching)), + ", ".join("'%s'" % c for c in non_matching), ) ) @@ -1569,8 +1528,8 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_concat_op_expression_clauselist( self, clauselist: elements.ClauseList, operator: Any, **kw: Any ) -> str: - return "concat(%s)" % ( - ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) + return "concat(%s)" % ", ".join( + self.process(elem, **kw) for elem in clauselist.clauses ) def visit_concat_op_binary( @@ -1942,8 +1901,7 @@ class MySQLCompiler(compiler.SQLCompiler): self, element_types: list[TypeEngine[Any]], **kw: Any ) -> str: return ( - "SELECT %(outer)s FROM (SELECT %(inner)s) " - "as _empty_set WHERE 1!=1" + "SELECT %(outer)s FROM (SELECT %(inner)s) as _empty_set WHERE 1!=1" % { "inner": ", ".join( "1 AS _in_%s" % idx @@ -1971,13 +1929,24 @@ class MySQLCompiler(compiler.SQLCompiler): self.process(binary.right), ) - def _mariadb_regexp_flags( - self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any + def _mysql_regexp_match( + self, + op_string: str, + binary: elements.BinaryExpression[Any], + operator: Any, + **kw: Any, ) -> str: - return "CONCAT('(?', %s, ')', %s)" % ( + flags = binary.modifiers["flags"] + + text = "REGEXP_LIKE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), self.render_literal_value(flags, sqltypes.STRINGTYPE), - self.process(pattern, **kw), ) + if op_string == " NOT REGEXP ": + return "NOT %s" % text + else: + return text def _regexp_match( self, @@ -1990,22 +1959,15 @@ class MySQLCompiler(compiler.SQLCompiler): flags = binary.modifiers["flags"] if flags is None: return self._generate_generic_binary(binary, op_string, **kw) - elif self.dialect.is_mariadb: - return "%s%s%s" % ( - self.process(binary.left, **kw), - op_string, - self._mariadb_regexp_flags(flags, binary.right), - ) else: - text = "REGEXP_LIKE(%s, %s, %s)" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw), - self.render_literal_value(flags, sqltypes.STRINGTYPE), + return self.dialect._dispatch_for_vendor( + self._mysql_regexp_match, + self._mariadb_regexp_match, + op_string, + binary, + operator, + **kw, ) - if op_string == " NOT REGEXP ": - return "NOT %s" % text - else: - return text def visit_regexp_match_op_binary( self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any @@ -2027,33 +1989,52 @@ class MySQLCompiler(compiler.SQLCompiler): self.process(binary.left, **kw), self.process(binary.right, **kw), ) - elif self.dialect.is_mariadb: - return "REGEXP_REPLACE(%s, %s, %s)" % ( - self.process(binary.left, **kw), - self._mariadb_regexp_flags(flags, binary.right.clauses[0]), - self.process(binary.right.clauses[1], **kw), - ) else: - return "REGEXP_REPLACE(%s, %s, %s)" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw), - self.render_literal_value(flags, sqltypes.STRINGTYPE), + return self.dialect._dispatch_for_vendor( + self._mysql_regexp_replace_op_binary, + self._mariadb_regexp_replace_op_binary, + binary, + operator, + **kw, ) + def _mysql_regexp_replace_op_binary( + self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any + ) -> str: + flags = binary.modifiers["flags"] -class MySQLDDLCompiler(compiler.DDLCompiler): + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + + +class MySQLDDLCompiler( + _mariadb_shim.MariaDBDDLCompilerShim, compiler.DDLCompiler +): dialect: MySQLDialect def get_column_specification( self, column: sa_schema.Column[Any], **kw: Any ) -> str: """Builds column DDL.""" - if ( - self.dialect.is_mariadb is True - and column.computed is not None - and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED - ): - column.nullable = True + + return self.dialect._dispatch_for_vendor( + self._mysql_get_column_specification, + self._mariadb_get_column_specification, + column, + **kw, + ) + + def _mysql_get_column_specification( + self, + column: sa_schema.Column[Any], + *, + _force_column_to_nullable: bool = False, + **kw: Any, + ) -> str: + colspec = [ self.preparer.format_column(column), self.dialect.type_compiler_instance.process( @@ -2069,7 +2050,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): sqltypes.TIMESTAMP, ) - if not column.nullable: + if not column.nullable and not _force_column_to_nullable: colspec.append("NOT NULL") # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa @@ -2315,11 +2296,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): qual = "INDEX " const = self.preparer.format_constraint(constraint) elif isinstance(constraint, sa_schema.CheckConstraint): - if self.dialect.is_mariadb: - qual = "CONSTRAINT " - else: - qual = "CHECK " - const = self.preparer.format_constraint(constraint) + return self.dialect._dispatch_for_vendor( + self._mysql_visit_drop_check_constraint, + self._mariadb_visit_drop_check_constraint, + drop, + **kw, + ) else: qual = "" const = self.preparer.format_constraint(constraint) @@ -2329,6 +2311,18 @@ class MySQLDDLCompiler(compiler.DDLCompiler): const, ) + def _mysql_visit_drop_check_constraint( + self, drop: ddl.DropConstraint, **kw: Any + ) -> str: + constraint = drop.element + qual = "CHECK " + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + def define_constraint_match( self, constraint: sa_schema.ForeignKeyConstraint ) -> str: @@ -2365,17 +2359,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler): self.get_column_specification(create.element), ) - def get_identity_options(self, identity_options: IdentityOptions) -> str: - """mariadb-specific sequence option; this will move to a - mariadb-specific module in 2.1 - """ - text = super().get_identity_options(identity_options) - text = text.replace("NO CYCLE", "NOCYCLE") - return text - - -class MySQLTypeCompiler(compiler.GenericTypeCompiler): +class MySQLTypeCompiler( + _mariadb_shim.MariaDBTypeCompilerShim, compiler.GenericTypeCompiler +): def _extend_numeric(self, type_: _NumericCommonType, spec: str) -> str: "Extend a numeric-type declaration with MySQL specific extensions." @@ -2687,7 +2674,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "BOOL" -class MySQLIdentifierPreparer(compiler.IdentifierPreparer): +class MySQLIdentifierPreparer( + _mariadb_shim.MariaDBIdentifierPreparerShim, compiler.IdentifierPreparer +): reserved_words = RESERVED_WORDS_MYSQL def __init__( @@ -2709,16 +2698,15 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): return tuple([self.quote_identifier(i) for i in ids if i is not None]) -class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): - reserved_words = RESERVED_WORDS_MARIADB - - -class MySQLDialect(default.DefaultDialect): +class MySQLDialect(_mariadb_shim.MariaDBShim, default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code. """ name = "mysql" + + is_mariadb = False + supports_statement_cache = True supports_alter = True @@ -2784,14 +2772,13 @@ class MySQLDialect(default.DefaultDialect): ischema_names = ischema_names preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer - is_mariadb: bool = False - _mariadb_normalized_version_info = None - # default SQL compilation settings - # these are modified upon initialize(), # i.e. first connect _backslash_escapes = True _server_ansiquotes = False + _support_default_function = True + _support_float_cast = False server_version_info: tuple[int, ...] identifier_preparer: MySQLIdentifierPreparer @@ -2864,24 +2851,6 @@ class MySQLDialect(default.DefaultDialect): val = val.decode() return val.upper().replace("-", " ") # type: ignore[no-any-return] - @classmethod - def _is_mariadb_from_url(cls, url: URL) -> bool: - dbapi = cls.import_dbapi() - dialect = cls(dbapi=dbapi) - - cargs, cparams = dialect.create_connect_args(url) - conn = dialect.connect(*cargs, **cparams) - try: - cursor = conn.cursor() - cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") - val = cursor.fetchone()[0] # type: ignore[index] - except: - raise - else: - return bool(val) - finally: - conn.close() - def _get_server_version_info( self, connection: Connection ) -> tuple[int, ...]: @@ -2905,6 +2874,8 @@ class MySQLDialect(default.DefaultDialect): r = re.compile(r"[.\-+]") tokens = r.split(val) + + _mariadb_normalized_version_info = None for token in tokens: parsed_token = re.match( r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token @@ -2912,21 +2883,21 @@ class MySQLDialect(default.DefaultDialect): if not parsed_token: continue elif parsed_token.group(2): - self._mariadb_normalized_version_info = tuple(version[-3:]) + _mariadb_normalized_version_info = tuple(version[-3:]) is_mariadb = True else: digit = int(parsed_token.group(1)) version.append(digit) - server_version_info = tuple(version) + if _mariadb_normalized_version_info: + server_version_info = _mariadb_normalized_version_info + else: + server_version_info = tuple(version) self._set_mariadb( bool(server_version_info and is_mariadb), server_version_info ) - if not is_mariadb: - self._mariadb_normalized_version_info = server_version_info - if server_version_info < (5, 0, 2): raise NotImplementedError( "the MySQL/MariaDB dialect supports server " @@ -2937,32 +2908,6 @@ class MySQLDialect(default.DefaultDialect): self.server_version_info = server_version_info return server_version_info - def _set_mariadb( - self, is_mariadb: Optional[bool], server_version_info: tuple[int, ...] - ) -> None: - if is_mariadb is None: - return - - if not is_mariadb and self.is_mariadb: - raise exc.InvalidRequestError( - "MySQL version %s is not a MariaDB variant." - % (".".join(map(str, server_version_info)),) - ) - if is_mariadb: - - if not issubclass(self.preparer, MariaDBIdentifierPreparer): - self.preparer = MariaDBIdentifierPreparer - # this would have been set by the default dialect already, - # so set it again - self.identifier_preparer = self.preparer(self) - - # this will be updated on first connect in initialize() - # if using older mariadb version - self.delete_returning = True - self.insert_returning = True - - self.is_mariadb = is_mariadb - def do_begin_twophase(self, connection: Connection, xid: Any) -> None: connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) @@ -3179,6 +3124,18 @@ class MySQLDialect(default.DefaultDialect): ) ] + def _dispatch_for_vendor( + self, + mysql_callable: Callable[..., _T], + mariadb_callable: Callable[..., _T], + *arg: Any, + **kw: Any, + ) -> _T: + if not self.is_mariadb: + return mysql_callable(*arg, **kw) + else: + return mariadb_callable(*arg, **kw) + def initialize(self, connection: Connection) -> None: # this is driver-based, does not need server version info # and is fairly critical for even basic SQL operations @@ -3202,92 +3159,33 @@ class MySQLDialect(default.DefaultDialect): self, server_ansiquotes=self._server_ansiquotes ) - self.supports_sequences = ( - self.is_mariadb and self.server_version_info >= (10, 3) + self._dispatch_for_vendor( + self._initialize_mysql, self._initialize_mariadb, connection ) - self.supports_for_update_of = ( - self._is_mysql and self.server_version_info >= (8,) - ) + def _initialize_mysql(self, connection: Connection) -> None: + assert not self.is_mariadb - self.use_mysql_for_share = ( - self._is_mysql and self.server_version_info >= (8, 0, 1) - ) + self.supports_for_update_of = self.server_version_info >= (8,) - self._needs_correct_for_88718_96365 = ( - not self.is_mariadb and self.server_version_info >= (8,) - ) - - self.delete_returning = ( - self.is_mariadb and self.server_version_info >= (10, 0, 5) - ) + self.use_mysql_for_share = self.server_version_info >= (8, 0, 1) - self.insert_returning = ( - self.is_mariadb and self.server_version_info >= (10, 5) - ) + self._needs_correct_for_88718_96365 = self.server_version_info >= (8,) self._requires_alias_for_on_duplicate_key = ( - self._is_mysql and self.server_version_info >= (8, 0, 20) + self.server_version_info >= (8, 0, 20) ) - self._warn_for_known_db_issues() + # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa + self._support_default_function = self.server_version_info >= (8, 0, 13) - def _warn_for_known_db_issues(self) -> None: - if self.is_mariadb: - mdb_version = self._mariadb_normalized_version_info - assert mdb_version is not None - if mdb_version > (10, 2) and mdb_version < (10, 2, 9): - util.warn( - "MariaDB %r before 10.2.9 has known issues regarding " - "CHECK constraints, which impact handling of NULL values " - "with SQLAlchemy's boolean datatype (MDEV-13596). An " - "additional issue prevents proper migrations of columns " - "with CHECK constraints (MDEV-11114). Please upgrade to " - "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " - "series, to avoid these issues." % (mdb_version,) - ) - - @property - def _support_float_cast(self) -> bool: - if not self.server_version_info: - return False - elif self.is_mariadb: - # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ - return self.server_version_info >= (10, 4, 5) - else: - # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa - return self.server_version_info >= (8, 0, 17) - - @property - def _support_default_function(self) -> bool: - if not self.server_version_info: - return False - elif self.is_mariadb: - # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/ - return self.server_version_info >= (10, 2, 1) - else: - # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa - return self.server_version_info >= (8, 0, 13) - - @property - def _is_mariadb(self) -> bool: - return self.is_mariadb + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + self._support_float_cast = self.server_version_info >= (8, 0, 17) @property def _is_mysql(self) -> bool: return not self.is_mariadb - @property - def _is_mariadb_102(self) -> bool: - return ( - self.is_mariadb - and self._mariadb_normalized_version_info # type:ignore[operator] - > ( - 10, - 2, - ) - ) - @reflection.cache def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]: rp = connection.exec_driver_sql("SHOW schemas") @@ -3427,7 +3325,7 @@ class MySQLDialect(default.DefaultDialect): } fkeys.append(fkey_d) - if self._needs_correct_for_88718_96365: + if self._is_mysql and self._needs_correct_for_88718_96365: self._correct_for_mysql_bugs_88718_96365(fkeys, connection) return fkeys if fkeys else ReflectionDefaults.foreign_keys() @@ -3904,8 +3802,8 @@ class MySQLDialect(default.DefaultDialect): elif code == 1356: raise exc.UnreflectableTableError( - "Table or view named %s could not be " - "reflected: %s" % (full_name, e) + "Table or view named %s could not be reflected: %s" + % (full_name, e) ) from e else: diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py index faad28af10..cdb1e2ffee 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadb.py +++ b/lib/sqlalchemy/dialects/mysql/mariadb.py @@ -8,22 +8,9 @@ from __future__ import annotations from typing import Any -from typing import Optional -from typing import TYPE_CHECKING -from .base import MariaDBIdentifierPreparer from .base import MySQLDialect -from .base import MySQLIdentifierPreparer -from .base import MySQLTypeCompiler -from ... import util from ...sql import sqltypes -from ...sql.sqltypes import _UUID_RETURN -from ...sql.sqltypes import UUID -from ...sql.sqltypes import Uuid - -if TYPE_CHECKING: - from ...engine.base import Connection - from ...sql.type_api import _BindProcessorType class INET4(sqltypes.TypeEngine[str]): @@ -44,40 +31,6 @@ class INET6(sqltypes.TypeEngine[str]): __visit_name__ = "INET6" -class _MariaDBUUID(UUID[_UUID_RETURN]): - def __init__(self, as_uuid: bool = True, native_uuid: bool = True): - self.as_uuid = as_uuid - - # the _MariaDBUUID internal type is only invoked for a Uuid() with - # native_uuid=True. for non-native uuid type, the plain Uuid - # returns itself due to the workings of the Emulated superclass. - assert native_uuid - - # for internal type, force string conversion for result_processor() as - # current drivers are returning a string, not a Python UUID object - self.native_uuid = False - - @property - def native(self) -> bool: # type: ignore[override] - # override to return True, this is a native type, just turning - # off native_uuid for internal data handling - return True - - def bind_processor(self, dialect: MariaDBDialect) -> Optional[_BindProcessorType[_UUID_RETURN]]: # type: ignore[override] # noqa: E501 - if not dialect.supports_native_uuid or not dialect._allows_uuid_binds: - return super().bind_processor(dialect) # type: ignore[return-value] # noqa: E501 - else: - return None - - -class MariaDBTypeCompiler(MySQLTypeCompiler): - def visit_INET4(self, type_: INET4, **kwargs: Any) -> str: - return "INET4" - - def visit_INET6(self, type_: INET6, **kwargs: Any) -> str: - return "INET6" - - class MariaDBDialect(MySQLDialect): is_mariadb = True supports_statement_cache = True @@ -86,18 +39,10 @@ class MariaDBDialect(MySQLDialect): _allows_uuid_binds = True name = "mariadb" - preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer - type_compiler_cls = MariaDBTypeCompiler - colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID}) - - def initialize(self, connection: Connection) -> None: - super().initialize(connection) - - self.supports_native_uuid = ( - self.server_version_info is not None - and self.server_version_info >= (10, 7) - ) + def __init__(self, **kw: Any) -> None: + kw["is_mariadb"] = True + super().__init__(**kw) def loader(driver: str) -> type[MariaDBDialect]: diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index 6159293c07..6ace9bb303 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -55,7 +55,6 @@ from typing import Sequence from typing import TYPE_CHECKING from typing import Union -from .base import MariaDBIdentifierPreparer from .base import MySQLCompiler from .base import MySQLDialect from .base import MySQLExecutionContext @@ -123,12 +122,6 @@ class MySQLIdentifierPreparer_mysqlconnector( pass -class MariaDBIdentifierPreparer_mysqlconnector( - IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer -): - pass - - class _myconnpyBIT(BIT): def result_processor(self, dialect: Any, coltype: Any) -> None: """MySQL-connector already converts mysql bits, so.""" @@ -296,7 +289,7 @@ class MariaDBDialect_mysqlconnector( ): supports_statement_cache = True _allows_uuid_binds = False - preparer = MariaDBIdentifierPreparer_mysqlconnector + preparer = MySQLIdentifierPreparer_mysqlconnector dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index bbbcb34859..22253d6a61 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -72,8 +72,6 @@ from ..util.typing import TupleAny from ..util.typing import Unpack if typing.TYPE_CHECKING: - from types import ModuleType - from .base import Engine from .cursor import ResultFetchStrategy from .interfaces import _CoreMultiExecuteParams @@ -302,7 +300,7 @@ class DefaultDialect(Dialect): self, paramstyle: Optional[_ParamStyle] = None, isolation_level: Optional[IsolationLevel] = None, - dbapi: Optional[ModuleType] = None, + dbapi: Optional[DBAPIModule] = None, implicit_returning: Literal[True] = True, supports_native_boolean: Optional[bool] = None, max_identifier_length: Optional[int] = None, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 911026feaa..31c8ac8b8a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -7314,7 +7314,9 @@ class DDLCompiler(Compiled): " CASCADE" if drop.cascade else "", ) - def get_column_specification(self, column, **kwargs): + def get_column_specification( + self, column: Column[Any], **kwargs: Any + ) -> str: colspec = ( self.preparer.format_column(column) + " " @@ -7341,10 +7343,10 @@ class DDLCompiler(Compiled): colspec += " NOT NULL" return colspec - def create_table_suffix(self, table): + def create_table_suffix(self, table: Table) -> str: return "" - def post_create_table(self, table): + def post_create_table(self, table: Table) -> str: return "" def get_column_default_string(self, column: Column[Any]) -> Optional[str]: diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 265d08c111..3cfa35c1dd 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -433,6 +433,11 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): ): dialect = mysql.dialect(is_mariadb=is_mariadb) dialect.server_version_info = version + if is_mariadb: + with testing.expect_warnings(".*"): + dialect._initialize_mariadb(None) + else: + dialect._initialize_mysql(None) m = MetaData() tbl = Table( @@ -1143,8 +1148,10 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL, CacheKeyFixture): if maria_db: dialect.is_mariadb = maria_db dialect.server_version_info = (10, 4, 5) + dialect._initialize_mariadb(None) else: dialect.server_version_info = (8, 0, 17) + dialect._initialize_mysql(None) t = sql.table("t", sql.column("col")) self.assert_compile(cast(t.c.col, type_), expected, dialect=dialect) diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index 0dcc079a22..41bf091c5a 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -428,37 +428,33 @@ class ParseVersionTest(fixtures.TestBase): ) @testing.combinations( - ((10, 2, 7), "10.2.7-MariaDB", (10, 2, 7), True), + ((10, 2, 7), "10.2.7-MariaDB", True), ( (10, 2, 7), "5.6.15.10.2.7-MariaDB", - (5, 6, 15, 10, 2, 7), True, ), - ((5, 0, 51, 24), "5.0.51a.24+lenny5", (5, 0, 51, 24), False), - ((10, 2, 10), "10.2.10-MariaDB", (10, 2, 10), True), - ((5, 7, 20), "5.7.20", (5, 7, 20), False), - ((5, 6, 15), "5.6.15", (5, 6, 15), False), + ((5, 0, 51, 24), "5.0.51a.24+lenny5", False), + ((10, 2, 10), "10.2.10-MariaDB", True), + ((5, 7, 20), "5.7.20", False), + ((5, 6, 15), "5.6.15", False), ( (10, 2, 6), "10.2.6.MariaDB.10.2.6+maria~stretch-log", - (10, 2, 6, 10, 2, 6), True, ), ( (10, 1, 9), "10.1.9-MariaDBV1.0R050D002-20170809-1522", - (10, 1, 9, 20170809, 1522), True, ), ) def test_mariadb_normalized_version( - self, expected, raw_version, version, is_mariadb + self, expected, raw_version, is_mariadb ): dialect = mysql.dialect() - eq_(dialect._parse_server_version(raw_version), version) - dialect.server_version_info = version - eq_(dialect._mariadb_normalized_version_info, expected) + eq_(dialect._parse_server_version(raw_version), expected) + eq_(dialect.server_version_info, expected) assert dialect._is_mariadb is is_mariadb @testing.combinations( diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 06626b5793..19c7ec6214 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -14,7 +14,6 @@ from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal_column -from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import schema from sqlalchemy import select @@ -408,7 +407,10 @@ class ComputedTest(fixtures.TestBase): ("unset"), argnames="nullable", ) - def test_column_computed_for_nullable(self, connection, nullable): + @testing.variation("setnull", ["constructor", "assign"]) + def test_column_computed_for_nullable( + self, metadata, connection, nullable, setnull: testing.Variation + ): """test #10056 we want to make sure that nullable is always set to True for computed @@ -416,29 +418,42 @@ class ComputedTest(fixtures.TestBase): ref: https://mariadb.com/kb/en/generated-columns/#statement-support """ - m = MetaData() - kwargs = {"nullable": nullable} if nullable != "unset" else {} + + if setnull.assign: + kwargs = {} + elif setnull.constructor: + kwargs = {"nullable": nullable} if nullable != "unset" else {} + else: + setnull.fail() + t = Table( "t", - m, + metadata, Column("x", Integer), Column("y", Integer, Computed("x + 2"), **kwargs), ) - if connection.engine.dialect.name == "mariadb" and nullable in ( - False, - None, + + # the "assign" path here is to exercise the actual bug shown at + # #10056; user defined nullable is not passed, but MappedColumn + # sets it after the fact. Previously this path was not exercised + if setnull.assign and nullable != "unset": + t.c.y.nullable = nullable + + if ( + connection.engine.dialect.name == "mariadb" + and setnull.constructor + and nullable in (False, None) ): + # if nullable was passed explicitly as False or None, + # then we dont touch .nullable assert_raises( exc.ProgrammingError, connection.execute, schema.CreateTable(t), ) - # If assertion happens table won't be created so - # return from test - return - # Create and then drop table - connection.execute(schema.CreateTable(t)) - connection.execute(schema.DropTable(t)) + else: + # assert table can be created + connection.execute(schema.CreateTable(t)) class LimitORMTest(fixtures.MappedTest): diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index f3210ad615..e42ad32ff9 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -872,7 +872,11 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): else: default = er["default"] - if default is not None and connection.dialect._is_mariadb_102: + if ( + default is not None + and connection.dialect.is_mariadb + and connection.dialect.server_version_info > (10, 2) + ): default = default.replace( "CURRENT_TIMESTAMP", "current_timestamp()" ) diff --git a/test/dialect/mysql/test_types.py b/test/dialect/mysql/test_types.py index 284370c4ac..9ce8e3b33c 100644 --- a/test/dialect/mysql/test_types.py +++ b/test/dialect/mysql/test_types.py @@ -475,26 +475,6 @@ class TypeCompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(type_, sql_text) -class MariaDBUUIDTest(fixtures.TestBase, AssertsCompiledSQL): - __only_on__ = "mysql", "mariadb" - __backend__ = True - - def test_requirements(self): - if testing.against("mariadb>=10.7"): - assert testing.requires.uuid_data_type.enabled - else: - assert not testing.requires.uuid_data_type.enabled - - def test_compile_generic(self): - if testing.against("mariadb>=10.7"): - self.assert_compile(sqltypes.Uuid(), "UUID") - else: - self.assert_compile(sqltypes.Uuid(), "CHAR(32)") - - def test_compile_upper(self): - self.assert_compile(sqltypes.UUID(), "UUID") - - class UUIDTest(fixtures.TestBase, AssertsCompiledSQL): @testing.combinations( (sqltypes.Uuid(), (10, 6, 5), "CHAR(32)"), @@ -503,11 +483,18 @@ class UUIDTest(fixtures.TestBase, AssertsCompiledSQL): (sqltypes.Uuid(native_uuid=False), (10, 7, 0), "CHAR(32)"), (sqltypes.UUID(), (10, 6, 5), "UUID"), (sqltypes.UUID(), (10, 7, 0), "UUID"), + argnames="type_, version, res", ) - def test_mariadb_uuid_combinations(self, type_, version, res): - dialect = mariadb.MariaDBDialect() + @testing.variation("use_mariadb", [True, False]) + def test_mariadb_uuid_combinations(self, type_, version, res, use_mariadb): + if use_mariadb: + dialect = mariadb.MariaDBDialect() + else: + dialect = mysql.MySQLDialect(is_mariadb=True) + dialect._set_mariadb(True, "10.2.0") + dialect.server_version_info = version - dialect.supports_native_uuid = version >= (10, 7) + dialect._initialize_mariadb(None) self.assert_compile(type_, res, dialect=dialect) @testing.combinations( @@ -519,15 +506,67 @@ class UUIDTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(type_, "CHAR(32)", dialect=dialect) -class INETMariadbTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = mariadb.MariaDBDialect() +class INETTest(fixtures.TestBase, AssertsCompiledSQL): + __only_on__ = ("mariadb",) + __backend__ = True @testing.combinations( (mariadb.INET4(), "INET4"), (mariadb.INET6(), "INET6"), + argnames="type_, res", ) - def test_mariadb_inet6(self, type_, res): - self.assert_compile(type_, res) + @testing.variation("use_mariadb", [True, False]) + def test_mariadb_inet6(self, type_, res, use_mariadb): + self.assert_compile( + type_, res, dialect="mariadb" if use_mariadb else "mysql" + ) + + @testing.fixture + def inet_table(self, metadata): + inet_table = Table( + "inet_table", + metadata, + Column("id", Integer, primary_key=True), + Column("ip4", mariadb.INET4()), + Column("ip6", mariadb.INET6()), + ) + return inet_table + + @testing.fixture + def mysql_url_mariadb(self, testing_engine): + url = testing.db.url + + url = url.set(drivername=f"mysql+{url.get_driver_name()}") + + return testing_engine(url) + + def test_inet(self, inet_table): + self._roundtrip(inet_table, testing.db) + + def test_inet_w_mysql_url(self, mysql_url_mariadb, inet_table): + db = mysql_url_mariadb + + self._roundtrip(inet_table, db) + + def _roundtrip(self, inet_table, db): + + with db.begin() as conn: + inet_table.create(conn) + + conn.execute( + inet_table.insert(), + { + "ip4": "192.168.1.1", + "ip6": "2001:db8:85a3:0:0:8a2e:370:7334", + }, + ) + + eq_( + conn.execute( + select(inet_table.c.ip4, inet_table.c.ip6) + ).first(), + ("192.168.1.1", "2001:db8:85a3::8a2e:370:7334"), + ) class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults): diff --git a/test/requirements.py b/test/requirements.py index cdc5e4f869..164c407d5e 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -796,8 +796,7 @@ class DefaultRequirements(SuiteRequirements): and ( ( config.db.dialect._is_mariadb - and config.db.dialect._mariadb_normalized_version_info - >= (10, 2) + and config.db.dialect.server_version_info >= (10, 2) ) ), "mariadb>10.2", @@ -1238,10 +1237,7 @@ class DefaultRequirements(SuiteRequirements): not config.db.dialect._is_mariadb and against(config, "mysql >= 5.7") ) - or ( - config.db.dialect._mariadb_normalized_version_info - >= (10, 2, 7) - ) + or (config.db.dialect.server_version_info >= (10, 2, 7)) ), "mariadb>=10.2.7", "postgresql >= 9.3", @@ -1915,14 +1911,14 @@ class DefaultRequirements(SuiteRequirements): return ( against(config, ["mysql", "mariadb"]) and config.db.dialect._is_mariadb - and config.db.dialect._mariadb_normalized_version_info >= (10, 2) + and config.db.dialect.server_version_info >= (10, 2) ) def _mariadb_105(self, config): return ( against(config, ["mysql", "mariadb"]) and config.db.dialect._is_mariadb - and config.db.dialect._mariadb_normalized_version_info >= (10, 5) + and config.db.dialect.server_version_info >= (10, 5) ) def _mysql_and_check_constraints_exist(self, config): @@ -1930,9 +1926,7 @@ class DefaultRequirements(SuiteRequirements): # 2. it enforces check constraints if exclusions.against(config, ["mysql", "mariadb"]): if config.db.dialect._is_mariadb: - norm_version_info = ( - config.db.dialect._mariadb_normalized_version_info - ) + norm_version_info = config.db.dialect.server_version_info return norm_version_info >= (10, 2) else: norm_version_info = config.db.dialect.server_version_info