From: Mike Bayer Date: Sun, 27 Nov 2011 19:56:01 +0000 (-0500) Subject: implement autogenerate feature X-Git-Tag: rel_0_1_0~41 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ed25eab41d5d814b0283f9fad12fc78b6e0ebcb0;p=thirdparty%2Fsqlalchemy%2Falembic.git implement autogenerate feature --- diff --git a/alembic/autogenerate.py b/alembic/autogenerate.py new file mode 100644 index 00000000..55885057 --- /dev/null +++ b/alembic/autogenerate.py @@ -0,0 +1,267 @@ +"""Provide the 'autogenerate' feature which can produce migration operations +automatically.""" + +from alembic.context import _context_opts +from alembic import util +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy import types as sqltypes, schema + +################################################### +# top level + +def produce_migration_diffs(template_args): + metadata = _context_opts['autogenerate_metadata'] + if metadata is None: + raise util.CommandError( + "Can't proceed with --autogenerate option; environment " + "script env.py does not provide " + "a MetaData object to the context.") + connection = get_bind() + diffs = [] + _produce_net_changes(connection, metadata, diffs) + _set_upgrade(template_args, _produce_upgrade_commands(diffs)) + _set_downgrade(template_args, _produce_downgrade_commands(diffs)) + +def _set_upgrade(template_args, text): + template_args[_context_opts['upgrade_token']] = text + +def _set_downgrade(template_args, text): + template_args[_context_opts['downgrade_token']] = text + +################################################### +# walk structures + +def _produce_net_changes(connection, metadata, diffs): + inspector = Inspector.from_engine(connection) + conn_table_names = set(inspector.get_table_names()) + metadata_table_names = set(metadata.tables) + + diffs.extend( + ("upgrade_table", metadata.tables[tname]) + for tname in metadata_table_names.difference(conn_table_names) + ) + diffs.extend( + ("downgrade_table", tname) + for tname in conn_table_names.difference(metadata_table_names) + ) + + existing_tables = conn_table_names.intersection(metadata_table_names) + + conn_column_info = dict( + (tname, + dict( + (rec["name"], rec) + for rec in inspector.get_columns(tname) + ) + ) + for tname in existing_tables + ) + + for tname in existing_tables: + _compare_columns(tname, + conn_column_info[tname], + metadata.tables[tname], + diffs) + + # TODO: + # index add/drop + # table constraints + # sequences + +################################################### +# element comparison + +def _compare_columns(tname, conn_table, metadata_table, diffs): + metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) + conn_col_names = set(conn_table) + metadata_col_names = set(metadata_cols_by_name) + + diffs.extend( + ("upgrade_column", tname, metadata_cols_by_name[cname]) + for cname in metadata_col_names.difference(conn_col_names) + ) + diffs.extend( + ("downgrade_column", tname, cname) + for cname in conn_col_names.difference(metadata_col_names) + ) + + for colname in metadata_col_names.intersection(conn_col_names): + metadata_col = metadata_table.c[colname] + conn_col = conn_table[colname] + _compare_type(tname, colname, + conn_col['type'], + metadata_col.type, + diffs + ) + _compare_nullable(tname, colname, + conn_col['nullable'], + metadata_col.nullable, + diffs + ) + +def _compare_nullable(tname, cname, conn_col_nullable, + metadata_col_nullable, diffs): + if conn_col_nullable is not metadata_col_nullable: + diffs.extend([ + ("upgrade_nullable", tname, cname, metadata_col_nullable), + ("downgrade_nullable", tname, cname, conn_col_nullable) + ]) + +def _compare_type(tname, cname, conn_type, metadata_type, diffs): + if conn_type._compare_type_affinity(metadata_type): + comparator = _type_comparators.get(conn_type._type_affinity, None) + + isdiff = comparator and comparator(metadata_type, conn_type) + else: + isdiff = True + + if isdiff: + diffs.extend([ + ("upgrade_type", tname, cname, metadata_type), + ("downgrade_type", tname, cname, conn_type) + ]) + +def _string_compare(t1, t2): + return \ + t1.length is not None and \ + t1.length != t2.length + +def _numeric_compare(t1, t2): + return \ + ( + t1.precision is not None and \ + t1.precision != t2.precision + ) or \ + ( + t1.scale is not None and \ + t1.scale != t2.scale + ) +_type_comparators = { + sqltypes.String:_string_compare, + sqltypes.Numeric:_numeric_compare +} + +################################################### +# render python + +def _produce_upgrade_commands(diffs): + for diff in diffs: + if diff.startswith('upgrade_'): + cmd = _commands[diff[0]] + cmd(*diff[1:]) + +def _produce_downgrade_commands(diffs): + for diff in diffs: + if diff.startswith('downgrade_'): + cmd = _commands[diff[0]] + cmd(*diff[1:]) + +def _upgrade_table(table): + return \ +"""create_table(%(tablename)r, + %(args)s + ) +""" % { + 'tablename':table.name, + 'args':',\n'.join( + [_render_col(col) for col in table.c] + + sorted([rcons for rcons in + [_render_constraint(cons) for cons in + table.constraints] + if rcons is not None + ]) + ), + } + +def _downgrade_table(tname): + return "drop_table(%r)" % tname + +def _upgrade_column(tname, column): + return "add_column(%r, %s)" % ( + tname, + _render_column(column)) + +def _downgrade_column(tname, cname): + return "drop_column(%r, %r)" % (tname, cname) + +def _up_or_downgrade_type(tname, cname, type_): + return "alter_column(%r, %r, type=%r)" % ( + tname, cname, type_ + ) + +def _up_or_downgrade_nullable(tname, cname, nullable): + return "alter_column(%r, %r, nullable=%r)" % ( + tname, cname, nullable + ) + +_commands = { + 'upgrade_table':_upgrade_table, + 'downgrade_table':_downgrade_table, + + 'upgrade_column':_upgrade_column, + 'downgrade_column':_downgrade_column, + + 'upgrade_type':_up_or_downgrade_type, + 'downgrde_type':_up_or_downgrade_type, + + 'upgrade_nullable':_up_or_downgrade_nullable, + 'downgrade_nullable':_up_or_downgrade_nullable, + +} + +def _render_col(column): + opts = [] + if column.server_default: + opts.append(("server_default", column.server_default)) + if column.nullable is not None: + opts.append(("nullable", column.nullable)) + + # TODO: for non-ascii colname, assign a "key" + return "Column(%(name)r, %(type)r, %(kw)s)" % { + 'name':column.name, + 'type':column.type, + 'kw':", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]) + } + +def _render_constraint(constraint): + renderer = _constraint_renderers.get(type(constraint), None) + if renderer: + return renderer(constraint) + else: + return None + +def _render_primary_key(constraint): + opts = [] + if constraint.name: + opts.append(("name", constraint.name)) + return "PrimaryKeyConstraint(%(args)s)" % { + "args":", ".join( + [c.key for c in constraint.columns] + + ["%s=%s" % (kwname, val) for kwname, val in opts] + ), + } + +def _render_foreign_key(constraint): + opts = [] + if constraint.name: + opts.append(("name", constraint.name)) + # TODO: deferrable, initially, etc. + return "ForeignKeyConstraint([%(cols)s], [%(refcols)s], %(args)s)" % { + "cols":", ".join(f.parent.key for f in constraint.elements), + "refcols":", ".join(repr(f._get_colspec()) for f in constraint.elements), + "args":", ".join( + ["%s=%s" % (kwname, val) for kwname, val in opts] + ), + } + +def _render_check_constraint(constraint): + opts = [] + if constraint.name: + opts.append(("name", constraint.name)) + return "CheckConstraint('TODO')" + +_constraint_renderers = { + schema.PrimaryKeyConstraint:_render_primary_key, + schema.ForeignKeyConstraint:_render_foreign_key, + schema.CheckConstraint:_render_check_constraint +} diff --git a/alembic/command.py b/alembic/command.py index ba0810e4..ed672b83 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -1,5 +1,5 @@ from alembic.script import ScriptDirectory -from alembic import util, ddl, context +from alembic import util, ddl, context, autogenerate as autogen import os import functools @@ -59,11 +59,26 @@ def init(config, directory, template='generic'): util.msg("Please edit configuration/connection/logging "\ "settings in %r before proceeding." % config_file) -def revision(config, message=None): +def revision(config, message=None, autogenerate=False): """Create a new revision file.""" script = ScriptDirectory.from_config(config) - script.generate_rev(util.rev_id(), message) + template_args = {} + if autogenerate: + def retrieve_migrations(rev): + if script._get_rev(rev) is not script._get_rev("head"): + raise util.CommandError("Target database is not up to date.") + autogen.produce_migration_diffs(template_args) + return [] + + context._opts( + config, + script, + fn = retrieve_migrations + ) + script.run_env() + script.generate_rev(util.rev_id(), message, **template_args) + def upgrade(config, revision, sql=False, tag=None): """Upgrade to a later version.""" diff --git a/alembic/config.py b/alembic/config.py index e6867495..0c2fe194 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -137,6 +137,12 @@ def main(argv): type=str, help="Arbitrary 'tag' name - can be used by " "custom env.py scripts.") + if 'autogenerate' in kwargs: + parser.add_argument("--autogenerate", + action="store_true", + help="Populate revision script with candidate " + "migration operations, based on comparison of database to model.") + # TODO: # --dialect - name of dialect when --sql mode is set - *no DB connections diff --git a/alembic/context.py b/alembic/context.py index 20491f39..7fee2b7a 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -260,7 +260,10 @@ def configure( transactional_ddl=None, output_buffer=None, starting_rev=None, - tag=None + tag=None, + autogenerate_metadata=None, + upgrade_token="upgrades", + downgrade_token="downgrades" ): """Configure the migration environment. @@ -297,7 +300,17 @@ def configure( ``--sql`` mode. :param tag: a string tag for usage by custom ``env.py`` scripts. Set via the ``--tag`` option, can be overridden here. - + :param autogenerate_metadata: a :class:`sqlalchemy.schema.MetaData` object that + will be consulted if the ``--autogenerate`` option is passed to the + "alembic revision" command. The tables present will be compared against + what is locally available on the target :class:`~sqlalchemy.engine.base.Connection` + to produce candidate upgrade/downgrade operations. + :param upgrade_token: when running "alembic revision" with the ``--autogenerate`` + option, the text of the candidate upgrade operations will be present in this + template variable when script.py.mako is rendered. + :param downgrade_token: when running "alembic revision" with the ``--autogenerate`` + option, the text of the candidate downgrade operations will be present in this + template variable when script.py.mako is rendered. """ if connection: @@ -323,6 +336,9 @@ def configure( opts['starting_rev'] = starting_rev if tag: opts['tag'] = tag + opts['autogenerate_metadata'] = autogenerate_metadata + opts['upgrade_token'] = upgrade_token + opts['downgrade_token'] = downgrade_token _context = Context( dialect, _script, connection, opts['fn'], diff --git a/alembic/script.py b/alembic/script.py index 4c4b3659..e36b888c 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -180,7 +180,7 @@ class ScriptDirectory(object): shutil.copy, src, dest) - def generate_rev(self, revid, message): + def generate_rev(self, revid, message, **kw): current_head = self._current_head() path = self._rev_path(revid) self.generate_template( @@ -189,7 +189,8 @@ class ScriptDirectory(object): up_revision=str(revid), down_revision=current_head, create_date=datetime.datetime.now(), - message=message if message is not None else ("empty message") + message=message if message is not None else ("empty message"), + **kw ) script = Script.from_path(path) self._revision_map[script.revision] = script diff --git a/alembic/templates/generic/env.py b/alembic/templates/generic/env.py index 423c53e6..5095f0be 100644 --- a/alembic/templates/generic/env.py +++ b/alembic/templates/generic/env.py @@ -10,6 +10,12 @@ config = context.config # This line sets up loggers basically. fileConfig(config.config_file_name) +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# autogenerate_metadata = mymodel.Base.metadata +autogenerate_metadata = None + # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") @@ -47,7 +53,10 @@ def run_migrations_online(): config.get_section(config.config_ini_section), prefix='sqlalchemy.') connection = engine.connect() - context.configure(connection=connection, dialect_name=engine.name) + context.configure( + connection=connection, + autogenerate_metadata=autogenerate_metadata + ) trans = connection.begin() try: diff --git a/alembic/templates/generic/script.py.mako b/alembic/templates/generic/script.py.mako index 128e3452..c6bf4041 100644 --- a/alembic/templates/generic/script.py.mako +++ b/alembic/templates/generic/script.py.mako @@ -12,7 +12,7 @@ down_revision = ${repr(down_revision)} from alembic.op import * def upgrade(): - pass + ${upgrades if upgrades else "pass"} def downgrade(): - pass + ${downgrades if downgrades else "pass"} diff --git a/alembic/templates/multidb/env.py b/alembic/templates/multidb/env.py index 1ffa597c..1ace59cd 100644 --- a/alembic/templates/multidb/env.py +++ b/alembic/templates/multidb/env.py @@ -9,9 +9,23 @@ import logging logging.fileConfig(options.config_file) # gather section names referring to different -# databases. +# databases. These are named "engine1", "engine2" +# in the sample .ini file. db_names = options.get_main_option('databases') +# add your model's MetaData objects here +# for 'autogenerate' support. These must be set +# up to hold just those tables targeting a +# particular database. table.tometadata() may be +# helpful here in case a "copy" of +# a MetaData is needed. +# from myapp import mymodel +# autogenerate_metadata = { +# 'engine1':mymodel.metadata1, +# 'engine2':mymodel.metadata2 +#} +autogenerate_metadata = {} + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -71,7 +85,9 @@ def run_migrations_online(): for name, rec in engines.items(): context.configure( connection=rec['connection'], - dialect_name=rec['engine'].name + upgrade_token="%s_upgrades", + downgrade_token="%s_downgrades", + autogenerate_metadata=autogenerate_metadata.get(name) ) context.execute("--running migrations for engine %s" % name) context.run_migrations(engine=name) diff --git a/alembic/templates/multidb/script.py.mako b/alembic/templates/multidb/script.py.mako index 5526f40e..f333d649 100644 --- a/alembic/templates/multidb/script.py.mako +++ b/alembic/templates/multidb/script.py.mako @@ -19,9 +19,11 @@ def downgrade(engine): % for engine in ["engine1", "engine2"]: - def upgrade_${engine}(): - pass - def downgrade_${engine}(): - pass +def upgrade_${engine}(): + ${context.get("%s_upgrades" % engine, "pass")} + +def downgrade_${engine}(): + ${context.get("%s_downgrades" % engine, "pass")} + % endfor \ No newline at end of file diff --git a/alembic/templates/pylons/env.py b/alembic/templates/pylons/env.py index 2726f906..da193f46 100644 --- a/alembic/templates/pylons/env.py +++ b/alembic/templates/pylons/env.py @@ -23,6 +23,12 @@ except: # customize this section for non-standard engine configurations. meta = __import__("%s.model.meta" % config['pylons.package']).model.meta +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# autogenerate_metadata = mymodel.Base.metadata +autogenerate_metadata = None + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -47,7 +53,10 @@ def run_migrations_online(): """ connection = meta.engine.connect() - context.configure_connection(connection) + context.configure( + connection=connection, + autogenerate_metadata=autogenerate_metadata + ) trans = connection.begin() try: context.run_migrations() diff --git a/alembic/templates/pylons/script.py.mako b/alembic/templates/pylons/script.py.mako index 128e3452..c6bf4041 100644 --- a/alembic/templates/pylons/script.py.mako +++ b/alembic/templates/pylons/script.py.mako @@ -12,7 +12,7 @@ down_revision = ${repr(down_revision)} from alembic.op import * def upgrade(): - pass + ${upgrades if upgrades else "pass"} def downgrade(): - pass + ${downgrades if downgrades else "pass"} diff --git a/tests/__init__.py b/tests/__init__.py index 1157aafe..720e5fba 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -48,6 +48,13 @@ def capture_context_buffer(): return capture() +def eq_ignore_whitespace(a, b, msg=None): + a = re.sub(r'^\s+?|\n', "", a) + a = re.sub(r' {2,}', " ", a) + b = re.sub(r'^\s+?|\n', "", b) + b = re.sub(r' {2,}', " ", b) + assert a == b, msg or "%r != %r" % (a, b) + def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) @@ -112,10 +119,24 @@ def _op_fixture(dialect='default', as_sql=False): ) return ctx(dialect, as_sql) +def _env_file_fixture(txt): + dir_ = os.path.join(staging_directory, 'scripts') + txt = """ +from alembic import context + +config = context.config +""" + txt + + path = os.path.join(dir_, "env.py") + pyc_path = util.pyc_file_from_path(path) + if os.access(pyc_path, os.F_OK): + os.unlink(pyc_path) + + file(path, 'w').write(txt) + def _sqlite_testing_config(): - cfg = _testing_config() dir_ = os.path.join(staging_directory, 'scripts') - open(cfg.config_file_name, 'w').write(""" + return _write_config_file(""" [alembic] script_location = %s sqlalchemy.url = sqlite:///%s/foo.db @@ -144,29 +165,12 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ % (dir_, dir_)) - return cfg - -def _env_file_fixture(txt): - dir_ = os.path.join(staging_directory, 'scripts') - txt = """ -from alembic import context - -config = context.config -""" + txt - - path = os.path.join(dir_, "env.py") - pyc_path = util.pyc_file_from_path(path) - if os.access(pyc_path, os.F_OK): - os.unlink(pyc_path) - - file(path, 'w').write(txt) def _no_sql_testing_config(): """use a postgresql url with no host so that connections guaranteed to fail""" - cfg = _testing_config() dir_ = os.path.join(staging_directory, 'scripts') - open(cfg.config_file_name, 'w').write(""" + return _write_config_file(""" [alembic] script_location = %s sqlalchemy.url = postgresql:// @@ -196,6 +200,10 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ % (dir_)) + +def _write_config_file(text): + cfg = _testing_config() + open(cfg.config_file_name, 'w').write(text) return cfg def sqlite_db(): @@ -205,7 +213,7 @@ def sqlite_db(): dir_ = os.path.join(staging_directory, 'scripts') return create_engine('sqlite:///%s/foo.db' % dir_) -def staging_env(create=True): +def staging_env(create=True, template="generic"): from alembic import command, script cfg = _testing_config() if create: diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py new file mode 100644 index 00000000..82e6b4bc --- /dev/null +++ b/tests/test_autogenerate.py @@ -0,0 +1,136 @@ +from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ + Numeric, CHAR, NUMERIC, ForeignKey, DATETIME +from alembic import autogenerate +from unittest import TestCase +from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace + +def _model_one(): + m = MetaData() + + Table('user', m, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('a1', Text), + Column("pw", String(50)) + ) + + Table('address', m, + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + ) + + Table('order', m, + Column('order_id', Integer, primary_key=True), + Column("amount", Numeric(8, 2), nullable=False) + ) + + Table('extra', m, + Column("x", CHAR) + ) + + return m + +def _model_two(): + m = MetaData() + + Table('user', m, + Column('id', Integer, primary_key=True), + Column('name', String(50), nullable=False), + Column('a1', Text), + ) + + Table('address', m, + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('street', String(50)) + ) + + Table('order', m, + Column('order_id', Integer, primary_key=True), + Column("amount", Numeric(10, 2), nullable=True) + ) + + Table('item', m, + Column('id', Integer, primary_key=True), + Column('description', String(100)) + ) + return m + +class AutogenerateDiffTest(TestCase): + @classmethod + def setup_class(cls): + env = staging_env() + cls.bind = sqlite_db() + cls.m1 = _model_one() + cls.m1.create_all(cls.bind) + + @classmethod + def teardown_class(cls): + clear_staging_env() + + def test_diffs(self): + metadata = _model_two() + connection = self.bind.connect() + diffs = [] + autogenerate._produce_net_changes(connection, metadata, diffs) + eq_(repr(diffs[5][3]), "Numeric(precision=10, scale=2)") + eq_(repr(diffs[6][3]), "NUMERIC(precision=8, scale=2)") + del diffs[5] + del diffs[5] + eq_( + diffs, + [ + ('upgrade_table', metadata.tables['item']), + ('downgrade_table', u'extra'), + ('downgrade_column', 'user', u'pw'), + ('upgrade_nullable', 'user', 'name', False), + ('downgrade_nullable', 'user', 'name', True), + ('upgrade_nullable', 'order', u'amount', True), + ('downgrade_nullable', 'order', u'amount', False), + ('upgrade_column', 'address', + metadata.tables['address'].c.street) + ] + ) + +class AutogenRenderTest(TestCase): + def test_render_table_upgrade(self): + m = MetaData() + t = Table('test', m, + Column('id', Integer, primary_key=True), + Column("address_id", Integer, ForeignKey("address.id")), + Column("timestamp", DATETIME, server_default="NOW()"), + Column("amount", Numeric(5, 2)), + ) + eq_ignore_whitespace( + autogenerate._upgrade_table(t), + "create_table('test', " + "Column('id', Integer(), nullable=False)," + "Column('address_id', Integer(), nullable=True)," + "Column('timestamp', DATETIME(), " + "server_default=DefaultClause('NOW()', for_update=False), " + "nullable=True)," + "Column('amount', Numeric(precision=5, scale=2), nullable=True)," + "ForeignKeyConstraint([address_id], ['address.id'], )," + "PrimaryKeyConstraint(id)" + " )" + ) + + def test_render_table_downgrade(self): + eq_( + autogenerate._downgrade_table("sometable"), + "drop_table('sometable')" + ) + + def test_render_type_upgrade(self): + eq_( + autogenerate._up_or_downgrade_type( + "sometable", "somecolumn", CHAR(10)), + "alter_column('sometable', 'somecolumn', type=CHAR(length=10))" + ) + + def test_render_nullable_upgrade(self): + eq_( + autogenerate._up_or_downgrade_nullable( + "sometable", "somecolumn", True), + "alter_column('sometable', 'somecolumn', nullable=True)" + )