]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
expand paren rules for default rendering, sqlite/mysql
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Mar 2025 20:25:48 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Mar 2025 19:04:01 +0000 (15:04 -0400)
Expanded the rules for when to apply parenthesis to a server default in DDL
to suit the general case of a default string that contains non-word
characters such as spaces or operators and is not a string literal.

Fixed issue in MySQL server default reflection where a default that has
spaces would not be correctly reflected.  Additionally, expanded the rules
for when to apply parenthesis to a server default in DDL to suit the
general case of a default string that contains non-word characters such as
spaces or operators and is not a string literal.

Fixes: #12425
Change-Id: Ie40703dcd5fdc135025d676c01baba57ff3b71ad
(cherry picked from commit 1afb820427545e259397b98851a910d7379b2eb8)

12 files changed:
doc/build/changelog/unreleased_20/12425.rst [new file with mode: 0644]
doc/build/orm/extensions/asyncio.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/reflection.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_query.py
test/dialect/test_sqlite.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/12425.rst b/doc/build/changelog/unreleased_20/12425.rst
new file mode 100644 (file)
index 0000000..fbc1f8a
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sqlite
+    :tickets: 12425
+
+    Expanded the rules for when to apply parenthesis to a server default in DDL
+    to suit the general case of a default string that contains non-word
+    characters such as spaces or operators and is not a string literal.
+
+.. change::
+    :tags: bug, mysql
+    :tickets: 12425
+
+    Fixed issue in MySQL server default reflection where a default that has
+    spaces would not be correctly reflected.  Additionally, expanded the rules
+    for when to apply parenthesis to a server default in DDL to suit the
+    general case of a default string that contains non-word characters such as
+    spaces or operators and is not a string literal.
+
index fbd965d15d9fbb65a292289f2e20e4f08e81d914..5b881054304f500905ca255da4e09dd843eae385 100644 (file)
@@ -281,7 +281,7 @@ configuration:
     CREATE TABLE a (
         id INTEGER NOT NULL,
         data VARCHAR NOT NULL,
-        create_date DATETIME DEFAULT (CURRENT_TIMESTAMP) NOT NULL,
+        create_date DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL,
         PRIMARY KEY (id)
     )
     ...
index 8bae6193b51861d6de2b1c28c0137757c2339d7b..122a7cb2e5e2b541ab0f40d310db441c4a7f1bd1 100644 (file)
@@ -1928,12 +1928,13 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
             colspec.append("AUTO_INCREMENT")
         else:
             default = self.get_column_default_string(column)
+
             if default is not None:
                 if (
-                    isinstance(
-                        column.server_default.arg, functions.FunctionElement
-                    )
-                    and self.dialect._support_default_function
+                    self.dialect._support_default_function
+                    and not re.match(r"^\s*[\'\"\(]", default)
+                    and "ON UPDATE" not in default
+                    and re.match(r".*\W.*", default)
                 ):
                     colspec.append(f"DEFAULT ({default})")
                 else:
index 3998be977d9ab5b89ca9ad30db5d1886ca244f73..d62390bb8457d63da4eb3e02fff96bc1e499001b 100644 (file)
@@ -451,7 +451,7 @@ class MySQLTableDefinitionParser:
             r"(?: +COLLATE +(?P<collate>[\w_]+))?"
             r"(?: +(?P<notnull>(?:NOT )?NULL))?"
             r"(?: +DEFAULT +(?P<default>"
-            r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
+            r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+"
             r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
             r"))?"
             r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
index 96b2414ccec4a20ca90ea732850054b608032067..c09fbb32ccc04832d16ee5b5b0e6491114a56aff 100644 (file)
@@ -932,7 +932,6 @@ from ...engine import processors
 from ...engine import reflection
 from ...engine.reflection import ReflectionDefaults
 from ...sql import coercions
-from ...sql import ColumnElement
 from ...sql import compiler
 from ...sql import elements
 from ...sql import roles
@@ -1594,9 +1593,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
         colspec = self.preparer.format_column(column) + " " + coltype
         default = self.get_column_default_string(column)
         if default is not None:
-            if isinstance(column.server_default.arg, ColumnElement):
-                default = "(" + default + ")"
-            colspec += " DEFAULT " + default
+
+            if not re.match(r"""^\s*[\'\"\(]""", default) and re.match(
+                r".*\W.*", default
+            ):
+                colspec += f" DEFAULT ({default})"
+            else:
+                colspec += f" DEFAULT {default}"
 
         if not column.nullable:
             colspec += " NOT NULL"
index 8364c15f8ffc8cc5f0814daf7235f1526e8e89fb..719692125fbdb14a26e3bd43865059cfdb64a19b 100644 (file)
@@ -274,8 +274,8 @@ def int_within_variance(expected, received, variance):
     )
 
 
