From: Mike Bayer Date: Fri, 26 Feb 2021 18:11:17 +0000 (-0500) Subject: Create schema objects fresh from ops X-Git-Tag: rel_1_6_0~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dd9c4d695c4463c872b25f59b6a8742bbc047150;p=thirdparty%2Fsqlalchemy%2Falembic.git Create schema objects fresh from ops Refactored the implementation of :class:`.MigrateOperation` constructs such as :class:`.CreateIndexOp`, :class:`.CreateTableOp`, etc. so that they no longer rely upon maintaining a persistent version of each schema object internally; instead, the state variables of each operation object will be used to produce the corresponding construct when the operation is invoked. The rationale is so that environments which make use of operation-manipulation schemes such as those those discussed in :ref:`autogen_rewriter` are better supported, allowing end-user code to manipulate the public attributes of these objects which will then be expressed in the final output, an example is ``some_create_index_op.kw["postgresql_concurrently"] = True``. Previously, these objects when generated from autogenerate would typically hold onto the original, reflected element internally without honoring the other state variables of each construct, preventing the public API from working. Change-Id: Ida2537740325de5c21e9447e930a518093f01bd4 Fixes: #803 --- diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 23890fb2..58b469c5 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -178,7 +178,9 @@ def _add_table(autogen_context, op): [ rcons for rcons in [ - _render_constraint(cons, autogen_context) + _render_constraint( + cons, autogen_context, op._namespace_metadata + ) for cons in table.constraints ] if rcons is not None @@ -308,7 +310,6 @@ def _add_fk_constraint(autogen_context, op): repr([_ident(col) for col in op.remote_cols]), ] ) - kwargs = [ "referent_schema", "onupdate", @@ -802,18 +803,18 @@ def _render_type_w_subtype( _constraint_renderers = util.Dispatcher() -def _render_constraint(constraint, autogen_context): +def _render_constraint(constraint, autogen_context, namespace_metadata): try: renderer = _constraint_renderers.dispatch(constraint) except ValueError: util.warn("No renderer is established for object %r" % constraint) return "[Unknown Python object %r]" % constraint else: - return renderer(constraint, autogen_context) + return renderer(constraint, autogen_context, namespace_metadata) @_constraint_renderers.dispatch_for(sa_schema.PrimaryKeyConstraint) -def _render_primary_key(constraint, autogen_context): +def _render_primary_key(constraint, autogen_context, namespace_metadata): rendered = _user_defined_render("primary_key", constraint, autogen_context) if rendered is not False: return rendered @@ -835,7 +836,7 @@ def _render_primary_key(constraint, autogen_context): } -def _fk_colspec(fk, metadata_schema): +def _fk_colspec(fk, metadata_schema, namespace_metadata): """Implement a 'safe' version of ForeignKey._get_colspec() that won't fail if the remote table can't be resolved. @@ -857,9 +858,9 @@ def _fk_colspec(fk, metadata_schema): # try to resolve the remote table in order to adjust for column.key. # the FK constraint needs to be rendered in terms of the column # name. - parent_metadata = fk.parent.table.metadata - if table_fullname in parent_metadata.tables: - col = parent_metadata.tables[table_fullname].c.get(colname) + + if table_fullname in namespace_metadata.tables: + col = namespace_metadata.tables[table_fullname].c.get(colname) if col is not None: colname = _ident(col.name) @@ -883,7 +884,7 @@ def _populate_render_fk_opts(constraint, opts): @_constraint_renderers.dispatch_for(sa_schema.ForeignKeyConstraint) -def _render_foreign_key(constraint, autogen_context): +def _render_foreign_key(constraint, autogen_context, namespace_metadata): rendered = _user_defined_render("foreign_key", constraint, autogen_context) if rendered is not False: return rendered @@ -896,7 +897,7 @@ def _render_foreign_key(constraint, autogen_context): _populate_render_fk_opts(constraint, opts) - apply_metadata_schema = constraint.parent.metadata.schema + apply_metadata_schema = namespace_metadata.schema return ( "%(prefix)sForeignKeyConstraint([%(cols)s], " "[%(refcols)s], %(args)s)" @@ -906,7 +907,7 @@ def _render_foreign_key(constraint, autogen_context): "%r" % _ident(f.parent.name) for f in constraint.elements ), "refcols": ", ".join( - repr(_fk_colspec(f, apply_metadata_schema)) + repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata)) for f in constraint.elements ), "args": ", ".join( @@ -917,7 +918,7 @@ def _render_foreign_key(constraint, autogen_context): @_constraint_renderers.dispatch_for(sa_schema.UniqueConstraint) -def _render_unique_constraint(constraint, autogen_context): +def _render_unique_constraint(constraint, autogen_context, namespace_metadata): rendered = _user_defined_render("unique", constraint, autogen_context) if rendered is not False: return rendered @@ -926,7 +927,7 @@ def _render_unique_constraint(constraint, autogen_context): @_constraint_renderers.dispatch_for(sa_schema.CheckConstraint) -def _render_check_constraint(constraint, autogen_context): +def _render_check_constraint(constraint, autogen_context, namespace_metadata): rendered = _user_defined_render("check", constraint, autogen_context) if rendered is not False: return rendered diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 6a2f007c..1e655860 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -464,7 +464,9 @@ def _add_exclude_constraint(autogen_context, op): @render._constraint_renderers.dispatch_for(ExcludeConstraint) -def _render_inline_exclude_constraint(constraint, autogen_context): +def _render_inline_exclude_constraint( + constraint, autogen_context, namespace_metadata +): rendered = render._user_defined_render( "exclude", constraint, autogen_context ) diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index ffb2f1b1..7ef21902 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -76,21 +76,16 @@ class DropConstraintOp(MigrateOperation): table_name, type_=None, schema=None, - _orig_constraint=None, + _reverse=None, ): self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ self.schema = schema - self._orig_constraint = _orig_constraint + self._reverse = _reverse def reverse(self): - if self._orig_constraint is None: - raise ValueError( - "operation is not reversible; " - "original constraint is not present" - ) - return AddConstraintOp.from_constraint(self._orig_constraint) + return AddConstraintOp.from_constraint(self.to_constraint()) def to_diff_tuple(self): if self.constraint_type == "foreignkey": @@ -115,12 +110,19 @@ class DropConstraintOp(MigrateOperation): constraint_table.name, schema=constraint_table.schema, type_=types[constraint.__visit_name__], - _orig_constraint=constraint, + _reverse=AddConstraintOp.from_constraint(constraint), ) def to_constraint(self): - if self._orig_constraint is not None: - return self._orig_constraint + + if self._reverse is not None: + constraint = self._reverse.to_constraint() + constraint.name = self.constraint_name + constraint_table = sqla_compat._table_for_constraint(constraint) + constraint_table.name = self.table_name + constraint_table.schema = self.schema + + return constraint else: raise ValueError( "constraint cannot be produced; " @@ -180,43 +182,33 @@ class CreatePrimaryKeyOp(AddConstraintOp): constraint_type = "primarykey" def __init__( - self, - constraint_name, - table_name, - columns, - schema=None, - _orig_constraint=None, - **kw + self, constraint_name, table_name, columns, schema=None, **kw ): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns self.schema = schema - self._orig_constraint = _orig_constraint self.kw = kw @classmethod def from_constraint(cls, constraint): constraint_table = sqla_compat._table_for_constraint(constraint) - return cls( constraint.name, constraint_table.name, - constraint.columns, + constraint.columns.keys(), schema=constraint_table.schema, - _orig_constraint=constraint, + **constraint.dialect_kwargs ) def to_constraint(self, migration_context=None): - if self._orig_constraint is not None: - return self._orig_constraint - schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.primary_key_constraint( self.constraint_name, self.table_name, self.columns, schema=self.schema, + **self.kw ) @classmethod @@ -295,19 +287,12 @@ class CreateUniqueConstraintOp(AddConstraintOp): constraint_type = "unique" def __init__( - self, - constraint_name, - table_name, - columns, - schema=None, - _orig_constraint=None, - **kw + self, constraint_name, table_name, columns, schema=None, **kw ): self.constraint_name = constraint_name self.table_name = table_name self.columns = columns self.schema = schema - self._orig_constraint = _orig_constraint self.kw = kw @classmethod @@ -319,20 +304,16 @@ class CreateUniqueConstraintOp(AddConstraintOp): kw["deferrable"] = constraint.deferrable if constraint.initially: kw["initially"] = constraint.initially - + kw.update(constraint.dialect_kwargs) return cls( constraint.name, constraint_table.name, [c.name for c in constraint.columns], schema=constraint_table.schema, - _orig_constraint=constraint, **kw ) def to_constraint(self, migration_context=None): - if self._orig_constraint is not None: - return self._orig_constraint - schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.unique_constraint( self.constraint_name, @@ -430,7 +411,6 @@ class CreateForeignKeyOp(AddConstraintOp): referent_table, local_cols, remote_cols, - _orig_constraint=None, **kw ): self.constraint_name = constraint_name @@ -438,7 +418,6 @@ class CreateForeignKeyOp(AddConstraintOp): self.referent_table = referent_table self.local_cols = local_cols self.remote_cols = remote_cols - self._orig_constraint = _orig_constraint self.kw = kw def to_diff_tuple(self): @@ -473,20 +452,17 @@ class CreateForeignKeyOp(AddConstraintOp): kw["source_schema"] = source_schema kw["referent_schema"] = target_schema - + kw.update(constraint.dialect_kwargs) return cls( constraint.name, source_table, target_table, source_columns, target_columns, - _orig_constraint=constraint, **kw ) def to_constraint(self, migration_context=None): - if self._orig_constraint is not None: - return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.foreign_key_constraint( self.constraint_name, @@ -642,19 +618,12 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint_type = "check" def __init__( - self, - constraint_name, - table_name, - condition, - schema=None, - _orig_constraint=None, - **kw + self, constraint_name, table_name, condition, schema=None, **kw ): self.constraint_name = constraint_name self.table_name = table_name self.condition = condition self.schema = schema - self._orig_constraint = _orig_constraint self.kw = kw @classmethod @@ -666,12 +635,10 @@ class CreateCheckConstraintOp(AddConstraintOp): constraint_table.name, constraint.sqltext, schema=constraint_table.schema, - _orig_constraint=constraint, + **constraint.dialect_kwargs ) def to_constraint(self, migration_context=None): - if self._orig_constraint is not None: - return self._orig_constraint schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.check_constraint( self.constraint_name, @@ -765,14 +732,7 @@ class CreateIndexOp(MigrateOperation): """Represent a create index operation.""" def __init__( - self, - index_name, - table_name, - columns, - schema=None, - unique=False, - _orig_index=None, - **kw + self, index_name, table_name, columns, schema=None, unique=False, **kw ): self.index_name = index_name self.table_name = table_name @@ -780,7 +740,6 @@ class CreateIndexOp(MigrateOperation): self.schema = schema self.unique = unique self.kw = kw - self._orig_index = _orig_index def reverse(self): return DropIndexOp.from_index(self.to_index()) @@ -796,15 +755,13 @@ class CreateIndexOp(MigrateOperation): sqla_compat._get_index_expressions(index), schema=index.table.schema, unique=index.unique, - _orig_index=index, **index.kwargs ) def to_index(self, migration_context=None): - if self._orig_index: - return self._orig_index schema_obj = schemaobj.SchemaObjects(migration_context) - return schema_obj.index( + + idx = schema_obj.index( self.index_name, self.table_name, self.columns, @@ -812,6 +769,7 @@ class CreateIndexOp(MigrateOperation): unique=self.unique, **self.kw ) + return idx @classmethod def create_index( @@ -897,23 +855,19 @@ class DropIndexOp(MigrateOperation): """Represent a drop index operation.""" def __init__( - self, index_name, table_name=None, schema=None, _orig_index=None, **kw + self, index_name, table_name=None, schema=None, _reverse=None, **kw ): self.index_name = index_name self.table_name = table_name self.schema = schema - self._orig_index = _orig_index + self._reverse = _reverse self.kw = kw def to_diff_tuple(self): return ("remove_index", self.to_index()) def reverse(self): - if self._orig_index is None: - raise ValueError( - "operation is not reversible; " "original index is not present" - ) - return CreateIndexOp.from_index(self._orig_index) + return CreateIndexOp.from_index(self.to_index()) @classmethod def from_index(cls, index): @@ -921,14 +875,11 @@ class DropIndexOp(MigrateOperation): index.name, index.table.name, schema=index.table.schema, - _orig_index=index, + _reverse=CreateIndexOp.from_index(index), **index.kwargs ) def to_index(self, migration_context=None): - if self._orig_index is not None: - return self._orig_index - schema_obj = schemaobj.SchemaObjects(migration_context) # need a dummy column name here since SQLAlchemy @@ -936,7 +887,7 @@ class DropIndexOp(MigrateOperation): return schema_obj.index( self.index_name, self.table_name, - ["x"], + self._reverse.columns if self._reverse else ["x"], schema=self.schema, **self.kw ) @@ -994,37 +945,49 @@ class CreateTableOp(MigrateOperation): """Represent a create table operation.""" def __init__( - self, table_name, columns, schema=None, _orig_table=None, **kw + self, table_name, columns, schema=None, _namespace_metadata=None, **kw ): self.table_name = table_name self.columns = columns self.schema = schema + self.comment = kw.pop("comment", None) + self.prefixes = kw.pop("prefixes", None) self.kw = kw - self._orig_table = _orig_table + self._namespace_metadata = _namespace_metadata def reverse(self): - return DropTableOp.from_table(self.to_table()) + return DropTableOp.from_table( + self.to_table(), _namespace_metadata=self._namespace_metadata + ) def to_diff_tuple(self): return ("add_table", self.to_table()) @classmethod - def from_table(cls, table): + def from_table(cls, table, _namespace_metadata=None): + if _namespace_metadata is None: + _namespace_metadata = table.metadata + return cls( table.name, list(table.c) + list(table.constraints), schema=table.schema, - _orig_table=table, + _namespace_metadata=_namespace_metadata, + comment=table.comment, + prefixes=table._prefixes, **table.kwargs ) def to_table(self, migration_context=None): - if self._orig_table is not None: - return self._orig_table schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( - self.table_name, *self.columns, schema=self.schema, **self.kw + self.table_name, + *self.columns, + schema=self.schema, + prefixes=self.prefixes, + comment=self.comment, + **self.kw ) @classmethod @@ -1113,35 +1076,43 @@ class CreateTableOp(MigrateOperation): class DropTableOp(MigrateOperation): """Represent a drop table operation.""" - def __init__( - self, table_name, schema=None, table_kw=None, _orig_table=None - ): + def __init__(self, table_name, schema=None, table_kw=None, _reverse=None): self.table_name = table_name self.schema = schema self.table_kw = table_kw or {} - self._orig_table = _orig_table + self._reverse = _reverse def to_diff_tuple(self): return ("remove_table", self.to_table()) def reverse(self): - if self._orig_table is None: - raise ValueError( - "operation is not reversible; " "original table is not present" - ) - return CreateTableOp.from_table(self._orig_table) + return CreateTableOp.from_table(self.to_table()) @classmethod - def from_table(cls, table): - return cls(table.name, schema=table.schema, _orig_table=table) + def from_table(cls, table, _namespace_metadata=None): + return cls( + table.name, + schema=table.schema, + table_kw=table.kwargs, + _reverse=CreateTableOp.from_table( + table, _namespace_metadata=_namespace_metadata + ), + ) def to_table(self, migration_context=None): - if self._orig_table is not None: - return self._orig_table + if self._reverse: + cols_and_constraints = self._reverse.columns + else: + cols_and_constraints = [] + schema_obj = schemaobj.SchemaObjects(migration_context) - return schema_obj.table( - self.table_name, schema=self.schema, **self.table_kw + t = schema_obj.table( + self.table_name, + *cols_and_constraints, + schema=self.schema, + **self.table_kw ) + return t @classmethod def drop_table(cls, operations, table_name, schema=None, **kw): @@ -1791,12 +1762,12 @@ class DropColumnOp(AlterTableOp): """Represent a drop column operation.""" def __init__( - self, table_name, column_name, schema=None, _orig_column=None, **kw + self, table_name, column_name, schema=None, _reverse=None, **kw ): super(DropColumnOp, self).__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw - self._orig_column = _orig_column + self._reverse = _reverse def to_diff_tuple(self): return ( @@ -1807,23 +1778,28 @@ class DropColumnOp(AlterTableOp): ) def reverse(self): - if self._orig_column is None: + if self._reverse is None: raise ValueError( "operation is not reversible; " "original column is not present" ) return AddColumnOp.from_column_and_tablename( - self.schema, self.table_name, self._orig_column + self.schema, self.table_name, self._reverse.column ) @classmethod def from_column_and_tablename(cls, schema, tname, col): - return cls(tname, col.name, schema=schema, _orig_column=col) + return cls( + tname, + col.name, + schema=schema, + _reverse=AddColumnOp.from_column_and_tablename(schema, tname, col), + ) def to_column(self, migration_context=None): - if self._orig_column is not None: - return self._orig_column + if self._reverse is not None: + return self._reverse.column schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.column(self.column_name, NULLTYPE) diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index 5e8aa4fe..5d04ee20 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -1,8 +1,12 @@ from sqlalchemy import schema as sa_schema +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import Constraint +from sqlalchemy.sql.schema import Index from sqlalchemy.types import Integer from sqlalchemy.types import NULLTYPE from .. import util +from ..util import sqla_compat from ..util.compat import raise_ from ..util.compat import string_types @@ -16,7 +20,6 @@ class SchemaObjects(object): columns = [sa_schema.Column(n, NULLTYPE) for n in cols] t = sa_schema.Table(table_name, m, *columns, schema=schema) p = sa_schema.PrimaryKeyConstraint(*[t.c[n] for n in cols], name=name) - t.append_constraint(p) return p def foreign_key_constraint( @@ -140,7 +143,25 @@ class SchemaObjects(object): def table(self, name, *columns, **kw): m = self.metadata() - t = sa_schema.Table(name, m, *columns, **kw) + + cols = [ + sqla_compat._copy(c) if c.table is not None else c + for c in columns + if isinstance(c, Column) + ] + t = sa_schema.Table(name, m, *cols, **kw) + + constraints = [ + sqla_compat._copy(elem, target_table=t) + if getattr(elem, "parent", None) is not None + else elem + for elem in columns + if isinstance(elem, (Constraint, Index)) + ] + + for const in constraints: + t.append_constraint(const) + for f in t.foreign_keys: self._ensure_table_for_fk(m, f) return t @@ -150,8 +171,11 @@ class SchemaObjects(object): def index(self, name, tablename, columns, schema=None, **kw): t = sa_schema.Table( - tablename or "no_table", self.metadata(), schema=schema + tablename or "no_table", + self.metadata(), + schema=schema, ) + kw["_table"] = t idx = sa_schema.Index( name, *[util.sqla_compat._textual_index_column(t, n) for n in columns], diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 5a110688..3a5426b6 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -67,6 +67,18 @@ class SuiteRequirements(Requirements): def reflects_fk_options(self): return exclusions.closed() + @property + def editor_installed(self): + def go(): + try: + import editor # noqa + except ImportError: + return False + else: + return True + + return exclusions.only_if(go, "editor package not installed") + @property def sqlalchemy_13(self): return exclusions.skip_if( diff --git a/alembic/testing/schemacompare.py b/alembic/testing/schemacompare.py new file mode 100644 index 00000000..c3a73823 --- /dev/null +++ b/alembic/testing/schemacompare.py @@ -0,0 +1,155 @@ +from sqlalchemy import schema +from sqlalchemy import util + + +class CompareTable(object): + def __init__(self, table): + self.table = table + + def __eq__(self, other): + if self.table.name != other.name or self.table.schema != other.schema: + return False + + for c1, c2 in util.zip_longest(self.table.c, other.c): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + # TODO: compare constraints, indexes + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareColumn(object): + def __init__(self, column): + self.column = column + + def __eq__(self, other): + return ( + self.column.name == other.name + and self.column.nullable == other.nullable + ) + # TODO: datatypes etc + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareIndex(object): + def __init__(self, index): + self.index = index + + def __eq__(self, other): + return ( + str(schema.CreateIndex(self.index)) + == str(schema.CreateIndex(other)) + and self.index.dialect_kwargs == other.dialect_kwargs + ) + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareCheckConstraint(object): + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + return ( + isinstance(other, schema.CheckConstraint) + and self.constraint.name == other.name + and (str(self.constraint.sqltext) == str(other.sqltext)) + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareForeignKey(object): + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.ForeignKeyConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + for c1, c2 in util.zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + return True + + def __ne__(self, other): + return not self.__eq__(other) + + +class ComparePrimaryKey(object): + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.PrimaryKeyConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + + for c1, c2 in util.zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + +class CompareUniqueConstraint(object): + def __init__(self, constraint): + self.constraint = constraint + + def __eq__(self, other): + r1 = ( + isinstance(other, schema.UniqueConstraint) + and self.constraint.name == other.name + and (other.table.name == self.constraint.table.name) + and other.table.schema == self.constraint.table.schema + ) + if not r1: + return False + + for c1, c2 in util.zip_longest(self.constraint.columns, other.columns): + if (c1 is None and c2 is not None) or ( + c2 is None and c1 is not None + ): + return False + if CompareColumn(c1) != c2: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 91e22d38..a04ab2e9 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -11,6 +11,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import CheckConstraint from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint +from sqlalchemy.sql import visitors from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.expression import _BindParamClause from sqlalchemy.sql.expression import _TextClause as TextClause @@ -252,10 +253,31 @@ def _textual_index_column(table, text_): return c elif isinstance(text_, TextClause): return _textual_index_element(table, text_) + elif isinstance(text_, sql.ColumnElement): + return _copy_expression(text_, table) else: raise ValueError("String or text() construct expected") +def _copy_expression(expression, target_table): + def replace(col): + if ( + isinstance(col, Column) + and col.table is not None + and col.table is not target_table + ): + if col.name in target_table.c: + return target_table.c[col.name] + else: + c = _copy(col) + target_table.append_column(c) + return c + else: + return None + + return visitors.replacement_traverse(expression, {}, replace) + + class _textual_index_element(sql.ColumnElement): """Wrap around a sqlalchemy text() construct in such a way that we appear like a column-oriented SQL expression to an Index diff --git a/docs/build/unreleased/803.rst b/docs/build/unreleased/803.rst new file mode 100644 index 00000000..11d105b6 --- /dev/null +++ b/docs/build/unreleased/803.rst @@ -0,0 +1,22 @@ +.. change:: + :tags: bug, autogenerate + :tickets: 803 + + Refactored the implementation of :class:`.MigrateOperation` constructs such + as :class:`.CreateIndexOp`, :class:`.CreateTableOp`, etc. so that they no + longer rely upon maintaining a persistent version of each schema object + internally; instead, the state variables of each operation object will be + used to produce the corresponding construct when the operation is invoked. + The rationale is so that environments which make use of + operation-manipulation schemes such as those those discussed in + :ref:`autogen_rewriter` are better supported, allowing end-user code to + manipulate the public attributes of these objects which will then be + expressed in the final output, an example is + ``some_create_index_op.kw["postgresql_concurrently"] = True``. + + Previously, these objects when generated from autogenerate would typically + hold onto the original, reflected element internally without honoring the + other state variables of each construct, preventing the public API from + working. + + diff --git a/tests/test_autogen_diffs.py b/tests/test_autogen_diffs.py index 02a750a2..94a91673 100644 --- a/tests/test_autogen_diffs.py +++ b/tests/test_autogen_diffs.py @@ -48,6 +48,7 @@ from alembic.testing import eq_ from alembic.testing import is_ from alembic.testing import is_not_ from alembic.testing import mock +from alembic.testing import schemacompare from alembic.testing import TestBase from alembic.testing.env import clear_staging_env from alembic.testing.env import staging_env @@ -429,7 +430,10 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestBase): autogenerate._produce_net_changes(ctx, uo) diffs = uo.as_diffs() - eq_(diffs[0], ("add_table", metadata.tables["item"])) + eq_( + diffs[0], + ("add_table", schemacompare.CompareTable(metadata.tables["item"])), + ) eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") @@ -720,7 +724,15 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): diffs = uo.as_diffs() - eq_(diffs[0], ("add_table", metadata.tables["%s.item" % self.schema])) + eq_( + diffs[0], + ( + "add_table", + schemacompare.CompareTable( + metadata.tables["%s.item" % self.schema] + ), + ), + ) eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") @@ -728,7 +740,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): eq_(diffs[2][0], "add_column") eq_(diffs[2][1], self.schema) eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables["%s.address" % self.schema].c.street) + eq_( + schemacompare.CompareColumn( + metadata.tables["%s.address" % self.schema].c.street + ), + diffs[2][3], + ) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -736,7 +753,12 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], self.schema) eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables["%s.order" % self.schema].c.user_id) + eq_( + schemacompare.CompareColumn( + metadata.tables["%s.order" % self.schema].c.user_id + ), + diffs[4][3], + ) eq_(diffs[5][0][0], "modify_type") eq_(diffs[5][0][1], self.schema) @@ -1504,7 +1526,10 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestBase): diffs = autogenerate.compare_metadata(self.context, metadata) - eq_(diffs[0], ("add_table", metadata.tables["item"])) + eq_( + diffs[0], + ("add_table", schemacompare.CompareTable(metadata.tables["item"])), + ) eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") @@ -1666,7 +1691,15 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): diffs = autogenerate.compare_metadata(context, metadata) - eq_(diffs[0], ("add_table", metadata.tables["test_schema.item"])) + eq_( + diffs[0], + ( + "add_table", + schemacompare.CompareTable( + metadata.tables["test_schema.item"] + ), + ), + ) eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "extra") @@ -1674,7 +1707,12 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): eq_(diffs[2][0], "add_column") eq_(diffs[2][1], "test_schema") eq_(diffs[2][2], "address") - eq_(diffs[2][3], metadata.tables["test_schema.address"].c.street) + eq_( + schemacompare.CompareColumn( + metadata.tables["test_schema.address"].c.street + ), + diffs[2][3], + ) eq_(diffs[3][0], "add_constraint") eq_(diffs[3][1].name, "uq_email") @@ -1682,7 +1720,12 @@ class PGCompareMetaData(ModelOne, AutogenTest, TestBase): eq_(diffs[4][0], "add_column") eq_(diffs[4][1], "test_schema") eq_(diffs[4][2], "order") - eq_(diffs[4][3], metadata.tables["test_schema.order"].c.user_id) + eq_( + schemacompare.CompareColumn( + metadata.tables["test_schema.order"].c.user_id + ), + diffs[4][3], + ) eq_(diffs[5][0][0], "modify_nullable") eq_(diffs[5][0][5], False) @@ -1711,42 +1754,54 @@ class OrigObjectTest(TestBase): def test_drop_fk(self): fk = self.fk op = ops.DropConstraintOp.from_constraint(fk) - is_(op.to_constraint(), fk) - is_(op.reverse().to_constraint(), fk) + eq_(op.to_constraint(), schemacompare.CompareForeignKey(fk)) + eq_(op.reverse().to_constraint(), schemacompare.CompareForeignKey(fk)) def test_add_fk(self): fk = self.fk op = ops.AddConstraintOp.from_constraint(fk) - is_(op.to_constraint(), fk) - is_(op.reverse().to_constraint(), fk) + eq_(op.to_constraint(), schemacompare.CompareForeignKey(fk)) + eq_(op.reverse().to_constraint(), schemacompare.CompareForeignKey(fk)) is_not_(None, op.to_constraint().table) def test_add_check(self): ck = self.ck op = ops.AddConstraintOp.from_constraint(ck) - is_(op.to_constraint(), ck) - is_(op.reverse().to_constraint(), ck) + eq_(op.to_constraint(), schemacompare.CompareCheckConstraint(ck)) + eq_( + op.reverse().to_constraint(), + schemacompare.CompareCheckConstraint(ck), + ) is_not_(None, op.to_constraint().table) def test_drop_check(self): ck = self.ck op = ops.DropConstraintOp.from_constraint(ck) - is_(op.to_constraint(), ck) - is_(op.reverse().to_constraint(), ck) + eq_(op.to_constraint(), schemacompare.CompareCheckConstraint(ck)) + eq_( + op.reverse().to_constraint(), + schemacompare.CompareCheckConstraint(ck), + ) is_not_(None, op.to_constraint().table) def test_add_unique(self): uq = self.uq op = ops.AddConstraintOp.from_constraint(uq) - is_(op.to_constraint(), uq) - is_(op.reverse().to_constraint(), uq) + eq_(op.to_constraint(), schemacompare.CompareUniqueConstraint(uq)) + eq_( + op.reverse().to_constraint(), + schemacompare.CompareUniqueConstraint(uq), + ) is_not_(None, op.to_constraint().table) def test_drop_unique(self): uq = self.uq op = ops.DropConstraintOp.from_constraint(uq) - is_(op.to_constraint(), uq) - is_(op.reverse().to_constraint(), uq) + eq_(op.to_constraint(), schemacompare.CompareUniqueConstraint(uq)) + eq_( + op.reverse().to_constraint(), + schemacompare.CompareUniqueConstraint(uq), + ) is_not_(None, op.to_constraint().table) def test_add_pk_no_orig(self): @@ -1758,15 +1813,15 @@ class OrigObjectTest(TestBase): def test_add_pk(self): pk = self.pk op = ops.AddConstraintOp.from_constraint(pk) - is_(op.to_constraint(), pk) - is_(op.reverse().to_constraint(), pk) + eq_(op.to_constraint(), schemacompare.ComparePrimaryKey(pk)) + eq_(op.reverse().to_constraint(), schemacompare.ComparePrimaryKey(pk)) is_not_(None, op.to_constraint().table) def test_drop_pk(self): pk = self.pk op = ops.DropConstraintOp.from_constraint(pk) - is_(op.to_constraint(), pk) - is_(op.reverse().to_constraint(), pk) + eq_(op.to_constraint(), schemacompare.ComparePrimaryKey(pk)) + eq_(op.reverse().to_constraint(), schemacompare.ComparePrimaryKey(pk)) is_not_(None, op.to_constraint().table) def test_drop_column(self): @@ -1789,27 +1844,25 @@ class OrigObjectTest(TestBase): t = self.table op = ops.DropTableOp.from_table(t) - is_(op.to_table(), t) - is_(op.reverse().to_table(), t) - is_(self.metadata, op.to_table().metadata) + eq_(op.to_table(), schemacompare.CompareTable(t)) + eq_(op.reverse().to_table(), schemacompare.CompareTable(t)) def test_add_table(self): t = self.table op = ops.CreateTableOp.from_table(t) - is_(op.to_table(), t) - is_(op.reverse().to_table(), t) - is_(self.metadata, op.to_table().metadata) + eq_(op.to_table(), schemacompare.CompareTable(t)) + eq_(op.reverse().to_table(), schemacompare.CompareTable(t)) def test_drop_index(self): op = ops.DropIndexOp.from_index(self.ix) - is_(op.to_index(), self.ix) - is_(op.reverse().to_index(), self.ix) + eq_(op.to_index(), schemacompare.CompareIndex(self.ix)) + eq_(op.reverse().to_index(), schemacompare.CompareIndex(self.ix)) def test_create_index(self): op = ops.CreateIndexOp.from_index(self.ix) - is_(op.to_index(), self.ix) - is_(op.reverse().to_index(), self.ix) + eq_(op.to_index(), schemacompare.CompareIndex(self.ix)) + eq_(op.reverse().to_index(), schemacompare.CompareIndex(self.ix)) class MultipleMetaDataTest(AutogenFixtureTest, TestBase): @@ -1843,7 +1896,7 @@ class MultipleMetaDataTest(AutogenFixtureTest, TestBase): c2 = Table("c2", m2c, Column("id", Integer, primary_key=True)) diffs = self._fixture([m1a, m1b, m1c], [m2a, m2b, m2c]) - eq_(diffs[0], ("add_table", c2)) + eq_(diffs[0], ("add_table", schemacompare.CompareTable(c2))) eq_(diffs[1][0], "remove_table") eq_(diffs[1][1].name, "b2") eq_(diffs[2], ("add_column", None, "a", a.c.q)) diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index 92acb741..6fda2833 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -16,6 +16,7 @@ from alembic.testing import assertions from alembic.testing import combinations from alembic.testing import config from alembic.testing import eq_ +from alembic.testing import schemacompare from alembic.testing import TestBase from alembic.testing import util from alembic.testing.env import staging_env @@ -788,7 +789,7 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): ) diffs = self._fixture(m1, m2) - eq_(diffs, [("add_index", idx)]) + eq_(diffs, [("add_index", schemacompare.CompareIndex(idx))]) def test_removed_idx_index_named_as_column(self): m1 = MetaData() @@ -869,7 +870,6 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestBase): else: eq_(diffs[0][0], "remove_table") eq_(len(diffs), 1) - constraints = [ c for c in diffs[0][1].constraints diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index 5d264935..dbd3b109 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -1244,6 +1244,7 @@ class AutogenRenderTest(TestBase): def render(type_, obj, context): if type_ == "foreign_key": + # causes it not to render return None if type_ == "column": if obj.name == "y": @@ -1269,7 +1270,7 @@ class AutogenRenderTest(TestBase): Column("y", Integer), Column("q", MySpecialType()), PrimaryKeyConstraint("x"), - ForeignKeyConstraint(["x"], ["y"]), + ForeignKeyConstraint(["x"], ["remote.y"]), ) op_obj = ops.CreateTableOp.from_table(t) result = autogenerate.render_op_text(self.autogen_context, op_obj) @@ -1380,7 +1381,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], onupdate='CASCADE')", @@ -1393,7 +1394,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], ondelete='CASCADE')", @@ -1405,7 +1406,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], deferrable=True)", @@ -1417,7 +1418,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], initially='XYZ')", @@ -1435,7 +1436,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], " @@ -1455,7 +1456,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", @@ -1474,7 +1475,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.nonexistent'], )", @@ -1498,7 +1499,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", @@ -1517,7 +1518,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['t2.c_rem'], )", @@ -1537,7 +1538,7 @@ class AutogenRenderTest(TestBase): eq_ignore_whitespace( autogenerate.render._render_constraint( - const, self.autogen_context + const, self.autogen_context, m ), "sa.ForeignKeyConstraint(['c_rem'], ['t.c'], " "name='fk1', use_alter=True)", @@ -1555,7 +1556,7 @@ class AutogenRenderTest(TestBase): r"u'", "'", autogenerate.render._render_constraint( - fk, self.autogen_context + fk, self.autogen_context, m ), ), "sa.ForeignKeyConstraint(['c'], ['foo.t2.c_rem'], " @@ -1567,6 +1568,7 @@ class AutogenRenderTest(TestBase): autogenerate.render._render_check_constraint( CheckConstraint("im a constraint", name="cc1"), self.autogen_context, + None, ), "sa.CheckConstraint(!U'im a constraint', name='cc1')", ) @@ -1577,7 +1579,9 @@ class AutogenRenderTest(TestBase): ten = literal_column("10") eq_ignore_whitespace( autogenerate.render._render_check_constraint( - CheckConstraint(and_(c > five, c < ten)), self.autogen_context + CheckConstraint(and_(c > five, c < ten)), + self.autogen_context, + None, ), "sa.CheckConstraint(!U'c > 5 AND c < 10')", ) @@ -1586,7 +1590,9 @@ class AutogenRenderTest(TestBase): c = column("c") eq_ignore_whitespace( autogenerate.render._render_check_constraint( - CheckConstraint(and_(c > 5, c < 10)), self.autogen_context + CheckConstraint(and_(c > 5, c < 10)), + self.autogen_context, + None, ), "sa.CheckConstraint(!U'c > 5 AND c < 10')", ) @@ -1598,6 +1604,7 @@ class AutogenRenderTest(TestBase): autogenerate.render._render_unique_constraint( UniqueConstraint(t.c.c, name="uq_1", deferrable="XYZ"), self.autogen_context, + None, ), "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1')", ) @@ -2248,7 +2255,9 @@ class RenderNamingConventionTest(TestBase): t = Table("t", self.metadata, Column("c", Integer)) eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - UniqueConstraint(t.c.c, deferrable="XYZ"), self.autogen_context + UniqueConstraint(t.c.c, deferrable="XYZ"), + self.autogen_context, + None, ), "sa.UniqueConstraint('c', deferrable='XYZ', " "name=op.f('uq_ct_t_c'))", @@ -2258,7 +2267,7 @@ class RenderNamingConventionTest(TestBase): t = Table("t", self.metadata, Column("c", Integer)) eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - UniqueConstraint(t.c.c, name="q"), self.autogen_context + UniqueConstraint(t.c.c, name="q"), self.autogen_context, None ), "sa.UniqueConstraint('c', name='q')", ) @@ -2316,7 +2325,7 @@ class RenderNamingConventionTest(TestBase): uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0] eq_ignore_whitespace( autogenerate.render._render_unique_constraint( - uq, self.autogen_context + uq, self.autogen_context, None ), "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))", ) @@ -2370,7 +2379,7 @@ class RenderNamingConventionTest(TestBase): eq_ignore_whitespace( autogenerate.render._render_check_constraint( - ck, self.autogen_context + ck, self.autogen_context, None ), "sa.CheckConstraint(!U'im a constraint', name=op.f('ck_t_cc1'))", ) diff --git a/tests/test_command.py b/tests/test_command.py index 6fb99579..4114be07 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -16,6 +16,7 @@ from alembic import util from alembic.script import ScriptDirectory from alembic.testing import assert_raises from alembic.testing import assert_raises_message +from alembic.testing import config as testing_config from alembic.testing import eq_ from alembic.testing import is_false from alembic.testing import is_true @@ -937,6 +938,7 @@ class EditTest(TestBase): command.edit(self.cfg, self.b[0:3]) edit.assert_called_with(expected_call_arg) + @testing_config.requirements.editor_installed @testing.emits_python_deprecation_warning("the imp module is deprecated") def test_edit_with_missing_editor(self): with mock.patch("editor.edit") as edit_mock: diff --git a/tests/test_op.py b/tests/test_op.py index 93bba287..63d3e965 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -6,10 +6,12 @@ from sqlalchemy import Column from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import String from sqlalchemy import Table +from sqlalchemy import UniqueConstraint from sqlalchemy.sql import column from sqlalchemy.sql import func from sqlalchemy.sql import text @@ -22,7 +24,7 @@ from alembic.testing import assert_raises_message from alembic.testing import combinations from alembic.testing import config from alembic.testing import eq_ -from alembic.testing import is_ +from alembic.testing import is_not_ from alembic.testing import mock from alembic.testing.fixtures import AlterColRoundTripFixture from alembic.testing.fixtures import op_fixture @@ -47,16 +49,10 @@ class OpTest(TestBase): op.rename_table("t1", "t2", schema="foo") context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2") - def test_create_index_no_expr_allowed(self): - op_fixture() - assert_raises_message( - ValueError, - r"String or text\(\) construct expected", - op.create_index, - "name", - "tname", - [func.foo(column("x"))], - ) + def test_create_index_arbitrary_expr(self): + context = op_fixture() + op.create_index("name", "tname", [func.foo(column("x"))]) + context.assert_("CREATE INDEX name ON tname (foo(x))") def test_add_column_schema_hard_quoting(self): @@ -843,6 +839,63 @@ class OpTest(TestBase): "FOREIGN KEY(st_id) REFERENCES some_table (id))" ) + def test_create_table_check_constraint(self): + context = op_fixture() + t1 = op.create_table( + "some_table", + Column("id", Integer, primary_key=True), + Column("foo_id", Integer), + CheckConstraint("foo_id>5", name="ck_1"), + ) + context.assert_( + "CREATE TABLE some_table (" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "PRIMARY KEY (id), " + "CONSTRAINT ck_1 CHECK (foo_id>5))" + ) + + ck = [c for c in t1.constraints if isinstance(c, CheckConstraint)] + eq_(ck[0].name, "ck_1") + + def test_create_table_unique_constraint(self): + context = op_fixture() + t1 = op.create_table( + "some_table", + Column("id", Integer, primary_key=True), + Column("foo_id", Integer), + UniqueConstraint("foo_id", name="uq_1"), + ) + context.assert_( + "CREATE TABLE some_table (" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "PRIMARY KEY (id), " + "CONSTRAINT uq_1 UNIQUE (foo_id))" + ) + + uq = [c for c in t1.constraints if isinstance(c, UniqueConstraint)] + eq_(uq[0].name, "uq_1") + + def test_create_table_index(self): + context = op_fixture() + t1 = op.create_table( + "some_table", + Column("id", Integer, primary_key=True), + Column("foo_id", Integer), + Index("ix_1", "foo_id"), + ) + context.assert_( + "CREATE TABLE some_table (" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "PRIMARY KEY (id))", + "CREATE INDEX ix_1 ON some_table (foo_id)", + ) + + ix = list(t1.indexes) + eq_(ix[0].name, "ix_1") + def test_create_table_fk_and_schema(self): context = op_fixture() t1 = op.create_table( @@ -1024,13 +1077,15 @@ class CustomOpTest(TestBase): context.assert_("CREATE SEQUENCE foob") -class EnsureOrigObjectFromToTest(TestBase): - """the to_XYZ and from_XYZ methods are used heavily in autogenerate. +class ObjectFromToTest(TestBase): + """Test operation round trips for to_obj() / from_obj(). - It's critical that these methods, at least the "drop" form, - always return the *same* object if available so that all the info - passed into to_XYZ is maintained in the from_XYZ. + Previously, these needed to preserve the "original" item + to this, but this makes them harder to work with. + As of #803 the constructs try to behave more intelligently + about the state they were given, so that they can both "reverse" + themselves but also take into accout their current state. """ @@ -1038,31 +1093,112 @@ class EnsureOrigObjectFromToTest(TestBase): schema_obj = schemaobj.SchemaObjects() idx = schema_obj.index("x", "y", ["z"]) op = ops.DropIndexOp.from_index(idx) - is_(op.to_index(), idx) + is_not_(op.to_index(), idx) + + def test_drop_index_add_kw(self): + schema_obj = schemaobj.SchemaObjects() + idx = schema_obj.index("x", "y", ["z"]) + op = ops.DropIndexOp.from_index(idx) + + op.kw["postgresql_concurrently"] = True + eq_(op.to_index().dialect_kwargs["postgresql_concurrently"], True) + + eq_( + op.reverse().to_index().dialect_kwargs["postgresql_concurrently"], + True, + ) def test_create_index(self): schema_obj = schemaobj.SchemaObjects() idx = schema_obj.index("x", "y", ["z"]) op = ops.CreateIndexOp.from_index(idx) - is_(op.to_index(), idx) + + is_not_(op.to_index(), idx) + + def test_create_index_add_kw(self): + schema_obj = schemaobj.SchemaObjects() + idx = schema_obj.index("x", "y", ["z"]) + op = ops.CreateIndexOp.from_index(idx) + + op.kw["postgresql_concurrently"] = True + + eq_(op.to_index().dialect_kwargs["postgresql_concurrently"], True) + eq_( + op.reverse().to_index().dialect_kwargs["postgresql_concurrently"], + True, + ) def test_drop_table(self): schema_obj = schemaobj.SchemaObjects() table = schema_obj.table("x", Column("q", Integer)) op = ops.DropTableOp.from_table(table) - is_(op.to_table(), table) + is_not_(op.to_table(), table) + + def test_drop_table_add_kw(self): + schema_obj = schemaobj.SchemaObjects() + table = schema_obj.table("x", Column("q", Integer)) + op = ops.DropTableOp.from_table(table) + + op.table_kw["postgresql_partition_by"] = "x" + + eq_(op.to_table().dialect_kwargs["postgresql_partition_by"], "x") + eq_( + op.reverse().to_table().dialect_kwargs["postgresql_partition_by"], + "x", + ) def test_create_table(self): schema_obj = schemaobj.SchemaObjects() table = schema_obj.table("x", Column("q", Integer)) op = ops.CreateTableOp.from_table(table) - is_(op.to_table(), table) + is_not_(op.to_table(), table) + + def test_create_table_add_kw(self): + schema_obj = schemaobj.SchemaObjects() + table = schema_obj.table("x", Column("q", Integer)) + op = ops.CreateTableOp.from_table(table) + op.kw["postgresql_partition_by"] = "x" + + eq_(op.to_table().dialect_kwargs["postgresql_partition_by"], "x") + eq_( + op.reverse().to_table().dialect_kwargs["postgresql_partition_by"], + "x", + ) + + def test_create_unique_constraint(self): + schema_obj = schemaobj.SchemaObjects() + const = schema_obj.unique_constraint("x", "foobar", ["a"]) + op = ops.AddConstraintOp.from_constraint(const) + is_not_(op.to_constraint(), const) + + def test_create_unique_constraint_add_kw(self): + schema_obj = schemaobj.SchemaObjects() + const = schema_obj.unique_constraint("x", "foobar", ["a"]) + op = ops.AddConstraintOp.from_constraint(const) + is_not_(op.to_constraint(), const) + + op.kw["sqlite_on_conflict"] = "IGNORE" + + eq_(op.to_constraint().dialect_kwargs["sqlite_on_conflict"], "IGNORE") + eq_( + op.reverse().to_constraint().dialect_kwargs["sqlite_on_conflict"], + "IGNORE", + ) def test_drop_unique_constraint(self): schema_obj = schemaobj.SchemaObjects() const = schema_obj.unique_constraint("x", "foobar", ["a"]) op = ops.DropConstraintOp.from_constraint(const) - is_(op.to_constraint(), const) + is_not_(op.to_constraint(), const) + + def test_drop_unique_constraint_change_name(self): + schema_obj = schemaobj.SchemaObjects() + const = schema_obj.unique_constraint("x", "foobar", ["a"]) + op = ops.DropConstraintOp.from_constraint(const) + + op.constraint_name = "my_name" + eq_(op.to_constraint().name, "my_name") + eq_(op.reverse().to_constraint().name, "my_name") def test_drop_constraint_not_available(self): op = ops.DropConstraintOp("x", "y", type_="unique")