From: Mike Bayer Date: Wed, 12 Mar 2025 20:25:48 +0000 (-0400) Subject: expand paren rules for default rendering, sqlite/mysql X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1afb820427545e259397b98851a910d7379b2eb8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git expand paren rules for default rendering, sqlite/mysql Expanded the rules for when to apply parenthesis to a server default in DDL to suit the general case of a default string that contains non-word characters such as spaces or operators and is not a string literal. Fixed issue in MySQL server default reflection where a default that has spaces would not be correctly reflected. Additionally, expanded the rules for when to apply parenthesis to a server default in DDL to suit the general case of a default string that contains non-word characters such as spaces or operators and is not a string literal. Fixes: #12425 Change-Id: Ie40703dcd5fdc135025d676c01baba57ff3b71ad --- diff --git a/doc/build/changelog/unreleased_20/12425.rst b/doc/build/changelog/unreleased_20/12425.rst new file mode 100644 index 0000000000..fbc1f8a4ef --- /dev/null +++ b/doc/build/changelog/unreleased_20/12425.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, sqlite + :tickets: 12425 + + Expanded the rules for when to apply parenthesis to a server default in DDL + to suit the general case of a default string that contains non-word + characters such as spaces or operators and is not a string literal. + +.. change:: + :tags: bug, mysql + :tickets: 12425 + + Fixed issue in MySQL server default reflection where a default that has + spaces would not be correctly reflected. Additionally, expanded the rules + for when to apply parenthesis to a server default in DDL to suit the + general case of a default string that contains non-word characters such as + spaces or operators and is not a string literal. + diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index 784265f625..b06fb6315f 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -273,7 +273,7 @@ configuration: CREATE TABLE a ( id INTEGER NOT NULL, data VARCHAR NOT NULL, - create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL, + create_date DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, PRIMARY KEY (id) ) ... diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index fd60d7ba65..34aaedb849 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1946,12 +1946,13 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec.append("AUTO_INCREMENT") else: default = self.get_column_default_string(column) + if default is not None: if ( - isinstance( - column.server_default.arg, functions.FunctionElement - ) - and self.dialect._support_default_function + self.dialect._support_default_function + and not re.match(r"^\s*[\'\"\(]", default) + and "ON UPDATE" not in default + and re.match(r".*\W.*", default) ): colspec.append(f"DEFAULT ({default})") else: diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index 3998be977d..d62390bb84 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -451,7 +451,7 @@ class MySQLTableDefinitionParser: r"(?: +COLLATE +(?P[\w_]+))?" r"(?: +(?P(?:NOT )?NULL))?" r"(?: +DEFAULT +(?P" - r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+" r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" r"))?" r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 7b8e42a285..b509159111 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -932,7 +932,6 @@ from ...engine import processors from ...engine import reflection from ...engine.reflection import ReflectionDefaults from ...sql import coercions -from ...sql import ColumnElement from ...sql import compiler from ...sql import elements from ...sql import roles @@ -1589,9 +1588,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: - if isinstance(column.server_default.arg, ColumnElement): - default = "(" + default + ")" - colspec += " DEFAULT " + default + + if not re.match(r"""^\s*[\'\"\(]""", default) and re.match( + r".*\W.*", default + ): + colspec += f" DEFAULT ({default})" + else: + colspec += f" DEFAULT {default}" if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index effe50d481..a22da65a62 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -280,8 +280,8 @@ def int_within_variance(expected, received, variance): ) -def eq_regex(a, b, msg=None): - assert re.match(b, a), msg or "%r !~ %r" % (a, b) +def eq_regex(a, b, msg=None, flags=0): + assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b) def eq_(a, b, msg=None): diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index bddefc0d2a..7c4d2fb605 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1168,6 +1168,19 @@ class SuiteRequirements(Requirements): """ return self.precision_numerics_many_significant_digits + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.closed() + + @property + def expression_server_defaults(self): + """Target backend supports server side defaults with SQL expressions + for columns""" + + return exclusions.closed() + @property def implicit_decimal_binds(self): """target backend will return a selected Decimal as a Decimal, not diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index efc66b44a9..6be86cde10 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -14,6 +14,7 @@ import sqlalchemy as sa from .. import config from .. import engines from .. import eq_ +from .. import eq_regex from .. import expect_raises from .. import expect_raises_message from .. import expect_warnings @@ -23,6 +24,8 @@ from ..provision import get_temp_table_name from ..provision import temp_table_keyword_args from ..schema import Column from ..schema import Table +from ... import Boolean +from ... import DateTime from ... import event from ... import ForeignKey from ... import func @@ -2884,6 +2887,47 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): eq_(opts, expected) # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + @testing.combinations( + (Integer, sa.text("10"), r"'?10'?"), + (Integer, "10", r"'?10'?"), + (Boolean, sa.true(), r"1|true"), + ( + Integer, + sa.text("3 + 5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + ( + Integer, + sa.text("(3 * 5)"), + r"3\*5", + testing.requires.expression_server_defaults, + ), + (DateTime, func.now(), r"current_timestamp|now|getdate"), + ( + Integer, + sa.literal_column("3") + sa.literal_column("5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + argnames="datatype, default, expected_reg", + ) + @testing.requires.server_defaults + def test_server_defaults( + self, metadata, connection, datatype, default, expected_reg + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + reflected = inspect(connection).get_columns("t")[1]["default"] + reflected_sanitized = re.sub(r"[\(\) \']", "", reflected) + eq_regex(reflected_sanitized, expected_reg, flags=re.IGNORECASE) + class NormalizedNameTest(fixtures.TablesTest): __requires__ = ("denormalized_names",) diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 553298c549..dc36973a9e 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -450,7 +450,7 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE testtbl (" - "time DATETIME DEFAULT (CURRENT_TIMESTAMP), " + "time DATETIME DEFAULT CURRENT_TIMESTAMP, " "name VARCHAR(255) DEFAULT 'some str', " "description VARCHAR(255) DEFAULT (lower('hi')), " "data JSON DEFAULT (json_object()))", diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 973fe3dbc2..cd1e9327d3 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -5,17 +5,22 @@ from sqlalchemy import Boolean from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Computed +from sqlalchemy import DateTime from sqlalchemy import delete from sqlalchemy import exc from sqlalchemy import false from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text from sqlalchemy import true from sqlalchemy import update from sqlalchemy.dialects.mysql import limit @@ -55,6 +60,35 @@ class IdiosyncrasyTest(fixtures.TestBase): ) +class ServerDefaultCreateTest(fixtures.TestBase): + @testing.combinations( + (Integer, text("10")), + (Integer, text("'10'")), + (Integer, "10"), + (Boolean, true()), + (Integer, text("3+5"), testing.requires.mysql_expression_defaults), + (Integer, text("3 + 5"), testing.requires.mysql_expression_defaults), + (Integer, text("(3 * 5)"), testing.requires.mysql_expression_defaults), + (DateTime, func.now()), + ( + Integer, + literal_column("3") + literal_column("5"), + testing.requires.mysql_expression_defaults, + ), + argnames="datatype, default", + ) + def test_create_server_defaults( + self, connection, metadata, datatype, default + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + class MatchTest(fixtures.TablesTest): __only_on__ = "mysql", "mariadb" __backend__ = True diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index c5b4f62e29..104cc86e2b 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1033,39 +1033,60 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): ")", ) - def test_column_defaults_ddl(self): + @testing.combinations( + ( + Boolean(create_constraint=True), + sql.false(), + "BOOLEAN DEFAULT 0, CHECK (x IN (0, 1))", + ), + ( + String(), + func.sqlite_version(), + "VARCHAR DEFAULT (sqlite_version())", + ), + (Integer(), func.abs(-5) + 17, "INTEGER DEFAULT (abs(-5) + 17)"), + ( + # test #12425 + String(), + func.now(), + "VARCHAR DEFAULT CURRENT_TIMESTAMP", + ), + ( + # test #12425 + String(), + func.datetime(func.now(), "localtime"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # test #12425 + String(), + text("datetime(CURRENT_TIMESTAMP, 'localtime')"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # default with leading spaces that should not be + # parenthesized + String, + text(" 'some default'"), + "VARCHAR DEFAULT 'some default'", + ), + (String, text("'some default'"), "VARCHAR DEFAULT 'some default'"), + argnames="datatype,default,expected", + ) + def test_column_defaults_ddl(self, datatype, default, expected): t = Table( "t", MetaData(), Column( "x", - Boolean(create_constraint=True), - server_default=sql.false(), + datatype, + server_default=default, ), ) self.assert_compile( CreateTable(t), - "CREATE TABLE t (x BOOLEAN DEFAULT (0), CHECK (x IN (0, 1)))", - ) - - t = Table( - "t", - MetaData(), - Column("x", String(), server_default=func.sqlite_version()), - ) - self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x VARCHAR DEFAULT (sqlite_version()))", - ) - - t = Table( - "t", - MetaData(), - Column("x", Integer(), server_default=func.abs(-5) + 17), - ) - self.assert_compile( - CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT (abs(-5) + 17))" + f"CREATE TABLE t (x {expected})", ) def test_create_partial_index(self): diff --git a/test/requirements.py b/test/requirements.py index 92fadf45da..1f4a4eb392 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1,7 +1,4 @@ -"""Requirements specific to SQLAlchemy's own unit tests. - - -""" +"""Requirements specific to SQLAlchemy's own unit tests.""" from sqlalchemy import exc from sqlalchemy.sql import sqltypes @@ -212,6 +209,19 @@ class DefaultRequirements(SuiteRequirements): ] ) + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.open() + + @property + def expression_server_defaults(self): + return skip_if( + lambda config: against(config, "mysql", "mariadb") + and not self._mysql_expression_defaults(config) + ) + @property def qmark_paramstyle(self): return only_on(["sqlite", "+pyodbc"]) @@ -1814,6 +1824,15 @@ class DefaultRequirements(SuiteRequirements): # 2. they dont enforce check constraints return not self._mysql_check_constraints_exist(config) + def _mysql_expression_defaults(self, config): + return (against(config, ["mysql", "mariadb"])) and ( + config.db.dialect._support_default_function + ) + + @property + def mysql_expression_defaults(self): + return only_if(self._mysql_expression_defaults) + def _mysql_not_mariadb_102(self, config): return (against(config, ["mysql", "mariadb"])) and ( not config.db.dialect._is_mariadb