]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Create schema objects fresh from ops
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Feb 2021 18:11:17 +0000 (13:11 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Apr 2021 18:39:45 +0000 (14:39 -0400)
Refactored the implementation of :class:`.MigrateOperation` constructs such
as :class:`.CreateIndexOp`, :class:`.CreateTableOp`, etc. so that they no
longer rely upon maintaining a persistent version of each schema object
internally; instead, the state variables of each operation object will be
used to produce the corresponding construct when the operation is invoked.
The rationale is so that environments which make use of
operation-manipulation schemes such as those those discussed in
:ref:`autogen_rewriter` are better supported, allowing end-user code to
manipulate the public attributes of these objects which will then be
expressed in the final output, an example is
``some_create_index_op.kw["postgresql_concurrently"] = True``.

Previously, these objects when generated from autogenerate would typically
hold onto the original, reflected element internally without honoring the
other state variables of each construct, preventing the public API from
working.

Change-Id: Ida2537740325de5c21e9447e930a518093f01bd4
Fixes: #803
13 files changed:
alembic/autogenerate/render.py
alembic/ddl/postgresql.py
alembic/operations/ops.py
alembic/operations/schemaobj.py
alembic/testing/requirements.py
alembic/testing/schemacompare.py [new file with mode: 0644]
alembic/util/sqla_compat.py
docs/build/unreleased/803.rst [new file with mode: 0644]
tests/test_autogen_diffs.py
tests/test_autogen_indexes.py
tests/test_autogen_render.py
tests/test_command.py
tests/test_op.py

index 23890fb22f7785329a90339f41da4f3e0035daeb..58b469c5010d1b3d8150e4efdbfe096c48c37f52 100644 (file)
@@ -178,7 +178,9 @@ def _add_table(autogen_context, op):
         [
             rcons
             for rcons in [
-                _render_constraint(cons, autogen_context)
+                _render_constraint(
+                    cons, autogen_context, op._namespace_metadata
+                )
                 for cons in table.constraints
             ]
             if rcons is not None
@@ -308,7 +310,6 @@ def _add_fk_constraint(autogen_context, op):
             repr([_ident(col) for col in op.remote_cols]),
         ]
     )
