]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Ensure SQLite default expressions are parenthesized
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Jun 2019 15:40:56 +0000 (11:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 25 Jun 2019 20:31:43 +0000 (16:31 -0400)
- SQLite server default reflection will ensure parenthesis are surrounding a
column default expression that is detected as being a non-constant
expression, such as a ``datetime()`` default, to accommodate for the
requirement that SQL expressions have to be parenthesized when being sent
as DDL.  Parenthesis are not added to constant expressions to allow for
maximum cross-compatibility with other dialects and existing test suites
(such as Alembic's), which necessarily entails scanning the expression to
eliminate for constant numeric and string values. The logic is added to the
two "reflection->DDL round trip" paths which are currently autogenerate and
batch migration.  Within autogenerate, the logic is on the rendering side,
whereas in batch the logic is installed as a column reflection hook.

- Improved SQLite server default comparison to accommodate for a ``text()``
construct that added parenthesis directly vs. a construct that relied
upon the SQLAlchemy SQLite dialect to render the parenthesis, as well
as improved support for various forms of constant expressions such as
values that are quoted vs. non-quoted.

- Fixed bug where the "literal_binds" flag was not being set when
autogenerate would create a server default value, meaning server default
comparisons would fail for functions that contained literal values.

Fixes: #579
Change-Id: I78b87573b8ecd15cb4ced08f054902f574e3956c

alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/impl.py
alembic/ddl/sqlite.py
alembic/operations/batch.py
alembic/testing/fixtures.py
alembic/testing/requirements.py
docs/build/unreleased/579.rst [new file with mode: 0644]
tests/test_batch.py
tests/test_postgresql.py
tests/test_sqlite.py

index df1e5096d3a3addf67ba1cd4b952cf5d71e22072..37371eb6ae041651a1681396576b58d078dd7e8c 100644 (file)
@@ -873,7 +873,10 @@ def _render_server_default_for_compare(
             metadata_default = metadata_default.arg
         else:
             metadata_default = str(
-                metadata_default.arg.compile(dialect=autogen_context.dialect)
+                metadata_default.arg.compile(
+                    dialect=autogen_context.dialect,
+                    compile_kwargs={"literal_binds": True},
+                )
             )
     if isinstance(metadata_default, compat.string_types):
         if metadata_col.type._type_affinity is sqltypes.String:
index 21bbb50322629c871f02734e6dab4e5386506724..0e3b2d872f09099b534bba6b616dacf3018ebb29 100644 (file)
@@ -492,11 +492,10 @@ def _ident(name):
         return name
 
 
-def _render_potential_expr(value, autogen_context, wrap_in_text=True):
+def _render_potential_expr(
+    value, autogen_context, wrap_in_text=True, is_server_default=False
+):
     if isinstance(value, sql.ClauseElement):
-        compile_kw = dict(
-            compile_kwargs={"literal_binds": True, "include_table": False}
-        )
 
         if wrap_in_text:
             template = "%(prefix)stext(%(sql)r)"
@@ -505,8 +504,8 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True):
 
         return template % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
-            "sql": compat.text_type(
-                value.compile(dialect=autogen_context.dialect, **compile_kw)
+            "sql": autogen_context.migration_context.impl.render_ddl_sql_expr(
+                value, is_server_default=is_server_default
             ),
         }
 
@@ -634,7 +633,9 @@ def _render_server_default(default, autogen_context, repr_=True):
         if isinstance(default.arg, compat.string_types):
             default = default.arg
         else:
-            return _render_potential_expr(default.arg, autogen_context)
+            return _render_potential_expr(
+                default.arg, autogen_context, is_server_default=True
+            )
 
     if isinstance(default, string_types) and repr_:
         default = repr(re.sub(r"^'|'$", "", default))
index 0843ebf1d73d62f5a8680214365516b7d23f372d..17209c988bc23a4dee1434fbd7962c15386b5f3c 100644 (file)
@@ -367,6 +367,19 @@ class DefaultImpl(with_metaclass(ImplMeta)):
     ):
         pass
 
+    def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+        """Render a SQL expression that is typically a server default,
+        index expression, etc.
+
+        .. versionadded:: 1.0.11
+
+        """
+
+        compile_kw = dict(
+            compile_kwargs={"literal_binds": True, "include_table": False}
+        )
+        return text_type(expr.compile(dialect=self.dialect, **compile_kw))
+
     def _compat_autogen_column_reflect(self, inspector):
         return self.autogen_column_reflect
 
index c0385e17a3af239e1a0afa709e43342084580e96..95f814a0efd0bb15e0e91ed0a2af987dd30cbc7c 100644 (file)
@@ -56,11 +56,16 @@ class SQLiteImpl(DefaultImpl):
 
         if rendered_metadata_default is not None:
             rendered_metadata_default = re.sub(
-                r"^\"'|\"'$", "", rendered_metadata_default
+                r"^\((.+)\)$", r"\1", rendered_metadata_default
             )
+
+            rendered_metadata_default = re.sub(
+                r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
+            )
+
         if rendered_inspector_default is not None:
             rendered_inspector_default = re.sub(
-                r"^\"'|\"'$", "", rendered_inspector_default
+                r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
             )
 
         return rendered_inspector_default != rendered_metadata_default
@@ -91,6 +96,47 @@ class SQLiteImpl(DefaultImpl):
             if idx.name is None and uq_sig(idx) not in conn_unique_sigs:
                 metadata_unique_constraints.remove(idx)
 
+    def _guess_if_default_is_unparenthesized_sql_expr(self, expr):
+        """Determine if a server default is a SQL expression or a constant.
+
+        There are too many assertions that expect server defaults to round-trip
+        identically without parenthesis added so we will add parens only in
+        very specific cases.
+
+        """
+        if not expr:
+            return False
+        elif re.match(r"^[0-9\.]$", expr):
+            return False
+        elif re.match(r"^'.+'$", expr):
+            return False
+        elif re.match(r"^\(.+\)$", expr):
+            return False
+        else:
+            return True
+
+    def autogen_column_reflect(self, inspector, table, column_info):
+        # SQLite expression defaults require parenthesis when sent
+        # as DDL
+        if self._guess_if_default_is_unparenthesized_sql_expr(
+            column_info.get("default", None)
+        ):
+            column_info["default"] = "(%s)" % (column_info["default"],)
+
+    def render_ddl_sql_expr(self, expr, is_server_default=False, **kw):
+        # SQLite expression defaults require parenthesis when sent
+        # as DDL
+        str_expr = super(SQLiteImpl, self).render_ddl_sql_expr(
+            expr, is_server_default=is_server_default, **kw
+        )
+
+        if (
+            is_server_default
+            and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
+        ):
+            str_expr = "(%s)" % (str_expr,)
+        return str_expr
+
 
 # @compiles(AddColumn, 'sqlite')
 # def visit_add_column(element, compiler, **kw):
index 9e829b36258a37845baf76f2baa5ea2ee83d1c27..42db905e66bf0d47059253bd68e5ee2396d9a2d5 100644 (file)
@@ -44,7 +44,13 @@ class BatchOperationsImpl(object):
         self.table_args = table_args
         self.table_kwargs = dict(table_kwargs)
         self.reflect_args = reflect_args
-        self.reflect_kwargs = reflect_kwargs
+        self.reflect_kwargs = dict(reflect_kwargs)
+        self.reflect_kwargs.setdefault(
+            "listeners", list(self.reflect_kwargs.get("listeners", ()))
+        )
+        self.reflect_kwargs["listeners"].append(
+            ("column_reflect", operations.impl.autogen_column_reflect)
+        )
         self.naming_convention = naming_convention
         self.batch = []
 
index 59be5712a1ef6a3448e9ac7d68f6ad64cc2ea9e7..624346724475e6f3b98b77d3764ca72b8e487171 100644 (file)
@@ -194,7 +194,7 @@ class AlterColRoundTripFixture(object):
     # the type / server default compare logic might not work on older
     # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
 
-    __requires__ = ("alter_column", "sqlachemy_12")
+    __requires__ = ("alter_column", "sqlalchemy_12")
 
     def setUp(self):
         self.conn = config.db.connect()
index 55054d9209bc4701cfaa3ffa1f5a5b97cea5a884..cf570a5cae8ccaab272cb42bd0989702bc14d45a 100644 (file)
@@ -62,12 +62,19 @@ class SuiteRequirements(Requirements):
         return exclusions.closed()
 
     @property
-    def sqlachemy_12(self):
+    def sqlalchemy_12(self):
         return exclusions.skip_if(
             lambda config: not util.sqla_1216,
             "SQLAlchemy 1.2.16 or greater required",
         )
 
+    @property
+    def sqlalchemy_10(self):
+        return exclusions.skip_if(
+            lambda config: not util.sqla_100,
+            "SQLAlchemy 1.0.0 or greater required",
+        )
+
     @property
     def fail_before_sqla_100(self):
         return exclusions.fails_if(
diff --git a/docs/build/unreleased/579.rst b/docs/build/unreleased/579.rst
new file mode 100644 (file)
index 0000000..ba8f7d6
--- /dev/null
@@ -0,0 +1,34 @@
+.. change::
+    :tags: bug, sqlite, autogenerate, batch
+    :tickets: 579
+
+    SQLite server default reflection will ensure parenthesis are surrounding a
+    column default expression that is detected as being a non-constant
+    expression, such as a ``datetime()`` default, to accommodate for the
+    requirement that SQL expressions have to be parenthesized when being sent
+    as DDL.  Parenthesis are not added to constant expressions to allow for
+    maximum cross-compatibility with other dialects and existing test suites
+    (such as Alembic's), which necessarily entails scanning the expression to
+    eliminate for constant numeric and string values. The logic is added to the
+    two "reflection->DDL round trip" paths which are currently autogenerate and
+    batch migration.  Within autogenerate, the logic is on the rendering side,
+    whereas in batch the logic is installed as a column reflection hook.
+
+
+.. change::
+    :tags: bug, sqlite, autogenerate
+    :tickets: 579
+
+    Improved SQLite server default comparison to accommodate for a ``text()``
+    construct that added parenthesis directly vs. a construct that relied
+    upon the SQLAlchemy SQLite dialect to render the parenthesis, as well
+    as improved support for various forms of constant expressions such as
+    values that are quoted vs. non-quoted.
+
+
+.. change::
+    :tags: bug, autogenerate
+
+    Fixed bug where the "literal_binds" flag was not being set when
+    autogenerate would create a server default value, meaning server default
+    comparisons would fail for functions that contained literal values.
\ No newline at end of file
index 8879c9ca5cd321f0ada1d5fb51f5ced7790cc70d..c8c5b33d6c4ac8ddbde0fb076ebcc8f5cbc196bc 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import func
 from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
@@ -1117,6 +1118,23 @@ class BatchRoundTripTest(TestBase):
         t.create(self.conn)
         return t
 
+    def _datetime_server_default_fixture(self):
+        return func.datetime("now", "localtime")
+
+    def _timestamp_w_expr_default_fixture(self):
+        t = Table(
+            "hasts",
+            self.metadata,
+            Column(
+                "x",
+                DateTime(),
+                server_default=self._datetime_server_default_fixture(),
+                nullable=False,
+            ),
+        )
+        t.create(self.conn)
+        return t
+
     def _int_to_boolean_fixture(self):
         t = Table("hasbool", self.metadata, Column("x", Integer))
         t.create(self.conn)
@@ -1157,6 +1175,23 @@ class BatchRoundTripTest(TestBase):
             [(datetime.datetime(2012, 5, 18, 15, 32, 5),)],
         )
 
+    @config.requirements.sqlalchemy_12
+    def test_no_net_change_timestamp_w_default(self):
+        t = self._timestamp_w_expr_default_fixture()
+
+        with self.op.batch_alter_table("hasts") as batch_op:
+            batch_op.alter_column(
+                "x",
+                type_=DateTime(),
+                nullable=False,
+                server_default=self._datetime_server_default_fixture(),
+            )
+
+        self.conn.execute(t.insert())
+
+        row = self.conn.execute(select([t.c.x])).fetchone()
+        assert row["x"] is not None
+
     def test_drop_col_schematype(self):
         self._boolean_fixture()
         with self.op.batch_alter_table("hasbool") as batch_op:
@@ -1612,6 +1647,9 @@ class BatchRoundTripMySQLTest(BatchRoundTripTest):
     __only_on__ = "mysql"
     __backend__ = True
 
+    def _datetime_server_default_fixture(self):
+        return func.current_timestamp()
+
     @exclusions.fails()
     def test_drop_pk_col_readd_pk_col(self):
         super(BatchRoundTripMySQLTest, self).test_drop_pk_col_readd_pk_col()
@@ -1655,6 +1693,9 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
     __only_on__ = "postgresql"
     __backend__ = True
 
+    def _datetime_server_default_fixture(self):
+        return func.current_timestamp()
+
     @exclusions.fails()
     def test_drop_pk_col_readd_pk_col(self):
         super(
index 9a0e8f3aa93bb8ce3f38adfa7041a64074432318..6c9838a7c4a495f8fc5ee97fe0d6845317df8f76 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import Float
+from sqlalchemy import func
 from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import Interval
@@ -507,10 +508,10 @@ class PostgresqlDefaultCompareTest(TestBase):
         )
 
     def test_compare_string_blank_default(self):
-        self._compare_default_roundtrip(String(8), '')
+        self._compare_default_roundtrip(String(8), "")
 
     def test_compare_string_nonblank_default(self):
-        self._compare_default_roundtrip(String(8), 'hi')
+        self._compare_default_roundtrip(String(8), "hi")
 
     def test_compare_interval_str(self):
         # this form shouldn't be used but testing here
@@ -534,6 +535,12 @@ class PostgresqlDefaultCompareTest(TestBase):
             DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)")
         )
 
