]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
more tests and now its sort of working
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Nov 2011 22:43:31 +0000 (17:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 27 Nov 2011 22:43:31 +0000 (17:43 -0500)
alembic/autogenerate.py
alembic/context.py
alembic/script.py
alembic/templates/generic/script.py.mako
alembic/templates/multidb/script.py.mako
alembic/templates/pylons/script.py.mako
tests/__init__.py
tests/test_autogenerate.py
tests/test_revision_create.py
tests/test_revision_paths.py
tests/test_versioning.py

index e017bb10d88bf350b9e2e6447ce021e42a9bf343..2eb6e71894376f3c4d1f2b8281aaa6affa567e7a 100644 (file)
@@ -5,6 +5,7 @@ from alembic.context import _context_opts, get_bind
 from alembic import util
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy import types as sqltypes, schema
+import re
 
 ###################################################
 # top level
@@ -19,8 +20,8 @@ def produce_migration_diffs(template_args):
     connection = get_bind()
     diffs = []
     _produce_net_changes(connection, metadata, diffs)
-    _set_upgrade(template_args, _produce_upgrade_commands(diffs))
-    _set_downgrade(template_args, _produce_downgrade_commands(diffs))
+    _set_upgrade(template_args, _indent(_produce_upgrade_commands(diffs)))
+    _set_downgrade(template_args, _indent(_produce_downgrade_commands(diffs)))
 
 def _set_upgrade(template_args, text):
     template_args[_context_opts['upgrade_token']] = text
@@ -28,6 +29,12 @@ def _set_upgrade(template_args, text):
 def _set_downgrade(template_args, text):
     template_args[_context_opts['downgrade_token']] = text
 
+def _indent(text):
+    text = "### commands auto generated by Alembic - please adjust! ###\n" + text
+    text += "\n### end Alembic commands ###"
+    text = re.compile(r'^', re.M).sub("    ", text).strip()
+    return text
+
 ###################################################
 # walk structures
 
@@ -85,7 +92,12 @@ def _compare_columns(tname, conn_table, metadata_table, diffs):
         for cname in metadata_col_names.difference(conn_col_names)
     )
     diffs.extend(
-        ("remove_column", tname, cname)
+        ("remove_column", tname, schema.Column(
+            cname,
+            conn_table[cname]['type'],
+            nullable=conn_table[cname]['nullable'],
+            server_default=conn_table[cname]['default']
+        ))
         for cname in conn_col_names.difference(metadata_col_names)
     )
 
@@ -145,28 +157,49 @@ _type_comparators = {
 }
 
 ###################################################
-# render python
+# produce command structure
 
 def _produce_upgrade_commands(diffs):
     buf = []
     for diff in diffs:
-        cmd = _commands[diff[0]]
-        buf.append(cmd(*diff[1:]))
+        buf.append(_invoke_command("upgrade", diff))
     return "\n".join(buf)
 
 def _produce_downgrade_commands(diffs):
     buf = []
     for diff in diffs:
-        cmd = _commands[diff[0]]
-        buf.append(cmd(*diff[1:]))
+        buf.append(_invoke_command("downgrade", diff))
     return "\n".join(buf)
 
+def _invoke_command(updown, args):
+    cmd_type = args[0]
+    adddrop, cmd_type = cmd_type.split("_")
+
+    cmd_args = args[1:]
+    cmd_callables = _commands[cmd_type]
+
+    if len(cmd_callables) == 2:
+        if (
+            updown == "upgrade" and adddrop == "add"
+        ) or (
+            updown == "downgrade" and adddrop == "remove"
+        ):
+            return cmd_callables[1](*cmd_args)
+        else:
+            return cmd_callables[0](*cmd_args)
+    else:
+        if updown == "upgrade":
+            return cmd_callables[0](
+                    cmd_args[0], cmd_args[1], cmd_args[3])
+        else:
+            return cmd_callables[0](
+                    cmd_args[0], cmd_args[1], cmd_args[2])
+
+###################################################
+# render python
+
 def _add_table(table):