-
     kwargs = [
         "referent_schema",
         "onupdate",
@@ -802,18 +803,18 @@ def _render_type_w_subtype(
 _constraint_renderers = util.Dispatcher()
 
 
-def _render_constraint(constraint, autogen_context):
+def _render_constraint(constraint, autogen_context, namespace_metadata):
     try:
         renderer = _constraint_renderers.dispatch(constraint)
     except ValueError:
         util.warn("No renderer is established for object %r" % constraint)
         return "[Unknown Python object %r]" % constraint
     else:
-        return renderer(constraint, autogen_context)
+        return renderer(constraint, autogen_context, namespace_metadata)
 
 
 @_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint)
-def _render_primary_key(constraint, autogen_context):
+def _render_primary_key(constraint, autogen_context, namespace_metadata):
     rendered = _user_defined_render("primary_key", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -835,7 +836,7 @@ def _render_primary_key(constraint, autogen_context):
     }
 
 
-def _fk_colspec(fk, metadata_schema):
+def _fk_colspec(fk, metadata_schema, namespace_metadata):
     """Implement a 'safe' version of ForeignKey._get_colspec() that
     won't fail if the remote table can't be resolved.
 
@@ -857,9 +858,9 @@ def _fk_colspec(fk, metadata_schema):
         # try to resolve the remote table in order to adjust for column.key.
         # the FK constraint needs to be rendered in terms of the column
         # name.
-        parent_metadata = fk.parent.table.metadata
-        if table_fullname in parent_metadata.tables:
-            col = parent_metadata.tables[table_fullname].c.get(colname)
+
+        if table_fullname in namespace_metadata.tables:
+            col = namespace_metadata.tables[table_fullname].c.get(colname)
             if col is not None:
                 colname = _ident(col.name)
 
@@ -883,7 +884,7 @@ def _populate_render_fk_opts(constraint, opts):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint)
-def _render_foreign_key(constraint, autogen_context):
+def _render_foreign_key(constraint, autogen_context, namespace_metadata):
     rendered = _user_defined_render("foreign_key", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -896,7 +897,7 @@ def _render_foreign_key(constraint, autogen_context):
 
     _populate_render_fk_opts(constraint, opts)
 
-    apply_metadata_schema = constraint.parent.metadata.schema
+    apply_metadata_schema = namespace_metadata.schema
     return (
         "%(prefix)sForeignKeyConstraint([%(cols)s], "
         "[%(refcols)s], %(args)s)"
@@ -906,7 +907,7 @@ def _render_foreign_key(constraint, autogen_context):
                 "%r" % _ident(f.parent.name) for f in constraint.elements
             ),
             "refcols": ", ".join(
-                repr(_fk_colspec(f, apply_metadata_schema))
+                repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata))
                 for f in constraint.elements
             ),
             "args": ", ".join(
@@ -917,7 +918,7 @@ def _render_foreign_key(constraint, autogen_context):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint)
-def _render_unique_constraint(constraint, autogen_context):
+def _render_unique_constraint(constraint, autogen_context, namespace_metadata):
     rendered = _user_defined_render("unique", constraint, autogen_context)
     if rendered is not False:
         return rendered
@@ -926,7 +927,7 @@ def _render_unique_constraint(constraint, autogen_context):
 
 
 @_constraint_renderers.dispatch_for(sa_schema.CheckConstraint)
-def _render_check_constraint(constraint, autogen_context):
+def _render_check_constraint(constraint, autogen_context, namespace_metadata):
     rendered = _user_defined_render("check", constraint, autogen_context)
     if rendered is not False:
         return rendered
index 6a2f007cea967d2c6b4e96a03e5eda74b15ac50b..1e6558605886302919c74db2ac2b7a2b8915e737 100644 (file)
@@ -464,7 +464,9 @@ def _add_exclude_constraint(autogen_context, op):
 
 
 @render._constraint_renderers.dispatch_for(ExcludeConstraint)
-def _render_inline_exclude_constraint(constraint, autogen_context):
+def _render_inline_exclude_constraint(
+    constraint, autogen_context, namespace_metadata
+):
     rendered = render._user_defined_render(
         "exclude", constraint, autogen_context
     )
index ffb2f1b104584ff638811557d6ea4139f2a06a01..7ef219020e362a3650bf38ff04a06eb21cb258cb 100644 (file)
@@ -76,21 +76,16 @@ class DropConstraintOp(MigrateOperation):
         table_name,
         type_=None,
         schema=None,
-        _orig_constraint=None,
+        _reverse=None,
     ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.constraint_type = type_
         self.schema = schema
-        self._orig_constraint = _orig_constraint
+        self._reverse = _reverse
 
     def reverse(self):
-        if self._orig_constraint is None:
-            raise ValueError(
-                "operation is not reversible; "
-                "original constraint is not present"
-            )
-        return AddConstraintOp.from_constraint(self._orig_constraint)
+        return AddConstraintOp.from_constraint(self.to_constraint())
 
     def to_diff_tuple(self):
         if self.constraint_type == "foreignkey":
@@ -115,12 +110,19 @@ class DropConstraintOp(MigrateOperation):
             constraint_table.name,
             schema=constraint_table.schema,
             type_=types[constraint.__visit_name__],
-            _orig_constraint=constraint,
+            _reverse=AddConstraintOp.from_constraint(constraint),
         )
 
     def to_constraint(self):
-        if self._orig_constraint is not None:
-            return self._orig_constraint
+
+        if self._reverse is not None:
+            constraint = self._reverse.to_constraint()
+            constraint.name = self.constraint_name
+            constraint_table = sqla_compat._table_for_constraint(constraint)
+            constraint_table.name = self.table_name
+            constraint_table.schema = self.schema
+
+            return constraint
         else:
             raise ValueError(
                 "constraint cannot be produced; "
@@ -180,43 +182,33 @@ class CreatePrimaryKeyOp(AddConstraintOp):
     constraint_type = "primarykey"
 
     def __init__(
-        self,
-        constraint_name,
-        table_name,
-        columns,
-        schema=None,
-        _orig_constraint=None,
-        **kw
+        self, constraint_name, table_name, columns, schema=None, **kw
     ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
         self.schema = schema
-        self._orig_constraint = _orig_constraint
         self.kw = kw
 
     @classmethod
     def from_constraint(cls, constraint):
         constraint_table = sqla_compat._table_for_constraint(constraint)
-
         return cls(
             constraint.name,
             constraint_table.name,
-            constraint.columns,
+            constraint.columns.keys(),
             schema=constraint_table.schema,
-            _orig_constraint=constraint,
+            **constraint.dialect_kwargs
         )
 
     def to_constraint(self, migration_context=None):
-        if self._orig_constraint is not None:
-            return self._orig_constraint
-
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.primary_key_constraint(
             self.constraint_name,
             self.table_name,
             self.columns,
             schema=self.schema,
+            **self.kw
         )
 
     @classmethod
@@ -295,19 +287,12 @@ class CreateUniqueConstraintOp(AddConstraintOp):
     constraint_type = "unique"
 
     def __init__(
-        self,
-        constraint_name,
-        table_name,
-        columns,
-        schema=None,
-        _orig_constraint=None,
-        **kw
+        self, constraint_name, table_name, columns, schema=None, **kw
     ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.columns = columns
         self.schema = schema
-        self._orig_constraint = _orig_constraint
         self.kw = kw
 
     @classmethod
@@ -319,20 +304,16 @@ class CreateUniqueConstraintOp(AddConstraintOp):
             kw["deferrable"] = constraint.deferrable
         if constraint.initially:
             kw["initially"] = constraint.initially
-
+        kw.update(constraint.dialect_kwargs)
         return cls(
             constraint.name,
             constraint_table.name,
             [c.name for c in constraint.columns],
             schema=constraint_table.schema,
-            _orig_constraint=constraint,
             **kw
         )
 
     def to_constraint(self, migration_context=None):
-        if self._orig_constraint is not None:
-            return self._orig_constraint
-
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.unique_constraint(
             self.constraint_name,
@@ -430,7 +411,6 @@ class CreateForeignKeyOp(AddConstraintOp):
         referent_table,
         local_cols,
         remote_cols,
-        _orig_constraint=None,
         **kw
     ):
         self.constraint_name = constraint_name
@@ -438,7 +418,6 @@ class CreateForeignKeyOp(AddConstraintOp):
         self.referent_table = referent_table
         self.local_cols = local_cols
         self.remote_cols = remote_cols
-        self._orig_constraint = _orig_constraint
         self.kw = kw
 
     def to_diff_tuple(self):
@@ -473,20 +452,17 @@ class CreateForeignKeyOp(AddConstraintOp):
 
         kw["source_schema"] = source_schema
         kw["referent_schema"] = target_schema
-
+        kw.update(constraint.dialect_kwargs)
         return cls(
             constraint.name,
             source_table,
             target_table,
             source_columns,
             target_columns,
-            _orig_constraint=constraint,
             **kw
         )
 
     def to_constraint(self, migration_context=None):
-        if self._orig_constraint is not None:
-            return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.foreign_key_constraint(
             self.constraint_name,
@@ -642,19 +618,12 @@ class CreateCheckConstraintOp(AddConstraintOp):
     constraint_type = "check"
 
     def __init__(
-        self,
-        constraint_name,
-        table_name,
-        condition,
-        schema=None,
-        _orig_constraint=None,
-        **kw
+        self, constraint_name, table_name, condition, schema=None, **kw
     ):
         self.constraint_name = constraint_name
         self.table_name = table_name
         self.condition = condition
         self.schema = schema
-        self._orig_constraint = _orig_constraint
         self.kw = kw
 
     @classmethod
@@ -666,12 +635,10 @@ class CreateCheckConstraintOp(AddConstraintOp):
             constraint_table.name,
             constraint.sqltext,
             schema=constraint_table.schema,
-            _orig_constraint=constraint,
+            **constraint.dialect_kwargs
         )
 
     def to_constraint(self, migration_context=None):
-        if self._orig_constraint is not None:
-            return self._orig_constraint
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.check_constraint(
             self.constraint_name,
@@ -765,14 +732,7 @@ class CreateIndexOp(MigrateOperation):
     """Represent a create index operation."""
 
     def __init__(
-        self,
-        index_name,
-        table_name,
-        columns,
-        schema=None,
-        unique=False,
-        _orig_index=None,
-        **kw
+        self, index_name, table_name, columns, schema=None, unique=False, **kw
     ):
         self.index_name = index_name
         self.table_name = table_name
@@ -780,7 +740,6 @@ class CreateIndexOp(MigrateOperation):
         self.schema = schema
         self.unique = unique
         self.kw = kw
-        self._orig_index = _orig_index
 
     def reverse(self):
         return DropIndexOp.from_index(self.to_index())
@@ -796,15 +755,13 @@ class CreateIndexOp(MigrateOperation):
             sqla_compat._get_index_expressions(index),
             schema=index.table.schema,
             unique=index.unique,
-            _orig_index=index,
             **index.kwargs
         )
 
     def to_index(self, migration_context=None):
-        if self._orig_index:
-            return self._orig_index
         schema_obj = schemaobj.SchemaObjects(migration_context)
-        return schema_obj.index(
+
+        idx = schema_obj.index(
             self.index_name,
             self.table_name,
             self.columns,
@@ -812,6 +769,7 @@ class CreateIndexOp(MigrateOperation):
             unique=self.unique,
             **self.kw
         )
+        return idx
 
     @classmethod
     def create_index(
@@ -897,23 +855,19 @@ class DropIndexOp(MigrateOperation):
     """Represent a drop index operation."""
 
     def __init__(
-        self, index_name, table_name=None, schema=None, _orig_index=None, **kw
+        self, index_name, table_name=None, schema=None, _reverse=None, **kw
     ):
         self.index_name = index_name
         self.table_name = table_name
         self.schema = schema
-        self._orig_index = _orig_index
+        self._reverse = _reverse
         self.kw = kw
 
     def to_diff_tuple(self):
         return ("remove_index", self.to_index())
 
     def reverse(self):
-        if self._orig_index is None:
-            raise ValueError(
-                "operation is not reversible; " "original index is not present"
-            )
-        return CreateIndexOp.from_index(self._orig_index)
+        return CreateIndexOp.from_index(self.to_index())
 
     @classmethod
     def from_index(cls, index):
@@ -921,14 +875,11 @@ class DropIndexOp(MigrateOperation):
             index.name,
             index.table.name,
             schema=index.table.schema,
-            _orig_index=index,
+            _reverse=CreateIndexOp.from_index(index),
             **index.kwargs
         )
 
     def to_index(self, migration_context=None):
-        if self._orig_index is not None:
-            return self._orig_index
-
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         # need a dummy column name here since SQLAlchemy
@@ -936,7 +887,7 @@ class DropIndexOp(MigrateOperation):
         return schema_obj.index(
             self.index_name,
             self.table_name,
-            ["x"],
+            self._reverse.columns if self._reverse else ["x"],
             schema=self.schema,
             **self.kw
         )
@@ -994,37 +945,49 @@ class CreateTableOp(MigrateOperation):
     """Represent a create table operation."""
 
     def __init__(
-        self, table_name, columns, schema=None, _orig_table=None, **kw
+        self, table_name, columns, schema=None, _namespace_metadata=None, **kw
     ):
         self.table_name = table_name
         self.columns = columns
         self.schema = schema
+        self.comment = kw.pop("comment", None)
+        self.prefixes = kw.pop("prefixes", None)
         self.kw = kw
-        self._orig_table = _orig_table
+        self._namespace_metadata = _namespace_metadata
 
     def reverse(self):
-        return DropTableOp.from_table(self.to_table())
+        return DropTableOp.from_table(
+            self.to_table(), _namespace_metadata=self._namespace_metadata
+        )
 
     def to_diff_tuple(self):
         return ("add_table", self.to_table())
 
     @classmethod
-    def from_table(cls, table):
+    def from_table(cls, table, _namespace_metadata=None):
+        if _namespace_metadata is None:
+            _namespace_metadata = table.metadata
+
         return cls(
             table.name,
             list(table.c) + list(table.constraints),
             schema=table.schema,
-            _orig_table=table,
+            _namespace_metadata=_namespace_metadata,
+            comment=table.comment,
+            prefixes=table._prefixes,
             **table.kwargs
         )
 
     def to_table(self, migration_context=None):
-        if self._orig_table is not None:
-            return self._orig_table
         schema_obj = schemaobj.SchemaObjects(migration_context)
 
         return schema_obj.table(
-            self.table_name, *self.columns, schema=self.schema, **self.kw
+            self.table_name,
+            *self.columns,
+            schema=self.schema,
+            prefixes=self.prefixes,
+            comment=self.comment,
+            **self.kw
         )
 
     @classmethod
@@ -1113,35 +1076,43 @@ class CreateTableOp(MigrateOperation):
 class DropTableOp(MigrateOperation):
     """Represent a drop table operation."""
 
-    def __init__(
-        self, table_name, schema=None, table_kw=None, _orig_table=None
-    ):
+    def __init__(self, table_name, schema=None, table_kw=None, _reverse=None):
         self.table_name = table_name
         self.schema = schema
         self.table_kw = table_kw or {}
-        self._orig_table = _orig_table
+        self._reverse = _reverse
 
     def to_diff_tuple(self):
         return ("remove_table", self.to_table())
 
     def reverse(self):
-        if self._orig_table is None:
-            raise ValueError(
-                "operation is not reversible; " "original table is not present"
-            )
-        return CreateTableOp.from_table(self._orig_table)
+        return CreateTableOp.from_table(self.to_table())
 
     @classmethod
-    def from_table(cls, table):
-        return cls(table.name, schema=table.schema, _orig_table=table)
+    def from_table(cls, table, _namespace_metadata=None):
+        return cls(
+            table.name,
+            schema=table.schema,
+            table_kw=table.kwargs,
+            _reverse=CreateTableOp.from_table(
+                table, _namespace_metadata=_namespace_metadata
+            ),
+        )
 
     def to_table(self, migration_context=None):
-        if self._orig_table is not None:
-            return self._orig_table
+        if self._reverse:
+            cols_and_constraints = self._reverse.columns
+        else:
+            cols_and_constraints = []
+
         schema_obj = schemaobj.SchemaObjects(migration_context)
-        return schema_obj.table(
-            self.table_name, schema=self.schema, **self.table_kw
+        t = schema_obj.table(
+            self.table_name,
+            *cols_and_constraints,
+            schema=self.schema,
+            **self.table_kw
         )
+        return t
 
     @classmethod
     def drop_table(cls, operations, table_name, schema=None, **kw):
@@ -1791,12 +1762,12 @@ class DropColumnOp(AlterTableOp):
     """Represent a drop column operation."""
 
     def __init__(
-        self, table_name, column_name, schema=None, _orig_column=None, **kw
+        self, table_name, column_name, schema=None, _reverse=None, **kw
     ):
         super(DropColumnOp, self).__init__(table_name, schema=schema)
         self.column_name = column_name
         self.kw = kw
-        self._orig_column = _orig_column
+        self._reverse = _reverse
 
     def to_diff_tuple(self):
         return (
@@ -1807,23 +1778,28 @@ class DropColumnOp(AlterTableOp):
         )
 
     def reverse(self):
-        if self._orig_column is None:
+        if self._reverse is None:
             raise ValueError(
                 "operation is not reversible; "
                 "original column is not present"
             )
 
         return AddColumnOp.from_column_and_tablename(
-            self.schema, self.table_name, self._orig_column
+            self.schema, self.table_name, self._reverse.column
         )
 
     @classmethod
     def from_column_and_tablename(cls, schema, tname, col):
-        return cls(tname, col.name, schema=schema, _orig_column=col)
+        return cls(
+            tname,
+            col.name,
+            schema=schema,
+            _reverse=AddColumnOp.from_column_and_tablename(schema, tname, col),
+        )
 
     def to_column(self, migration_context=None):
-        if self._orig_column is not None:
-            return self._orig_column
+        if self._reverse is not None:
+            return self._reverse.column
         schema_obj = schemaobj.SchemaObjects(migration_context)
         return schema_obj.column(self.column_name, NULLTYPE)
 
index 5e8aa4fec0d416f0c1405833fddcbaae62372256..5d04ee200bd339a0ba4d095412006d5421089c66 100644 (file)
@@ -1,8 +1,12 @@
 from sqlalchemy import schema as sa_schema
+from sqlalchemy.sql.schema import Column
+from sqlalchemy.sql.schema import Constraint
+from sqlalchemy.sql.schema import Index
 from sqlalchemy.types import Integer
 from sqlalchemy.types import NULLTYPE
 
 from .. import util
+from ..util import sqla_compat
 from ..util.compat import raise_
 from ..util.compat import string_types
 
@@ -16,7 +20,6 @@ class SchemaObjects(object):
         columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
         t = sa_schema.Table(table_name, m, *columns, schema=schema)
         p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name)
-        t.append_constraint(p)
         return p
 
     def foreign_key_constraint(
@@ -140,7 +143,25 @@ class SchemaObjects(object):
 
     def table(self, name, *columns, **kw):
         m = self.metadata()
-        t = sa_schema.Table(name, m, *columns, **kw)
+
+        cols = [
+            sqla_compat._copy(c) if c.table is not None else c
+            for c in columns
+            if isinstance(c, Column)
+        ]
+        t = sa_schema.Table(name, m, *cols, **kw)
+
+        constraints = [
+            sqla_compat._copy(elem, target_table=t)
+            if getattr(elem, "parent", None) is not None
+            else elem
+            for elem in columns
+            if isinstance(elem, (Constraint, Index))
+        ]
+
+        for const in constraints:
+            t.append_constraint(const)
+
         for f in t.foreign_keys:
             self._ensure_table_for_fk(m, f)
         return t
@@ -150,8 +171,11 @@ class SchemaObjects(object):
 
     def index(self, name, tablename, columns, schema=None, **kw):
         t = sa_schema.Table(
-            tablename or "no_table", self.metadata(), schema=schema
+            tablename or "no_table",
+            self.metadata(),
+            schema=schema,
         )
+        kw["_table"] = t
         idx = sa_schema.Index(
             name,
             *[util.sqla_compat._textual_index_column(t, n) for n in columns],
index 5a1106881bb60cb96cf9cbccd622cfa5cf85d858..3a5426b64cf0f8ebc9368644f10499bf3ec52b9e 100644 (file)
@@ -67,6 +67,18 @@ class SuiteRequirements(Requirements):
     def reflects_fk_options(self):
         return exclusions.closed()
 
+    @property
+    def editor_installed(self):
+        def go():
+            try:
+                import editor  # noqa
+            except ImportError:
+                return False
+            else:
+                return True
+
+        return exclusions.only_if(go, "editor package not installed")
+
     @property
     def sqlalchemy_13(self):
         return exclusions.skip_if(
diff --git a/alembic/testing/schemacompare.py b/alembic/testing/schemacompare.py
new file mode 100644 (file)
index 0000000..c3a7382
--- /dev/null
@@ -0,0 +1,155 @@
+from sqlalchemy import schema
+from sqlalchemy import util
+
+
+class CompareTable(object):
+    def __init__(self, table):
+        self.table = table
+
+    def __eq__(self, other):
+        if self.table.name != other.name or self.table.schema != other.schema:
+            return False
+
+        for c1, c2 in util.zip_longest(self.table.c, other.c):
+            if (c1 is None and c2 is not None) or (
+                c2 is None and c1 is not None
+            ):
+                return False
+            if CompareColumn(c1) != c2:
+                return False
+
+        return True
+
+        # TODO: compare constraints, indexes
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class CompareColumn(object):
+    def __init__(self, column):
+        self.column = column
+
+    def __eq__(self, other):
+        return (
+            self.column.name == other.name
+            and self.column.nullable == other.nullable
+        )
+        # TODO: datatypes etc
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class CompareIndex(object):
+    def __init__(self, index):
+        self.index = index
+
+    def __eq__(self, other):
+        return (
+            str(schema.CreateIndex(self.index))
+            == str(schema.CreateIndex(other))
+            and self.index.dialect_kwargs == other.dialect_kwargs
+        )
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class CompareCheckConstraint(object):
+    def __init__(self, constraint):
+        self.constraint = constraint
+
+    def __eq__(self, other):
+        return (
+            isinstance(other, schema.CheckConstraint)
+            and self.constraint.name == other.name
+            and (str(self.constraint.sqltext) == str(other.sqltext))
+            and (other.table.name == self.constraint.table.name)
+            and other.table.schema == self.constraint.table.schema
+        )
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class CompareForeignKey(object):
+    def __init__(self, constraint):
+        self.constraint = constraint
+
+    def __eq__(self, other):
+        r1 = (
+            isinstance(other, schema.ForeignKeyConstraint)
+            and self.constraint.name == other.name
+            and (other.table.name == self.constraint.table.name)
+            and other.table.schema == self.constraint.table.schema
+        )
+        if not r1:
+            return False
+        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+            if (c1 is None and c2 is not None) or (
+                c2 is None and c1 is not None
+            ):
+                return False
+            if CompareColumn(c1) != c2:
+                return False
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class ComparePrimaryKey(object):
+    def __init__(self, constraint):
+        self.constraint = constraint
+
+    def __eq__(self, other):
+        r1 = (
+            isinstance(other, schema.PrimaryKeyConstraint)
+            and self.constraint.name == other.name
+            and (other.table.name == self.constraint.table.name)
+            and other.table.schema == self.constraint.table.schema
+        )
+        if not r1:
+            return False
+
+        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+            if (c1 is None and c2 is not None) or (
+                c2 is None and c1 is not None
+            ):
+                return False
+            if CompareColumn(c1) != c2:
+                return False
+
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class CompareUniqueConstraint(object):
+    def __init__(self, constraint):
+        self.constraint = constraint
+
+    def __eq__(self, other):
+        r1 = (
+            isinstance(other, schema.UniqueConstraint)
+            and self.constraint.name == other.name
+            and (other.table.name == self.constraint.table.name)
+            and other.table.schema == self.constraint.table.schema
+        )
+        if not r1:
+            return False
+
+        for c1, c2 in util.zip_longest(self.constraint.columns, other.columns):
+            if (c1 is None and c2 is not None) or (
+                c2 is None and c1 is not None
+            ):
+                return False
+            if CompareColumn(c1) != c2:
+                return False
+
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
index 91e22d38455847d2ed78161a2b9b78b3464920cf..a04ab2e9ce0625c27adc49e15194f7925152912c 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.schema import CheckConstraint
 from sqlalchemy.schema import Column
 from sqlalchemy.schema import ForeignKeyConstraint
+from sqlalchemy.sql import visitors
 from sqlalchemy.sql.elements import quoted_name
 from sqlalchemy.sql.expression import _BindParamClause
 from sqlalchemy.sql.expression import _TextClause as TextClause
@@ -252,10 +253,31 @@ def _textual_index_column(table, text_):
         return c
     elif isinstance(text_, TextClause):
         return _textual_index_element(table, text_)
+    elif isinstance(text_, sql.ColumnElement):
+        return _copy_expression(text_, table)
     else:
         raise ValueError("String or text() construct expected")
 
 
+def _copy_expression(expression, target_table):
+    def replace(col):
+        if (
+            isinstance(col, Column)
+            and col.table is not None
+            and col.table is not target_table
+        ):
+            if col.name in target_table.c:
+                return target_table.c[col.name]
+            else:
+                c = _copy(col)
+                target_table.append_column(c)
+                return c
+        else:
+            return None
+
+    return visitors.replacement_traverse(expression, {}, replace)
+
+
 class _textual_index_element(sql.ColumnElement):
     """Wrap around a sqlalchemy text() construct in such a way that
     we appear like a column-oriented SQL expression to an Index
diff --git a/docs/build/unreleased/803.rst b/docs/build/unreleased/803.rst
new file mode 100644 (file)
index 0000000..11d105b
--- /dev/null
@@ -0,0 +1,22 @@
+.. change::
+    :tags: bug, autogenerate
+    :tickets: 803
+
+    Refactored the implementation of :class:`.MigrateOperation` constructs such
+    as :class:`.CreateIndexOp`, :class:`.CreateTableOp`, etc. so that they no
+    longer rely upon maintaining a persistent version of each schema object
+    internally; instead, the state variables of each operation object will be
+    used to produce the corresponding construct when the operation is invoked.
+    The rationale is so that environments which make use of
+    operation-manipulation schemes such as those those discussed in
+    :ref:`autogen_rewriter` are better supported, allowing end-user code to
+    manipulate the public attributes of these objects which will then be
+    expressed in the final output, an example is
+    ``some_create_index_op.kw["postgresql_concurrently"] = True``.
+
+    Previously, these objects when generated from autogenerate would typically
+    hold onto the original, reflected element internally without honoring the
+    other state variables of each construct, preventing the public API from
+    working.
+
+
index 02a750a2804845eb6c11a56228d05cf094090d60..94a916736099527a60757b575d0d44312f86101b 100644 (file)
@@ -48,6 +48,7 @@ from alembic.testing import eq_
 from alembic.testing import is_
 from alembic.testing import is_not_
 from alembic.testing import mock
+from alembic.testing import schemacompare
 from alembic.testing import TestBase
 from alembic.testing.env import clear_staging_env
 from alembic.testing.env import staging_env
@@ -429,7 +430,10 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase):
         autogenerate._produce_net_changes(ctx, uo)
 
         diffs = uo.as_diffs()
