]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
expand paren rules for default rendering, sqlite/mysql
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Mar 2025 20:25:48 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Mar 2025 19:02:30 +0000 (15:02 -0400)
Expanded the rules for when to apply parenthesis to a server default in DDL
to suit the general case of a default string that contains non-word
characters such as spaces or operators and is not a string literal.

Fixed issue in MySQL server default reflection where a default that has
spaces would not be correctly reflected.  Additionally, expanded the rules
for when to apply parenthesis to a server default in DDL to suit the
general case of a default string that contains non-word characters such as
spaces or operators and is not a string literal.

Fixes: #12425
Change-Id: Ie40703dcd5fdc135025d676c01baba57ff3b71ad

12 files changed:
doc/build/changelog/unreleased_20/12425.rst [new file with mode: 0644]
doc/build/orm/extensions/asyncio.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/reflection.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_query.py
test/dialect/test_sqlite.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/12425.rst b/doc/build/changelog/unreleased_20/12425.rst
new file mode 100644 (file)
index 0000000..fbc1f8a
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sqlite
+    :tickets: 12425
+
+    Expanded the rules for when to apply parenthesis to a server default in DDL
+    to suit the general case of a default string that contains non-word
+    characters such as spaces or operators and is not a string literal.
+
+.. change::
+    :tags: bug, mysql
+    :tickets: 12425
+
+    Fixed issue in MySQL server default reflection where a default that has
+    spaces would not be correctly reflected.  Additionally, expanded the rules
+    for when to apply parenthesis to a server default in DDL to suit the
+    general case of a default string that contains non-word characters such as
+    spaces or operators and is not a string literal.
+
index 784265f625dd325c818c861bae58695e389979e0..b06fb6315f18eba5b0fd2780c98cd26bafb1617c 100644 (file)
@@ -273,7 +273,7 @@ configuration:
     CREATE TABLE a (
         id INTEGER NOT NULL,
         data VARCHAR NOT NULL,
-        create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
+        create_date DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
         PRIMARY KEY (id)
     )
     ...
index fd60d7ba65c7072cc5115496c2fa63e64e46e6af..34aaedb849cfd546ff7fc6848de53a888cfecb3a 100644 (file)
@@ -1946,12 +1946,13 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             colspec.append("AUTO_INCREMENT")
         else:
             default = self.get_column_default_string(column)
+
             if default is not None:
                 if (
-                    isinstance(
-                        column.server_default.arg, functions.FunctionElement
-                    )
-                    and self.dialect._support_default_function
+                    self.dialect._support_default_function
+                    and not re.match(r"^\s*[\'\"\(]", default)
+                    and "ON UPDATE" not in default
+                    and re.match(r".*\W.*", default)
                 ):
                     colspec.append(f"DEFAULT ({default})")
                 else:
index 3998be977d9ab5b89ca9ad30db5d1886ca244f73..d62390bb8457d63da4eb3e02fff96bc1e499001b 100644 (file)
@@ -451,7 +451,7 @@ class MySQLTableDefinitionParser:
             r"(?: +COLLATE +(?P<collate>[\w_]+))?"
             r"(?: +(?P<notnull>(?:NOT )?NULL))?"
             r"(?: +DEFAULT +(?P<default>"
-            r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
+            r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
             r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
             r"))?"
             r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
index 7b8e42a2854ba23484fb88a4b80bc18656b511a9..b50915911116c791b8425044d0ff33af12b813f3 100644 (file)
@@ -932,7 +932,6 @@ from ...engine import processors
 from ...engine import reflection
 from ...engine.reflection import ReflectionDefaults
 from ...sql import coercions
-from ...sql import ColumnElement
 from ...sql import compiler
 from ...sql import elements
 from ...sql import roles
@@ -1589,9 +1588,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
         colspec = self.preparer.format_column(column) + " " + coltype
         default = self.get_column_default_string(column)
         if default is not None:
-            if isinstance(column.server_default.arg, ColumnElement):
-                default = "(" + default + ")"
-            colspec += " DEFAULT " + default
+
+            if not re.match(r"""^\s*[\'\"\(]""", default) and re.match(
+                r".*\W.*", default
+            ):
+                colspec += f" DEFAULT ({default})"
+            else:
+                colspec += f" DEFAULT {default}"
 
         if not column.nullable:
             colspec += " NOT NULL"
index effe50d4810952fcd2cb4c00fd7485b15ade4e1b..a22da65a625bc2ce6cd24908ef39490ebcc2e039 100644 (file)
@@ -280,8 +280,8 @@ def int_within_variance(expected, received, variance):
     )
 
 
