]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Support explicit column ordering in batch mode
authorMarcin Szymanski <ms32035@gmail.com>
Fri, 31 Jan 2020 01:36:15 +0000 (20:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 4 Feb 2020 19:20:49 +0000 (14:20 -0500)
Added new parameters :paramref:`.BatchOperations.add_column.insert_before`,
:paramref:`.BatchOperations.add_column.insert_after` which provide for
establishing the specific position in which a new column should be placed.
Also added :paramref:`.Operations.batch_alter_table.partial_reordering`
which allows the complete set of columns to be reordered when the new table
is created.   Both operations apply only to when batch mode is recreating
the whole table using ``recreate="always"``.  Thanks to Marcin Szymanski
for assistance with the implementation.

Co-Authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Fixes: #640
Closes: #646
Pull-request: https://github.com/sqlalchemy/alembic/pull/646
Pull-request-sha: 29392b796dd995fabdabe50e8725385dbe1b4883

Change-Id: Iec9942baa08da4a7fc49b69093e84785fea3cec2

alembic/operations/base.py
alembic/operations/batch.py
alembic/operations/ops.py
alembic/operations/toimpl.py
docs/build/unreleased/640.rst [new file with mode: 0644]
tests/test_batch.py

index a4adf15e3f2181e0d5cdcab581ac6eda2cf4cbf9..602b7c76c49a3f2db465e3d64353598abfc9de8f 100644 (file)
@@ -177,6 +177,7 @@ class Operations(util.ModuleClsProxy):
         table_name,
         schema=None,
         recreate="auto",
+        partial_reordering=None,
         copy_from=None,
         table_args=(),
         table_kwargs=util.immutabledict(),
@@ -301,6 +302,33 @@ class Operations(util.ModuleClsProxy):
 
          .. versionadded:: 0.7.1
 
+        :param partial_reordering: a list of tuples, each suggesting a desired
+         ordering of two or more columns in the newly created table.  Requires
+         that :paramref:`.batch_alter_table.recreate` is set to ``"always"``.
+         Examples, given a table with columns "a", "b", "c", and "d":
+
+         Specify the order of all columns::
+
+            with op.batch_alter_table(
+                    "some_table", recreate="always",
+                    partial_reordering=[("c", "d", "a", "b")]
+            ) as batch_op:
+                pass
+
+         Ensure "d" appears before "c", and "b", appears before "a"::
+
+            with op.batch_alter_table(
+                    "some_table", recreate="always",
+                    partial_reordering=[("d", "c"), ("b", "a")]
+            ) as batch_op:
+                pass
+
+         The ordering of columns not included in the partial_reordering
+         set is undefined.   Therefore it is best to specify the complete
+         ordering of all columns for best results.
+
+         .. versionadded:: 1.4.0
+
         .. note:: batch mode requires SQLAlchemy 0.8 or above.
 
         .. seealso::
@@ -319,6 +347,7 @@ class Operations(util.ModuleClsProxy):
             reflect_args,
             reflect_kwargs,
             naming_convention,
+            partial_reordering,
         )
         batch_op = BatchOperations(self.migration_context, impl=impl)
         yield batch_op
index 42db905e66bf0d47059253bd68e5ee2396d9a2d5..6ca6f90c9f260b1b9c6b4a2aec780c819f73f8d8 100644 (file)
@@ -11,7 +11,9 @@ from sqlalchemy import Table
 from sqlalchemy import types as sqltypes
 from sqlalchemy.events import SchemaEventTarget
 from sqlalchemy.util import OrderedDict
+from sqlalchemy.util import topological
 
+from ..util import exc
 from ..util.sqla_compat import _columns_for_constraint
 from ..util.sqla_compat import _fk_is_self_referential
 from ..util.sqla_compat import _is_type_bound
@@ -31,6 +33,7 @@ class BatchOperationsImpl(object):
         reflect_args,
         reflect_kwargs,
         naming_convention,
+        partial_reordering,
     ):
         self.operations = operations
         self.table_name = table_name
@@ -52,6 +55,7 @@ class BatchOperationsImpl(object):
             ("column_reflect", operations.impl.autogen_column_reflect)
         )
         self.naming_convention = naming_convention