-        eq_(diffs[0], ("add_table", metadata.tables["item"]))
+        eq_(
+            diffs[0],
+            ("add_table", schemacompare.CompareTable(metadata.tables["item"])),
+        )
 
         eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
@@ -720,7 +724,15 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
 
         diffs = uo.as_diffs()
 
-        eq_(diffs[0], ("add_table", metadata.tables["%s.item" % self.schema]))
+        eq_(
+            diffs[0],
+            (
+                "add_table",
+                schemacompare.CompareTable(
+                    metadata.tables["%s.item" % self.schema]
+                ),
+            ),
+        )
 
         eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
@@ -728,7 +740,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], self.schema)
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables["%s.address" % self.schema].c.street)
+        eq_(
+            schemacompare.CompareColumn(
+                metadata.tables["%s.address" % self.schema].c.street
+            ),
+            diffs[2][3],
+        )
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -736,7 +753,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], self.schema)
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables["%s.order" % self.schema].c.user_id)
+        eq_(
+            schemacompare.CompareColumn(
+                metadata.tables["%s.order" % self.schema].c.user_id
+            ),
+            diffs[4][3],
+        )
 
         eq_(diffs[5][0][0], "modify_type")
         eq_(diffs[5][0][1], self.schema)
@@ -1504,7 +1526,10 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase):
 
         diffs = autogenerate.compare_metadata(self.context, metadata)
 