-def eq_regex(a, b, msg=None):
-    assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+def eq_regex(a, b, msg=None, flags=0):
+    assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b)
 
 
 def eq_(a, b, msg=None):
index bddefc0d2a3f70a141721d9415fcf5d0970d749b..7c4d2fb605b4184ef287d06ca254a2c2b9a597eb 100644 (file)
@@ -1168,6 +1168,19 @@ class SuiteRequirements(Requirements):
         """
         return self.precision_numerics_many_significant_digits
 
+    @property
+    def server_defaults(self):
+        """Target backend supports server side defaults for columns"""
+
+        return exclusions.closed()
+
+    @property
+    def expression_server_defaults(self):
+        """Target backend supports server side defaults with SQL expressions
+        for columns"""
+
+        return exclusions.closed()
+
     @property
     def implicit_decimal_binds(self):
         """target backend will return a selected Decimal as a Decimal, not
index efc66b44a97523c92bf56581ac9267c1eb006d2b..6be86cde106e1f597cd4abc3f301df94f8ca34d9 100644 (file)
@@ -14,6 +14,7 @@ import sqlalchemy as sa
 from .. import config
 from .. import engines
 from .. import eq_
+from .. import eq_regex
 from .. import expect_raises
 from .. import expect_raises_message
 from .. import expect_warnings
@@ -23,6 +24,8 @@ from ..provision import get_temp_table_name
 from ..provision import temp_table_keyword_args
 from ..schema import Column
 from ..schema import Table
+from ... import Boolean
+from ... import DateTime
 from ... import event
 from ... import ForeignKey
 from ... import func
@@ -2884,6 +2887,47 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase):
         eq_(opts, expected)
         # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
 
+    @testing.combinations(
+        (Integer, sa.text("10"), r"'?10'?"),
+        (Integer, "10", r"'?10'?"),
+        (Boolean, sa.true(), r"1|true"),
+        (
+            Integer,
+            sa.text("3 + 5"),
+            r"3\+5",
+            testing.requires.expression_server_defaults,
+        ),
+        (
+            Integer,
+            sa.text("(3 * 5)"),
+            r"3\*5",
+            testing.requires.expression_server_defaults,
+        ),
+        (DateTime, func.now(), r"current_timestamp|now|getdate"),
+        (
+            Integer,
+            sa.literal_column("3") + sa.literal_column("5"),
+            r"3\+5",
+            testing.requires.expression_server_defaults,
+        ),
+        argnames="datatype, default, expected_reg",
+    )
+    @testing.requires.server_defaults
+    def test_server_defaults(
+        self, metadata, connection, datatype, default, expected_reg
+    ):
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("thecol", datatype, server_default=default),
+        )
+        t.create(connection)
+
+        reflected = inspect(connection).get_columns("t")[1]["default"]
+        reflected_sanitized = re.sub(r"[\(\) \']", "", reflected)
+        eq_regex(reflected_sanitized, expected_reg, flags=re.IGNORECASE)
+
 
 class NormalizedNameTest(fixtures.TablesTest):
     __requires__ = ("denormalized_names",)
index 553298c549bc9575b2968df1bc42250eb21e658d..dc36973a9eaaf6b149f91724ed9336a5e98cf07f 100644 (file)
@@ -450,7 +450,7 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             self.assert_compile(
                 schema.CreateTable(tbl),
                 "CREATE TABLE testtbl ("
-                "time DATETIME DEFAULT (CURRENT_TIMESTAMP), "
+                "time DATETIME DEFAULT CURRENT_TIMESTAMP, "
                 "name VARCHAR(255) DEFAULT 'some str', "
                 "description VARCHAR(255) DEFAULT (lower('hi')), "
                 "data JSON DEFAULT (json_object()))",
index 973fe3dbc29ec8a06380dceb978d0bd995f50630..cd1e9327d3f44b668fe0a67bd8ab1f3cf68ab6dc 100644 (file)
@@ -5,17 +5,22 @@ from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import Computed
+from sqlalchemy import DateTime
 from sqlalchemy import delete
 from sqlalchemy import exc
 from sqlalchemy import false
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import schema
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy import update
 from sqlalchemy.dialects.mysql import limit
@@ -55,6 +60,35 @@ class IdiosyncrasyTest(fixtures.TestBase):
             )
 
 