+        self.partial_reordering = partial_reordering
         self.batch = []
 
     @property
@@ -99,7 +103,11 @@ class BatchOperationsImpl(object):
                 reflected = True
 
             batch_impl = ApplyBatchImpl(
-                existing_table, self.table_args, self.table_kwargs, reflected
+                existing_table,
+                self.table_args,
+                self.table_kwargs,
+                reflected,
+                partial_reordering=self.partial_reordering,
             )
             for opname, arg, kw in self.batch:
                 fn = getattr(batch_impl, opname)
@@ -111,6 +119,13 @@ class BatchOperationsImpl(object):
         self.batch.append(("alter_column", arg, kw))
 
     def add_column(self, *arg, **kw):
+        if (
+            "insert_before" in kw or "insert_after" in kw
+        ) and not self._should_recreate():
+            raise exc.CommandError(
+                "Can't specify insert_before or insert_after when using "
+                "ALTER; please specify recreate='always'"
+            )
         self.batch.append(("add_column", arg, kw))
 
     def drop_column(self, *arg, **kw):
@@ -139,15 +154,23 @@ class BatchOperationsImpl(object):
 
 
 class ApplyBatchImpl(object):
-    def __init__(self, table, table_args, table_kwargs, reflected):
+    def __init__(
+        self, table, table_args, table_kwargs, reflected, partial_reordering=()
+    ):
         self.table = table  # this is a Table object
         self.table_args = table_args
         self.table_kwargs = table_kwargs
         self.temp_table_name = self._calc_temp_name(table.name)
         self.new_table = None
+
+        self.partial_reordering = partial_reordering  # tuple of tuples
+        self.add_col_ordering = ()  # tuple of tuples
+
         self.column_transfers = OrderedDict(
             (c.name, {"expr": c}) for c in self.table.c
         )
+        self.existing_ordering = list(self.column_transfers)
+
         self.reflected = reflected
         self._grab_table_elements()
 
@@ -188,12 +211,45 @@ class ApplyBatchImpl(object):
         for k in self.table.kwargs:
             self.table_kwargs.setdefault(k, self.table.kwargs[k])
 
+    def _adjust_self_columns_for_partial_reordering(self):
+        pairs = set()
+
+        col_by_idx = list(self.columns)
+
+        if self.partial_reordering:
+            for tuple_ in self.partial_reordering:
+                for index, elem in enumerate(tuple_):
+                    if index > 0:
+                        pairs.add((tuple_[index - 1], elem))
+        else:
+            for index, elem in enumerate(self.existing_ordering):
+                if index > 0:
+                    pairs.add((col_by_idx[index - 1], elem))
+
+        pairs.update(self.add_col_ordering)
+
+        # this can happen if some columns were dropped and not removed
+        # from existing_ordering.  this should be prevented already, but
+        # conservatively making sure this didn't happen
+        pairs = [p for p in pairs if p[0] != p[1]]
+
+        sorted_ = list(
+            topological.sort(pairs, col_by_idx, deterministic_order=True)
+        )
+        self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
+        self.column_transfers = OrderedDict(
+            (k, self.column_transfers[k]) for k in sorted_
+        )
+
     def _transfer_elements_to_new_table(self):
         assert self.new_table is None, "Can only create new table once"
 
         m = MetaData()
         schema = self.table.schema
 
