conn_type, metadata_type, tname, cname
)
+def _render_server_default_for_compare(metadata_default,
+ metadata_col, autogen_context):
+ return _render_server_default(
+ metadata_default, autogen_context,
+ repr_=metadata_col.type._type_affinity is sqltypes.String)
+
def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
diffs, autogen_context):
conn_col_default = conn_col.server_default
if conn_col_default is None and metadata_default is None:
return False
- rendered_metadata_default = _render_server_default(
- metadata_default, autogen_context)
+ rendered_metadata_default = _render_server_default_for_compare(
+ metadata_default, metadata_col, autogen_context)
+
rendered_conn_default = conn_col.server_default.arg.text \
if conn_col.server_default else None
isdiff = autogen_context['context']._compare_server_default(
from unittest import TestCase
-from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, String
+from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \
+ String, Interval
+from sqlalchemy.dialects.postgresql import ARRAY
+from sqlalchemy.schema import DefaultClause
from sqlalchemy.engine.reflection import Inspector
from alembic.operations import Operations
from sqlalchemy.sql import table, column
+from alembic.autogenerate.compare import _compare_server_default
from alembic import command, util
from alembic.migration import MigrationContext
'connection': connection,
'dialect': connection.dialect,
'context': context
- }
+ }
@classmethod
def teardown_class(cls):
def tearDown(self):
self.metadata.drop_all()
- def _compare_default_roundtrip(self, type_, txt, alternate=None):
- if alternate:
- expected = True
- else:
- alternate = txt
- expected = False
- t = Table("test", self.metadata,
- Column("somecol", type_, server_default=text(txt))
- )
+ def _compare_default_roundtrip(self, type_, orig_default, alternate=None):
+ diff_expected = alternate is not None
+ if alternate is None:
+ alternate = orig_default
+
+ t1 = Table("test", self.metadata,
+ Column("somecol", type_, server_default=orig_default))
t2 = Table("test", MetaData(),
- Column("somecol", type_, server_default=text(alternate))
- )
- assert self._compare_default(
- t, t2, t2.c.somecol, alternate
- ) is expected
+ Column("somecol", type_, server_default=alternate))
+
+ t1.create(self.bind)
+
+ insp = Inspector.from_engine(self.bind)
+ cols = insp.get_columns(t1.name)
+ insp_col = Column("somecol", cols[0]['type'],
+ server_default=text(cols[0]['default']))
+ diffs = []
+ _compare_server_default(None, "test", "somecol", insp_col,
+ t2.c.somecol, diffs, self.autogen_context)
+ eq_(bool(diffs), diff_expected)
def _compare_default(
self,
t1, t2, col,
rendered
):
- t1.create(self.bind)
+ t1.create(self.bind, checkfirst=True)
insp = Inspector.from_engine(self.bind)
cols = insp.get_columns(t1.name)
ctx = self.autogen_context['context']
+
return ctx.impl.compare_server_default(
None,
col,
rendered,
cols[0]['default'])
- def test_compare_current_timestamp(self):
+ def test_compare_interval_str(self):
+ # this form shouldn't be used but testing here
+ # for compatibility
+ self._compare_default_roundtrip(
+ Interval,
+ "14 days"
+ )
+
+ def test_compare_interval_text(self):
+ self._compare_default_roundtrip(
+ Interval,
+ text("'14 days'")
+ )
+
+ def test_compare_array_of_integer_text(self):
+ self._compare_default_roundtrip(
+ ARRAY(Integer),
+ text("(ARRAY[]::integer[])")
+ )
+
+ def test_compare_current_timestamp_text(self):
self._compare_default_roundtrip(
DateTime(),
- "TIMEZONE('utc', CURRENT_TIMESTAMP)",
+ text("TIMEZONE('utc', CURRENT_TIMESTAMP)"),
)
- def test_compare_integer(self):
+ def test_compare_integer_str(self):
self._compare_default_roundtrip(
Integer(),
"5",
)
- def test_compare_integer_diff(self):
+ def test_compare_integer_text(self):
self._compare_default_roundtrip(
Integer(),
- "5", "7"
+ text("5"),
+ )
+
+ def test_compare_integer_text_diff(self):
+ self._compare_default_roundtrip(
+ Integer(),
+ text("5"), "7"
+ )
+
+ def test_compare_character_str(self):
+ self._compare_default_roundtrip(
+ String(),
+ "hello",
+ )
+
+ def test_compare_character_text(self):
+ self._compare_default_roundtrip(
+ String(),
+ text("'hello'"),
+ )
+
+ def test_compare_character_str_diff(self):
+ self._compare_default_roundtrip(
+ String(),
+ "hello",
+ "there"
)
- def test_compare_character_diff(self):
+ def test_compare_character_text_diff(self):
self._compare_default_roundtrip(
String(),
- "'hello'",
- "'there'"
+ text("'hello'"),
+ text("'there'")
)
def test_primary_key_skip(self):