]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- move -c / -n arguments to front
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Apr 2011 17:07:51 +0000 (13:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Apr 2011 17:07:51 +0000 (13:07 -0400)
- add create_table, drop_table
- support range revs for when the --sql flag is set

alembic/command.py
alembic/config.py
alembic/context.py
alembic/op.py
alembic/script.py
tests/test_revision_paths.py

index b2352d6c448e33da41352e8ad944105d15bacccd..f1d5000634ab4f1106ebcffca5222954e18f4f0f 100644 (file)
@@ -71,7 +71,7 @@ def upgrade(config, revision, sql=False):
     script = ScriptDirectory.from_config(config)
     context.opts(
         config,
-        fn = functools.partial(script.upgrade_from, revision),
+        fn = functools.partial(script.upgrade_from, sql, revision),
         as_sql = sql
     )
     script.run_env()
@@ -82,7 +82,7 @@ def downgrade(config, revision, sql=False):
     script = ScriptDirectory.from_config(config)
     context.opts(
         config,
-        fn = functools.partial(script.downgrade_to, revision),
+        fn = functools.partial(script.downgrade_to, sql, revision),
         as_sql = sql,
     )
     script.run_env()
index a8c261483d8cbcad3e21c16e23371e2d479952f0..008189700f937a4d12c7806e25a53ac483c20ad3 100644 (file)
@@ -5,8 +5,9 @@ import inspect
 import os
 
 class Config(object):
-    def __init__(self, file_):
+    def __init__(self, file_, ini_section='alembic'):
         self.config_file_name = file_
+        self.config_ini_section = ini_section
 
     @util.memoized_property
     def file_config(self):
@@ -21,21 +22,18 @@ class Config(object):
         return dict(self.file_config.items(name))
 
     def get_main_option(self, name, default=None):
-        if not self.file_config.has_section('alembic'):
+        if not self.file_config.has_section(self.config_ini_section):
             util.err("No config file %r found, or file has no "
-                                "'[alembic]' section" % self.config_file_name)
-        if self.file_config.get('alembic', name):
-            return self.file_config.get('alembic', name)
+                                "'[%s]' section" % 
+                                (self.config_file_name, self.config_ini_section))
+        if self.file_config.get(self.config_ini_section, name):
+            return self.file_config.get(self.config_ini_section, name)
         else:
             return default
 
 def main(argv):
 
     def add_options(parser, positional, kwargs):
