From: Mike Bayer Date: Tue, 19 Apr 2011 17:07:51 +0000 (-0400) Subject: - move -c / -n arguments to front X-Git-Tag: rel_0_1_0~79 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=936fac3e4835f6312fe588cbd909c70cab787369;p=thirdparty%2Fsqlalchemy%2Falembic.git - move -c / -n arguments to front - add create_table, drop_table - support range revs for when the --sql flag is set --- diff --git a/alembic/command.py b/alembic/command.py index b2352d6c..f1d50006 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -71,7 +71,7 @@ def upgrade(config, revision, sql=False): script = ScriptDirectory.from_config(config) context.opts( config, - fn = functools.partial(script.upgrade_from, revision), + fn = functools.partial(script.upgrade_from, sql, revision), as_sql = sql ) script.run_env() @@ -82,7 +82,7 @@ def downgrade(config, revision, sql=False): script = ScriptDirectory.from_config(config) context.opts( config, - fn = functools.partial(script.downgrade_to, revision), + fn = functools.partial(script.downgrade_to, sql, revision), as_sql = sql, ) script.run_env() diff --git a/alembic/config.py b/alembic/config.py index a8c26148..00818970 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -5,8 +5,9 @@ import inspect import os class Config(object): - def __init__(self, file_): + def __init__(self, file_, ini_section='alembic'): self.config_file_name = file_ + self.config_ini_section = ini_section @util.memoized_property def file_config(self): @@ -21,21 +22,18 @@ class Config(object): return dict(self.file_config.items(name)) def get_main_option(self, name, default=None): - if not self.file_config.has_section('alembic'): + if not self.file_config.has_section(self.config_ini_section): util.err("No config file %r found, or file has no " - "'[alembic]' section" % self.config_file_name) - if self.file_config.get('alembic', name): - return self.file_config.get('alembic', name) + "'[%s]' section" % + (self.config_file_name, self.config_ini_section)) + if self.file_config.get(self.config_ini_section, name): + return self.file_config.get(self.config_ini_section, name) else: return default def main(argv): def add_options(parser, positional, kwargs): - parser.add_argument("-c", "--config", - type=str, - default="alembic.ini", - help="Alternate config file") if 'template' in kwargs: parser.add_argument("-t", "--template", default='generic', @@ -59,6 +57,14 @@ def main(argv): subparser.add_argument(arg, help=positional_help.get(arg)) parser = ArgumentParser() + parser.add_argument("-c", "--config", + type=str, + default="alembic.ini", + help="Alternate config file") + parser.add_argument("-n", "--name", + type=str, + default="alembic", + help="Name of section in .ini file to use for Alembic config") subparsers = parser.add_subparsers() for fn in [getattr(command, n) for n in dir(command)]: @@ -84,7 +90,7 @@ def main(argv): fn, positional, kwarg = options.cmd - cfg = Config(options.config) + cfg = Config(options.config, options.name) try: fn(cfg, *[getattr(options, k) for k in positional], diff --git a/alembic/context.py b/alembic/context.py index ae1b3f09..6ea301b2 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -1,7 +1,7 @@ from alembic import util from sqlalchemy import MetaData, Table, Column, String, literal_column, \ text -from sqlalchemy.schema import CreateTable +from sqlalchemy import schema, create_engine import logging log = logging.getLogger(__name__) @@ -93,6 +93,29 @@ class DefaultContext(object): def execute(self, sql): self._exec(sql) + @util.memoized_property + def _stdout_connection(self): + def dump(construct, *multiparams, **params): + self._exec(construct) + + return create_engine(self.connection.engine.url, + strategy="mock", executor=dump) + + @property + def bind(self): + """Return a bind suitable for passing to the create() + or create_all() methods of MetaData, Table. + + Note that when "standard output" mode is enabled, + this bind will be a "mock" connection handler that cannot + return results and is only appropriate for DDL. + + """ + if self.as_sql: + return self._stdout_connection + else: + return self.connection + def alter_column(self, table_name, column_name, nullable=util.NO_VALUE, server_default=util.NO_VALUE, @@ -112,6 +135,14 @@ class DefaultContext(object): def add_constraint(self, const): self._exec(schema.AddConstraint(const)) + def create_table(self, table): + self._exec(schema.CreateTable(table)) + for index in table.indexes: + self._exec(schema.CreateIndex(index)) + + def drop_table(self, table): + self._exec(schema.DropTable(table)) + def opts(cfg, **kw): global _context_opts, config _context_opts = kw diff --git a/alembic/op.py b/alembic/op.py index 77fc7505..c402d7cb 100644 --- a/alembic/op.py +++ b/alembic/op.py @@ -6,7 +6,11 @@ from sqlalchemy import schema __all__ = [ 'alter_column', 'create_foreign_key', + 'create_table', + 'drop_table', 'create_unique_constraint', + 'get_context', + 'get_bind', 'execute'] def alter_column(table_name, column_name, @@ -90,5 +94,13 @@ def create_table(name, *columns, **kw): _table(name, *columns, **kw) ) +def drop_table(name, *columns, **kw): + get_context().drop_table( + _table(name, *columns, **kw) + ) + def execute(sql): - get_context().execute(sql) \ No newline at end of file + get_context().execute(sql) + +def get_bind(): + return get_context().bind \ No newline at end of file diff --git a/alembic/script.py b/alembic/script.py index ed075f6b..67aa654f 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -61,18 +61,35 @@ class ScriptDirectory(object): script = upper while script != lower: yield script - script = self._revision_map[script.down_revision] + downrev = script.down_revision + script = self._revision_map[downrev] + if script is None and lower is not None: + raise util.CommandError("Couldn't find revision %s" % downrev) + + def upgrade_from(self, range_ok, destination, current_rev): + if destination is not None and ':' in destination: + if not range_ok: + raise util.CommandError("Range revision not allowed") + revs = self._revs(*reversed(destination.split(':', 2))) + else: + revs = self._revs(destination, current_rev) - def upgrade_from(self, destination, current_rev): return [ (script.module.upgrade, script.revision) for script in - reversed(list(self._revs(destination, current_rev))) + reversed(list(revs)) ] - def downgrade_to(self, destination, current_rev): + def downgrade_to(self, range_ok, destination, current_rev): + if destination is not None and ':' in destination: + if not range_ok: + raise util.CommandError("Range revision not allowed") + revs = self._revs(*reversed(destination.split(':', 2))) + else: + revs = self._revs(current_rev, destination) + return [ (script.module.downgrade, script.down_revision) for script in - self._revs(current_rev, destination) + revs ] def run_env(self): diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py index c477c0b0..7feb4445 100644 --- a/tests/test_revision_paths.py +++ b/tests/test_revision_paths.py @@ -19,7 +19,7 @@ def teardown(): def test_upgrade_path(): eq_( - env.upgrade_from(e.revision, c.revision), + env.upgrade_from(False, e.revision, c.revision), [ (d.module.upgrade, d.revision), (e.module.upgrade, e.revision), @@ -27,7 +27,7 @@ def test_upgrade_path(): ) eq_( - env.upgrade_from(c.revision, None), + env.upgrade_from(False, c.revision, None), [ (a.module.upgrade, a.revision), (b.module.upgrade, b.revision), @@ -38,7 +38,7 @@ def test_upgrade_path(): def test_downgrade_path(): eq_( - env.downgrade_to(c.revision, e.revision), + env.downgrade_to(False, c.revision, e.revision), [ (e.module.downgrade, e.down_revision), (d.module.downgrade, d.down_revision), @@ -46,7 +46,7 @@ def test_downgrade_path(): ) eq_( - env.downgrade_to(None, c.revision), + env.downgrade_to(False, None, c.revision), [ (c.module.downgrade, c.down_revision), (b.module.downgrade, b.down_revision),