]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add server-specific support for alias
authorCaspar Wylie <casparwylie@Caspars-MacBook-Pro.local>
Fri, 3 Feb 2023 14:14:21 +0000 (14:14 +0000)
committerCaspar Wylie <casparwylie@Caspars-MacBook-Pro.local>
Fri, 3 Feb 2023 14:17:30 +0000 (14:17 +0000)
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

index 8949d4abf255b99e00cdfcc3009cf642b8576cc4..a732581952a92ec8ec4c8c6c04602bf335ef71ac 100644 (file)
@@ -249,7 +249,7 @@ different isolation level settings.  See the discussion at
 
     :ref:`dbapi_autocommit`
 
-AUTO_INCREMENT Behavior
+AUTO_INCREMENT Behaviour
 -----------------------
 
 When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on
@@ -1339,12 +1339,17 @@ class MySQLCompiler(compiler.SQLCompiler):
                         isinstance(obj, elements.ColumnClause)
                         and obj.table is on_duplicate.inserted_alias
                     ):
-                        if not alias_clause:
+                        if self.dialect.supports_mysql_on_duplicate_alias:
+                            column_literal_clause = (
+                                f"{ON_DUP_ALIAS_NAME}."
+                                + self.preparer.quote(obj.name)
+                            )
                             alias_clause = f"AS {ON_DUP_ALIAS_NAME} "
-                        return literal_column(
-                            f"{ON_DUP_ALIAS_NAME}."
-                            + self.preparer.quote(obj.name)
-                        )
+                        else:
+                            column_literal_clause = (
+                                "VALUES(" + self.preparer.quote(obj.name) + ")"
+                            )
+                        return literal_column(column_literal_clause)
                     else:
                         # element is not replaced
                         return None
@@ -2393,6 +2398,9 @@ class MySQLDialect(default.DefaultDialect):
     supports_for_update_of = False  # default for MySQL ...
     # ... may be updated to True for MySQL 8+ in initialize()
 
+    supports_mysql_on_duplicate_alias = False  # Only available ...
+    # ... in MySQL 8+
+
     # MySQL doesn't support "DEFAULT VALUES" but *does* support
     # "VALUES (DEFAULT)"
     supports_default_values = False
@@ -2785,6 +2793,10 @@ class MySQLDialect(default.DefaultDialect):
             self.is_mariadb and self.server_version_info >= (10, 5)
         )
 
+        self.supports_mysql_on_duplicate_alias = (
+            self._is_mysql and self.server_version_info >= (8, 0, 20)
+        )
+
         self._warn_for_known_db_issues()
 
     def _warn_for_known_db_issues(self):
index 1184c9e1b48e6f61fb2e16823075619d7283346f..0ec6457f0e642cc4322133935f9a12b9caadcfad 100644 (file)
@@ -508,7 +508,6 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
         table1 = table(
             "mytable", column("myid"), column("name"), column("description")
         )
-
         self.assert_compile(
             table1.select()
             .where(table1.c.myid == 7)
@@ -1108,12 +1107,22 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             bar=stmt.inserted.bar, baz=stmt.inserted.baz
         )
         expected_sql = (
+            "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
+            "ON DUPLICATE KEY UPDATE bar = VALUES(bar), baz = VALUES(baz)"
+        )
+        dialect = mysql.dialect()
+        self.assert_compile(stmt, expected_sql, dialect=dialect)
+
+        expected_alias_supported_sql = (
             "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
             f"AS {mysql.ON_DUP_ALIAS_NAME} ON DUPLICATE KEY UPDATE "
             f"bar = {mysql.ON_DUP_ALIAS_NAME}.bar, "
             f"baz = {mysql.ON_DUP_ALIAS_NAME}.baz"
         )
-        self.assert_compile(stmt, expected_sql)
+        dialect.supports_mysql_on_duplicate_alias = True
+        self.assert_compile(
+            stmt, expected_alias_supported_sql, dialect=dialect
+        )
 
     def test_from_literal(self):
         stmt = insert(self.table).values(
@@ -1146,15 +1155,35 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
             baz=stmt.inserted.baz + "some literal" + stmt.inserted.bar,
         )
         expected_sql = (
+            "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON "
+            "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), "
+            "baz = (concat(VALUES(baz), %s, VALUES(bar)))"
+        )
+        dialect = mysql.dialect()
+        self.assert_compile(
+            stmt,
+            expected_sql,
+            checkparams={
+                "id_m0": 1,
+                "bar_m0": "ab",
+                "id_m1": 2,
+                "bar_m1": "b",
+                "baz_1": "some literal",
+            },
+            dialect=dialect,
+        )
+
+        expected_alias_supported_sql = (
             "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
             f"AS {mysql.ON_DUP_ALIAS_NAME} ON DUPLICATE KEY UPDATE bar = "
             f"coalesce({mysql.ON_DUP_ALIAS_NAME}.bar), "
             f"baz = (concat({mysql.ON_DUP_ALIAS_NAME}.baz, %s, "
             f"{mysql.ON_DUP_ALIAS_NAME}.bar))"
         )
+        dialect.supports_mysql_on_duplicate_alias = True
         self.assert_compile(
             stmt,
-            expected_sql,
+            expected_alias_supported_sql,
             checkparams={
                 "id_m0": 1,
                 "bar_m0": "ab",
@@ -1162,6 +1191,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
                 "bar_m1": "b",
                 "baz_1": "some literal",
             },
+            dialect=dialect,
         )