-        eq_(diffs[0], ("add_table", metadata.tables["item"]))
+        eq_(
+            diffs[0],
+            ("add_table", schemacompare.CompareTable(metadata.tables["item"])),
+        )
 
         eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
@@ -1666,7 +1691,15 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
 
         diffs = autogenerate.compare_metadata(context, metadata)
 
-        eq_(diffs[0], ("add_table", metadata.tables["test_schema.item"]))
+        eq_(
+            diffs[0],
+            (
+                "add_table",
+                schemacompare.CompareTable(
+                    metadata.tables["test_schema.item"]
+                ),
+            ),
+        )
 
         eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "extra")
@@ -1674,7 +1707,12 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
         eq_(diffs[2][0], "add_column")
         eq_(diffs[2][1], "test_schema")
         eq_(diffs[2][2], "address")
-        eq_(diffs[2][3], metadata.tables["test_schema.address"].c.street)
+        eq_(
+            schemacompare.CompareColumn(
+                metadata.tables["test_schema.address"].c.street
+            ),
+            diffs[2][3],
+        )
 
         eq_(diffs[3][0], "add_constraint")
         eq_(diffs[3][1].name, "uq_email")
@@ -1682,7 +1720,12 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase):
         eq_(diffs[4][0], "add_column")
         eq_(diffs[4][1], "test_schema")
         eq_(diffs[4][2], "order")
