From: huuyafwww Date: Sat, 5 Oct 2024 06:04:13 +0000 (-0400) Subject: Fixed syntax error in mysql function defaults X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=40e990aab3f92051f3c693a81de938ab3b4eb5e4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixed syntax error in mysql function defaults Fixed a bug that caused a syntax error when a function was specified to server_default when creating a column in MySQL or MariaDB. Fixes #11317 Closes: #11953 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11953 Pull-request-sha: d93ac419a9201134e9c4845dd2e4dc48db4b6f78 Change-Id: I67fc83867df2b7dcf591c8f53b7a97afb90ebba9 --- diff --git a/doc/build/changelog/unreleased_20/11317.rst b/doc/build/changelog/unreleased_20/11317.rst new file mode 100644 index 0000000000..e41a0733d2 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11317.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, schema + :tickets: 11317 + + Fixed a bug that caused a syntax error when a function was specified + to server_default when creating a column in MySQL or MariaDB. + Pull request courtesy of huuya. diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index aa99bf4d68..f5eb169f8c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1850,7 +1850,15 @@ class MySQLDDLCompiler(compiler.DDLCompiler): else: default = self.get_column_default_string(column) if default is not None: - colspec.append("DEFAULT " + default) + if ( + isinstance( + column.server_default.arg, functions.FunctionElement + ) + and self.dialect._support_default_function + ): + colspec.append(f"DEFAULT ({default})") + else: + colspec.append("DEFAULT " + default) return " ".join(colspec) def post_create_table(self, table): @@ -2895,6 +2903,17 @@ class MySQLDialect(default.DefaultDialect): # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa return self.server_version_info >= (8, 0, 17) + @property + def _support_default_function(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/ + return self.server_version_info >= (10, 2, 1) + else: + # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa + return self.server_version_info >= (8, 0, 13) + @property def _is_mariadb(self): return self.is_mariadb diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 189390659a..f0dcb58388 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -25,6 +25,7 @@ from sqlalchemy import Index from sqlalchemy import INT from sqlalchemy import Integer from sqlalchemy import Interval +from sqlalchemy import JSON from sqlalchemy import LargeBinary from sqlalchemy import literal from sqlalchemy import MetaData @@ -406,6 +407,56 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): "PRIMARY KEY (data) USING btree)", ) + @testing.combinations( + (True, True, (10, 2, 2)), + (True, True, (10, 2, 1)), + (False, True, (10, 2, 0)), + (True, False, (8, 0, 14)), + (True, False, (8, 0, 13)), + (False, False, (8, 0, 12)), + argnames="has_brackets,is_mariadb,version", + ) + def test_create_server_default_with_function_using( + self, has_brackets, is_mariadb, version + ): + dialect = mysql.dialect(is_mariadb=is_mariadb) + dialect.server_version_info = version + + m = MetaData() + tbl = Table( + "testtbl", + m, + Column("time", DateTime, server_default=func.current_timestamp()), + Column("name", String(255), server_default="some str"), + Column( + "description", String(255), server_default=func.lower("hi") + ), + Column("data", JSON, server_default=func.json_object()), + ) + + eq_(dialect._support_default_function, has_brackets) + + if has_brackets: + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE testtbl (" + "time DATETIME DEFAULT (CURRENT_TIMESTAMP), " + "name VARCHAR(255) DEFAULT 'some str', " + "description VARCHAR(255) DEFAULT (lower('hi')), " + "data JSON DEFAULT (json_object()))", + dialect=dialect, + ) + else: + self.assert_compile( + schema.CreateTable(tbl), + "CREATE TABLE testtbl (" + "time DATETIME DEFAULT CURRENT_TIMESTAMP, " + "name VARCHAR(255) DEFAULT 'some str', " + "description VARCHAR(255) DEFAULT lower('hi'), " + "data JSON DEFAULT json_object())", + dialect=dialect, + ) + def test_create_index_expr(self): m = MetaData() t1 = Table("foo", m, Column("x", Integer))