import re
import sys
from unittest import TestCase
+from mock import Mock, patch
from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
- Numeric, CHAR, ForeignKey, DATETIME, \
+ Numeric, CHAR, ForeignKey, DATETIME, VARCHAR, \
TypeDecorator, CheckConstraint, Unicode, Enum,\
UniqueConstraint, Boolean, ForeignKeyConstraint,\
PrimaryKeyConstraint
[('remove_table', 'extra'), ('remove_table', 'user')]
)
+ def test_uses_custom_compare_type_function(self):
+ my_compare_type = Mock()
+ my_compare_type.return_value = None
+
+ context = MigrationContext.configure(
+ connection=self.bind.connect(),
+ opts={
+ 'compare_type': my_compare_type,
+ 'target_metadata': self.m1,
+ 'upgrade_token':"upgrades",
+ 'downgrade_token':"downgrades",
+ 'alembic_module_prefix':'op.',
+ 'sqlalchemy_module_prefix':'sa.'
+ }
+ )
+ autogenerate._produce_migration_diffs(context, {}, set())
+
+ first_table = self.m1.tables['address']
+ first_column = first_table.columns['email_address']
+
+ # We'll just test the first call
+ _, args, _ = my_compare_type.mock_calls[0]
+ ctx, inspected_column, metadata_column, inspected_type, metadata_type = args
+ eq_(ctx, context)
+ eq_(metadata_column, first_column)
+ eq_(metadata_type, first_column.type)
+ eq_(inspected_column.name, first_column.name)
+ eq_(type(inspected_type), VARCHAR)
+
+ def test_fields_excluded_when_custom_compare_type_returns_False(self):
+ my_compare_type = Mock()
+ my_compare_type.return_value = False
+
+ context = MigrationContext.configure(
+ connection=self.bind.connect(),
+ opts={
+ 'compare_type': my_compare_type,
+ 'target_metadata': self.m1,
+ 'upgrade_token':"upgrades",
+ 'downgrade_token':"downgrades",
+ 'alembic_module_prefix':'op.',
+ 'sqlalchemy_module_prefix':'sa.'
+ }
+ )
+ template_args = {}
+ newtype = String(length=30)
+ with patch.object(self.m1.tables['address'].columns['email_address'], 'type', new=newtype):
+ autogenerate._produce_migration_diffs(context, template_args, set())
+
+ eq_(re.sub(r"u'", "'", template_args['upgrades']),
+"""### commands auto generated by Alembic - please adjust! ###
+ pass
+ ### end Alembic commands ###""")
+
+ def test_fields_included_when_custom_compare_type_returns_True(self):
+ my_compare_type = Mock()
+ my_compare_type.return_value = True
+
+ context = MigrationContext.configure(
+ connection=self.bind.connect(),
+ opts={
+ 'compare_type': my_compare_type,
+ 'target_metadata': self.m1,
+ 'upgrade_token':"upgrades",
+ 'downgrade_token':"downgrades",
+ 'alembic_module_prefix':'op.',
+ 'sqlalchemy_module_prefix':'sa.'
+ }
+ )
+ template_args = {}
+ autogenerate._produce_migration_diffs(context, template_args, set())
+
+ eq_(re.sub(r"u'", "'", template_args['upgrades']),
+"""### commands auto generated by Alembic - please adjust! ###
+ op.alter_column('address', 'email_address',
+ existing_type=sa.VARCHAR(length=100),
+ type_=sa.String(length=100),
+ existing_nullable=False)
+ op.alter_column('address', 'id',
+ existing_type=sa.INTEGER(),
+ type_=sa.Integer(),
+ existing_nullable=False)
+ op.alter_column('extra', 'uid',
+ existing_type=sa.INTEGER(),
+ type_=sa.Integer(),
+ existing_nullable=True)
+ op.alter_column('extra', 'x',
+ existing_type=sa.CHAR(),
+ type_=sa.CHAR(),
+ existing_nullable=True)
+ op.alter_column('order', 'amount',
+ existing_type=sa.NUMERIC(precision=8, scale=2),
+ type_=sa.Numeric(precision=8, scale=2),
+ existing_nullable=False,
+ existing_server_default='0')
+ op.alter_column('order', 'order_id',
+ existing_type=sa.INTEGER(),
+ type_=sa.Integer(),
+ existing_nullable=False)
+ op.alter_column('user', 'a1',
+ existing_type=sa.TEXT(),
+ type_=sa.Text(),
+ existing_nullable=True)
+ op.alter_column('user', 'id',
+ existing_type=sa.INTEGER(),
+ type_=sa.Integer(),
+ existing_nullable=False)
+ op.alter_column('user', 'name',
+ existing_type=sa.VARCHAR(length=50),
+ type_=sa.String(length=50),
+ existing_nullable=True)
+ op.alter_column('user', 'pw',
+ existing_type=sa.VARCHAR(length=50),
+ type_=sa.String(length=50),
+ existing_nullable=True)
+ ### end Alembic commands ###""")
+
+
class AutogenKeyTest(AutogenTest, TestCase):
@classmethod
def _get_db_schema(cls):