-        eq_(diffs[4][3], metadata.tables["test_schema.order"].c.user_id)
+        eq_(
+            schemacompare.CompareColumn(
+                metadata.tables["test_schema.order"].c.user_id
+            ),
+            diffs[4][3],
+        )
 
         eq_(diffs[5][0][0], "modify_nullable")
         eq_(diffs[5][0][5], False)
@@ -1711,42 +1754,54 @@ class OrigObjectTest(TestBase):
     def test_drop_fk(self):
         fk = self.fk
         op = ops.DropConstraintOp.from_constraint(fk)
-        is_(op.to_constraint(), fk)
-        is_(op.reverse().to_constraint(), fk)
+        eq_(op.to_constraint(), schemacompare.CompareForeignKey(fk))
+        eq_(op.reverse().to_constraint(), schemacompare.CompareForeignKey(fk))
 
     def test_add_fk(self):
         fk = self.fk
         op = ops.AddConstraintOp.from_constraint(fk)
-        is_(op.to_constraint(), fk)
-        is_(op.reverse().to_constraint(), fk)
+        eq_(op.to_constraint(), schemacompare.CompareForeignKey(fk))
+        eq_(op.reverse().to_constraint(), schemacompare.CompareForeignKey(fk))
         is_not_(None, op.to_constraint().table)
 
     def test_add_check(self):
         ck = self.ck
         op = ops.AddConstraintOp.from_constraint(ck)
