From: Mike Bayer Date: Wed, 12 Mar 2025 20:25:48 +0000 (-0400) Subject: expand paren rules for default rendering, sqlite/mysql X-Git-Tag: rel_2_0_40~13^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a936360ef01ab78b83d0c16ebbd61b1c55801ac2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git expand paren rules for default rendering, sqlite/mysql Expanded the rules for when to apply parenthesis to a server default in DDL to suit the general case of a default string that contains non-word characters such as spaces or operators and is not a string literal. Fixed issue in MySQL server default reflection where a default that has spaces would not be correctly reflected. Additionally, expanded the rules for when to apply parenthesis to a server default in DDL to suit the general case of a default string that contains non-word characters such as spaces or operators and is not a string literal. Fixes: #12425 Change-Id: Ie40703dcd5fdc135025d676c01baba57ff3b71ad (cherry picked from commit 1afb820427545e259397b98851a910d7379b2eb8) --- diff --git a/doc/build/changelog/unreleased_20/12425.rst b/doc/build/changelog/unreleased_20/12425.rst new file mode 100644 index 0000000000..fbc1f8a4ef --- /dev/null +++ b/doc/build/changelog/unreleased_20/12425.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, sqlite + :tickets: 12425 + + Expanded the rules for when to apply parenthesis to a server default in DDL + to suit the general case of a default string that contains non-word + characters such as spaces or operators and is not a string literal. + +.. change:: + :tags: bug, mysql + :tickets: 12425 + + Fixed issue in MySQL server default reflection where a default that has + spaces would not be correctly reflected. Additionally, expanded the rules + for when to apply parenthesis to a server default in DDL to suit the + general case of a default string that contains non-word characters such as + spaces or operators and is not a string literal. + diff --git a/doc/build/orm/extensions/asyncio.rst b/doc/build/orm/extensions/asyncio.rst index fbd965d15d..5b88105430 100644 --- a/doc/build/orm/extensions/asyncio.rst +++ b/doc/build/orm/extensions/asyncio.rst @@ -281,7 +281,7 @@ configuration: CREATE TABLE a ( id INTEGER NOT NULL, data VARCHAR NOT NULL, - create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL, + create_date DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, PRIMARY KEY (id) ) ... diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 8bae6193b5..122a7cb2e5 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1928,12 +1928,13 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec.append("AUTO_INCREMENT") else: default = self.get_column_default_string(column) + if default is not None: if ( - isinstance( - column.server_default.arg, functions.FunctionElement - ) - and self.dialect._support_default_function + self.dialect._support_default_function + and not re.match(r"^\s*[\'\"\(]", default) + and "ON UPDATE" not in default + and re.match(r".*\W.*", default) ): colspec.append(f"DEFAULT ({default})") else: diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index 3998be977d..d62390bb84 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -451,7 +451,7 @@ class MySQLTableDefinitionParser: r"(?: +COLLATE +(?P[\w_]+))?" r"(?: +(?P(?:NOT )?NULL))?" r"(?: +DEFAULT +(?P" - r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+" r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" r"))?" r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 96b2414cce..c09fbb32cc 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -932,7 +932,6 @@ from ...engine import processors from ...engine import reflection from ...engine.reflection import ReflectionDefaults from ...sql import coercions -from ...sql import ColumnElement from ...sql import compiler from ...sql import elements from ...sql import roles @@ -1594,9 +1593,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: - if isinstance(column.server_default.arg, ColumnElement): - default = "(" + default + ")" - colspec += " DEFAULT " + default + + if not re.match(r"""^\s*[\'\"\(]""", default) and re.match( + r".*\W.*", default + ): + colspec += f" DEFAULT ({default})" + else: + colspec += f" DEFAULT {default}" if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 8364c15f8f..719692125f 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -274,8 +274,8 @@ def int_within_variance(expected, received, variance): ) -def eq_regex(a, b, msg=None): - assert re.match(b, a), msg or "%r !~ %r" % (a, b) +def eq_regex(a, b, msg=None, flags=0): + assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b) def eq_(a, b, msg=None): diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 93541dca70..a2c3aa531d 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1168,6 +1168,19 @@ class SuiteRequirements(Requirements): """ return self.precision_numerics_many_significant_digits + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.closed() + + @property + def expression_server_defaults(self): + """Target backend supports server side defaults with SQL expressions + for columns""" + + return exclusions.closed() + @property def implicit_decimal_binds(self): """target backend will return a selected Decimal as a Decimal, not diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 2837e9fe0a..0f2a2062a8 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -14,6 +14,7 @@ import sqlalchemy as sa from .. import config from .. import engines from .. import eq_ +from .. import eq_regex from .. import expect_raises from .. import expect_raises_message from .. import expect_warnings @@ -23,6 +24,8 @@ from ..provision import get_temp_table_name from ..provision import temp_table_keyword_args from ..schema import Column from ..schema import Table +from ... import Boolean +from ... import DateTime from ... import event from ... import ForeignKey from ... import func @@ -2883,6 +2886,47 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase): eq_(opts, expected) # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + @testing.combinations( + (Integer, sa.text("10"), r"'?10'?"), + (Integer, "10", r"'?10'?"), + (Boolean, sa.true(), r"1|true"), + ( + Integer, + sa.text("3 + 5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + ( + Integer, + sa.text("(3 * 5)"), + r"3\*5", + testing.requires.expression_server_defaults, + ), + (DateTime, func.now(), r"current_timestamp|now|getdate"), + ( + Integer, + sa.literal_column("3") + sa.literal_column("5"), + r"3\+5", + testing.requires.expression_server_defaults, + ), + argnames="datatype, default, expected_reg", + ) + @testing.requires.server_defaults + def test_server_defaults( + self, metadata, connection, datatype, default, expected_reg + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + reflected = inspect(connection).get_columns("t")[1]["default"] + reflected_sanitized = re.sub(r"[\(\) \']", "", reflected) + eq_regex(reflected_sanitized, expected_reg, flags=re.IGNORECASE) + class NormalizedNameTest(fixtures.TablesTest): __requires__ = ("denormalized_names",) diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 8387d4e07c..f9cfeba05b 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -446,7 +446,7 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( schema.CreateTable(tbl), "CREATE TABLE testtbl (" - "time DATETIME DEFAULT (CURRENT_TIMESTAMP), " + "time DATETIME DEFAULT CURRENT_TIMESTAMP, " "name VARCHAR(255) DEFAULT 'some str', " "description VARCHAR(255) DEFAULT (lower('hi')), " "data JSON DEFAULT (json_object()))", diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index 9cbc38378f..96650dab56 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -5,16 +5,21 @@ from sqlalchemy import Boolean from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import Computed +from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import false from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text from sqlalchemy import true from sqlalchemy.testing import assert_raises from sqlalchemy.testing import combinations @@ -50,6 +55,35 @@ class IdiosyncrasyTest(fixtures.TestBase): ) +class ServerDefaultCreateTest(fixtures.TestBase): + @testing.combinations( + (Integer, text("10")), + (Integer, text("'10'")), + (Integer, "10"), + (Boolean, true()), + (Integer, text("3+5"), testing.requires.mysql_expression_defaults), + (Integer, text("3 + 5"), testing.requires.mysql_expression_defaults), + (Integer, text("(3 * 5)"), testing.requires.mysql_expression_defaults), + (DateTime, func.now()), + ( + Integer, + literal_column("3") + literal_column("5"), + testing.requires.mysql_expression_defaults, + ), + argnames="datatype, default", + ) + def test_create_server_defaults( + self, connection, metadata, datatype, default + ): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("thecol", datatype, server_default=default), + ) + t.create(connection) + + class MatchTest(fixtures.TablesTest): __only_on__ = "mysql", "mariadb" __backend__ = True diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 819bf8aa06..c2c63e9ef0 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1032,39 +1032,60 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): ")", ) - def test_column_defaults_ddl(self): + @testing.combinations( + ( + Boolean(create_constraint=True), + sql.false(), + "BOOLEAN DEFAULT 0, CHECK (x IN (0, 1))", + ), + ( + String(), + func.sqlite_version(), + "VARCHAR DEFAULT (sqlite_version())", + ), + (Integer(), func.abs(-5) + 17, "INTEGER DEFAULT (abs(-5) + 17)"), + ( + # test #12425 + String(), + func.now(), + "VARCHAR DEFAULT CURRENT_TIMESTAMP", + ), + ( + # test #12425 + String(), + func.datetime(func.now(), "localtime"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # test #12425 + String(), + text("datetime(CURRENT_TIMESTAMP, 'localtime')"), + "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))", + ), + ( + # default with leading spaces that should not be + # parenthesized + String, + text(" 'some default'"), + "VARCHAR DEFAULT 'some default'", + ), + (String, text("'some default'"), "VARCHAR DEFAULT 'some default'"), + argnames="datatype,default,expected", + ) + def test_column_defaults_ddl(self, datatype, default, expected): t = Table( "t", MetaData(), Column( "x", - Boolean(create_constraint=True), - server_default=sql.false(), + datatype, + server_default=default, ), ) self.assert_compile( CreateTable(t), - "CREATE TABLE t (x BOOLEAN DEFAULT (0), CHECK (x IN (0, 1)))", - ) - - t = Table( - "t", - MetaData(), - Column("x", String(), server_default=func.sqlite_version()), - ) - self.assert_compile( - CreateTable(t), - "CREATE TABLE t (x VARCHAR DEFAULT (sqlite_version()))", - ) - - t = Table( - "t", - MetaData(), - Column("x", Integer(), server_default=func.abs(-5) + 17), - ) - self.assert_compile( - CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT (abs(-5) + 17))" + f"CREATE TABLE t (x {expected})", ) def test_create_partial_index(self): diff --git a/test/requirements.py b/test/requirements.py index 12c25ece1a..2311f6e35f 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1,7 +1,4 @@ -"""Requirements specific to SQLAlchemy's own unit tests. - - -""" +"""Requirements specific to SQLAlchemy's own unit tests.""" from sqlalchemy import exc from sqlalchemy.sql import sqltypes @@ -212,6 +209,19 @@ class DefaultRequirements(SuiteRequirements): ] ) + @property + def server_defaults(self): + """Target backend supports server side defaults for columns""" + + return exclusions.open() + + @property + def expression_server_defaults(self): + return skip_if( + lambda config: against(config, "mysql", "mariadb") + and not self._mysql_expression_defaults(config) + ) + @property def qmark_paramstyle(self): return only_on(["sqlite", "+pyodbc"]) @@ -1818,6 +1828,15 @@ class DefaultRequirements(SuiteRequirements): # 2. they dont enforce check constraints return not self._mysql_check_constraints_exist(config) + def _mysql_expression_defaults(self, config): + return (against(config, ["mysql", "mariadb"])) and ( + config.db.dialect._support_default_function + ) + + @property + def mysql_expression_defaults(self): + return only_if(self._mysql_expression_defaults) + def _mysql_not_mariadb_102(self, config): return (against(config, ["mysql", "mariadb"])) and ( not config.db.dialect._is_mariadb