From: Mike Bayer Date: Mon, 28 Nov 2011 07:24:18 +0000 (-0500) Subject: - rework MySQL + autogenerate so that X-Git-Tag: rel_0_1_0~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2bf20596567ba13cec357ae51d496fbd5fff69b7;p=thirdparty%2Fsqlalchemy%2Falembic.git - rework MySQL + autogenerate so that multiple changes to a single col are collapsed into one step, will work for other dialects which may support this too - add support for "imports" in scripts so that dialect-specific types can be rendered straight in from their parent module and work immediately - rework the internals of autogenerate to be more succinct, though there's a lot more that could happen here to make this easier --- diff --git a/alembic/autogenerate.py b/alembic/autogenerate.py index 89e9f904..4f81dd7d 100644 --- a/alembic/autogenerate.py +++ b/alembic/autogenerate.py @@ -22,9 +22,11 @@ def produce_migration_diffs(template_args): "a MetaData object to the context.") connection = get_bind() diffs = [] + imports = set() _produce_net_changes(connection, metadata, diffs) - _set_upgrade(template_args, _indent(_produce_upgrade_commands(diffs))) - _set_downgrade(template_args, _indent(_produce_downgrade_commands(diffs))) + _set_upgrade(template_args, _indent(_produce_upgrade_commands(diffs, imports))) + _set_downgrade(template_args, _indent(_produce_downgrade_commands(diffs, imports))) + template_args['imports'] = "\n".join(sorted(imports)) def _set_upgrade(template_args, text): template_args[_context_opts['upgrade_token']] = text @@ -110,21 +112,24 @@ def _compare_columns(tname, conn_table, metadata_table, diffs): for colname in metadata_col_names.intersection(conn_col_names): metadata_col = metadata_table.c[colname] conn_col = conn_table[colname] + col_diff = [] _compare_type(tname, colname, conn_col, metadata_col.type, - diffs + col_diff ) _compare_nullable(tname, colname, conn_col, metadata_col.nullable, - diffs + col_diff ) _compare_server_default(tname, colname, conn_col, metadata_col.server_default, - diffs + col_diff ) + if col_diff: + diffs.append(col_diff) def _compare_nullable(tname, cname, conn_col, metadata_col_nullable, diffs): @@ -209,67 +214,84 @@ _type_comparators = { ################################################### # produce command structure -def _produce_upgrade_commands(diffs): +def _produce_upgrade_commands(diffs, imports): buf = [] for diff in diffs: - buf.append(_invoke_command("upgrade", diff)) + buf.append(_invoke_command("upgrade", diff, imports)) return "\n".join(buf) -def _produce_downgrade_commands(diffs): +def _produce_downgrade_commands(diffs, imports): buf = [] for diff in diffs: - buf.append(_invoke_command("downgrade", diff)) + buf.append(_invoke_command("downgrade", diff, imports)) return "\n".join(buf) -def _invoke_command(updown, args): +def _invoke_command(updown, args, imports): + if isinstance(args, tuple): + return _invoke_adddrop_command(updown, args, imports) + else: + return _invoke_modify_command(updown, args, imports) + +def _invoke_adddrop_command(updown, args, imports): cmd_type = args[0] adddrop, cmd_type = cmd_type.split("_") - cmd_args = args[1:] + cmd_args = args[1:] + (imports,) - # TODO: MySQL really blew this up - # so try to clean this up _commands = { "table":(_drop_table, _add_table), "column":(_drop_column, _add_column), - "type":(_modify_type,), - "nullable":(_modify_nullable,), - "default":(_modify_server_default,), } cmd_callables = _commands[cmd_type] - if len(cmd_callables) == 2: - if ( - updown == "upgrade" and adddrop == "add" - ) or ( - updown == "downgrade" and adddrop == "remove" - ): - return cmd_callables[1](*cmd_args) - else: - return cmd_callables[0](*cmd_args) + if ( + updown == "upgrade" and adddrop == "add" + ) or ( + updown == "downgrade" and adddrop == "remove" + ): + return cmd_callables[1](*cmd_args) else: - tname, cname = cmd_args[0:2] - args = [] + return cmd_callables[0](*cmd_args) + +def _invoke_modify_command(updown, args, imports): + tname, cname = args[0][1:3] + kw = {} + + _arg_struct = { + "modify_type":("existing_type", "type_"), + "modify_nullable":("existing_nullable", "nullable"), + "modify_default":("existing_server_default", "server_default"), + } + for diff in args: + diff_kw = diff[3] for arg in ("existing_type", \ "existing_nullable", \ "existing_server_default"): - if arg in cmd_args[2]: - args.append(cmd_args[2][arg]) + if arg in diff_kw: + kw.setdefault(arg, diff_kw[arg]) + old_kw, new_kw = _arg_struct[diff[0]] if updown == "upgrade": - args += (cmd_args[-1], cmd_args[-2]) + kw[new_kw] = diff[-1] + kw[old_kw] = diff[-2] else: - args += cmd_args[-2:] - return cmd_callables[0](tname, cname, *args) + kw[new_kw] = diff[-2] + kw[old_kw] = diff[-1] + + if "nullable" in kw: + kw.pop("existing_nullable", None) + if "server_default" in kw: + kw.pop("existing_server_default", None) + return _modify_col(tname, cname, imports, **kw) ################################################### # render python -def _add_table(table): +def _add_table(table, imports): return "create_table(%(tablename)r,\n%(args)s\n)" % { 'tablename':table.name, 'args':',\n'.join( - [_render_column(col) for col in table.c] + + [_render_column(col, imports) for col in table.c] + sorted([rcons for rcons in [_render_constraint(cons) for cons in table.constraints] @@ -278,74 +300,46 @@ def _add_table(table): ), } -def _drop_table(table): +def _drop_table(table, imports): return "drop_table(%r)" % table.name -def _add_column(tname, column): +def _add_column(tname, column, imports): return "add_column(%r, %s)" % ( tname, - _render_column(column)) + _render_column(column, imports)) -def _drop_column(tname, column): +def _drop_column(tname, column, imports): return "drop_column(%r, %r)" % (tname, column.name) -def _modify_type(tname, cname, - existing_nullable, - existing_server_default, - type_, existing_type): - return _modify_col(tname, cname, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - existing_type=existing_type, - type_=type_ - ) - return text - -def _modify_nullable(tname, cname, - existing_type, - existing_server_default, - nullable, previous): - return _modify_col(tname, cname, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=previous, - nullable=nullable - ) - return text - -def _modify_server_default(tname, cname, - existing_type, - existing_nullable, - server_default, prev_default): - return _modify_col(tname, cname, - server_default=server_default, - existing_nullable=existing_nullable, - existing_type=existing_type, - ) - -def _modify_col(tname, cname, existing_type, - server_default=None, +def _modify_col(tname, cname, + imports, + server_default=False, type_=None, nullable=None, + existing_type=None, existing_nullable=None, - existing_server_default=None): + existing_server_default=False): prefix = _autogenerate_prefix() indent = " " * 11 text = "alter_column(%r, %r" % (tname, cname) - text += ", \n%sexisting_type=%s%r" % (indent, prefix, existing_type,) - if server_default: + text += ", \n%sexisting_type=%s" % (indent, + _repr_type(prefix, existing_type, imports)) + if server_default is not False: text += ", \n%sserver_default=%s" % (indent, _render_server_default(server_default),) if type_ is not None: - text += ", \n%stype_=%s%r" % (indent, prefix, type_) + text += ", \n%stype_=%s" % (indent, _repr_type(prefix, type_, imports)) if nullable is not None: - text += ", \n%snullable=%r" % (indent, nullable,) + text += ", \n%snullable=%r" % ( + indent, nullable,) if existing_nullable is not None: - text += ", \n%sexisting_nullable=%r" % (indent, existing_nullable) + text += ", \n%sexisting_nullable=%r" % ( + indent, existing_nullable) if existing_server_default: text += ", \n%sexisting_server_default=%s" % ( indent, - _render_server_default(existing_server_default), + _render_server_default( + existing_server_default), ) text += ")" return text @@ -353,18 +347,19 @@ def _modify_col(tname, cname, existing_type, def _autogenerate_prefix(): return _context_opts['autogenerate_sqlalchemy_prefix'] -def _render_column(column): +def _render_column(column, imports): opts = [] if column.server_default: - opts.append(("server_default", _render_server_default(column.server_default))) + opts.append(("server_default", + _render_server_default(column.server_default))) if column.nullable is not None: opts.append(("nullable", column.nullable)) # TODO: for non-ascii colname, assign a "key" - return "%(prefix)sColumn(%(name)r, %(prefix)s%(type)r, %(kw)s)" % { + return "%(prefix)sColumn(%(name)r, %(type)s, %(kw)s)" % { 'prefix':_autogenerate_prefix(), 'name':column.name, - 'type':column.type, + 'type':_repr_type(_autogenerate_prefix(), column.type, imports), 'kw':", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]) } @@ -383,6 +378,15 @@ def _render_server_default(default): else: return None +def _repr_type(prefix, type_, imports): + mod = type(type_).__module__ + if mod.startswith("sqlalchemy.dialects"): + dname = re.match(r"sqlalchemy\.dialects\.(\w+)", mod).group(1) + imports.add("from sqlalchemy.dialects import %s" % dname) + return "%s.%r" % (dname, type_) + else: + return "%s%r" % (prefix, type_) + def _render_constraint(constraint): renderer = _constraint_renderers.get(type(constraint), None) if renderer: diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 7797c759..00cf7ab5 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -1,63 +1,76 @@ from alembic.ddl.impl import DefaultImpl -from alembic.ddl.base import ColumnNullable, ColumnName, ColumnDefault, ColumnType +from alembic.ddl.base import ColumnNullable, ColumnName, ColumnDefault, ColumnType, AlterColumn from sqlalchemy.ext.compiler import compiles from alembic.ddl.base import alter_table +from alembic import util +from sqlalchemy import types as sqltypes class MySQLImpl(DefaultImpl): __dialect__ = 'mysql' + def alter_column(self, table_name, column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None + ): + self._exec( + MySQLAlterColumn( + table_name, column_name, + schema=schema, + newname=name if name is not None else column_name, + nullable =nullable if nullable is not None else + existing_nullable if existing_nullable is not None + else True, + type_=type_ if type_ is not None else existing_type, + default=server_default if server_default is not False else existing_server_default, + ) + ) -@compiles(ColumnNullable, 'mysql') -def _change_column_nullable(element, compiler, **kw): - return _mysql_change( - element, compiler, - nullable=element.nullable, - ) +class MySQLAlterColumn(AlterColumn): + def __init__(self, name, column_name, schema=None, + newname=None, + type_=None, + nullable=None, + default=False): + super(AlterColumn, self).__init__(name, schema=schema) + self.column_name = column_name + self.nullable = nullable + self.newname = newname + self.default = default + if type_ is None: + raise util.CommandError( + "All MySQL ALTER COLUMN operations " + "require the existing type." + ) -@compiles(ColumnName, 'mysql') -def _change_column_name(element, compiler, **kw): - return _mysql_change( - element, compiler, - name=element.newname, - ) + self.type_ = sqltypes.to_instance(type_) +@compiles(ColumnNullable, 'mysql') +@compiles(ColumnName, 'mysql') @compiles(ColumnDefault, 'mysql') -def _change_column_default(element, compiler, **kw): - return _mysql_change( - element, compiler, - server_default=element.default, - ) - @compiles(ColumnType, 'mysql') -def _change_column_type(element, compiler, **kw): - return _mysql_change( - element, compiler, - type_=element.type_ - ) +def _mysql_doesnt_support_individual(element, compiler, **kw): + raise NotImplementedError( + "Individual alter column constructs not supported by MySQL" + ) + -def _mysql_change(element, compiler, nullable=None, - server_default=False, type_=None, - name=None): - if name is None: - name = element.column_name - if nullable is None: - nullable=True - if server_default is False: - server_default = element.existing_server_default - if type_ is None: - if element.existing_type is None: - raise util.CommandError("All MySQL column alterations " - "require the existing type") - type_ = element.existing_type +@compiles(MySQLAlterColumn, "mysql") +def _mysql_alter_column(element, compiler, **kw): return "%s CHANGE %s %s" % ( alter_table(compiler, element.table_name, element.schema), element.column_name, _mysql_colspec( compiler, - name=name, - nullable=nullable, - server_default=server_default, - type_=type_ + name=element.newname, + nullable=element.nullable, + server_default=element.default, + type_=element.type_ ), ) diff --git a/alembic/templates/generic/script.py.mako b/alembic/templates/generic/script.py.mako index 369c8374..b17c3ad3 100644 --- a/alembic/templates/generic/script.py.mako +++ b/alembic/templates/generic/script.py.mako @@ -11,6 +11,7 @@ down_revision = ${repr(down_revision)} from alembic.op import * import sqlalchemy as sa +${imports if imports else ""} def upgrade(): ${upgrades if upgrades else "pass"} diff --git a/alembic/templates/multidb/script.py.mako b/alembic/templates/multidb/script.py.mako index b3b5da2d..b9e606ca 100644 --- a/alembic/templates/multidb/script.py.mako +++ b/alembic/templates/multidb/script.py.mako @@ -11,6 +11,7 @@ down_revision = ${repr(down_revision)} from alembic.op import * import sqlalchemy as sa +${imports if imports else ""} def upgrade(engine): eval("upgrade_%s" % engine.name)() diff --git a/alembic/templates/pylons/script.py.mako b/alembic/templates/pylons/script.py.mako index 369c8374..b17c3ad3 100644 --- a/alembic/templates/pylons/script.py.mako +++ b/alembic/templates/pylons/script.py.mako @@ -11,6 +11,7 @@ down_revision = ${repr(down_revision)} from alembic.op import * import sqlalchemy as sa +${imports if imports else ""} def upgrade(): ${upgrades if upgrades else "pass"} diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index b1388927..7fd6ea4f 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -92,37 +92,40 @@ class AutogenerateDiffTest(TestCase): eq_(diffs[2][0], 'remove_column') eq_(diffs[2][2].name, 'pw') - eq_(diffs[3][0], "modify_default") - eq_(diffs[3][1], "user") - eq_(diffs[3][2], "a1") - eq_(diffs[3][5].arg, "x") + eq_(diffs[3][0][0], "modify_default") + eq_(diffs[3][0][1], "user") + eq_(diffs[3][0][2], "a1") + eq_(diffs[3][0][5].arg, "x") - eq_(diffs[4][0], 'modify_nullable') - eq_(diffs[4][4], True) - eq_(diffs[4][5], False) + eq_(diffs[4][0][0], 'modify_nullable') + eq_(diffs[4][0][4], True) + eq_(diffs[4][0][5], False) eq_(diffs[5][0], "add_column") eq_(diffs[5][1], "order") eq_(diffs[5][2], metadata.tables['order'].c.user_id) - eq_(diffs[6][0], "modify_type") - eq_(diffs[6][1], "order") - eq_(diffs[6][2], "amount") - eq_(repr(diffs[6][4]), "NUMERIC(precision=8, scale=2)") - eq_(repr(diffs[6][5]), "Numeric(precision=10, scale=2)") + eq_(diffs[6][0][0], "modify_type") + eq_(diffs[6][0][1], "order") + eq_(diffs[6][0][2], "amount") + eq_(repr(diffs[6][0][4]), "NUMERIC(precision=8, scale=2)") + eq_(repr(diffs[6][0][5]), "Numeric(precision=10, scale=2)") - eq_(diffs[7][0], 'modify_nullable') - eq_(diffs[7][4], False) - eq_(diffs[7][5], True) + eq_(diffs[6][1][0], 'modify_nullable') + eq_(diffs[6][1][4], False) + eq_(diffs[6][1][5], True) - eq_(diffs[8][0], "add_column") - eq_(diffs[8][1], "address") - eq_(diffs[8][2], metadata.tables['address'].c.street) + eq_(diffs[7][0], "add_column") + eq_(diffs[7][1], "address") + eq_(diffs[7][2], metadata.tables['address'].c.street) def test_render_diffs(self): """test a full render including indentation""" + # TODO: this test isn't going + # to be so spectacular on Py3K... + metadata = _model_two() connection = self.bind.connect() template_args = {} @@ -130,9 +133,6 @@ class AutogenerateDiffTest(TestCase): connection=connection, autogenerate_metadata=metadata) autogenerate.produce_migration_diffs(template_args) - - print template_args['upgrades'] - return eq_(template_args['upgrades'], """### commands auto generated by Alembic - please adjust! ### create_table('item', @@ -146,7 +146,8 @@ class AutogenerateDiffTest(TestCase): drop_column('user', u'pw') alter_column('user', 'a1', existing_type=sa.TEXT(), - server_default='x') + server_default='x', + existing_nullable=True) alter_column('user', 'name', existing_type=sa.VARCHAR(length=50), nullable=False) @@ -154,9 +155,6 @@ class AutogenerateDiffTest(TestCase): alter_column('order', u'amount', existing_type=sa.NUMERIC(precision=8, scale=2), type_=sa.Numeric(precision=10, scale=2), - existing_server_default='0') - alter_column('order', u'amount', - existing_type=sa.NUMERIC(precision=8, scale=2), nullable=True, existing_server_default='0') add_column('address', sa.Column('street', sa.String(length=50), nullable=True)) @@ -171,7 +169,9 @@ class AutogenerateDiffTest(TestCase): ) add_column('user', sa.Column(u'pw', sa.VARCHAR(length=50), nullable=True)) alter_column('user', 'a1', - existing_type=sa.TEXT()) + existing_type=sa.TEXT(), + server_default=None, + existing_nullable=True) alter_column('user', 'name', existing_type=sa.VARCHAR(length=50), nullable=True) @@ -179,9 +179,6 @@ class AutogenerateDiffTest(TestCase): alter_column('order', u'amount', existing_type=sa.Numeric(precision=10, scale=2), type_=sa.NUMERIC(precision=8, scale=2), - existing_server_default='0') - alter_column('order', u'amount', - existing_type=sa.NUMERIC(precision=8, scale=2), nullable=False, existing_server_default='0') drop_column('address', 'street') @@ -203,7 +200,7 @@ class AutogenRenderTest(TestCase): Column("amount", Numeric(5, 2)), ) eq_ignore_whitespace( - autogenerate._add_table(t), + autogenerate._add_table(t, set()), "create_table('test'," "sa.Column('id', sa.Integer(), nullable=False)," "sa.Column('address_id', sa.Integer(), nullable=True)," @@ -218,14 +215,14 @@ class AutogenRenderTest(TestCase): def test_render_drop_table(self): eq_( - autogenerate._drop_table(Table("sometable", MetaData())), + autogenerate._drop_table(Table("sometable", MetaData()), set()), "drop_table('sometable')" ) def test_render_add_column(self): eq_( autogenerate._add_column( - "foo", Column("x", Integer, server_default="5")), + "foo", Column("x", Integer, server_default="5"), set()), "add_column('foo', sa.Column('x', sa.Integer(), " "server_default='5', nullable=True))" ) @@ -233,37 +230,43 @@ class AutogenRenderTest(TestCase): def test_render_drop_column(self): eq_( autogenerate._drop_column( - "foo", Column("x", Integer, server_default="5")), + "foo", Column("x", Integer, server_default="5"), set()), + "drop_column('foo', 'x')" ) def test_render_modify_type(self): eq_ignore_whitespace( - autogenerate._modify_type( + autogenerate._modify_col( "sometable", "somecolumn", - None, None, - CHAR(10), CHAR(20)), + set(), + type_=CHAR(10), existing_type=CHAR(20)), "alter_column('sometable', 'somecolumn', " "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))" ) def test_render_modify_nullable(self): eq_ignore_whitespace( - autogenerate._modify_nullable( - "sometable", "somecolumn", Integer(), - None, - True, None), + autogenerate._modify_col( + "sometable", "somecolumn", + set(), + existing_type=Integer(), + nullable=True), "alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True)" ) def test_render_modify_nullable_w_default(self): eq_ignore_whitespace( - autogenerate._modify_nullable( - "sometable", "somecolumn", Integer(), - "5", - True, None), + autogenerate._modify_col( + "sometable", "somecolumn", + set(), + existing_type=Integer(), + existing_server_default="5", + nullable=True), "alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True, " "existing_server_default='5')" ) + +# TODO: tests for dialect-specific type rendering + imports diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 8090d8cb..4951aba8 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -1,5 +1,5 @@ -from tests import _op_fixture -from alembic import op +from tests import _op_fixture, assert_raises_message +from alembic import op, util from sqlalchemy import Integer, Column, ForeignKey, \ UniqueConstraint, Table, MetaData, String from sqlalchemy.sql import table @@ -25,3 +25,18 @@ def test_col_nullable(): 'ALTER TABLE t1 CHANGE c1 c1 INTEGER NOT NULL' ) +def test_col_multi_alter(): + context = _op_fixture('mysql') + op.alter_column('t1', 'c1', nullable=False, server_default="q", type_=Integer) + context.assert_( + "ALTER TABLE t1 CHANGE c1 c1 INTEGER NOT NULL DEFAULT 'q'" + ) + + +def test_col_alter_type_required(): + context = _op_fixture('mysql') + assert_raises_message( + util.CommandError, + "All MySQL ALTER COLUMN operations require the existing type.", + op.alter_column, 't1', 'c1', nullable=False, server_default="q" + )