-def eq_regex(a, b, msg=None):
-    assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+def eq_regex(a, b, msg=None, flags=0):
+    assert re.match(b, a, flags), msg or "%r !~ %r" % (a, b)
 
 
 def eq_(a, b, msg=None):
index 93541dca70efaaf4bea480dc067e7dfb2ff32065..a2c3aa531dc3b4e39c04ce901f2e67c875c6a45d 100644 (file)
@@ -1168,6 +1168,19 @@ class SuiteRequirements(Requirements):
         """
         return self.precision_numerics_many_significant_digits
 
+    @property
+    def server_defaults(self):
+        """Target backend supports server side defaults for columns"""
+
+        return exclusions.closed()
+
+    @property
+    def expression_server_defaults(self):
+        """Target backend supports server side defaults with SQL expressions
+        for columns"""
+
+        return exclusions.closed()
+
     @property
     def implicit_decimal_binds(self):
         """target backend will return a selected Decimal as a Decimal, not
index 2837e9fe0a33cff3e8aa5804a25b63a91a8da5fd..0f2a2062a8e5c52853c059f9f6242bfaca5261cc 100644 (file)
@@ -14,6 +14,7 @@ import sqlalchemy as sa
 from .. import config
 from .. import engines
 from .. import eq_
+from .. import eq_regex
 from .. import expect_raises
 from .. import expect_raises_message
 from .. import expect_warnings
@@ -23,6 +24,8 @@ from ..provision import get_temp_table_name
 from ..provision import temp_table_keyword_args
 from ..schema import Column
 from ..schema import Table
+from ... import Boolean
+from ... import DateTime
 from ... import event
 from ... import ForeignKey
 from ... import func
@@ -2883,6 +2886,47 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase):
         eq_(opts, expected)
         # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
 
+    @testing.combinations(
+        (Integer, sa.text("10"), r"'?10'?"),
+        (Integer, "10", r"'?10'?"),
+        (Boolean, sa.true(), r"1|true"),
+        (
+            Integer,
+            sa.text("3 + 5"),
+            r"3\+5",
+            testing.requires.expression_server_defaults,
+        ),
+        (
+            Integer,
+            sa.text("(3 * 5)"),
+            r"3\*5",
+            testing.requires.expression_server_defaults,
+        ),
+        (DateTime, func.now(), r"current_timestamp|now|getdate"),
+        (
+            Integer,
+            sa.literal_column("3") + sa.literal_column("5"),
+            r"3\+5",
+            testing.requires.expression_server_defaults,
+        ),
+        argnames="datatype, default, expected_reg",
+    )
+    @testing.requires.server_defaults
+    def test_server_defaults(
+        self, metadata, connection, datatype, default, expected_reg
+    ):
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("thecol", datatype, server_default=default),
+        )
+        t.create(connection)
+
+        reflected = inspect(connection).get_columns("t")[1]["default"]
+        reflected_sanitized = re.sub(r"[\(\) \']", "", reflected)
+        eq_regex(reflected_sanitized, expected_reg, flags=re.IGNORECASE)
+
 
 class NormalizedNameTest(fixtures.TablesTest):
     __requires__ = ("denormalized_names",)
index 8387d4e07c67ef7bf551642f789469ec62cfeb1f..f9cfeba05b8be13460ee2a53d3423ce6d88e0c80 100644 (file)
@@ -446,7 +446,7 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             self.assert_compile(
                 schema.CreateTable(tbl),
                 "CREATE TABLE testtbl ("
-                "time DATETIME DEFAULT (CURRENT_TIMESTAMP), "
+                "time DATETIME DEFAULT CURRENT_TIMESTAMP, "
                 "name VARCHAR(255) DEFAULT 'some str', "
                 "description VARCHAR(255) DEFAULT (lower('hi')), "
                 "data JSON DEFAULT (json_object()))",
index 9cbc38378fbfe23c660086e04635c080dfcc9e6d..96650dab564628b7a24f67260c816573ddec368b 100644 (file)
@@ -5,16 +5,21 @@ from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import Computed
+from sqlalchemy import DateTime
 from sqlalchemy import exc
 from sqlalchemy import false
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import schema
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import combinations
@@ -50,6 +55,35 @@ class IdiosyncrasyTest(fixtures.TestBase):
             )
 
 
+class ServerDefaultCreateTest(fixtures.TestBase):
+    @testing.combinations(
+        (Integer, text("10")),
+        (Integer, text("'10'")),
+        (Integer, "10"),
+        (Boolean, true()),
+        (Integer, text("3+5"), testing.requires.mysql_expression_defaults),
+        (Integer, text("3 + 5"), testing.requires.mysql_expression_defaults),
+        (Integer, text("(3 * 5)"), testing.requires.mysql_expression_defaults),
+        (DateTime, func.now()),
+        (
+            Integer,
+            literal_column("3") + literal_column("5"),
+            testing.requires.mysql_expression_defaults,
+        ),
+        argnames="datatype, default",
+    )
+    def test_create_server_defaults(
+        self, connection, metadata, datatype, default
+    ):
+        t = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("thecol", datatype, server_default=default),
+        )
+        t.create(connection)
+
+
 class MatchTest(fixtures.TablesTest):
     __only_on__ = "mysql", "mariadb"
     __backend__ = True
