]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement native uuid for mariadb >= 10.7
authorVolodymyr Kochetkov <whysages@gmail.com>
Fri, 26 Jan 2024 15:54:11 +0000 (10:54 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Feb 2024 17:44:38 +0000 (12:44 -0500)
Modified the MariaDB dialect so that when using the :class:`_sqltypes.Uuid`
datatype with  MariaDB >= 10.7, leaving the
:paramref:`_sqltypes.Uuid.native_uuid` parameter at its default of True,
the native ``UUID`` datatype will be rendered in DDL and used for database
communication, rather than ``CHAR(32)`` (the non-native UUID type) as was
the case previously.   This is a behavioral change since 2.0, where the
generic :class:`_sqltypes.Uuid` datatype delivered ``CHAR(32)`` for all
MySQL and MariaDB variants.   Support for all major DBAPIs is implemented
including support for less common "insertmanyvalues" scenarios where UUID
values are generated in different ways for primary keys.   Thanks much to
Volodymyr Kochetkov for delivering the PR.

To support this fully without hacks, the mariadb dialect now supports
driver-specific mariadb dialects as well, where we add one here for the
mysqlconnector DBAPI that doesn't accept Python UUID objects, whereas
all the other ones do.

Fixes: #10339
Closes: #10849
Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10849
Pull-request-sha: 8490b08713f6c19692b11c084ae38d19e60dd396

Change-Id: Ib920871102b9b64f2cba9697f5cb72b6263e4ed8

doc/build/changelog/unreleased_21/10339.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/mariadb.py
lib/sqlalchemy/dialects/mysql/mariadbconnector.py
lib/sqlalchemy/dialects/mysql/mysqlconnector.py
test/dialect/mysql/test_types.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_21/10339.rst b/doc/build/changelog/unreleased_21/10339.rst
new file mode 100644 (file)
index 0000000..91fe20d
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: usecase, mariadb
+    :tickets: 10339
+
+    Modified the MariaDB dialect so that when using the :class:`_sqltypes.Uuid`
+    datatype with  MariaDB >= 10.7, leaving the
+    :paramref:`_sqltypes.Uuid.native_uuid` parameter at its default of True,
+    the native ``UUID`` datatype will be rendered in DDL and used for database
+    communication, rather than ``CHAR(32)`` (the non-native UUID type) as was
+    the case previously.   This is a behavioral change since 2.0, where the
+    generic :class:`_sqltypes.Uuid` datatype delivered ``CHAR(32)`` for all
+    MySQL and MariaDB variants.   Support for all major DBAPIs is implemented
+    including support for less common "insertmanyvalues" scenarios where UUID
+    values are generated in different ways for primary keys.   Thanks much to
+    Volodymyr Kochetkov for delivering the PR.
+
index 10a05f9cb36cfdfda1d1a1e62cefd4a4b940666d..baf57c91200c4ae1b11f29fb372dd3e82dbca6b3 100644 (file)
 # mypy: ignore-errors
 from .base import MariaDBIdentifierPreparer
 from .base import MySQLDialect
+from ... import util
+from ...sql.sqltypes import UUID
+from ...sql.sqltypes import Uuid
+
+
+class _MariaDBUUID(UUID):
+    def __init__(self, as_uuid: bool = True, native_uuid: bool = True):
+        self.as_uuid = as_uuid
+
+        # the _MariaDBUUID internal type is only invoked for a Uuid() with
+        # native_uuid=True.   for non-native uuid type, the plain Uuid
+        # returns itself due to the workings of the Emulated superclass.
+        assert native_uuid
+
+        # for internal type, force string conversion for result_processor() as
+        # current drivers are returning a string, not a Python UUID object
+        self.native_uuid = False
+
+    @property
+    def native(self):
+        # override to return True, this is a native type, just turning
+        # off native_uuid for internal data handling
+        return True
+
+    def bind_processor(self, dialect):
+        if not dialect.supports_native_uuid or not dialect._allows_uuid_binds:
+            return super().bind_processor(dialect)
+        else:
+            return None
+
+    def _sentinel_value_resolver(self, dialect):
+        """Return a callable that will receive the uuid object or string
+        as it is normally passed to the DB in the parameter set, after
+        bind_processor() is called.  Convert this value to match
+        what it would be as coming back from MariaDB RETURNING.  this seems
+        to be *after* SQLAlchemy's datatype has converted, so these
+        will be UUID objects if as_uuid=True and dashed strings if
+        as_uuid=False
+
+        """
+
+        if not dialect._allows_uuid_binds:
+
+            def process(value):
+                return (
+                    f"{value[0:8]}-{value[8:12]}-"
+                    f"{value[12:16]}-{value[16:20]}-{value[20:]}"
+                )
+
+            return process
+        elif self.as_uuid:
+            return str
+        else:
+            return None
 
 
 class MariaDBDialect(MySQLDialect):
     is_mariadb = True
     supports_statement_cache = True
+    supports_native_uuid = True
+
+    _allows_uuid_binds = True
+
     name = "mariadb"
     preparer = MariaDBIdentifierPreparer
 
+    colspecs = util.update_copy(MySQLDialect.colspecs, {Uuid: _MariaDBUUID})
+
+    def initialize(self, connection):
+        super().initialize(connection)
+
+        self.supports_native_uuid = (
+            self.server_version_info is not None
+            and self.server_version_info >= (10, 7)
+        )
+
 
 def loader(driver):
-    driver_mod = __import__(
+    dialect_mod = __import__(
         "sqlalchemy.dialects.mysql.%s" % driver
     ).dialects.mysql
-    driver_cls = getattr(driver_mod, driver).dialect
-
-    return type(
-        "MariaDBDialect_%s" % driver,
-        (
-            MariaDBDialect,
-            driver_cls,
-        ),
-        {"supports_statement_cache": True},
-    )
+
+    driver_mod = getattr(dialect_mod, driver)
+    if hasattr(driver_mod, "mariadb_dialect"):
+        driver_cls = driver_mod.mariadb_dialect
+        return driver_cls
+    else:
+        driver_cls = driver_mod.dialect
+
+        return type(
+            "MariaDBDialect_%s" % driver,
+            (
+                MariaDBDialect,
+                driver_cls,
+            ),
+            {"supports_statement_cache": True},
+        )
index 2fe3a192aa9e8e6bc547c7dc7500daf1b7b6bb1b..86bc59d45a39392393ffa8289fa3f5c8318b54a0 100644 (file)
@@ -35,6 +35,7 @@ from uuid import UUID as _python_UUID
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLExecutionContext
+from .mariadb import MariaDBDialect
 from ... import sql
 from ... import util
 from ...sql import sqltypes
@@ -279,4 +280,12 @@ class MySQLDialect_mariadbconnector(MySQLDialect):
         )
 
 
+class MariaDBDialect_mariadbconnector(
+    MariaDBDialect, MySQLDialect_mariadbconnector
+):
+    supports_statement_cache = True
+    _allows_uuid_binds = False
+
+
 dialect = MySQLDialect_mariadbconnector
+mariadb_dialect = MariaDBDialect_mariadbconnector
index b1523392d8cc8af9b96f4099609bd3ef9dc32a34..8a6c2da8b4f304a45230ee3ccb3d74d73ca7c741 100644 (file)
@@ -29,6 +29,7 @@ from .base import BIT
 from .base import MySQLCompiler
 from .base import MySQLDialect
 from .base import MySQLIdentifierPreparer
+from .mariadb import MariaDBDialect
 from ... import util
 
 
@@ -176,4 +177,12 @@ class MySQLDialect_mysqlconnector(MySQLDialect):
             super()._set_isolation_level(connection, level)
 
 
+class MariaDBDialect_mysqlconnector(
+    MariaDBDialect, MySQLDialect_mysqlconnector
+):
+    supports_statement_cache = True
+    _allows_uuid_binds = False
+
+
 dialect = MySQLDialect_mysqlconnector
+mariadb_dialect = MariaDBDialect_mysqlconnector
index c73e82a945b9f1a26fb3160facf59f2f9a2c1c99..5c72d2ae8879c3407d897fdf31b2a59a4e9dd454 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy import TypeDecorator
 from sqlalchemy import types as sqltypes
 from sqlalchemy import UnicodeText
 from sqlalchemy.dialects.mysql import base as mysql
+from sqlalchemy.dialects.mysql.mariadb import MariaDBDialect
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -474,6 +475,48 @@ class TypeCompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(type_, sql_text)
 
 
+class MariaDBUUIDTest(fixtures.TestBase, AssertsCompiledSQL):
+    __only_on__ = "mysql", "mariadb"
+    __backend__ = True
+
+    def test_requirements(self):
+        if testing.against("mariadb>=10.7"):
+            assert testing.requires.uuid_data_type.enabled
+        else:
+            assert not testing.requires.uuid_data_type.enabled
+
+    def test_compile_generic(self):
+        if testing.against("mariadb>=10.7"):
+            self.assert_compile(sqltypes.Uuid(), "UUID")
+        else:
+            self.assert_compile(sqltypes.Uuid(), "CHAR(32)")
+
+    def test_compile_upper(self):
+        self.assert_compile(sqltypes.UUID(), "UUID")
+
+    @testing.combinations(
+        (sqltypes.Uuid(), (10, 6, 5), "CHAR(32)"),
+        (sqltypes.Uuid(native_uuid=False), (10, 6, 5), "CHAR(32)"),
+        (sqltypes.Uuid(), (10, 7, 0), "UUID"),
+        (sqltypes.Uuid(native_uuid=False), (10, 7, 0), "CHAR(32)"),
+        (sqltypes.UUID(), (10, 6, 5), "UUID"),
+        (sqltypes.UUID(), (10, 7, 0), "UUID"),
+    )
+    def test_mariadb_uuid_combinations(self, type_, version, res):
+        dialect = MariaDBDialect()
+        dialect.server_version_info = version
+        dialect.supports_native_uuid = version >= (10, 7)
+        self.assert_compile(type_, res, dialect=dialect)
+
+    @testing.combinations(
+        (sqltypes.Uuid(),),
+        (sqltypes.Uuid(native_uuid=False),),
+    )
+    def test_mysql_uuid_combinations(self, type_):
+        dialect = mysql.MySQLDialect()
+        self.assert_compile(type_, "CHAR(32)", dialect=dialect)
+
+
 class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults):
     __dialect__ = mysql.dialect()
     __only_on__ = "mysql", "mariadb"
index 78a933358e238629c8602f73362de041c19f6023..a692cd3fee3c3a43e42a30102472b36a02f0b580 100644 (file)
@@ -1527,8 +1527,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def async_dialect(self):
-        """dialect makes use of await_() to invoke operations on
-        the DBAPI."""
+        """dialect makes use of await_() to invoke operations on the DBAPI."""
 
         return self.asyncio + only_on(
             LambdaPredicate(