+    @config.requirements.sqlalchemy_10
+    def test_compare_current_timestamp_fn_w_binds(self):
+        self._compare_default_roundtrip(
+            DateTime(), func.timezone("utc", func.current_timestamp())
+        )
+
     def test_compare_integer_str(self):
         self._compare_default_roundtrip(Integer(), "5")
 
index dd81a13a343828241cce1afa1566fc3f3bdd5d90..7616ef20c77c5782253b7a005c1ad4542d92717d 100644 (file)
@@ -1,11 +1,28 @@
 from sqlalchemy import Boolean
 from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import Float
+from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import text
+from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.sql import column
 
+from alembic import autogenerate
 from alembic import op
+from alembic.autogenerate import api
+from alembic.autogenerate.compare import _compare_server_default
+from alembic.migration import MigrationContext
+from alembic.operations import ops
 from alembic.testing import assert_raises_message
 from alembic.testing import config
+from alembic.testing import eq_
+from alembic.testing import eq_ignore_whitespace
+from alembic.testing.env import clear_staging_env
+from alembic.testing.env import staging_env
 from alembic.testing.fixtures import op_fixture
 from alembic.testing.fixtures import TestBase
 
@@ -63,3 +80,179 @@ class SQLiteTest(TestBase):
         context = op_fixture("sqlite")
         op.add_column("t1", Column("c1", Integer, comment="c1 comment"))
         context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER")
