]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed syntax error in mysql function defaults
authorhuuyafwww <huuya1234fwww@gmail.com>
Sat, 5 Oct 2024 06:04:13 +0000 (02:04 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 5 Oct 2024 07:42:00 +0000 (09:42 +0200)
Fixed a bug that caused a syntax error when a function was specified
to server_default when creating a column in MySQL or MariaDB.

Fixes #11317
Closes: #11953
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11953
Pull-request-sha: d93ac419a9201134e9c4845dd2e4dc48db4b6f78

Change-Id: I67fc83867df2b7dcf591c8f53b7a97afb90ebba9

doc/build/changelog/unreleased_20/11317.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_20/11317.rst b/doc/build/changelog/unreleased_20/11317.rst
new file mode 100644 (file)
index 0000000..e41a073
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, schema
+    :tickets: 11317
+
+    Fixed a bug that caused a syntax error when a function was specified
+    to server_default when creating a column in MySQL or MariaDB.
+    Pull request courtesy of huuya.
index aa99bf4d6849d24fdae11f5924810e02e748708f..f5eb169f8c46f09eeb62636f3462f4e8de4ed62a 100644 (file)
@@ -1850,7 +1850,15 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
         else:
             default = self.get_column_default_string(column)
             if default is not None:
-                colspec.append("DEFAULT " + default)
+                if (
+                    isinstance(
+                        column.server_default.arg, functions.FunctionElement
+                    )
+                    and self.dialect._support_default_function
+                ):
+                    colspec.append(f"DEFAULT ({default})")
+                else:
+                    colspec.append("DEFAULT " + default)
         return " ".join(colspec)
 
     def post_create_table(self, table):
@@ -2895,6 +2903,17 @@ class MySQLDialect(default.DefaultDialect):
             # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature  # noqa
             return self.server_version_info >= (8, 0, 17)
 
+    @property
+    def _support_default_function(self):
+        if not self.server_version_info:
+            return False
+        elif self.is_mariadb:
+            # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/
+            return self.server_version_info >= (10, 2, 1)
+        else:
+            # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa
+            return self.server_version_info >= (8, 0, 13)
+
     @property
     def _is_mariadb(self):
         return self.is_mariadb
index 189390659add8df9c39b465ed5a9ae652895543e..f0dcb5838847c92c0884e3aa0899120239dd766b 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy import Index
 from sqlalchemy import INT
 from sqlalchemy import Integer
 from sqlalchemy import Interval
+from sqlalchemy import JSON
 from sqlalchemy import LargeBinary
 from sqlalchemy import literal
 from sqlalchemy import MetaData
@@ -406,6 +407,56 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             "PRIMARY KEY (data) USING btree)",
         )
 
+    @testing.combinations(
+        (True, True, (10, 2, 2)),
+        (True, True, (10, 2, 1)),
+        (False, True, (10, 2, 0)),
+        (True, False, (8, 0, 14)),
+        (True, False, (8, 0, 13)),
+        (False, False, (8, 0, 12)),
+        argnames="has_brackets,is_mariadb,version",
+    )
+    def test_create_server_default_with_function_using(
+        self, has_brackets, is_mariadb, version
+    ):
+        dialect = mysql.dialect(is_mariadb=is_mariadb)
+        dialect.server_version_info = version
+
+        m = MetaData()
+        tbl = Table(
+            "testtbl",
+            m,
+            Column("time", DateTime, server_default=func.current_timestamp()),
+            Column("name", String(255), server_default="some str"),
+            Column(
+                "description", String(255), server_default=func.lower("hi")
+            ),
+            Column("data", JSON, server_default=func.json_object()),
+        )
+
+        eq_(dialect._support_default_function, has_brackets)
+
+        if has_brackets:
+            self.assert_compile(
+                schema.CreateTable(tbl),
+                "CREATE TABLE testtbl ("
+                "time DATETIME DEFAULT (CURRENT_TIMESTAMP), "
+                "name VARCHAR(255) DEFAULT 'some str', "
+                "description VARCHAR(255) DEFAULT (lower('hi')), "
+                "data JSON DEFAULT (json_object()))",
+                dialect=dialect,
+            )
+        else:
+            self.assert_compile(
+                schema.CreateTable(tbl),
+                "CREATE TABLE testtbl ("
+                "time DATETIME DEFAULT CURRENT_TIMESTAMP, "
+                "name VARCHAR(255) DEFAULT 'some str', "
+                "description VARCHAR(255) DEFAULT lower('hi'), "
+                "data JSON DEFAULT json_object())",
+                dialect=dialect,
+            )
+
     def test_create_index_expr(self):
         m = MetaData()
         t1 = Table("foo", m, Column("x", Integer))