+class ServerDefaultCreateTest(fixtures.TestBase):
+    @testing.combinations(
+        (Integer, text("10")),
+        (Integer, text("'10'")),
+        (Integer, "10"),
+        (Boolean, true()),
+        (Integer, text("3+5"), testing.requires.mysql_expression_defaults),
+        (Integer, text("3 + 5"), testing.requires.mysql_expression_defaults),
+        (Integer, text("(3 * 5)"), testing.requires.mysql_expression_defaults),
+        (DateTime, func.now()),
+        (
+            Integer,
+            literal_column("3") + literal_column("5"),
+            testing.requires.mysql_expression_defaults,
+        ),
+        argnames="datatype, default",
+    )
+    def test_create_server_defaults(
+        self, connection, metadata, datatype, default
+    ):
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("thecol", datatype, server_default=default),
+        )
+        t.create(connection)
+
+
 class MatchTest(fixtures.TablesTest):
     __only_on__ = "mysql", "mariadb"
     __backend__ = True
index c5b4f62e2969e67d8c90904ba080eb457685e799..104cc86e2b34fd310e48ca63287e1bb889edf05a 100644 (file)
@@ -1033,39 +1033,60 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             ")",
         )
 
-    def test_column_defaults_ddl(self):
+    @testing.combinations(
+        (
+            Boolean(create_constraint=True),
+            sql.false(),
+            "BOOLEAN DEFAULT 0, CHECK (x IN (0, 1))",
+        ),
+        (
+            String(),
+            func.sqlite_version(),
+            "VARCHAR DEFAULT (sqlite_version())",
+        ),
+        (Integer(), func.abs(-5) + 17, "INTEGER DEFAULT (abs(-5) + 17)"),
+        (
+            # test #12425
+            String(),
+            func.now(),
+            "VARCHAR DEFAULT CURRENT_TIMESTAMP",
+        ),
+        (
+            # test #12425
+            String(),
+            func.datetime(func.now(), "localtime"),
+            "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))",
+        ),
+        (
+            # test #12425
+            String(),
+            text("datetime(CURRENT_TIMESTAMP, 'localtime')"),
+            "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))",
+        ),
+        (
+            # default with leading spaces that should not be
+            # parenthesized
+            String,
+            text("  'some default'"),
+            "VARCHAR DEFAULT   'some default'",
+        ),
+        (String, text("'some default'"), "VARCHAR DEFAULT 'some default'"),
+        argnames="datatype,default,expected",
+    )
+    def test_column_defaults_ddl(self, datatype, default, expected):
         t = Table(
             "t",
             MetaData(),
             Column(
                 "x",
-                Boolean(create_constraint=True),
-                server_default=sql.false(),
+                datatype,
+                server_default=default,
             ),
         )
 
         self.assert_compile(
             CreateTable(t),
-            "CREATE TABLE t (x BOOLEAN DEFAULT (0), CHECK (x IN (0, 1)))",
-        )
-
-        t = Table(
-            "t",
-            MetaData(),
-            Column("x", String(), server_default=func.sqlite_version()),
-        )
-        self.assert_compile(
-            CreateTable(t),
-            "CREATE TABLE t (x VARCHAR DEFAULT (sqlite_version()))",
-        )
-
-        t = Table(
-            "t",
-            MetaData(),
-            Column("x", Integer(), server_default=func.abs(-5) + 17),
-        )
-        self.assert_compile(
-            CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT (abs(-5) + 17))"
+            f"CREATE TABLE t (x {expected})",
         )
 
     def test_create_partial_index(self):
index 92fadf45dac9d360e4185f42daf9d9bd3d796de2..1f4a4eb392344e9c278f01215df54bc1594cd83e 100644 (file)
@@ -1,7 +1,4 @@
-"""Requirements specific to SQLAlchemy's own unit tests.
-
-
-"""
+"""Requirements specific to SQLAlchemy's own unit tests."""
 
 from sqlalchemy import exc
 from sqlalchemy.sql import sqltypes
@@ -212,6 +209,19 @@ class DefaultRequirements(SuiteRequirements):
             ]
         )
 
+    @property
+    def server_defaults(self):
+        """Target backend supports server side defaults for columns"""
+
+        return exclusions.open()
+
+    @property
+    def expression_server_defaults(self):
+        return skip_if(
+            lambda config: against(config, "mysql", "mariadb")
+            and not self._mysql_expression_defaults(config)
+        )
+
     @property
     def qmark_paramstyle(self):
         return only_on(["sqlite", "+pyodbc"])
@@ -1814,6 +1824,15 @@ class DefaultRequirements(SuiteRequirements):
         # 2. they dont enforce check constraints
         return not self._mysql_check_constraints_exist(config)
 
+    def _mysql_expression_defaults(self, config):
+        return (against(config, ["mysql", "mariadb"])) and (
+            config.db.dialect._support_default_function
+        )
+
+    @property
+    def mysql_expression_defaults(self):
+        return only_if(self._mysql_expression_defaults)
+
     def _mysql_not_mariadb_102(self, config):
         return (against(config, ["mysql", "mariadb"])) and (
             not config.db.dialect._is_mariadb