-    return \
-"""create_table(%(tablename)r, 
-        %(args)s
-    )
-""" % {
+    return "create_table(%(tablename)r,\n%(args)s\n)" % {
         'tablename':table.name,
         'args':',\n'.join(
             [_render_column(col) for col in table.c] +
@@ -178,16 +211,16 @@ def _add_table(table):
         ),
     }
 
-def _drop_table(tname):
-    return "drop_table(%r)" % tname
+def _drop_table(table):
+    return "drop_table(%r)" % table.name
 
 def _add_column(tname, column):
     return "add_column(%r, %s)" % (
             tname, 
             _render_column(column))
 
-def _drop_column(tname, cname):
-    return "drop_column(%r, %r)" % (tname, cname)
+def _drop_column(tname, column):
+    return "drop_column(%r, %r)" % (tname, column.name)
 
 def _modify_type(tname, cname, type_):
     return "alter_column(%r, %r, type=%r)" % (
@@ -200,22 +233,37 @@ def _modify_nullable(tname, cname, nullable):
     )
 
 _commands = {
+    "table":(_drop_table, _add_table),
+    "column":(_drop_column, _add_column),
+    "type":(_modify_type,),
+    "nullable":(_modify_nullable,),
 }
 
+def _autogenerate_prefix():
+    return _context_opts['autogenerate_sqlalchemy_prefix']
+
 def _render_column(column):
     opts = []
     if column.server_default:
-        opts.append(("server_default", column.server_default))
+        opts.append(("server_default", _render_server_default(column.server_default)))
     if column.nullable is not None:
         opts.append(("nullable", column.nullable))
 
     # TODO: for non-ascii colname, assign a "key"
-    return "Column(%(name)r, %(type)r, %(kw)s)" % {
+    return "%(prefix)sColumn(%(name)r, %(prefix)s%(type)r, %(kw)s)" % {
+        'prefix':_autogenerate_prefix(),
         'name':column.name,
         'type':column.type,
         'kw':", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
     }
 
+def _render_server_default(default):
+    assert isinstance(default, schema.DefaultClause)
+    return "%(prefix)sDefaultClause(%(arg)r)" % {
+                'prefix':_autogenerate_prefix(),
+                'arg':str(default.arg)
+            }
+
 def _render_constraint(constraint):
     renderer = _constraint_renderers.get(type(constraint), None)
     if renderer:
@@ -226,10 +274,11 @@ def _render_constraint(constraint):
 def _render_primary_key(constraint):
     opts = []
     if constraint.name:
-        opts.append(("name", constraint.name))
-    return "PrimaryKeyConstraint(%(args)s)" % {
+        opts.append(("name", repr(constraint.name)))
+    return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % {
+        "prefix":_autogenerate_prefix(),
         "args":", ".join(
-            [c.key for c in constraint.columns] +
+            [repr(c.key) for c in constraint.columns] +
             ["%s=%s" % (kwname, val) for kwname, val in opts]
         ),
     }
@@ -237,9 +286,10 @@ def _render_primary_key(constraint):
 def _render_foreign_key(constraint):
     opts = []
     if constraint.name:
-        opts.append(("name", constraint.name))
+        opts.append(("name", repr(constraint.name)))
     # TODO: deferrable, initially, etc.
-    return "ForeignKeyConstraint([%(cols)s], [%(refcols)s], %(args)s)" % {
+    return "%(prefix)sForeignKeyConstraint([%(cols)s], [%(refcols)s], %(args)s)" % {
+        "prefix":_autogenerate_prefix(),
         "cols":", ".join(f.parent.key for f in constraint.elements),
         "refcols":", ".join(repr(f._get_colspec()) for f in constraint.elements),
         "args":", ".join(
@@ -250,8 +300,10 @@ def _render_foreign_key(constraint):
 def _render_check_constraint(constraint):
     opts = []
     if constraint.name:
-        opts.append(("name", constraint.name))
-    return "CheckConstraint('TODO')"
+        opts.append(("name", repr(constraint.name)))
+    return "%(prefix)sCheckConstraint('TODO')" % {
+            "prefix":_autogenerate_prefix()
+        }
 
 _constraint_renderers = {
     schema.PrimaryKeyConstraint:_render_primary_key,
index 7fee2b7a0a2949676135a0427062a5d2eb34c7fb..adba2300b7b87320fe8799e5d2f388d86ea19d8e 100644 (file)
@@ -149,7 +149,7 @@ _context = None
 _script = None
 
 def _opts(cfg, script, **kw):
-    """Set up options that will be used by the :func:`.configure_connection`
+    """Set up options that will be used by the :func:`.configure`
     function.
     
     This basically sets some global variables.
@@ -263,7 +263,8 @@ def configure(
         tag=None,
         autogenerate_metadata=None,
         upgrade_token="upgrades",
-        downgrade_token="downgrades"
+        downgrade_token="downgrades",
+        autogenerate_sqlalchemy_prefix="sa.",
     ):
     """Configure the migration environment.
     
@@ -311,6 +312,10 @@ def configure(
     :param downgrade_token: when running "alembic revision" with the ``--autogenerate``
      option, the text of the candidate downgrade operations will be present in this
      template variable when script.py.mako is rendered.
+    :param autogenerate_sqlalchemy_prefix: When autogenerate refers to SQLAlchemy 
+     :class:`~sqlalchemy.schema.Column` or type classes, this prefix will be used
+     (i.e. ``sa.Column("somename", sa.Integer)``)
+     
     """
 
     if connection:
@@ -339,6 +344,7 @@ def configure(
     opts['autogenerate_metadata'] = autogenerate_metadata
     opts['upgrade_token'] = upgrade_token
     opts['downgrade_token'] = downgrade_token
+    opts['autogenerate_sqlalchemy_prefix'] = autogenerate_sqlalchemy_prefix
     _context = Context(
                         dialect, _script, connection, 
                         opts['fn'],
index e36b888cec8e1d59242d1b23fae27c7f637aeabf..7d2ee46eeacb4bc781c5012fe8cbe9a7c0447b15 100644 (file)
@@ -180,7 +180,7 @@ class ScriptDirectory(object):
                     shutil.copy, 
                     src, dest)
 
-    def generate_rev(self, revid, message, **kw):
+    def generate_rev(self, revid, message, refresh=False, **kw):
         current_head = self._current_head()
         path = self._rev_path(revid)
         self.generate_template(
@@ -192,12 +192,16 @@ class ScriptDirectory(object):
             message=message if message is not None else ("empty message"),
             **kw
         )
-        script = Script.from_path(path)
-        self._revision_map[script.revision] = script
-        if script.down_revision:
-            self._revision_map[script.down_revision].\
-                    add_nextrev(script.revision)
-        return script
+        if refresh:
+            script = Script.from_path(path)
+            self._revision_map[script.revision] = script
+            if script.down_revision:
+                self._revision_map[script.down_revision].\
+                        add_nextrev(script.revision)
+            return script
+        else:
+            return revid
+
 
 class Script(object):
     """Represent a single revision file in a ``versions/`` directory."""
index c6bf4041ad8d7ad61090de70693219c337dac939..369c8374257e2ac81a47f11317fde8f013c1abc3 100644 (file)
@@ -10,6 +10,7 @@ Create Date: ${create_date}
 down_revision = ${repr(down_revision)}
 
 from alembic.op import *
+import sqlalchemy as sa
 
 def upgrade():
     ${upgrades if upgrades else "pass"}
index f333d6490ff86e253ab3206ba92dd955cab0cbba..b3b5da2da1e6b37a13cb167c2b0592337e25abf6 100644 (file)
@@ -10,6 +10,7 @@ Create Date: ${create_date}
 down_revision = ${repr(down_revision)}
 
 from alembic.op import *
+import sqlalchemy as sa
 
 def upgrade(engine):
     eval("upgrade_%s" % engine.name)()
index c6bf4041ad8d7ad61090de70693219c337dac939..369c8374257e2ac81a47f11317fde8f013c1abc3 100644 (file)
@@ -10,6 +10,7 @@ Create Date: ${create_date}
 down_revision = ${repr(down_revision)}
 
 from alembic.op import *
+import sqlalchemy as sa
 
 def upgrade():
     ${upgrades if upgrades else "pass"}
index 720e5fbaef8133de05d926ad6c349e8e3e3e8f27..032d9b31f2b244f3fedf7a7679cbc2fa00caeaf7 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy.engine import url, default
 import shutil
 import os
 import itertools
-from sqlalchemy import create_engine, text
+from sqlalchemy import create_engine, text, MetaData
 from alembic import context, util
 import re
 from alembic.script import ScriptDirectory
@@ -218,7 +218,9 @@ def staging_env(create=True, template="generic"):
     cfg = _testing_config()
     if create:
         command.init(cfg, os.path.join(staging_directory, 'scripts'))
-    return script.ScriptDirectory.from_config(cfg)
+    sc = script.ScriptDirectory.from_config(cfg)
+    context._opts(cfg,sc, fn=lambda:None)
+    return sc
 
 def clear_staging_env():
     shutil.rmtree(staging_directory, True)
@@ -230,7 +232,7 @@ def three_rev_fixture(cfg):
     c = util.rev_id()
 
     script = ScriptDirectory.from_config(cfg)
-    script.generate_rev(a, None)
+    script.generate_rev(a, None, refresh=True)
     script.write(a, """
 down_revision = None
 
@@ -244,7 +246,7 @@ def downgrade():
 
 """)
 
-    script.generate_rev(b, None)
+    script.generate_rev(b, None, refresh=True)
     script.write(b, """
 down_revision = '%s'
 
@@ -258,7 +260,7 @@ def downgrade():
 
 """ % a)
 
-    script.generate_rev(c, None)
+    script.generate_rev(c, None, refresh=True)
     script.write(c, """
 down_revision = '%s'
 
index d6d18e0199fb0c736dbe30f1f56f7f521aa9c44a..475135617a0b67ed33fbd8aa3e7a6e74cac0b733 100644 (file)
@@ -1,8 +1,8 @@
 from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
     Numeric, CHAR, NUMERIC, ForeignKey, DATETIME
-from alembic import autogenerate
+from alembic import autogenerate, context
 from unittest import TestCase
-from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace
+from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace, capture_context_buffer, _no_sql_testing_config, _testing_config
 
 def _model_one():
     m = MetaData()
@@ -76,14 +76,21 @@ class AutogenerateDiffTest(TestCase):
         extra = diffs[1][1]
         eq_(extra.name, "extra")
         del diffs[1]
-        eq_(repr(diffs[3][3]), "NUMERIC(precision=8, scale=2)")
-        eq_(repr(diffs[3][4]), "Numeric(precision=10, scale=2)")
-        del diffs[3]
+
+        dropcol = diffs[1][2]
+        del diffs[1]
+        eq_(dropcol.name, "pw")
+        eq_(dropcol.nullable, True)
+        eq_(dropcol.type._type_affinity, String)
+        eq_(dropcol.type.length, 50)
+
+        eq_(repr(diffs[2][3]), "NUMERIC(precision=8, scale=2)")
+        eq_(repr(diffs[2][4]), "Numeric(precision=10, scale=2)")
+        del diffs[2]
         eq_(
             diffs,
             [
                 ('add_table', metadata.tables['item']), 
-                ('remove_column', 'user', u'pw'), 
                 ('modify_nullable', 'user', 'name', True, False), 
                 ('modify_nullable', 'order', u'amount', False, True), 
                 ('add_column', 'address', 
@@ -91,7 +98,47 @@ class AutogenerateDiffTest(TestCase):
             ]
         )
 
+    def test_render_diffs(self):
+        metadata = _model_two()
+        connection = self.bind.connect()
+        template_args = {}
+        context.configure(
+            connection=self.bind.connect(), 
+            autogenerate_metadata=metadata)
+        autogenerate.produce_migration_diffs(template_args)
+        eq_(template_args['upgrades'],
+"""### commands auto generated by Alembic - please adjust! ###
+    create_table('item',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('description', sa.String(length=100), nullable=True),
+    sa.PrimaryKeyConstraint('id')
+    )
+    drop_table(u'extra')
+    drop_column('user', u'pw')
+    alter_column('user', 'name', nullable=False)
+    alter_column('order', u'amount', type=Numeric(precision=10, scale=2))
+    alter_column('order', u'amount', nullable=True)
+    add_column('address', sa.Column('street', sa.String(length=50), nullable=True))
+    ### end Alembic commands ###""")
+        eq_(template_args['downgrades'],
+"""### commands auto generated by Alembic - please adjust! ###
+    drop_table('item')
+    create_table(u'extra',
+    sa.Column(u'x', sa.CHAR(), nullable=True),
+    sa.PrimaryKeyConstraint()
+    )
+    add_column('user', sa.Column(u'pw', sa.VARCHAR(length=50), nullable=True))
+    alter_column('user', 'name', nullable=True)
+    alter_column('order', u'amount', type=NUMERIC(precision=8, scale=2))
+    alter_column('order', u'amount', nullable=False)
+    drop_column('address', 'street')
+    ### end Alembic commands ###""")
+
 class AutogenRenderTest(TestCase):
+    @classmethod
+    def setup_class(cls):
+        context._context_opts['autogenerate_sqlalchemy_prefix'] = 'sa.'
+
     def test_render_table_upgrade(self):
         m = MetaData()
         t = Table('test', m,
@@ -102,32 +149,47 @@ class AutogenRenderTest(TestCase):
         )
         eq_ignore_whitespace(
             autogenerate._add_table(t),
-            "create_table('test', "
-            "Column('id', Integer(), nullable=False),"
-            "Column('address_id', Integer(), nullable=True),"
-            "Column('timestamp', DATETIME(), "
-                "server_default=DefaultClause('NOW()', for_update=False), "
+            "create_table('test',"
+            "sa.Column('id', sa.Integer(), nullable=False),"
+            "sa.Column('address_id', sa.Integer(), nullable=True),"
+            "sa.Column('timestamp', sa.DATETIME(), "
+                "server_default=sa.DefaultClause('NOW()'), "
                 "nullable=True),"
-            "Column('amount', Numeric(precision=5, scale=2), nullable=True),"
-            "ForeignKeyConstraint([address_id], ['address.id'], ),"
-            "PrimaryKeyConstraint(id)"
-            " )"
+            "sa.Column('amount', sa.Numeric(precision=5, scale=2), nullable=True),"
+            "sa.ForeignKeyConstraint([address_id], ['address.id'], ),"
+            "sa.PrimaryKeyConstraint('id')"
+            ")"
         )
 
-    def test_render_table_downgrade(self):
+    def test_render_drop_table(self):
         eq_(
-            autogenerate._drop_table("sometable"),
+            autogenerate._drop_table(Table("sometable", MetaData())),
             "drop_table('sometable')"
         )
 
-    def test_render_type_upgrade(self):
+    def test_render_add_column(self):
+        eq_(
+            autogenerate._add_column(
+                    "foo", Column("x", Integer, server_default="5")),
+            "add_column('foo', sa.Column('x', sa.Integer(), "
+                "server_default=sa.DefaultClause('5'), nullable=True))"
+        )
+
+    def test_render_drop_column(self):
+        eq_(
+            autogenerate._drop_column(
+                    "foo", Column("x", Integer, server_default="5")),
+            "drop_column('foo', 'x')"
+        )
+
+    def test_render_modify_type(self):
         eq_(
             autogenerate._modify_type(
                         "sometable", "somecolumn", CHAR(10)),
             "alter_column('sometable', 'somecolumn', type=CHAR(length=10))"
         )
 
-    def test_render_nullable_upgrade(self):
+    def test_render_modify_nullable(self):
         eq_(
             autogenerate._modify_nullable(
                         "sometable", "somecolumn", True),
index be1e7819635d24628fc60de6aa13ea8766cf0748..6bc858397a315080712ade1a480c2ed47a4a931b 100644 (file)
@@ -19,7 +19,7 @@ def test_003_heads():
     eq_(env._get_heads(), [])
 
 def test_004_rev():
-    script = env.generate_rev(abc, "this is a message")
+    script = env.generate_rev(abc, "this is a message", refresh=True)
     eq_(script.doc, "this is a message")
     eq_(script.revision, abc)
     eq_(script.down_revision, None)
@@ -29,7 +29,7 @@ def test_004_rev():
     eq_(env._get_heads(), [abc])
 
 def test_005_nextrev():
-    script = env.generate_rev(def_, "this is the next rev")
+    script = env.generate_rev(def_, "this is the next rev", refresh=True)
     eq_(script.revision, def_)
     eq_(script.down_revision, abc)
     eq_(env._revision_map[abc].nextrev, set([def_]))
@@ -50,6 +50,13 @@ def test_006_from_clean_env():
     eq_(def_rev.down_revision, abc)
     eq_(env._get_heads(), [def_])
 
+def test_007_no_refresh():
+    script = env.generate_rev(util.rev_id(), "dont' refresh")
+    ne_(script, env._as_rev_number("head"))
+    env2 = staging_env(create=False)
+    eq_(script, env2._as_rev_number("head"))
+
+
 def setup():
     global env
     env = staging_env()
index c6e0ea676ac2a2a7044ad64a74ee360bcabd60e0..b4bff6e045651fe0d2659d842519f35ce247a6e7 100644 (file)
@@ -6,11 +6,11 @@ def setup():
     global env
     env = staging_env()
     global a, b, c, d, e
-    a = env.generate_rev(util.rev_id(), None)
-    b = env.generate_rev(util.rev_id(), None)
-    c = env.generate_rev(util.rev_id(), None)
-    d = env.generate_rev(util.rev_id(), None)
-    e = env.generate_rev(util.rev_id(), None)
+    a = env.generate_rev(util.rev_id(), None, refresh=True)
+    b = env.generate_rev(util.rev_id(), None, refresh=True)
+    c = env.generate_rev(util.rev_id(), None, refresh=True)
+    d = env.generate_rev(util.rev_id(), None, refresh=True)
+    e = env.generate_rev(util.rev_id(), None, refresh=True)
 
 def teardown():
     clear_staging_env()
index 6928c50e3f66eb91b576f5d7a7a123d7d4d39b3b..75ba348a2b09332c240ad32dd1099d27a433ecc7 100644 (file)
@@ -10,7 +10,7 @@ def test_001_revisions():
     c = util.rev_id()
 
     script = ScriptDirectory.from_config(cfg)
-    script.generate_rev(a, None)
+    script.generate_rev(a, None, refresh=True)
     script.write(a, """
 down_revision = None
 
@@ -24,7 +24,7 @@ def downgrade():
 
 """)
 
-    script.generate_rev(b, None)
+    script.generate_rev(b, None, refresh=True)
     script.write(b, """
 down_revision = '%s'
 
@@ -38,7 +38,7 @@ def downgrade():
 
 """ % a)
 
-    script.generate_rev(c, None)
+    script.generate_rev(c, None, refresh=True)
     script.write(c, """
 down_revision = '%s'