+        if self.partial_reordering or self.add_col_ordering:
+            self._adjust_self_columns_for_partial_reordering()
+
         self.new_table = new_table = Table(
             self.temp_table_name,
             m,
@@ -371,7 +427,57 @@ class ApplyBatchImpl(object):
         if autoincrement is not None:
             existing.autoincrement = bool(autoincrement)
 
-    def add_column(self, table_name, column, **kw):
+    def _setup_dependencies_for_add_column(
+        self, colname, insert_before, insert_after
+    ):
+        index_cols = self.existing_ordering
+        col_indexes = {name: i for i, name in enumerate(index_cols)}
+
+        if not self.partial_reordering:
+            if insert_after:
+                if not insert_before:
+                    if insert_after in col_indexes:
+                        # insert after an existing column
+                        idx = col_indexes[insert_after] + 1
+                        if idx < len(index_cols):
+                            insert_before = index_cols[idx]
+                    else:
+                        # insert after a column that is also new
+                        insert_before = dict(self.add_col_ordering)[
+                            insert_after
+                        ]
+            if insert_before:
+                if not insert_after:
+                    if insert_before in col_indexes:
+                        # insert before an existing column
+                        idx = col_indexes[insert_before] - 1
+                        if idx >= 0:
+                            insert_after = index_cols[idx]
+                    else:
+                        # insert before a column that is also new
+                        insert_after = dict(
+                            (b, a) for a, b in self.add_col_ordering
+                        )[insert_before]
+
+        if insert_before:
+            self.add_col_ordering += ((colname, insert_before),)
+        if insert_after:
+            self.add_col_ordering += ((insert_after, colname),)
+
+        if (
+            not self.partial_reordering
+            and not insert_before
+            and not insert_after
+            and col_indexes
+        ):
+            self.add_col_ordering += ((index_cols[-1], colname),)
+
+    def add_column(
+        self, table_name, column, insert_before=None, insert_after=None, **kw
+    ):
+        self._setup_dependencies_for_add_column(
+            column.name, insert_before, insert_after
+        )
         # we copy the column because operations.add_column()
         # gives us a Column that is part of a Table already.
         self.columns[column.name] = column.copy(schema=self.table.schema)
@@ -384,6 +490,7 @@ class ApplyBatchImpl(object):
             )
         del self.columns[column.name]
         del self.column_transfers[column.name]
+        self.existing_ordering.remove(column.name)
 
     def add_constraint(self, const):
         if not const.name:
index 8f00b0c3a9e523f0f9fea82b35a9ba4578adbd3d..5ec27623358a5e89fe25f97c20e59f21eb316135 100644 (file)
@@ -1790,15 +1790,35 @@ class AlterColumnOp(AlterTableOp):
         existing_server_default=False,
         existing_nullable=None,
         existing_comment=None,
