From: Mike Bayer Date: Tue, 25 Jun 2019 15:40:56 +0000 (-0400) Subject: Ensure SQLite default expressions are parenthesized X-Git-Tag: rel_1_0_11~1^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d4951ffb66f933ca332baf85faa15c7547523ecb;p=thirdparty%2Fsqlalchemy%2Falembic.git Ensure SQLite default expressions are parenthesized - SQLite server default reflection will ensure parenthesis are surrounding a column default expression that is detected as being a non-constant expression, such as a ``datetime()`` default, to accommodate for the requirement that SQL expressions have to be parenthesized when being sent as DDL. Parenthesis are not added to constant expressions to allow for maximum cross-compatibility with other dialects and existing test suites (such as Alembic's), which necessarily entails scanning the expression to eliminate for constant numeric and string values. The logic is added to the two "reflection->DDL round trip" paths which are currently autogenerate and batch migration. Within autogenerate, the logic is on the rendering side, whereas in batch the logic is installed as a column reflection hook. - Improved SQLite server default comparison to accommodate for a ``text()`` construct that added parenthesis directly vs. a construct that relied upon the SQLAlchemy SQLite dialect to render the parenthesis, as well as improved support for various forms of constant expressions such as values that are quoted vs. non-quoted. - Fixed bug where the "literal_binds" flag was not being set when autogenerate would create a server default value, meaning server default comparisons would fail for functions that contained literal values. Fixes: #579 Change-Id: I78b87573b8ecd15cb4ced08f054902f574e3956c --- diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index df1e5096..37371eb6 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -873,7 +873,10 @@ def _render_server_default_for_compare( metadata_default = metadata_default.arg else: metadata_default = str( - metadata_default.arg.compile(dialect=autogen_context.dialect) + metadata_default.arg.compile( + dialect=autogen_context.dialect, + compile_kwargs={"literal_binds": True}, + ) ) if isinstance(metadata_default, compat.string_types): if metadata_col.type._type_affinity is sqltypes.String: diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 21bbb503..0e3b2d87 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -492,11 +492,10 @@ def _ident(name): return name -def _render_potential_expr(value, autogen_context, wrap_in_text=True): +def _render_potential_expr( + value, autogen_context, wrap_in_text=True, is_server_default=False +): if isinstance(value, sql.ClauseElement): - compile_kw = dict( - compile_kwargs={"literal_binds": True, "include_table": False} - ) if wrap_in_text: template = "%(prefix)stext(%(sql)r)" @@ -505,8 +504,8 @@ def _render_potential_expr(value, autogen_context, wrap_in_text=True): return template % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "sql": compat.text_type( - value.compile(dialect=autogen_context.dialect, **compile_kw) + "sql": autogen_context.migration_context.impl.render_ddl_sql_expr( + value, is_server_default=is_server_default ), } @@ -634,7 +633,9 @@ def _render_server_default(default, autogen_context, repr_=True): if isinstance(default.arg, compat.string_types): default = default.arg else: - return _render_potential_expr(default.arg, autogen_context) + return _render_potential_expr( + default.arg, autogen_context, is_server_default=True + ) if isinstance(default, string_types) and repr_: default = repr(re.sub(r"^'|'$", "", default)) diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 0843ebf1..17209c98 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -367,6 +367,19 @@ class DefaultImpl(with_metaclass(ImplMeta)): ): pass + def render_ddl_sql_expr(self, expr, is_server_default=False, **kw): + """Render a SQL expression that is typically a server default, + index expression, etc. + + .. versionadded:: 1.0.11 + + """ + + compile_kw = dict( + compile_kwargs={"literal_binds": True, "include_table": False} + ) + return text_type(expr.compile(dialect=self.dialect, **compile_kw)) + def _compat_autogen_column_reflect(self, inspector): return self.autogen_column_reflect diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index c0385e17..95f814a0 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -56,11 +56,16 @@ class SQLiteImpl(DefaultImpl): if rendered_metadata_default is not None: rendered_metadata_default = re.sub( - r"^\"'|\"'$", "", rendered_metadata_default + r"^\((.+)\)$", r"\1", rendered_metadata_default ) + + rendered_metadata_default = re.sub( + r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default + ) + if rendered_inspector_default is not None: rendered_inspector_default = re.sub( - r"^\"'|\"'$", "", rendered_inspector_default + r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default ) return rendered_inspector_default != rendered_metadata_default @@ -91,6 +96,47 @@ class SQLiteImpl(DefaultImpl): if idx.name is None and uq_sig(idx) not in conn_unique_sigs: metadata_unique_constraints.remove(idx) + def _guess_if_default_is_unparenthesized_sql_expr(self, expr): + """Determine if a server default is a SQL expression or a constant. + + There are too many assertions that expect server defaults to round-trip + identically without parenthesis added so we will add parens only in + very specific cases. + + """ + if not expr: + return False + elif re.match(r"^[0-9\.]$", expr): + return False + elif re.match(r"^'.+'$", expr): + return False + elif re.match(r"^\(.+\)$", expr): + return False + else: + return True + + def autogen_column_reflect(self, inspector, table, column_info): + # SQLite expression defaults require parenthesis when sent + # as DDL + if self._guess_if_default_is_unparenthesized_sql_expr( + column_info.get("default", None) + ): + column_info["default"] = "(%s)" % (column_info["default"],) + + def render_ddl_sql_expr(self, expr, is_server_default=False, **kw): + # SQLite expression defaults require parenthesis when sent + # as DDL + str_expr = super(SQLiteImpl, self).render_ddl_sql_expr( + expr, is_server_default=is_server_default, **kw + ) + + if ( + is_server_default + and self._guess_if_default_is_unparenthesized_sql_expr(str_expr) + ): + str_expr = "(%s)" % (str_expr,) + return str_expr + # @compiles(AddColumn, 'sqlite') # def visit_add_column(element, compiler, **kw): diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 9e829b36..42db905e 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -44,7 +44,13 @@ class BatchOperationsImpl(object): self.table_args = table_args self.table_kwargs = dict(table_kwargs) self.reflect_args = reflect_args - self.reflect_kwargs = reflect_kwargs + self.reflect_kwargs = dict(reflect_kwargs) + self.reflect_kwargs.setdefault( + "listeners", list(self.reflect_kwargs.get("listeners", ())) + ) + self.reflect_kwargs["listeners"].append( + ("column_reflect", operations.impl.autogen_column_reflect) + ) self.naming_convention = naming_convention self.batch = [] diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 59be5712..62434672 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -194,7 +194,7 @@ class AlterColRoundTripFixture(object): # the type / server default compare logic might not work on older # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle - __requires__ = ("alter_column", "sqlachemy_12") + __requires__ = ("alter_column", "sqlalchemy_12") def setUp(self): self.conn = config.db.connect() diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 55054d92..cf570a5c 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -62,12 +62,19 @@ class SuiteRequirements(Requirements): return exclusions.closed() @property - def sqlachemy_12(self): + def sqlalchemy_12(self): return exclusions.skip_if( lambda config: not util.sqla_1216, "SQLAlchemy 1.2.16 or greater required", ) + @property + def sqlalchemy_10(self): + return exclusions.skip_if( + lambda config: not util.sqla_100, + "SQLAlchemy 1.0.0 or greater required", + ) + @property def fail_before_sqla_100(self): return exclusions.fails_if( diff --git a/docs/build/unreleased/579.rst b/docs/build/unreleased/579.rst new file mode 100644 index 00000000..ba8f7d68 --- /dev/null +++ b/docs/build/unreleased/579.rst @@ -0,0 +1,34 @@ +.. change:: + :tags: bug, sqlite, autogenerate, batch + :tickets: 579 + + SQLite server default reflection will ensure parenthesis are surrounding a + column default expression that is detected as being a non-constant + expression, such as a ``datetime()`` default, to accommodate for the + requirement that SQL expressions have to be parenthesized when being sent + as DDL. Parenthesis are not added to constant expressions to allow for + maximum cross-compatibility with other dialects and existing test suites + (such as Alembic's), which necessarily entails scanning the expression to + eliminate for constant numeric and string values. The logic is added to the + two "reflection->DDL round trip" paths which are currently autogenerate and + batch migration. Within autogenerate, the logic is on the rendering side, + whereas in batch the logic is installed as a column reflection hook. + + +.. change:: + :tags: bug, sqlite, autogenerate + :tickets: 579 + + Improved SQLite server default comparison to accommodate for a ``text()`` + construct that added parenthesis directly vs. a construct that relied + upon the SQLAlchemy SQLite dialect to render the parenthesis, as well + as improved support for various forms of constant expressions such as + values that are quoted vs. non-quoted. + + +.. change:: + :tags: bug, autogenerate + + Fixed bug where the "literal_binds" flag was not being set when + autogenerate would create a server default value, meaning server default + comparisons would fail for functions that contained literal values. \ No newline at end of file diff --git a/tests/test_batch.py b/tests/test_batch.py index 8879c9ca..c8c5b33d 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -9,6 +9,7 @@ from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import MetaData @@ -1117,6 +1118,23 @@ class BatchRoundTripTest(TestBase): t.create(self.conn) return t + def _datetime_server_default_fixture(self): + return func.datetime("now", "localtime") + + def _timestamp_w_expr_default_fixture(self): + t = Table( + "hasts", + self.metadata, + Column( + "x", + DateTime(), + server_default=self._datetime_server_default_fixture(), + nullable=False, + ), + ) + t.create(self.conn) + return t + def _int_to_boolean_fixture(self): t = Table("hasbool", self.metadata, Column("x", Integer)) t.create(self.conn) @@ -1157,6 +1175,23 @@ class BatchRoundTripTest(TestBase): [(datetime.datetime(2012, 5, 18, 15, 32, 5),)], ) + @config.requirements.sqlalchemy_12 + def test_no_net_change_timestamp_w_default(self): + t = self._timestamp_w_expr_default_fixture() + + with self.op.batch_alter_table("hasts") as batch_op: + batch_op.alter_column( + "x", + type_=DateTime(), + nullable=False, + server_default=self._datetime_server_default_fixture(), + ) + + self.conn.execute(t.insert()) + + row = self.conn.execute(select([t.c.x])).fetchone() + assert row["x"] is not None + def test_drop_col_schematype(self): self._boolean_fixture() with self.op.batch_alter_table("hasbool") as batch_op: @@ -1612,6 +1647,9 @@ class BatchRoundTripMySQLTest(BatchRoundTripTest): __only_on__ = "mysql" __backend__ = True + def _datetime_server_default_fixture(self): + return func.current_timestamp() + @exclusions.fails() def test_drop_pk_col_readd_pk_col(self): super(BatchRoundTripMySQLTest, self).test_drop_pk_col_readd_pk_col() @@ -1655,6 +1693,9 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest): __only_on__ = "postgresql" __backend__ = True + def _datetime_server_default_fixture(self): + return func.current_timestamp() + @exclusions.fails() def test_drop_pk_col_readd_pk_col(self): super( diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 9a0e8f3a..6c9838a7 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -3,6 +3,7 @@ from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import Float +from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import Interval @@ -507,10 +508,10 @@ class PostgresqlDefaultCompareTest(TestBase): ) def test_compare_string_blank_default(self): - self._compare_default_roundtrip(String(8), '') + self._compare_default_roundtrip(String(8), "") def test_compare_string_nonblank_default(self): - self._compare_default_roundtrip(String(8), 'hi') + self._compare_default_roundtrip(String(8), "hi") def test_compare_interval_str(self): # this form shouldn't be used but testing here @@ -534,6 +535,12 @@ class PostgresqlDefaultCompareTest(TestBase): DateTime(), text("TIMEZONE('utc', CURRENT_TIMESTAMP)") ) + @config.requirements.sqlalchemy_10 + def test_compare_current_timestamp_fn_w_binds(self): + self._compare_default_roundtrip( + DateTime(), func.timezone("utc", func.current_timestamp()) + ) + def test_compare_integer_str(self): self._compare_default_roundtrip(Integer(), "5") diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index dd81a13a..7616ef20 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -1,11 +1,28 @@ from sqlalchemy import Boolean from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Float +from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import text +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql import column +from alembic import autogenerate from alembic import op +from alembic.autogenerate import api +from alembic.autogenerate.compare import _compare_server_default +from alembic.migration import MigrationContext +from alembic.operations import ops from alembic.testing import assert_raises_message from alembic.testing import config +from alembic.testing import eq_ +from alembic.testing import eq_ignore_whitespace +from alembic.testing.env import clear_staging_env +from alembic.testing.env import staging_env from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase @@ -63,3 +80,179 @@ class SQLiteTest(TestBase): context = op_fixture("sqlite") op.add_column("t1", Column("c1", Integer, comment="c1 comment")) context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER") + + +class SQLiteDefaultCompareTest(TestBase): + __only_on__ = "sqlite" + __backend__ = True + + @classmethod + def setup_class(cls): + cls.bind = config.db + staging_env() + cls.migration_context = MigrationContext.configure( + connection=cls.bind.connect(), + opts={"compare_type": True, "compare_server_default": True}, + ) + + def setUp(self): + self.metadata = MetaData(self.bind) + self.autogen_context = api.AutogenContext(self.migration_context) + + @classmethod + def teardown_class(cls): + clear_staging_env() + + def tearDown(self): + self.metadata.drop_all() + + def _compare_default_roundtrip( + self, type_, orig_default, alternate=None, diff_expected=None + ): + diff_expected = ( + diff_expected + if diff_expected is not None + else alternate is not None + ) + if alternate is None: + alternate = orig_default + + t1 = Table( + "test", + self.metadata, + Column("somecol", type_, server_default=orig_default), + ) + t2 = Table( + "test", + MetaData(), + Column("somecol", type_, server_default=alternate), + ) + + t1.create(self.bind) + + insp = Inspector.from_engine(self.bind) + cols = insp.get_columns(t1.name) + insp_col = Column( + "somecol", cols[0]["type"], server_default=text(cols[0]["default"]) + ) + op = ops.AlterColumnOp("test", "somecol") + _compare_server_default( + self.autogen_context, + op, + None, + "test", + "somecol", + insp_col, + t2.c.somecol, + ) + + diffs = op.to_diff_tuple() + eq_(bool(diffs), diff_expected) + + def _compare_default(self, t1, t2, col, rendered): + t1.create(self.bind, checkfirst=True) + insp = Inspector.from_engine(self.bind) + cols = insp.get_columns(t1.name) + ctx = self.autogen_context.migration_context + + return ctx.impl.compare_server_default( + None, col, rendered, cols[0]["default"] + ) + + @config.requirements.sqlalchemy_12 + def test_compare_current_timestamp_func(self): + self._compare_default_roundtrip( + DateTime(), func.datetime("now", "localtime") + ) + + def test_compare_current_timestamp_text(self): + # SQLAlchemy doesn't render the parenthesis for a + # SQLite server default specified as text(), so users will be doing + # this; sqlite comparison needs to accommodate for these. + self._compare_default_roundtrip( + DateTime(), text("(datetime('now', 'localtime'))") + ) + + def test_compare_integer_str(self): + self._compare_default_roundtrip(Integer(), "5") + + def test_compare_integer_str_diff(self): + self._compare_default_roundtrip(Integer(), "5", "7") + + def test_compare_integer_text(self): + self._compare_default_roundtrip(Integer(), text("5")) + + def test_compare_integer_text_diff(self): + self._compare_default_roundtrip(Integer(), text("5"), "7") + + def test_compare_float_str(self): + self._compare_default_roundtrip(Float(), "5.2") + + def test_compare_float_str_diff(self): + self._compare_default_roundtrip(Float(), "5.2", "5.3") + + def test_compare_float_text(self): + self._compare_default_roundtrip(Float(), text("5.2")) + + def test_compare_float_text_diff(self): + self._compare_default_roundtrip(Float(), text("5.2"), "5.3") + + def test_compare_string_literal(self): + self._compare_default_roundtrip(String(), "im a default") + + def test_compare_string_literal_diff(self): + self._compare_default_roundtrip(String(), "im a default", "me too") + + +class SQLiteAutogenRenderTest(TestBase): + def setUp(self): + ctx_opts = { + "sqlalchemy_module_prefix": "sa.", + "alembic_module_prefix": "op.", + "target_metadata": MetaData(), + } + context = MigrationContext.configure( + dialect_name="sqlite", opts=ctx_opts + ) + + self.autogen_context = api.AutogenContext(context) + + def test_render_server_default_expr_needs_parens(self): + c = Column( + "date_value", + DateTime(), + server_default=func.datetime("now", "localtime"), + ) + + result = autogenerate.render._render_column(c, self.autogen_context) + eq_ignore_whitespace( + result, + "sa.Column('date_value', sa.DateTime(), " + "server_default=sa.text(!U\"(datetime('now', 'localtime'))\"), " + "nullable=True)", + ) + + def test_render_server_default_text_expr_needs_parens(self): + c = Column( + "date_value", + DateTime(), + server_default=text("(datetime('now', 'localtime'))"), + ) + + result = autogenerate.render._render_column(c, self.autogen_context) + eq_ignore_whitespace( + result, + "sa.Column('date_value', sa.DateTime(), " + "server_default=sa.text(!U\"(datetime('now', 'localtime'))\"), " + "nullable=True)", + ) + + def test_render_server_default_const(self): + c = Column("int_value", Integer, server_default="5") + + result = autogenerate.render._render_column(c, self.autogen_context) + eq_ignore_whitespace( + result, + "sa.Column('int_value', sa.Integer(), server_default='5', " + "nullable=True)", + )