index 819bf8aa06b75dd02d2ff48dfdd8ce2564782850..c2c63e9ef06fdb0a7dc51d23ca1892c4c7d8c19e 100644 (file)
@@ -1032,39 +1032,60 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
             ")",
         )
 
-    def test_column_defaults_ddl(self):
+    @testing.combinations(
+        (
+            Boolean(create_constraint=True),
+            sql.false(),
+            "BOOLEAN DEFAULT 0, CHECK (x IN (0, 1))",
+        ),
+        (
+            String(),
+            func.sqlite_version(),
+            "VARCHAR DEFAULT (sqlite_version())",
+        ),
+        (Integer(), func.abs(-5) + 17, "INTEGER DEFAULT (abs(-5) + 17)"),
+        (
+            # test #12425
+            String(),
+            func.now(),
+            "VARCHAR DEFAULT CURRENT_TIMESTAMP",
+        ),
+        (
+            # test #12425
+            String(),
+            func.datetime(func.now(), "localtime"),
+            "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))",
+        ),
+        (
+            # test #12425
+            String(),
+            text("datetime(CURRENT_TIMESTAMP, 'localtime')"),
+            "VARCHAR DEFAULT (datetime(CURRENT_TIMESTAMP, 'localtime'))",
+        ),
+        (
+            # default with leading spaces that should not be
+            # parenthesized
+            String,
+            text("  'some default'"),
+            "VARCHAR DEFAULT   'some default'",
+        ),
+        (String, text("'some default'"), "VARCHAR DEFAULT 'some default'"),
+        argnames="datatype,default,expected",
+    )
+    def test_column_defaults_ddl(self, datatype, default, expected):
         t = Table(
             "t",
             MetaData(),
             Column(
                 "x",
-                Boolean(create_constraint=True),
-                server_default=sql.false(),
+                datatype,
+                server_default=default,
             ),
         )
 
         self.assert_compile(
             CreateTable(t),
-            "CREATE TABLE t (x BOOLEAN DEFAULT (0), CHECK (x IN (0, 1)))",
-        )
-
-        t = Table(
-            "t",
-            MetaData(),
-            Column("x", String(), server_default=func.sqlite_version()),
-        )
-        self.assert_compile(
-            CreateTable(t),
-            "CREATE TABLE t (x VARCHAR DEFAULT (sqlite_version()))",
-        )
-
-        t = Table(
-            "t",
-            MetaData(),
-            Column("x", Integer(), server_default=func.abs(-5) + 17),
-        )
-        self.assert_compile(
-            CreateTable(t), "CREATE TABLE t (x INTEGER DEFAULT (abs(-5) + 17))"
+            f"CREATE TABLE t (x {expected})",
         )
 
     def test_create_partial_index(self):
index 12c25ece1aab0f7b260bdbabd6dd9fef2de4b79d..2311f6e35fcdf8d822565454934ffb08f73c4ddf 100644 (file)
@@ -1,7 +1,4 @@
-"""Requirements specific to SQLAlchemy's own unit tests.
-
-
-"""
+"""Requirements specific to SQLAlchemy's own unit tests."""
 
 from sqlalchemy import exc
 from sqlalchemy.sql import sqltypes
@@ -212,6 +209,19 @@ class DefaultRequirements(SuiteRequirements):
             ]
         )
 
+    @property
+    def server_defaults(self):
+        """Target backend supports server side defaults for columns"""
+
+        return exclusions.open()
+
+    @property
+    def expression_server_defaults(self):
+        return skip_if(
+            lambda config: against(config, "mysql", "mariadb")
+            and not self._mysql_expression_defaults(config)
+        )
+
     @property
     def qmark_paramstyle(self):
         return only_on(["sqlite", "+pyodbc"])
@@ -1818,6 +1828,15 @@ class DefaultRequirements(SuiteRequirements):
         # 2. they dont enforce check constraints
         return not self._mysql_check_constraints_exist(config)
 
+    def _mysql_expression_defaults(self, config):
+        return (against(config, ["mysql", "mariadb"])) and (
+            config.db.dialect._support_default_function
+        )
+
+    @property
+    def mysql_expression_defaults(self):
+        return only_if(self._mysql_expression_defaults)
+
     def _mysql_not_mariadb_102(self, config):
         return (against(config, ["mysql", "mariadb"])) and (
             not config.db.dialect._is_mariadb