+
+
+class SQLiteDefaultCompareTest(TestBase):
+    __only_on__ = "sqlite"
+    __backend__ = True
+
+    @classmethod
+    def setup_class(cls):
+        cls.bind = config.db
+        staging_env()
+        cls.migration_context = MigrationContext.configure(
+            connection=cls.bind.connect(),
+            opts={"compare_type": True, "compare_server_default": True},
+        )
+
+    def setUp(self):
+        self.metadata = MetaData(self.bind)
+        self.autogen_context = api.AutogenContext(self.migration_context)
+
+    @classmethod
+    def teardown_class(cls):
+        clear_staging_env()
+
+    def tearDown(self):
+        self.metadata.drop_all()
+
+    def _compare_default_roundtrip(
+        self, type_, orig_default, alternate=None, diff_expected=None
+    ):
+        diff_expected = (
+            diff_expected
+            if diff_expected is not None
+            else alternate is not None
+        )
+        if alternate is None:
+            alternate = orig_default
+
+        t1 = Table(
+            "test",
+            self.metadata,
+            Column("somecol", type_, server_default=orig_default),
+        )
+        t2 = Table(
+            "test",
+            MetaData(),
+            Column("somecol", type_, server_default=alternate),
+        )
+
+        t1.create(self.bind)
+
+        insp = Inspector.from_engine(self.bind)
+        cols = insp.get_columns(t1.name)
+        insp_col = Column(
+            "somecol", cols[0]["type"], server_default=text(cols[0]["default"])
+        )
+        op = ops.AlterColumnOp("test", "somecol")
+        _compare_server_default(
+            self.autogen_context,
+            op,
+            None,
+            "test",
+            "somecol",
+            insp_col,
+            t2.c.somecol,
+        )
+
+        diffs = op.to_diff_tuple()
+        eq_(bool(diffs), diff_expected)
+
+    def _compare_default(self, t1, t2, col, rendered):
+        t1.create(self.bind, checkfirst=True)
+        insp = Inspector.from_engine(self.bind)
+        cols = insp.get_columns(t1.name)
+        ctx = self.autogen_context.migration_context
+
+        return ctx.impl.compare_server_default(
+            None, col, rendered, cols[0]["default"]
+        )
+
+    @config.requirements.sqlalchemy_12
+    def test_compare_current_timestamp_func(self):
+        self._compare_default_roundtrip(
+            DateTime(), func.datetime("now", "localtime")
+        )
+
+    def test_compare_current_timestamp_text(self):
+        # SQLAlchemy doesn't render the parenthesis for a
+        # SQLite server default specified as text(), so users will be doing
+        # this; sqlite comparison needs to accommodate for these.
+        self._compare_default_roundtrip(
+            DateTime(), text("(datetime('now', 'localtime'))")
+        )
+
+    def test_compare_integer_str(self):
+        self._compare_default_roundtrip(Integer(), "5")
+
+    def test_compare_integer_str_diff(self):
+        self._compare_default_roundtrip(Integer(), "5", "7")
+
+    def test_compare_integer_text(self):
+        self._compare_default_roundtrip(Integer(), text("5"))
+
+    def test_compare_integer_text_diff(self):
+        self._compare_default_roundtrip(Integer(), text("5"), "7")
+
+    def test_compare_float_str(self):
+        self._compare_default_roundtrip(Float(), "5.2")
+
+    def test_compare_float_str_diff(self):
+        self._compare_default_roundtrip(Float(), "5.2", "5.3")
+
+    def test_compare_float_text(self):
+        self._compare_default_roundtrip(Float(), text("5.2"))
+
+    def test_compare_float_text_diff(self):
+        self._compare_default_roundtrip(Float(), text("5.2"), "5.3")
+
+    def test_compare_string_literal(self):
+        self._compare_default_roundtrip(String(), "im a default")
+
+    def test_compare_string_literal_diff(self):
+        self._compare_default_roundtrip(String(), "im a default", "me too")
+
+
+class SQLiteAutogenRenderTest(TestBase):
+    def setUp(self):
+        ctx_opts = {
+            "sqlalchemy_module_prefix": "sa.",
+            "alembic_module_prefix": "op.",
+            "target_metadata": MetaData(),
+        }
+        context = MigrationContext.configure(
+            dialect_name="sqlite", opts=ctx_opts
+        )
+
+        self.autogen_context = api.AutogenContext(context)
+
+    def test_render_server_default_expr_needs_parens(self):
+        c = Column(
+            "date_value",
+            DateTime(),
+            server_default=func.datetime("now", "localtime"),
+        )
+
+        result = autogenerate.render._render_column(c, self.autogen_context)
+        eq_ignore_whitespace(
+            result,
+            "sa.Column('date_value', sa.DateTime(), "
+            "server_default=sa.text(!U\"(datetime('now', 'localtime'))\"), "
+            "nullable=True)",
+        )
+
+    def test_render_server_default_text_expr_needs_parens(self):
+        c = Column(
+            "date_value",
+            DateTime(),
+            server_default=text("(datetime('now', 'localtime'))"),
+        )
+
+        result = autogenerate.render._render_column(c, self.autogen_context)
+        eq_ignore_whitespace(
+            result,
+            "sa.Column('date_value', sa.DateTime(), "
+            "server_default=sa.text(!U\"(datetime('now', 'localtime'))\"), "
+            "nullable=True)",
+        )
+
+    def test_render_server_default_const(self):
+        c = Column("int_value", Integer, server_default="5")
+
+        result = autogenerate.render._render_column(c, self.autogen_context)
+        eq_ignore_whitespace(
+            result,
+            "sa.Column('int_value', sa.Integer(), server_default='5', "
+            "nullable=True)",
+        )