-        is_(op.to_constraint(), ck)
-        is_(op.reverse().to_constraint(), ck)
+        eq_(op.to_constraint(), schemacompare.CompareCheckConstraint(ck))
+        eq_(
+            op.reverse().to_constraint(),
+            schemacompare.CompareCheckConstraint(ck),
+        )
         is_not_(None, op.to_constraint().table)
 
     def test_drop_check(self):
         ck = self.ck
         op = ops.DropConstraintOp.from_constraint(ck)
-        is_(op.to_constraint(), ck)
-        is_(op.reverse().to_constraint(), ck)
+        eq_(op.to_constraint(), schemacompare.CompareCheckConstraint(ck))
+        eq_(
+            op.reverse().to_constraint(),
+            schemacompare.CompareCheckConstraint(ck),
+        )
         is_not_(None, op.to_constraint().table)
 
     def test_add_unique(self):
         uq = self.uq
         op = ops.AddConstraintOp.from_constraint(uq)
-        is_(op.to_constraint(), uq)
-        is_(op.reverse().to_constraint(), uq)
+        eq_(op.to_constraint(), schemacompare.CompareUniqueConstraint(uq))
+        eq_(
+            op.reverse().to_constraint(),
+            schemacompare.CompareUniqueConstraint(uq),
+        )
         is_not_(None, op.to_constraint().table)
 
     def test_drop_unique(self):
         uq = self.uq
         op = ops.DropConstraintOp.from_constraint(uq)
-        is_(op.to_constraint(), uq)
-        is_(op.reverse().to_constraint(), uq)
+        eq_(op.to_constraint(), schemacompare.CompareUniqueConstraint(uq))
+        eq_(
+            op.reverse().to_constraint(),
+            schemacompare.CompareUniqueConstraint(uq),
+        )
         is_not_(None, op.to_constraint().table)
 
     def test_add_pk_no_orig(self):
@@ -1758,15 +1813,15 @@ class OrigObjectTest(TestBase):
     def test_add_pk(self):
         pk = self.pk
         op = ops.AddConstraintOp.from_constraint(pk)
-        is_(op.to_constraint(), pk)
-        is_(op.reverse().to_constraint(), pk)
+        eq_(op.to_constraint(), schemacompare.ComparePrimaryKey(pk))
+        eq_(op.reverse().to_constraint(), schemacompare.ComparePrimaryKey(pk))
         is_not_(None, op.to_constraint().table)
 
     def test_drop_pk(self):
         pk = self.pk
         op = ops.DropConstraintOp.from_constraint(pk)
-        is_(op.to_constraint(), pk)
-        is_(op.reverse().to_constraint(), pk)
+        eq_(op.to_constraint(), schemacompare.ComparePrimaryKey(pk))
+        eq_(op.reverse().to_constraint(), schemacompare.ComparePrimaryKey(pk))
         is_not_(None, op.to_constraint().table)
 
     def test_drop_column(self):
@@ -1789,27 +1844,25 @@ class OrigObjectTest(TestBase):
         t = self.table
 
         op = ops.DropTableOp.from_table(t)
-        is_(op.to_table(), t)
-        is_(op.reverse().to_table(), t)
-        is_(self.metadata, op.to_table().metadata)
+        eq_(op.to_table(), schemacompare.CompareTable(t))
+        eq_(op.reverse().to_table(), schemacompare.CompareTable(t))
 
     def test_add_table(self):
         t = self.table
 
         op = ops.CreateTableOp.from_table(t)
-        is_(op.to_table(), t)
-        is_(op.reverse().to_table(), t)
-        is_(self.metadata, op.to_table().metadata)
+        eq_(op.to_table(), schemacompare.CompareTable(t))
+        eq_(op.reverse().to_table(), schemacompare.CompareTable(t))
 
     def test_drop_index(self):
         op = ops.DropIndexOp.from_index(self.ix)
-        is_(op.to_index(), self.ix)
-        is_(op.reverse().to_index(), self.ix)
+        eq_(op.to_index(), schemacompare.CompareIndex(self.ix))
+        eq_(op.reverse().to_index(), schemacompare.CompareIndex(self.ix))
 
     def test_create_index(self):
         op = ops.CreateIndexOp.from_index(self.ix)
-        is_(op.to_index(), self.ix)
-        is_(op.reverse().to_index(), self.ix)
+        eq_(op.to_index(), schemacompare.CompareIndex(self.ix))
+        eq_(op.reverse().to_index(), schemacompare.CompareIndex(self.ix))
 
 
 class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
@@ -1843,7 +1896,7 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase):
         c2 = Table("c2", m2c, Column("id", Integer, primary_key=True))
 
         diffs = self._fixture([m1a, m1b, m1c], [m2a, m2b, m2c])
-        eq_(diffs[0], ("add_table", c2))
+        eq_(diffs[0], ("add_table", schemacompare.CompareTable(c2)))
         eq_(diffs[1][0], "remove_table")
         eq_(diffs[1][1].name, "b2")
         eq_(diffs[2], ("add_column", None, "a", a.c.q))
index 92acb74197ba3d95d7e141aebcfa4f76f16fe0dd..6fda2833495ede448b9ff4e2b253f25326d5dc6a 100644 (file)
@@ -16,6 +16,7 @@ from alembic.testing import assertions
 from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing import eq_
+from alembic.testing import schemacompare
 from alembic.testing import TestBase
 from alembic.testing import util
 from alembic.testing.env import staging_env
@@ -788,7 +789,7 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         )
 
         diffs = self._fixture(m1, m2)
-        eq_(diffs, [("add_index", idx)])
+        eq_(diffs, [("add_index", schemacompare.CompareIndex(idx))])
 
     def test_removed_idx_index_named_as_column(self):
         m1 = MetaData()
@@ -869,7 +870,6 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase):
         else:
             eq_(diffs[0][0], "remove_table")
             eq_(len(diffs), 1)
-
             constraints = [
                 c
                 for c in diffs[0][1].constraints
index 5d26493592b4eb87c433518bc593a2eb6f060880..dbd3b10914c5f8f60a96fe97379e7397a7b45e55 100644 (file)
@@ -1244,6 +1244,7 @@ class AutogenRenderTest(TestBase):
 
         def render(type_, obj, context):
             if type_ == "foreign_key":
+                # causes it not to render
                 return None
             if type_ == "column":
                 if obj.name == "y":
@@ -1269,7 +1270,7 @@ class AutogenRenderTest(TestBase):
             Column("y", Integer),
             Column("q", MySpecialType()),
             PrimaryKeyConstraint("x"),
-            ForeignKeyConstraint(["x"], ["y"]),
+            ForeignKeyConstraint(["x"], ["remote.y"]),
         )
         op_obj = ops.CreateTableOp.from_table(t)
         result = autogenerate.render_op_text(self.autogen_context, op_obj)