-        parser.add_argument("-c", "--config", 
-                            type=str, 
-                            default="alembic.ini", 
-                            help="Alternate config file")
         if 'template' in kwargs:
             parser.add_argument("-t", "--template",
                             default='generic',
@@ -59,6 +57,14 @@ def main(argv):
             subparser.add_argument(arg, help=positional_help.get(arg))
 
     parser = ArgumentParser()
+    parser.add_argument("-c", "--config", 
+                        type=str, 
+                        default="alembic.ini", 
+                        help="Alternate config file")
+    parser.add_argument("-n", "--name", 
+                        type=str, 
+                        default="alembic", 
+                        help="Name of section in .ini file to use for Alembic config")
     subparsers = parser.add_subparsers()
 
     for fn in [getattr(command, n) for n in dir(command)]:
@@ -84,7 +90,7 @@ def main(argv):
 
     fn, positional, kwarg = options.cmd
 
-    cfg = Config(options.config)
+    cfg = Config(options.config, options.name)
     try:
         fn(cfg, 
                     *[getattr(options, k) for k in positional], 
index ae1b3f09430e606f5dc86db5b6d61dc999432f1d..6ea301b2f7427d0abb385386df8eff4acd58f0ce 100644 (file)
@@ -1,7 +1,7 @@
 from alembic import util
 from sqlalchemy import MetaData, Table, Column, String, literal_column, \
     text
-from sqlalchemy.schema import CreateTable
+from sqlalchemy import schema, create_engine
 import logging
 
 log = logging.getLogger(__name__)
@@ -93,6 +93,29 @@ class DefaultContext(object):
     def execute(self, sql):
         self._exec(sql)
 
+    @util.memoized_property
+    def _stdout_connection(self):
+        def dump(construct, *multiparams, **params):
+            self._exec(construct)
+
+        return create_engine(self.connection.engine.url, 
+                        strategy="mock", executor=dump)
+
+    @property
+    def bind(self):
+        """Return a bind suitable for passing to the create() 
+        or create_all() methods of MetaData, Table.
+        
+        Note that when "standard output" mode is enabled, 
+        this bind will be a "mock" connection handler that cannot
+        return results and is only appropriate for DDL.
+        
+        """
+        if self.as_sql:
+            return self._stdout_connection
+        else:
+            return self.connection
+
     def alter_column(self, table_name, column_name, 
                         nullable=util.NO_VALUE,
                         server_default=util.NO_VALUE,
@@ -112,6 +135,14 @@ class DefaultContext(object):
     def add_constraint(self, const):
         self._exec(schema.AddConstraint(const))
 
+    def create_table(self, table):
+        self._exec(schema.CreateTable(table))
+        for index in table.indexes:
+            self._exec(schema.CreateIndex(index))
+
+    def drop_table(self, table):
+        self._exec(schema.DropTable(table))
+
 def opts(cfg, **kw):
     global _context_opts, config
     _context_opts = kw
index 77fc7505502c374ea4747cd3985361bd3ccd3897..c402d7cb31f06e46ad6a0da44c0f87aee5072c4d 100644 (file)
@@ -6,7 +6,11 @@ from sqlalchemy import schema
 __all__ = [
             'alter_column', 
             'create_foreign_key', 
+            'create_table',
+            'drop_table',
             'create_unique_constraint', 
+            'get_context',
+            'get_bind',
             'execute']
 
 def alter_column(table_name, column_name, 
@@ -90,5 +94,13 @@ def create_table(name, *columns, **kw):
         _table(name, *columns, **kw)
     )
 
+def drop_table(name, *columns, **kw):
+    get_context().drop_table(
+        _table(name, *columns, **kw)
+    )
+
 def execute(sql):
-    get_context().execute(sql)
\ No newline at end of file
+    get_context().execute(sql)
+
+def get_bind():
+    return get_context().bind
\ No newline at end of file
index ed075f6b3e40b3de3a8cc7b8c9c907a4d8a8f71b..67aa654f57450b37a759c353f3ecdaf3c462c871 100644 (file)
@@ -61,18 +61,35 @@ class ScriptDirectory(object):
         script = upper
         while script != lower:
             yield script
-            script = self._revision_map[script.down_revision]
+            downrev = script.down_revision
+            script = self._revision_map[downrev]
+            if script is None and lower is not None:
+                raise util.CommandError("Couldn't find revision %s" % downrev)
+
+    def upgrade_from(self, range_ok, destination, current_rev):
+        if destination is not None and ':' in destination:
+            if not range_ok:
+                raise util.CommandError("Range revision not allowed")
+            revs = self._revs(*reversed(destination.split(':', 2)))
+        else:
+            revs = self._revs(destination, current_rev)
 
-    def upgrade_from(self, destination, current_rev):
         return [
             (script.module.upgrade, script.revision) for script in 
-            reversed(list(self._revs(destination, current_rev)))
+            reversed(list(revs))
             ]
 
-    def downgrade_to(self, destination, current_rev):
+    def downgrade_to(self, range_ok, destination, current_rev):
+        if destination is not None and ':' in destination:
+            if not range_ok:
+                raise util.CommandError("Range revision not allowed")
+            revs = self._revs(*reversed(destination.split(':', 2)))
+        else:
+            revs = self._revs(current_rev, destination)
+
         return [
             (script.module.downgrade, script.down_revision) for script in 
-            self._revs(current_rev, destination)
+            revs
             ]
 
     def run_env(self):
index c477c0b06e8aa7e5e561e2f56880d79b04f6f612..7feb44452dc427132196e97f21c7f19423788bb7 100644 (file)
@@ -19,7 +19,7 @@ def teardown():
 def test_upgrade_path():
 
     eq_(
-        env.upgrade_from(e.revision, c.revision),
+        env.upgrade_from(False, e.revision, c.revision),
         [
             (d.module.upgrade, d.revision),
             (e.module.upgrade, e.revision),
@@ -27,7 +27,7 @@ def test_upgrade_path():
     )
 
     eq_(
-        env.upgrade_from(c.revision, None),
+        env.upgrade_from(False, c.revision, None),
         [
             (a.module.upgrade, a.revision),
             (b.module.upgrade, b.revision),
@@ -38,7 +38,7 @@ def test_upgrade_path():
 def test_downgrade_path():
 
     eq_(
-        env.downgrade_to(c.revision, e.revision),
+        env.downgrade_to(False, c.revision, e.revision),
         [
             (e.module.downgrade, e.down_revision),
             (d.module.downgrade, d.down_revision),
@@ -46,7 +46,7 @@ def test_downgrade_path():
     )
 
     eq_(
-        env.downgrade_to(None, c.revision),
+        env.downgrade_to(False, None, c.revision),
         [
             (c.module.downgrade, c.down_revision),
             (b.module.downgrade, b.down_revision),