+        insert_before=None,
+        insert_after=None,
         **kw
     ):
         """Issue an "alter column" instruction using the current
         batch migration context.
 
+        Parameters are the same as that of :meth:`.Operations.alter_column`,
+        as well as the following option(s):
+
+        :param insert_before: String name of an existing column which this
+         column should be placed before, when creating the new table.
+
+         .. versionadded:: 1.4.0
+
+        :param insert_before: String name of an existing column which this
+         column should be placed after, when creating the new table.  If
+         both :paramref:`.BatchOperations.alter_column.insert_before`
+         and :paramref:`.BatchOperations.alter_column.insert_after` are
+         omitted, the column is inserted after the last existing column
+         in the table.
+
+         .. versionadded:: 1.4.0
+
         .. seealso::
 
             :meth:`.Operations.alter_column`
 
+
         """
         alt = cls(
             operations.impl.table_name,
@@ -1824,9 +1844,10 @@ class AlterColumnOp(AlterTableOp):
 class AddColumnOp(AlterTableOp):
     """Represent an add column operation."""
 
-    def __init__(self, table_name, column, schema=None):
+    def __init__(self, table_name, column, schema=None, **kw):
         super(AddColumnOp, self).__init__(table_name, schema=schema)
         self.column = column
+        self.kw = kw
 
     def reverse(self):
         return DropColumnOp.from_column_and_tablename(
@@ -1906,7 +1927,9 @@ class AddColumnOp(AlterTableOp):
         return operations.invoke(op)
 
     @classmethod
-    def batch_add_column(cls, operations, column):
+    def batch_add_column(
+        cls, operations, column, insert_before=None, insert_after=None
+    ):
         """Issue an "add column" instruction using the current
         batch migration context.
 
@@ -1915,8 +1938,18 @@ class AddColumnOp(AlterTableOp):
             :meth:`.Operations.add_column`
 
         """
+
+        kw = {}
+        if insert_before:
+            kw["insert_before"] = insert_before
+        if insert_after:
+            kw["insert_after"] = insert_after
+
         op = cls(
-            operations.impl.table_name, column, schema=operations.impl.schema
+            operations.impl.table_name,
+            column,
+            schema=operations.impl.schema,
+            **kw
         )
         return operations.invoke(op)
 
index 569942354691d16edad0fbe64f6b95ba471e0dce..24a2e3706316e24981289636375991d204893459 100644 (file)
@@ -126,9 +126,10 @@ def add_column(operations, operation):
     table_name = operation.table_name
     column = operation.column
     schema = operation.schema
+    kw = operation.kw
 
     t = operations.schema_obj.table(table_name, column, schema=schema)
-    operations.impl.add_column(table_name, column, schema=schema)
+    operations.impl.add_column(table_name, column, schema=schema, **kw)
     for constraint in t.constraints:
         if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
             operations.impl.add_constraint(constraint)
diff --git a/docs/build/unreleased/640.rst b/docs/build/unreleased/640.rst
new file mode 100644 (file)
index 0000000..8a28b76
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: feature, batch
+    :tickets: 640
+
+    Added new parameters :paramref:`.BatchOperations.add_column.insert_before`,
+    :paramref:`.BatchOperations.add_column.insert_after` which provide for
+    establishing the specific position in which a new column should be placed.
+    Also added :paramref:`.Operations.batch_alter_table.partial_reordering`
+    which allows the complete set of columns to be reordered when the new table
+    is created.   Both operations apply only to when batch mode is recreating
+    the whole table using ``recreate="always"``.  Thanks to Marcin Szymanski
+    for assistance with the implementation.
index a344d0c48c79893e04e0d4ef6d51e52425ed6b7b..5b4d3ec063aef5e220ccff6fca129205a3c7f83e 100644 (file)
@@ -34,6 +34,7 @@ from alembic.testing import exclusions
 from alembic.testing import mock
 from alembic.testing import TestBase
 from alembic.testing.fixtures import op_fixture
+from alembic.util import exc as alembic_exc
 from alembic.util.sqla_compat import sqla_14
 
 
@@ -41,7 +42,7 @@ class BatchApplyTest(TestBase):
     def setUp(self):
         self.op = Operations(mock.Mock(opts={}))
 
-    def _simple_fixture(self, table_args=(), table_kwargs={}):
+    def _simple_fixture(self, table_args=(), table_kwargs={}, **kw):
         m = MetaData()
         t = Table(
             "tname",
@@ -50,7 +51,7 @@ class BatchApplyTest(TestBase):
             Column("x", String(10)),
             Column("y", Integer),
         )
-        return ApplyBatchImpl(t, table_args, table_kwargs, False)
+        return ApplyBatchImpl(t, table_args, table_kwargs, False, **kw)
 
     def _uq_fixture(self, table_args=(), table_kwargs={}):
         m = MetaData()
@@ -466,6 +467,98 @@ class BatchApplyTest(TestBase):
         new_table = self._assert_impl(impl, colnames=["id", "x", "y", "g"])
         eq_(new_table.c.g.name, "g")
 
+    def test_partial_reordering(self):
+        impl = self._simple_fixture(partial_reordering=[("x", "id", "y")])
+        new_table = self._assert_impl(impl, colnames=["x", "id", "y"])
+        eq_(new_table.c.x.name, "x")
+
+    def test_add_col_partial_reordering(self):
+        impl = self._simple_fixture(partial_reordering=[("id", "x", "g", "y")])
+        col = Column("g", Integer)
+        # operations.add_column produces a table
+        t = self.op.schema_obj.table("tname", col)  # noqa
+        impl.add_column("tname", col)
+        new_table = self._assert_impl(impl, colnames=["id", "x", "g", "y"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_before(self):
+        impl = self._simple_fixture()
+        col = Column("g", Integer)
+        # operations.add_column produces a table
+        t = self.op.schema_obj.table("tname", col)  # noqa
+        impl.add_column("tname", col, insert_before="x")
+        new_table = self._assert_impl(impl, colnames=["id", "g", "x", "y"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_before_beginning(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_before="id")
+        new_table = self._assert_impl(impl, colnames=["g", "id", "x", "y"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_before_middle(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_before="y")
+        new_table = self._assert_impl(impl, colnames=["id", "x", "g", "y"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_after_middle(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_after="id")
+        new_table = self._assert_impl(impl, colnames=["id", "g", "x", "y"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_after_penultimate(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_after="x")
+        self._assert_impl(impl, colnames=["id", "x", "g", "y"])
+
+    def test_add_col_insert_after_end(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_after="y")
+        new_table = self._assert_impl(impl, colnames=["id", "x", "y", "g"])
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_after_plus_no_order(self):
+        impl = self._simple_fixture()
+        # operations.add_column produces a table
+        impl.add_column("tname", Column("g", Integer), insert_after="id")
+        impl.add_column("tname", Column("q", Integer))
+        new_table = self._assert_impl(
+            impl, colnames=["id", "g", "x", "y", "q"]
+        )
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_no_order_plus_insert_after(self):
+        impl = self._simple_fixture()
+        col = Column("g", Integer)
+        # operations.add_column produces a table
+        t = self.op.schema_obj.table("tname", col)  # noqa
+        impl.add_column("tname", Column("q", Integer))
+        impl.add_column("tname", Column("g", Integer), insert_after="id")
+        new_table = self._assert_impl(
+            impl, colnames=["id", "g", "x", "y", "q"]
+        )
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_after_another_insert(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_after="id")
+        impl.add_column("tname", Column("q", Integer), insert_after="g")
+        new_table = self._assert_impl(
+            impl, colnames=["id", "g", "q", "x", "y"]
+        )
+        eq_(new_table.c.g.name, "g")
+
+    def test_add_col_insert_before_another_insert(self):
+        impl = self._simple_fixture()
+        impl.add_column("tname", Column("g", Integer), insert_after="id")
+        impl.add_column("tname", Column("q", Integer), insert_before="g")
+        new_table = self._assert_impl(
+            impl, colnames=["id", "q", "g", "x", "y"]
+        )
+        eq_(new_table.c.g.name, "g")
+
     def test_add_server_default(self):
         impl = self._simple_fixture()
         impl.alter_column("tname", "y", server_default="10")
@@ -1593,6 +1686,64 @@ class BatchRoundTripTest(TestBase):
                 {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
             ]
         )
+        eq_(
+            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            ["id", "data", "x", "data2"],
+        )
+
+    def test_add_column_insert_before_recreate(self):
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+            batch_op.add_column(
+                Column("data2", String(50), server_default="hi"),
+                insert_before="data",
+            )
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+                {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+                {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+                {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+                {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+            ]
+        )
+        eq_(
+            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            ["id", "data2", "data", "x"],
+        )
+
+    def test_add_column_insert_after_recreate(self):
+        with self.op.batch_alter_table("foo", recreate="always") as batch_op:
+            batch_op.add_column(
+                Column("data2", String(50), server_default="hi"),
+                insert_after="data",
+            )
+        self._assert_data(
+            [
+                {"id": 1, "data": "d1", "x": 5, "data2": "hi"},
+                {"id": 2, "data": "22", "x": 6, "data2": "hi"},
+                {"id": 3, "data": "8.5", "x": 7, "data2": "hi"},
+                {"id": 4, "data": "9.46", "x": 8, "data2": "hi"},
+                {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
+            ]
+        )
+        eq_(
+            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            ["id", "data", "data2", "x"],
+        )
+
+    def test_add_column_insert_before_raise_on_alter(self):
+        def go():
+            with self.op.batch_alter_table("foo") as batch_op:
+                batch_op.add_column(
+                    Column("data2", String(50), server_default="hi"),
+                    insert_before="data",
+                )
+
+        assert_raises_message(
+            alembic_exc.CommandError,
+            "Can't specify insert_before or insert_after when using ALTER",
+            go,
+        )
 
     def test_add_column_recreate(self):
         with self.op.batch_alter_table("foo", recreate="always") as batch_op:
@@ -1609,6 +1760,10 @@ class BatchRoundTripTest(TestBase):
                 {"id": 5, "data": "d5", "x": 9, "data2": "hi"},
             ]
         )
+        eq_(
+            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            ["id", "data", "x", "data2"],
+        )
 
     def test_create_drop_index(self):
         insp = inspect(config.db)