From: Mike Bayer Date: Thu, 11 Jul 2013 23:07:14 +0000 (-0400) Subject: Added new kw argument to :meth:`.EnvironmentContext.configure` X-Git-Tag: rel_0_6_0~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=949367b87972ad2fba7653e9bfd624b7a629c04e;p=thirdparty%2Fsqlalchemy%2Falembic.git Added new kw argument to :meth:`.EnvironmentContext.configure` ``include_object``. This is a more flexible version of the ``include_symbol`` argument which allows filtering of columns as well as tables from the autogenerate process, and in the future will also work for types, constraints and other constructs. The fully constructed schema object is passed, including its name and type as well as a flag indicating if the object is from the local application metadata or is reflected. --- diff --git a/alembic/autogenerate.py b/alembic/autogenerate.py index f1af020d..70b2e91c 100644 --- a/alembic/autogenerate.py +++ b/alembic/autogenerate.py @@ -109,12 +109,26 @@ def compare_metadata(context, metadata): def _produce_migration_diffs(context, template_args, imports, include_symbol=None, + include_object=None, include_schemas=False): opts = context.opts metadata = opts['target_metadata'] + include_object = opts.get('include_object', include_object) include_symbol = opts.get('include_symbol', include_symbol) include_schemas = opts.get('include_schemas', include_schemas) + object_filters = [] + if include_symbol: + def include_symbol_filter(object, name, type_, reflected, compare_to): + if type_ == "table": + return include_symbol(name, object.schema) + else: + return True + object_filters.append(include_symbol_filter) + if include_object: + object_filters.append(include_object) + + if metadata is None: raise util.CommandError( "Can't proceed with --autogenerate option; environment " @@ -126,8 +140,7 @@ def _produce_migration_diffs(context, template_args, diffs = [] _produce_net_changes(connection, metadata, diffs, - autogen_context, include_symbol, - include_schemas) + autogen_context, object_filters, include_schemas) template_args[opts['upgrade_token']] = \ _indent(_produce_upgrade_commands(diffs, autogen_context)) template_args[opts['downgrade_token']] = \ @@ -155,8 +168,16 @@ def _indent(text): ################################################### # walk structures + +def _run_filters(object_, name, type_, reflected, compare_to, object_filters): + for fn in object_filters: + if not fn(object_, name, type_, reflected, compare_to): + return False + else: + return True + def _produce_net_changes(connection, metadata, diffs, autogen_context, - include_symbol=None, + object_filters=(), include_schemas=False): inspector = Inspector.from_engine(connection) # TODO: not hardcode alembic_version here ? @@ -179,53 +200,53 @@ def _produce_net_changes(connection, metadata, diffs, autogen_context, metadata_table_names = OrderedSet([(table.schema, table.name) for table in metadata.sorted_tables]) - if include_symbol: - conn_table_names = set((s, name) - for s, name in conn_table_names - if include_symbol(name, s)) - metadata_table_names = OrderedSet((s, name) - for s, name in metadata_table_names - if include_symbol(name, s)) - _compare_tables(conn_table_names, metadata_table_names, + object_filters, inspector, metadata, diffs, autogen_context) def _compare_tables(conn_table_names, metadata_table_names, + object_filters, inspector, metadata, diffs, autogen_context): for s, tname in metadata_table_names.difference(conn_table_names): name = '%s.%s' % (s, tname) if s else tname - diffs.append(("add_table", metadata.tables[name])) - log.info("Detected added table %r", name) + metadata_table = metadata.tables[sa_schema._get_table_key(tname, s)] + if _run_filters(metadata_table, tname, "table", False, None, object_filters): + diffs.append(("add_table", metadata.tables[name])) + log.info("Detected added table %r", name) removal_metadata = sa_schema.MetaData() for s, tname in conn_table_names.difference(metadata_table_names): - name = '%s.%s' % (s, tname) if s else tname + name = sa_schema._get_table_key(tname, s) exists = name in removal_metadata.tables t = sa_schema.Table(tname, removal_metadata, schema=s) if not exists: inspector.reflecttable(t, None) - diffs.append(("remove_table", t)) - log.info("Detected removed table %r", name) + if _run_filters(t, tname, "table", True, None, object_filters): + diffs.append(("remove_table", t)) + log.info("Detected removed table %r", name) existing_tables = conn_table_names.intersection(metadata_table_names) - conn_column_info = dict( - ((s, tname), - dict( - (rec["name"], rec) - for rec in inspector.get_columns(tname, schema=s) - ) - ) - for s, tname in existing_tables - ) + existing_metadata = sa_schema.MetaData() + conn_column_info = {} + for s, tname in existing_tables: + name = sa_schema._get_table_key(tname, s) + exists = name in existing_metadata.tables + t = sa_schema.Table(tname, existing_metadata, schema=s) + if not exists: + inspector.reflecttable(t, None) + conn_column_info[(s, tname)] = t for s, tname in sorted(existing_tables): name = '%s.%s' % (s, tname) if s else tname - _compare_columns(s, tname, - conn_column_info[(s, tname)], - metadata.tables[name], - diffs, autogen_context) + metadata_table = metadata.tables[name] + conn_table = existing_metadata.tables[name] + if _run_filters(metadata_table, tname, "table", False, conn_table, object_filters): + _compare_columns(s, tname, object_filters, + conn_table, + metadata_table, + diffs, autogen_context) # TODO: # index add/drop @@ -235,33 +256,42 @@ def _compare_tables(conn_table_names, metadata_table_names, ################################################### # element comparison -def _compare_columns(schema, tname, conn_table, metadata_table, +def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, diffs, autogen_context): name = '%s.%s' % (schema, tname) if schema else tname metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) - conn_col_names = set(conn_table) + conn_col_names = dict((c.name, c) for c in conn_table.c) metadata_col_names = OrderedSet(sorted(metadata_cols_by_name)) - for cname in metadata_col_names.difference(conn_col_names): - diffs.append( - ("add_column", schema, tname, metadata_cols_by_name[cname]) - ) - log.info("Detected added column '%s.%s'", name, cname) - for cname in conn_col_names.difference(metadata_col_names): - diffs.append( - ("remove_column", schema, tname, sa_schema.Column( - cname, - conn_table[cname]['type'], - nullable=conn_table[cname]['nullable'], - server_default=conn_table[cname]['default'] - )) - ) - log.info("Detected removed column '%s.%s'", name, cname) + for cname in metadata_col_names.difference(conn_col_names): + if _run_filters(metadata_cols_by_name[cname], cname, + "column", False, None, object_filters): + diffs.append( + ("add_column", schema, tname, metadata_cols_by_name[cname]) + ) + log.info("Detected added column '%s.%s'", name, cname) + + for cname in set(conn_col_names).difference(metadata_col_names): + rem_col = sa_schema.Column( + cname, + conn_table.c[cname].type, + nullable=conn_table.c[cname].nullable, + server_default=conn_table.c[cname].server_default + ) + if _run_filters(rem_col, cname, + "column", True, None, object_filters): + diffs.append( + ("remove_column", schema, tname, rem_col) + ) + log.info("Detected removed column '%s.%s'", name, cname) for colname in metadata_col_names.intersection(conn_col_names): metadata_col = metadata_cols_by_name[colname] - conn_col = conn_table[colname] + conn_col = conn_table.c[colname] + if not _run_filters( + metadata_col, colname, "column", False, conn_col, object_filters): + continue col_diff = [] _compare_type(schema, tname, colname, conn_col, @@ -284,13 +314,13 @@ def _compare_columns(schema, tname, conn_table, metadata_table, def _compare_nullable(schema, tname, cname, conn_col, metadata_col_nullable, diffs, autogen_context): - conn_col_nullable = conn_col['nullable'] + conn_col_nullable = conn_col.nullable if conn_col_nullable is not metadata_col_nullable: diffs.append( ("modify_nullable", schema, tname, cname, { - "existing_type": conn_col['type'], - "existing_server_default": conn_col['default'], + "existing_type": conn_col.type, + "existing_server_default": conn_col.server_default, }, conn_col_nullable, metadata_col_nullable), @@ -305,7 +335,7 @@ def _compare_type(schema, tname, cname, conn_col, metadata_col, diffs, autogen_context): - conn_type = conn_col['type'] + conn_type = conn_col.type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: log.info("Couldn't determine database type " @@ -323,8 +353,8 @@ def _compare_type(schema, tname, cname, conn_col, diffs.append( ("modify_type", schema, tname, cname, { - "existing_nullable": conn_col['nullable'], - "existing_server_default": conn_col['default'], + "existing_nullable": conn_col.nullable, + "existing_server_default": conn_col.server_default, }, conn_type, metadata_type), @@ -337,22 +367,25 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, diffs, autogen_context): metadata_default = metadata_col.server_default - conn_col_default = conn_col['default'] + conn_col_default = conn_col.server_default if conn_col_default is None and metadata_default is None: return False rendered_metadata_default = _render_server_default( metadata_default, autogen_context) + rendered_conn_default = conn_col.server_default.arg.text \ + if conn_col.server_default else None isdiff = autogen_context['context']._compare_server_default( conn_col, metadata_col, - rendered_metadata_default + rendered_metadata_default, + rendered_conn_default ) if isdiff: - conn_col_default = conn_col['default'] + conn_col_default = rendered_conn_default diffs.append( ("modify_default", schema, tname, cname, { - "existing_nullable": conn_col['nullable'], - "existing_type": conn_col['type'], + "existing_nullable": conn_col.nullable, + "existing_type": conn_col.type, }, conn_col_default, metadata_default), diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 2bb672c9..6a7c688a 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -182,7 +182,7 @@ class DefaultImpl(with_metaclass(ImplMeta)): def compare_type(self, inspector_column, metadata_column): - conn_type = inspector_column['type'] + conn_type = inspector_column.type metadata_type = metadata_column.type metadata_impl = metadata_type.dialect_impl(self.dialect) @@ -202,9 +202,9 @@ class DefaultImpl(with_metaclass(ImplMeta)): def compare_server_default(self, inspector_column, metadata_column, - rendered_metadata_default): - conn_col_default = inspector_column['default'] - return conn_col_default != rendered_metadata_default + rendered_metadata_default, + rendered_inspector_default): + return rendered_inspector_default != rendered_metadata_default def start_migrations(self): """A hook called when :meth:`.EnvironmentContext.run_migrations` diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index d63958ba..5ca0d1f5 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -11,14 +11,15 @@ class PostgresqlImpl(DefaultImpl): def compare_server_default(self, inspector_column, metadata_column, - rendered_metadata_default): + rendered_metadata_default, + rendered_inspector_default): # don't do defaults for SERIAL columns if metadata_column.primary_key and \ metadata_column is metadata_column.table._autoincrement_column: return False - conn_col_default = inspector_column['default'] + conn_col_default = rendered_inspector_default if None in (conn_col_default, rendered_metadata_default): return conn_col_default != rendered_metadata_default diff --git a/alembic/environment.py b/alembic/environment.py index 5d4e0561..641d435a 100644 --- a/alembic/environment.py +++ b/alembic/environment.py @@ -264,6 +264,7 @@ class EnvironmentContext(object): template_args=None, target_metadata=None, include_symbol=None, + include_object=None, include_schemas=False, compare_type=False, compare_server_default=False, @@ -425,9 +426,59 @@ class EnvironmentContext(object): execute the two defaults on the database side to compare for equivalence. + :param include_object: A callable function which is given + the chance to return ``True`` or ``False`` for any object, + indicating if the given object should be considered in the + autogenerate sweep. + + The function accepts the following positional arguments: + + * ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such as a + :class:`~sqlalchemy.schema.Table` or :class:`~sqlalchemy.schema.Column` + object + * ``name``: the name of the object. This is typically available + via ``object.name``. + * ``type``: a string describing the type of object; currently + ``"table"`` or ``"column"``, but will include other types in a + future release + * ``reflected``: ``True`` if the given object was produced based on + table reflection, ``False`` if it's from a local :class:`.MetaData` + object. + * ``compare_to``: the object being compared against, if available, + else ``None``. + + E.g.:: + + def include_object(object, name, type_, reflected, compare_to): + if (type_ == "column" and + not reflected and + object.info.get("skip_autogenerate", False)): + return False + else: + return True + + context.configure( + # ... + include_object = include_object + ) + + The ``include_object`` filter will be expanded in a future release + to also receive type, constraint, and default objects. + + .. versionadded:: 0.6.0 + + .. seealso:: + + ``include_schemas``, ``include_symbol`` + + :param include_symbol: A callable function which, given a table name and schema name (may be ``None``), returns ``True`` or ``False``, indicating if the given table should be considered in the autogenerate sweep. + + .. deprecated:: 0.6.0 ``include_symbol`` is superceded by the + more generic ``include_object`` parameter. + E.g.:: def include_symbol(tablename, schema): @@ -455,6 +506,24 @@ class EnvironmentContext(object): .. versionchanged:: 0.4.0 the ``include_symbol`` callable must now also accept a "schema" argument, which may be None. + .. seealso:: + + ``include_schemas``, ``include_object`` + + :param include_schemas: If True, autogenerate will scan across + all schemas located by the SQLAlchemy + :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names` + method, and include all differences in tables found across all + those schemas. When using this option, you may want to also + use the ``include_symbol`` option to specify a callable which + can filter the tables/schemas that get included. + + .. versionadded :: 0.4.0 + + .. seealso:: + + ``include_symbol``, ``include_object`` + :param render_item: Callable that can be used to override how any schema item, i.e. column, constraint, type, etc., is rendered for autogenerate. The callable receives a @@ -506,15 +575,6 @@ class EnvironmentContext(object): will render them using the dialect module name, i.e. ``mssql.BIT()``, ``postgresql.UUID()``. - :param include_schemas: If True, autogenerate will scan across - all schemas located by the SQLAlchemy - :meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names` - method, and include all differences in tables found across all - those schemas. When using this option, you may want to also - use the ``include_symbol`` option to specify a callable which - can filter the tables/schemas that get included. - - .. versionadded :: 0.4.0 Parameters specific to individual backends: @@ -546,6 +606,7 @@ class EnvironmentContext(object): opts['template_args'].update(template_args) opts['target_metadata'] = target_metadata opts['include_symbol'] = include_symbol + opts['include_object'] = include_object opts['include_schemas'] = include_schemas opts['upgrade_token'] = upgrade_token opts['downgrade_token'] = downgrade_token diff --git a/alembic/migration.py b/alembic/migration.py index cdd244a4..4838e66b 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -292,7 +292,8 @@ class MigrationContext(object): def _compare_server_default(self, inspector_column, metadata_column, - rendered_metadata_default): + rendered_metadata_default, + rendered_column_default): if self._user_compare_server_default is False: return False @@ -302,7 +303,7 @@ class MigrationContext(object): self, inspector_column, metadata_column, - inspector_column['default'], + rendered_column_default, metadata_column.server_default, rendered_metadata_default ) @@ -312,5 +313,6 @@ class MigrationContext(object): return self.impl.compare_server_default( inspector_column, metadata_column, - rendered_metadata_default) + rendered_metadata_default, + rendered_column_default) diff --git a/docs/build/changelog.rst b/docs/build/changelog.rst index bc39aea2..6b98e3c4 100644 --- a/docs/build/changelog.rst +++ b/docs/build/changelog.rst @@ -7,6 +7,19 @@ Changelog :version: 0.6.0 :released: + .. change:: + :tags: feature + :tickets: 101 + + Added new kw argument to :meth:`.EnvironmentContext.configure` + ``include_object``. This is a more flexible version of the + ``include_symbol`` argument which allows filtering of columns as well as tables + from the autogenerate process, + and in the future will also work for types, constraints and + other constructs. The fully constructed schema object is passed, + including its name and type as well as a flag indicating if the object + is from the local application metadata or is reflected. + .. change:: :tags: feature diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index 9bdc2dec..a1b5ac3e 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -94,9 +94,17 @@ def _model_four(): return m -_default_include_symbol = lambda name, schema=None: name in ("parent", "child", +def _default_include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name in ("parent", "child", "user", "order", "item", "address", "extra") + else: + return True + +_default_object_filters = [ + _default_include_object +] class AutogenTest(object): @classmethod @@ -244,9 +252,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name == "t3" + else: + return True autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol=lambda n, s: n == 't3', + object_filters=[include_object], include_schemas=True ) eq_(diffs[0][0], "add_table") @@ -256,9 +269,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name == "t4" + else: + return True autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol=lambda n, s: n == 't4', + object_filters=[include_object], include_schemas=True ) eq_(diffs[0][0], "add_table") @@ -268,9 +286,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name == "t1" + else: + return True autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol=lambda n, s: n == 't1', + object_filters=[include_object], include_schemas=True ) eq_(diffs[0][0], "remove_table") @@ -280,9 +303,14 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): + if type_ == "table": + return name == "t2" + else: + return True autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol=lambda n, s: n == 't2', + object_filters=[include_object], include_schemas=True ) eq_(diffs[0][0], "remove_table") @@ -315,7 +343,7 @@ class AutogenerateDiffTestWSchema(AutogenTest, TestCase): diffs = [] autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol=_default_include_symbol, + object_filters=_default_object_filters, include_schemas=True ) @@ -390,7 +418,7 @@ class AutogenerateDiffTestWSchema(AutogenTest, TestCase): template_args = {} autogenerate._produce_migration_diffs( self.context, template_args, set(), - include_symbol=_default_include_symbol, + include_object=_default_include_object, include_schemas=True ) eq_(re.sub(r"u'", "'", template_args['upgrades']), @@ -472,7 +500,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase): diffs = [] autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, - include_symbol= _default_include_symbol + object_filters=_default_object_filters, ) eq_( @@ -620,11 +648,55 @@ class AutogenerateDiffTest(AutogenTest, TestCase): assert "alter_column('order'" in template_args['upgrades'] assert "alter_column('order'" in template_args['downgrades'] + def test_include_object(self): + def include_object(obj, name, type_, reflected, compare_to): + assert obj.name == name + if type_ == "table": + if reflected: + assert obj.metadata is not self.m2 + else: + assert obj.metadata is self.m2 + return name in ("address", "order") + elif type_ == "column": + if reflected: + assert obj.table.metadata is not self.m2 + else: + assert obj.table.metadata is self.m2 + return name != "street" + else: + return True + + + context = MigrationContext.configure( + connection=self.bind.connect(), + opts={ + 'compare_type': True, + 'compare_server_default': True, + 'target_metadata': self.m2, + 'include_object': include_object, + 'upgrade_token': "upgrades", + 'downgrade_token': "downgrades", + 'alembic_module_prefix': 'op.', + 'sqlalchemy_module_prefix': 'sa.', + } + ) + template_args = {} + autogenerate._produce_migration_diffs(context, template_args, set()) + template_args['upgrades'] = template_args['upgrades'].replace("u'", "'") + template_args['downgrades'] = template_args['downgrades'].\ + replace("u'", "'") + + assert "alter_column('user'" not in template_args['upgrades'] + assert "alter_column('user'" not in template_args['downgrades'] + assert "'street'" not in template_args['upgrades'] + assert "'street'" not in template_args['downgrades'] + assert "alter_column('order'" in template_args['upgrades'] + assert "alter_column('order'" in template_args['downgrades'] + def test_skip_null_type_comparison_reflected(self): diff = [] autogenerate._compare_type(None, "sometable", "somecol", - {"name":"somecol", "type":NULLTYPE, - "nullable":True, "default":None}, + Column("somecol", NULLTYPE), Column("somecol", Integer()), diff, self.autogen_context ) @@ -633,8 +705,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase): def test_skip_null_type_comparison_local(self): diff = [] autogenerate._compare_type(None, "sometable", "somecol", - {"name":"somecol", "type":Integer(), - "nullable":True, "default":None}, + Column("somecol", Integer()), Column("somecol", NULLTYPE), diff, self.autogen_context ) @@ -652,8 +723,7 @@ class AutogenerateDiffTest(AutogenTest, TestCase): diff = [] autogenerate._compare_type(None, "sometable", "somecol", - {"name":"somecol", "type":Integer(), - "nullable":True, "default":None}, + Column("somecol", Integer, nullable=True), Column("somecol", MyType()), diff, self.autogen_context ) @@ -664,7 +734,8 @@ class AutogenerateDiffTest(AutogenTest, TestCase): from sqlalchemy.util import OrderedSet inspector = Inspector.from_engine(self.bind) autogenerate._compare_tables( - OrderedSet([(None, 'extra'), (None, 'user')]), OrderedSet(), inspector, + OrderedSet([(None, 'extra'), (None, 'user')]), + OrderedSet(), [], inspector, MetaData(), diffs, self.autogen_context ) eq_( diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 047ac92d..9c530353 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -197,9 +197,10 @@ class PostgresqlDefaultCompareTest(TestCase): cols = insp.get_columns(t1.name) ctx = self.autogen_context['context'] return ctx.impl.compare_server_default( - cols[0], + None, col, - rendered) + rendered, + cols[0]['default']) def test_compare_current_timestamp(self): self._compare_default_roundtrip(