@@ -1380,7 +1381,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')",
@@ -1393,7 +1394,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')",
@@ -1405,7 +1406,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)",
@@ -1417,7 +1418,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')",
@@ -1435,7 +1436,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], "
@@ -1455,7 +1456,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
@@ -1474,7 +1475,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )",
@@ -1498,7 +1499,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
@@ -1517,7 +1518,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )",
@@ -1537,7 +1538,7 @@ class AutogenRenderTest(TestBase):
 
         eq_ignore_whitespace(
             autogenerate.render._render_constraint(
-                const, self.autogen_context
+                const, self.autogen_context, m
             ),
             "sa.ForeignKeyConstraint(['c_rem'], ['t.c'], "
             "name='fk1', use_alter=True)",
@@ -1555,7 +1556,7 @@ class AutogenRenderTest(TestBase):
                 r"u'",
                 "'",
                 autogenerate.render._render_constraint(
-                    fk, self.autogen_context
+                    fk, self.autogen_context, m
                 ),
             ),
             "sa.ForeignKeyConstraint(['c'], ['foo.t2.c_rem'], "
@@ -1567,6 +1568,7 @@ class AutogenRenderTest(TestBase):
             autogenerate.render._render_check_constraint(
                 CheckConstraint("im a constraint", name="cc1"),
                 self.autogen_context,
+                None,
             ),
             "sa.CheckConstraint(!U'im a constraint', name='cc1')",
         )
@@ -1577,7 +1579,9 @@ class AutogenRenderTest(TestBase):
         ten = literal_column("10")
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                CheckConstraint(and_(c > five, c < ten)), self.autogen_context
+                CheckConstraint(and_(c > five, c < ten)),
+                self.autogen_context,
+                None,
             ),
             "sa.CheckConstraint(!U'c > 5 AND c < 10')",
         )
@@ -1586,7 +1590,9 @@ class AutogenRenderTest(TestBase):
         c = column("c")
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                CheckConstraint(and_(c > 5, c < 10)), self.autogen_context
+                CheckConstraint(and_(c > 5, c < 10)),
+                self.autogen_context,
+                None,
             ),
             "sa.CheckConstraint(!U'c > 5 AND c < 10')",
         )
@@ -1598,6 +1604,7 @@ class AutogenRenderTest(TestBase):
             autogenerate.render._render_unique_constraint(
                 UniqueConstraint(t.c.c, name="uq_1", deferrable="XYZ"),
                 self.autogen_context,
+                None,
             ),
             "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')",
         )
@@ -2248,7 +2255,9 @@ class RenderNamingConventionTest(TestBase):
         t = Table("t", self.metadata, Column("c", Integer))
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                UniqueConstraint(t.c.c, deferrable="XYZ"), self.autogen_context
+                UniqueConstraint(t.c.c, deferrable="XYZ"),
+                self.autogen_context,
+                None,
             ),
             "sa.UniqueConstraint('c', deferrable='XYZ', "
             "name=op.f('uq_ct_t_c'))",
@@ -2258,7 +2267,7 @@ class RenderNamingConventionTest(TestBase):
         t = Table("t", self.metadata, Column("c", Integer))
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                UniqueConstraint(t.c.c, name="q"), self.autogen_context
+                UniqueConstraint(t.c.c, name="q"), self.autogen_context, None
             ),
             "sa.UniqueConstraint('c', name='q')",
         )
@@ -2316,7 +2325,7 @@ class RenderNamingConventionTest(TestBase):
         uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0]
         eq_ignore_whitespace(
             autogenerate.render._render_unique_constraint(
-                uq, self.autogen_context
+                uq, self.autogen_context, None
             ),
             "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))",
         )
@@ -2370,7 +2379,7 @@ class RenderNamingConventionTest(TestBase):
 
         eq_ignore_whitespace(
             autogenerate.render._render_check_constraint(
-                ck, self.autogen_context
+                ck, self.autogen_context, None
             ),
             "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))",
         )
index 6fb99579865ea46e83ecdfb396c54691d7951a30..4114be070efcbdcb2999a1894a5eef6f26540875 100644 (file)
@@ -16,6 +16,7 @@ from alembic import util
 from alembic.script import ScriptDirectory
 from alembic.testing import assert_raises
 from alembic.testing import assert_raises_message
+from alembic.testing import config as testing_config
 from alembic.testing import eq_
 from alembic.testing import is_false
 from alembic.testing import is_true
@@ -937,6 +938,7 @@ class EditTest(TestBase):
             command.edit(self.cfg, self.b[0:3])
             edit.assert_called_with(expected_call_arg)
 
+    @testing_config.requirements.editor_installed
     @testing.emits_python_deprecation_warning("the imp module is deprecated")
     def test_edit_with_missing_editor(self):
         with mock.patch("editor.edit") as edit_mock:
index 93bba287e305958941b354f46c2f3464c97b8765..63d3e965379e33e6de907e64c33b91e83eaa722e 100644 (file)
@@ -6,10 +6,12 @@ from sqlalchemy import Column
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import UniqueConstraint
 from sqlalchemy.sql import column
 from sqlalchemy.sql import func
 from sqlalchemy.sql import text
@@ -22,7 +24,7 @@ from alembic.testing import assert_raises_message
 from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing import eq_
-from alembic.testing import is_
+from alembic.testing import is_not_
 from alembic.testing import mock
 from alembic.testing.fixtures import AlterColRoundTripFixture
 from alembic.testing.fixtures import op_fixture
@@ -47,16 +49,10 @@ class OpTest(TestBase):
         op.rename_table("t1", "t2", schema="foo")
         context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2")
 
-    def test_create_index_no_expr_allowed(self):
-        op_fixture()
-        assert_raises_message(
-            ValueError,
-            r"String or text\(\) construct expected",
-            op.create_index,
-            "name",
-            "tname",
-            [func.foo(column("x"))],
-        )
+    def test_create_index_arbitrary_expr(self):
+        context = op_fixture()
+        op.create_index("name", "tname", [func.foo(column("x"))])
+        context.assert_("CREATE INDEX name ON tname (foo(x))")
 
     def test_add_column_schema_hard_quoting(self):
 
@@ -843,6 +839,63 @@ class OpTest(TestBase):
             "FOREIGN KEY(st_id) REFERENCES some_table (id))"
         )
 
