]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- testing approaches for BatchOperationsImpl and ApplyBatchImpl
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Nov 2014 16:22:34 +0000 (11:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Nov 2014 16:22:34 +0000 (11:22 -0500)
alembic/batch.py
tests/test_batch.py [new file with mode: 0644]

index c258a1a9a5015e0b5b71c4803e6bdd5965ea1621..35b7494f7ccc821e73e42b3dff80fcf7a8c14eda 100644 (file)
@@ -1,4 +1,5 @@
-from sqlalchemy import Table, MetaData, Index, select
+from sqlalchemy import Table, MetaData, Index, select, Column, \
+    ForeignKeyConstraint
 from sqlalchemy import types as sqltypes
 from sqlalchemy.util import OrderedDict
 
@@ -51,36 +52,23 @@ class BatchOperationsImpl(object):
 
             batch_impl._create(self.impl)
 
-
     def alter_column(self, *arg, **kw):
-        self.batch.append(
-            ("alter_column", arg, kw)
-        )
+        self.batch.append(("alter_column", arg, kw))
 
     def add_column(self, *arg, **kw):
-        self.batch.append(
-            ("add_column", arg, kw)
-        )
+        self.batch.append(("add_column", arg, kw))
 
     def drop_column(self, *arg, **kw):
-        self.batch.append(
-            ("drop_column", arg, kw)
-        )
+        self.batch.append(("drop_column", arg, kw))
 
     def add_constraint(self, const):
-        self.batch.append(
-            ("add_constraint", (const,), {})
-        )
+        self.batch.append(("add_constraint", (const,), {}))
 
     def drop_constraint(self, const):
-        self.batch.append(
-            ("drop_constraint", (const, ), {})
-        )
+        self.batch.append(("drop_constraint", (const, ), {}))
 
     def rename_table(self, *arg, **kw):
-        self.batch.append(
-            ("rename_table", arg, kw)
-        )
+        self.batch.append(("rename_table", arg, kw))
 
     def create_table(self, table):
         raise NotImplementedError("Can't create table in batch mode")
@@ -98,7 +86,8 @@ class BatchOperationsImpl(object):
 class ApplyBatchImpl(object):
     def __init__(self, table):
         self.table = table  # this is a Table object
-        self.column_transfers = dict(
+        self.new_table = None
+        self.column_transfers = OrderedDict(
             (c.name, {}) for c in self.table.c
         )
         self._grab_table_elements()
@@ -122,29 +111,51 @@ class ApplyBatchImpl(object):
             self.indexes[idx.name] = idx
 
     def _transfer_elements_to_new_table(self):
+        assert self.new_table is None, "Can only create new table once"
+
         m = MetaData()
         schema = self.table.schema
-        new_table = Table(
+        self.new_table = new_table = Table(
             '_alembic_batch_temp', m, *self.columns.values(), schema=schema)
 
-        for c in list(self.named_constraints.values()) + \
+        for const in list(self.named_constraints.values()) + \
                 self.unnamed_constraints:
-            c_copy = c.copy(schema=schema, target_table=new_table)
-            new_table.append_constraint(c_copy)
+            const_columns = set([c.key for c in const.columns])
+            if not const_columns.issubset(self.column_transfers):
+                continue
+            const_copy = const.copy(schema=schema, target_table=new_table)
+            if isinstance(const, ForeignKeyConstraint):
+                self._setup_referent(m, const)
+            new_table.append_constraint(const_copy)
 
         for index in self.indexes.values():
             Index(index.name,
                   unique=index.unique,
                   *[new_table.c[col] for col in index.columns.keys()],
                   **index.kwargs)
-        return new_table
+
+    def _setup_referent(self, metadata, constraint):
+        spec = constraint.elements[0]._get_colspec()
+        parts = spec.split(".")
+        tname = parts[-2]
+        if len(parts) == 3:
+            referent_schema = parts[0]
+        else:
+            referent_schema = None
+        if tname != '_alembic_batch_temp':
+            Table(
+                tname, metadata,
+                *[Column(n, sqltypes.NULLTYPE) for n in
+                    [elem._get_colspec().split(".")[-1]
+                     for elem in constraint.elements]],
+                schema=referent_schema)
 
     def _create(self, op_impl):
-        new_table = self._transfer_elements_to_new_table()
-        op_impl.create_table(new_table)
+        self._transfer_elements_to_new_table()
+        op_impl.create_table(self.new_table)
 
-        op_impl.bind.execute(
-            new_table.insert(inline=True).from_select(
+        op_impl._exec(
+            self.new_table.insert(inline=True).from_select(
                 list(self.column_transfers.keys()),
                 select([
                     self.table.c[key]
diff --git a/tests/test_batch.py b/tests/test_batch.py
new file mode 100644 (file)
index 0000000..c3df74d
--- /dev/null
@@ -0,0 +1,182 @@
+from contextlib import contextmanager
+import re
+
+from alembic.testing import TestBase, eq_
+from alembic.testing.fixtures import op_fixture
+from alembic.testing import mock
+from alembic.operations import Operations
+from alembic.batch import ApplyBatchImpl
+
+from sqlalchemy import Integer, Table, Column, String, MetaData, ForeignKey, \
+    UniqueConstraint, Index, CheckConstraint, PrimaryKeyConstraint, \
+    ForeignKeyConstraint
+from sqlalchemy.sql import column
+from sqlalchemy.schema import CreateTable
+
+
+class BatchApplyTest(TestBase):
+    def _simple_fixture(self):
+        m = MetaData()
+        t = Table(
+            'tname', m,
+            Column('id', Integer, primary_key=True),
+            Column('x', String()),
+            Column('y', Integer)
+        )
+        return ApplyBatchImpl(t)
+
+    def _fk_fixture(self):
+        m = MetaData()
+        t = Table(
+            'tname', m,
+            Column('id', Integer, primary_key=True),
+            Column('email', String()),
+            Column('user_id', Integer, ForeignKey('user.id'))
+        )
+        return ApplyBatchImpl(t)
+
+    def _selfref_fk_fixture(self):
+        m = MetaData()
+        t = Table(
+            'tname', m,
+            Column('id', Integer, primary_key=True),
+            Column('parent_id', ForeignKey('tname.id')),
+            Column('data', String)
+        )
+        return ApplyBatchImpl(t)
+
+    def _assert_impl(self, impl, colnames=None):
+        context = op_fixture()
+
+        impl._create(context.impl)
+
+        if colnames is None:
+            colnames = ['id', 'x', 'y']
+        eq_(impl.new_table.c.keys(), colnames)
+
+        pk_cols = [col for col in impl.new_table.c if col.primary_key]
+        eq_(list(impl.new_table.primary_key), pk_cols)
+
+        create_stmt = str(
+            CreateTable(impl.new_table).compile(dialect=context.dialect))
+        create_stmt = re.sub(r'[\n\t]', '', create_stmt)
+        if pk_cols:
+            assert "PRIMARY KEY" in create_stmt
+        else:
+            assert "PRIMARY KEY" not in create_stmt
+
+        context.assert_(
+            create_stmt,
+            'INSERT INTO _alembic_batch_temp (%(colnames)s) '
+            'SELECT %(tname_colnames)s FROM tname' % {
+                "colnames": ", ".join([
+                    impl.new_table.c[name].name for name in colnames]),
+                "tname_colnames":
+                ", ".join("tname.%s" % name for name in colnames)
+            },
+            'DROP TABLE tname',
+            'ALTER TABLE _alembic_batch_temp RENAME TO tname'
+        )
+        return impl.new_table
+
+    def test_change_type(self):
+        impl = self._simple_fixture()
+        impl.alter_column('tname', 'x', type_=Integer)
+        new_table = self._assert_impl(impl)
+        assert new_table.c.x.type._type_affinity is Integer
+
+    def test_rename_col(self):
+        impl = self._simple_fixture()
+        impl.alter_column('tname', 'x', new_column_name='q')
+        new_table = self._assert_impl(impl)
+        eq_(new_table.c.x.name, 'q')
+
+    def test_rename_col_pk(self):
+        impl = self._simple_fixture()
+        impl.alter_column('tname', 'id', new_column_name='foobar')
+        new_table = self._assert_impl(impl)
+        eq_(new_table.c.id.name, 'foobar')
+        eq_(list(new_table.primary_key), [new_table.c.id])
+
+    def test_rename_col_fk(self):
+        impl = self._fk_fixture()
+        impl.alter_column('tname', 'user_id', new_column_name='foobar')
+        new_table = self._assert_impl(
+            impl, colnames=['id', 'email', 'user_id'])
+        eq_(new_table.c.user_id.name, 'foobar')
+        eq_(
+            list(new_table.c.user_id.foreign_keys)[0]._get_colspec(),
+            "user.id"
+        )
+
+    def test_drop_col(self):
+        impl = self._simple_fixture()
+        impl.drop_column('tname', column('x'))
+        new_table = self._assert_impl(impl, colnames=['id', 'y'])
+        assert 'y' in new_table.c
+        assert 'x' not in new_table.c
+
+    def test_drop_col_remove_pk(self):
+        impl = self._simple_fixture()
+        impl.drop_column('tname', column('id'))
+        new_table = self._assert_impl(impl, colnames=['x', 'y'])
+        assert 'y' in new_table.c
+        assert 'id' not in new_table.c
+        assert not new_table.primary_key
+
+    def test_drop_col_remove_fk(self):
+        impl = self._fk_fixture()
+        impl.drop_column('tname', column('user_id'))
+        new_table = self._assert_impl(impl, colnames=['id', 'email'])
+        assert 'user_id' not in new_table.c
+        assert not new_table.foreign_keys
+
+    def test_drop_col_retain_fk(self):
+        impl = self._fk_fixture()
+        impl.drop_column('tname', column('email'))
+        new_table = self._assert_impl(impl, colnames=['id', 'user_id'])
+        assert 'email' not in new_table.c
+        assert new_table.c.user_id.foreign_keys
+
+    def test_drop_col_retain_fk_selfref(self):
+        impl = self._selfref_fk_fixture()
+        impl.drop_column('tname', column('data'))
+        new_table = self._assert_impl(impl, colnames=['id', 'parent_id'])
+        assert 'data' not in new_table.c
+        assert new_table.c.parent_id.foreign_keys
+
+
+class BatchAPITest(TestBase):
+    @contextmanager
+    def _fixture(self):
+        migration_context = mock.Mock(opts={})
+        op = Operations(migration_context)
+        batch = op.batch_alter_table('tname', recreate='never').__enter__()
+
+        with mock.patch("alembic.operations.sa_schema") as mock_schema:
+            yield batch
+        self.mock_schema = mock_schema
+
+    def test_drop_col(self):
+        with self._fixture() as batch:
+            batch.drop_column('q')
+            batch.impl.flush()
+
+        eq_(
+            batch.impl.operations.impl.mock_calls,
+            [mock.call.drop_column(
+                'tname', self.mock_schema.Column(), schema=None)]
+        )
+
+    def test_add_col(self):
+        column = Column('w', String(50))
+
+        with self._fixture() as batch:
+            batch.add_column(column)
+            batch.impl.flush()
+
+        eq_(
+            batch.impl.operations.impl.mock_calls,
+            [mock.call.add_column(
+                'tname', column, schema=None)]
+        )