+    def test_create_table_check_constraint(self):
+        context = op_fixture()
+        t1 = op.create_table(
+            "some_table",
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer),
+            CheckConstraint("foo_id>5", name="ck_1"),
+        )
+        context.assert_(
+            "CREATE TABLE some_table ("
+            "id INTEGER NOT NULL, "
+            "foo_id INTEGER, "
+            "PRIMARY KEY (id), "
+            "CONSTRAINT ck_1 CHECK (foo_id>5))"
+        )
+
+        ck = [c for c in t1.constraints if isinstance(c, CheckConstraint)]
+        eq_(ck[0].name, "ck_1")
+
+    def test_create_table_unique_constraint(self):
+        context = op_fixture()
+        t1 = op.create_table(
+            "some_table",
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer),
+            UniqueConstraint("foo_id", name="uq_1"),
+        )
+        context.assert_(
+            "CREATE TABLE some_table ("
+            "id INTEGER NOT NULL, "
+            "foo_id INTEGER, "
+            "PRIMARY KEY (id), "
+            "CONSTRAINT uq_1 UNIQUE (foo_id))"
+        )
+
+        uq = [c for c in t1.constraints if isinstance(c, UniqueConstraint)]
+        eq_(uq[0].name, "uq_1")
+
+    def test_create_table_index(self):
+        context = op_fixture()
+        t1 = op.create_table(
+            "some_table",
+            Column("id", Integer, primary_key=True),
+            Column("foo_id", Integer),
+            Index("ix_1", "foo_id"),
+        )
+        context.assert_(
+            "CREATE TABLE some_table ("
+            "id INTEGER NOT NULL, "
+            "foo_id INTEGER, "
+            "PRIMARY KEY (id))",
+            "CREATE INDEX ix_1 ON some_table (foo_id)",
+        )
+
+        ix = list(t1.indexes)
+        eq_(ix[0].name, "ix_1")
+
     def test_create_table_fk_and_schema(self):
         context = op_fixture()
         t1 = op.create_table(
@@ -1024,13 +1077,15 @@ class CustomOpTest(TestBase):
         context.assert_("CREATE SEQUENCE foob")
 
 
-class EnsureOrigObjectFromToTest(TestBase):
-    """the to_XYZ and from_XYZ methods are used heavily in autogenerate.
+class ObjectFromToTest(TestBase):
+    """Test operation round trips for to_obj() / from_obj().
 
-    It's critical that these methods, at least the "drop" form,
-    always return the *same* object if available so that all the info
-    passed into to_XYZ is maintained in the from_XYZ.
+    Previously, these needed to preserve the "original" item
+    to this, but this makes them harder to work with.
 
+    As of #803 the constructs try to behave more intelligently
+    about the state they were given, so that they can both "reverse"
+    themselves but also take into accout their current state.
 
     """
 
@@ -1038,31 +1093,112 @@ class EnsureOrigObjectFromToTest(TestBase):
         schema_obj = schemaobj.SchemaObjects()
         idx = schema_obj.index("x", "y", ["z"])
         op = ops.DropIndexOp.from_index(idx)
-        is_(op.to_index(), idx)
+        is_not_(op.to_index(), idx)
+
+    def test_drop_index_add_kw(self):
+        schema_obj = schemaobj.SchemaObjects()
+        idx = schema_obj.index("x", "y", ["z"])
+        op = ops.DropIndexOp.from_index(idx)
+
+        op.kw["postgresql_concurrently"] = True
+        eq_(op.to_index().dialect_kwargs["postgresql_concurrently"], True)
+
+        eq_(
+            op.reverse().to_index().dialect_kwargs["postgresql_concurrently"],
+            True,
+        )
 
     def test_create_index(self):
         schema_obj = schemaobj.SchemaObjects()
         idx = schema_obj.index("x", "y", ["z"])
         op = ops.CreateIndexOp.from_index(idx)
-        is_(op.to_index(), idx)
+
+        is_not_(op.to_index(), idx)
+
+    def test_create_index_add_kw(self):
+        schema_obj = schemaobj.SchemaObjects()
+        idx = schema_obj.index("x", "y", ["z"])
+        op = ops.CreateIndexOp.from_index(idx)
+
+        op.kw["postgresql_concurrently"] = True
+
+        eq_(op.to_index().dialect_kwargs["postgresql_concurrently"], True)
+        eq_(
+            op.reverse().to_index().dialect_kwargs["postgresql_concurrently"],
+            True,
+        )
 
     def test_drop_table(self):
         schema_obj = schemaobj.SchemaObjects()
         table = schema_obj.table("x", Column("q", Integer))
         op = ops.DropTableOp.from_table(table)
-        is_(op.to_table(), table)
+        is_not_(op.to_table(), table)
+
+    def test_drop_table_add_kw(self):
+        schema_obj = schemaobj.SchemaObjects()
+        table = schema_obj.table("x", Column("q", Integer))
+        op = ops.DropTableOp.from_table(table)
+
+        op.table_kw["postgresql_partition_by"] = "x"
+
+        eq_(op.to_table().dialect_kwargs["postgresql_partition_by"], "x")
+        eq_(
+            op.reverse().to_table().dialect_kwargs["postgresql_partition_by"],
+            "x",
+        )
 
     def test_create_table(self):
         schema_obj = schemaobj.SchemaObjects()
         table = schema_obj.table("x", Column("q", Integer))
         op = ops.CreateTableOp.from_table(table)
-        is_(op.to_table(), table)
+        is_not_(op.to_table(), table)
+
+    def test_create_table_add_kw(self):
+        schema_obj = schemaobj.SchemaObjects()
+        table = schema_obj.table("x", Column("q", Integer))
+        op = ops.CreateTableOp.from_table(table)
+        op.kw["postgresql_partition_by"] = "x"
+
+        eq_(op.to_table().dialect_kwargs["postgresql_partition_by"], "x")
+        eq_(
+            op.reverse().to_table().dialect_kwargs["postgresql_partition_by"],
+            "x",
+        )
+
+    def test_create_unique_constraint(self):
+        schema_obj = schemaobj.SchemaObjects()
+        const = schema_obj.unique_constraint("x", "foobar", ["a"])
+        op = ops.AddConstraintOp.from_constraint(const)
+        is_not_(op.to_constraint(), const)
+
+    def test_create_unique_constraint_add_kw(self):
+        schema_obj = schemaobj.SchemaObjects()
+        const = schema_obj.unique_constraint("x", "foobar", ["a"])
+        op = ops.AddConstraintOp.from_constraint(const)
+        is_not_(op.to_constraint(), const)
+
+        op.kw["sqlite_on_conflict"] = "IGNORE"
+
+        eq_(op.to_constraint().dialect_kwargs["sqlite_on_conflict"], "IGNORE")
+        eq_(
+            op.reverse().to_constraint().dialect_kwargs["sqlite_on_conflict"],
+            "IGNORE",
+        )
 
     def test_drop_unique_constraint(self):
         schema_obj = schemaobj.SchemaObjects()
         const = schema_obj.unique_constraint("x", "foobar", ["a"])
         op = ops.DropConstraintOp.from_constraint(const)
-        is_(op.to_constraint(), const)
+        is_not_(op.to_constraint(), const)
+
+    def test_drop_unique_constraint_change_name(self):
+        schema_obj = schemaobj.SchemaObjects()
+        const = schema_obj.unique_constraint("x", "foobar", ["a"])
+        op = ops.DropConstraintOp.from_constraint(const)
+
+        op.constraint_name = "my_name"
+        eq_(op.to_constraint().name, "my_name")
+        eq_(op.reverse().to_constraint().name, "my_name")
 
     def test_drop_constraint_not_available(self):
         op = ops.DropConstraintOp("x", "y", type_="unique")