From: Mike Bayer Date: Tue, 24 Jan 2012 18:42:43 +0000 (-0500) Subject: this is all tests passing with the refactor, which IMHO is X-Git-Tag: rel_0_2_0~15 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=a851daaa1b2a4efa5990f55c3c97282cafdab9e1;p=thirdparty%2Fsqlalchemy%2Falembic.git this is all tests passing with the refactor, which IMHO is miraculous --- diff --git a/alembic/autogenerate.py b/alembic/autogenerate.py index 728d1e22..d90114a3 100644 --- a/alembic/autogenerate.py +++ b/alembic/autogenerate.py @@ -13,7 +13,8 @@ log = logging.getLogger(__name__) # top level -def produce_migration_diffs(context, opts, template_args, imports): +def produce_migration_diffs(context, template_args, imports): + opts = context.opts metadata = opts['target_metadata'] if metadata is None: raise util.CommandError( @@ -22,7 +23,7 @@ def produce_migration_diffs(context, opts, template_args, imports): "a MetaData object to the context." % ( context._script.env_py_location )) - connection = get_bind() + connection = context.bind diffs = [] autogen_context = { 'imports':imports, @@ -308,7 +309,7 @@ def _add_table(table, autogen_context): 'args':',\n'.join( [_render_column(col, autogen_context) for col in table.c] + sorted([rcons for rcons in - [_render_constraint(cons) for cons in + [_render_constraint(cons, autogen_context) for cons in table.constraints] if rcons is not None ]) @@ -420,14 +421,14 @@ def _repr_type(prefix, type_, autogen_context): else: return "%s%r" % (prefix, type_) -def _render_constraint(constraint): +def _render_constraint(constraint, autogen_context): renderer = _constraint_renderers.get(type(constraint), None) if renderer: - return renderer(constraint) + return renderer(constraint, autogen_context) else: return None -def _render_primary_key(constraint): +def _render_primary_key(constraint, autogen_context): opts = [] if constraint.name: opts.append(("name", repr(constraint.name))) @@ -439,7 +440,7 @@ def _render_primary_key(constraint): ), } -def _render_foreign_key(constraint): +def _render_foreign_key(constraint, autogen_context): opts = [] if constraint.name: opts.append(("name", repr(constraint.name))) diff --git a/alembic/command.py b/alembic/command.py index b743f3b5..a8c0fc42 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -67,10 +67,10 @@ def revision(config, message=None, autogenerate=False): imports = set() if autogenerate: util.requires_07("autogenerate") - def retrieve_migrations(rev): + def retrieve_migrations(rev, context): if script._get_rev(rev) is not script._get_rev("head"): raise util.CommandError("Target database is not up to date.") - autogen.produce_migration_diffs(template_args, imports) + autogen.produce_migration_diffs(context, template_args, imports) return [] with environment.configure( @@ -150,7 +150,7 @@ def current(config): """Display the current revision for each database.""" script = ScriptDirectory.from_config(config) - def display_version(rev): + def display_version(rev, context): print "Current revision for %s: %s" % ( util.obfuscate_url_pw( context.get_context().connection.engine.url), @@ -169,15 +169,15 @@ def stamp(config, revision, sql=False, tag=None): run any migrations.""" script = ScriptDirectory.from_config(config) - def do_stamp(rev): + def do_stamp(rev, context): if sql: current = False else: - current = context.get_context()._current_rev() + current = context._current_rev() dest = script._get_rev(revision) if dest is not None: dest = dest.revision - context.get_context()._update_current_rev(current, dest) + context._update_current_rev(current, dest) return [] with environment.configure( config, @@ -186,7 +186,7 @@ def stamp(config, revision, sql=False, tag=None): as_sql = sql, destination_rev = revision, tag = tag - ): + ) as env: script.run_env() def splice(config, parent, child): diff --git a/alembic/config.py b/alembic/config.py index 22a48749..1dc6eb95 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -107,6 +107,9 @@ class Config(object): """ self.file_config.set(self.config_ini_section, name, value) + def remove_main_option(self, name): + self.file_config.remove_option(self.config_ini_section, name) + def set_section_option(self, section, name, value): """Set an option programmatically within the given section. diff --git a/alembic/environment.py b/alembic/environment.py index 12518798..53c35594 100644 --- a/alembic/environment.py +++ b/alembic/environment.py @@ -2,17 +2,20 @@ import alembic from alembic.operations import Operations from alembic.migration import MigrationContext from alembic import util -from sqlalchemy.engine import url as sqla_url +from contextlib import contextmanager class EnvironmentContext(object): """Represent the state made available to an env.py script.""" _migration_context = None + _default_opts = None def __init__(self, config, script, **kw): self.config = config self.script = script self.context_opts = kw + if self._default_opts: + self.context_opts.update(self._default_opts) def __enter__(self): """Establish a context which provides a @@ -264,18 +267,6 @@ class EnvironmentContext(object): one step. """ - - if connection: - dialect = connection.dialect - elif url: - url = sqla_url.make_url(url) - dialect = url.get_dialect()() - elif dialect_name: - url = sqla_url.make_url("%s://" % dialect_name) - dialect = url.get_dialect()() - else: - raise Exception("Connection, url, or dialect_name is required.") - opts = self.context_opts if transactional_ddl is not None: opts["transactional_ddl"] = transactional_ddl @@ -292,19 +283,19 @@ class EnvironmentContext(object): opts['downgrade_token'] = downgrade_token opts['sqlalchemy_module_prefix'] = sqlalchemy_module_prefix opts['alembic_module_prefix'] = alembic_module_prefix + if compare_type is not None: + opts['compare_type'] = compare_type + if compare_server_default is not None: + opts['compare_server_default'] = compare_server_default + opts['script'] = self.script opts.update(kw) - self._migration_context = MigrationContext( - dialect, self.script, connection, - opts, - as_sql=opts.get('as_sql', False), - output_buffer=opts.get("output_buffer"), - transactional_ddl=opts.get("transactional_ddl"), - starting_rev=opts.get("starting_rev"), - compare_type=compare_type, - compare_server_default=compare_server_default, - ) - alembic.op._proxy = Operations(self._migration_context) + self._migration_context = MigrationContext.configure( + connection=connection, + url=url, + dialect_name=dialect_name, + opts=opts + ) def run_migrations(self, **kw): """Run migrations as determined by the current command line configuration @@ -324,7 +315,8 @@ class EnvironmentContext(object): made available via :func:`.configure`. """ - self.migration_context.run_migrations(**kw) + with Operations.context(self._migration_context): + self.migration_context.run_migrations(**kw) def execute(self, sql): """Execute the given SQL using the current change context. diff --git a/alembic/migration.py b/alembic/migration.py index 69e69304..733727d2 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -4,7 +4,7 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \ from sqlalchemy import create_engine from alembic import ddl import sys -from contextlib import contextmanager +from sqlalchemy.engine import url as sqla_url import logging log = logging.getLogger(__name__) @@ -21,22 +21,19 @@ class MigrationContext(object): Mediates the relationship between an ``env.py`` environment script, a :class:`.ScriptDirectory` instance, and a :class:`.DefaultImpl` instance. - The :class:`.Context` is available directly via the :func:`.get_context` function, + The :class:`.MigrationContext` is available directly via the :func:`.get_context` function, though usually it is referenced behind the scenes by the various module level functions within the :mod:`alembic.context` module. """ - def __init__(self, dialect, script, connection, - opts, - as_sql=False, - output_buffer=None, - transactional_ddl=None, - starting_rev=None, - compare_type=False, - compare_server_default=False): + def __init__(self, dialect, connection, opts): + self.opts = opts self.dialect = dialect - # TODO: need this ? - self.script = script + self.script = opts.get('script') + + as_sql=opts.get('as_sql', False) + transactional_ddl=opts.get("transactional_ddl") + if as_sql: self.connection = self._stdout_connection(connection) assert self.connection is not None @@ -44,12 +41,12 @@ class MigrationContext(object): self.connection = connection self._migrations_fn = opts.get('fn') self.as_sql = as_sql - self.output_buffer = output_buffer if output_buffer else sys.stdout + self.output_buffer = opts.get("output_buffer", sys.stdout) - self._user_compare_type = compare_type - self._user_compare_server_default = compare_server_default + self._user_compare_type = opts.get('compare_type', False) + self._user_compare_server_default = opts.get('compare_server_default', False) - self._start_from_rev = starting_rev + self._start_from_rev = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( dialect, self.connection, self.as_sql, transactional_ddl, @@ -63,6 +60,46 @@ class MigrationContext(object): "transactional" if self.impl.transactional_ddl else "non-transactional") + @classmethod + def configure(cls, + connection=None, + url=None, + dialect_name=None, + opts=None, + ): + """Create a new :class:`.MigrationContext`. + + This is a factory method usually called + by :meth:`.EnvironmentContext.configure`. + + :param connection: a :class:`~sqlalchemy.engine.base.Connection` to use + for SQL execution in "online" mode. When present, is also used to + determine the type of dialect in use. + :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object. + The type of dialect to be used will be derived from this if ``connection`` is + not passed. + :param dialect_name: string name of a dialect, such as "postgresql", "mssql", etc. + The type of dialect to be used will be derived from this if ``connection`` + and ``url`` are not passed. + :param opts: dictionary of options. Most other options + accepted by :meth:`.EnvironmentContext.configure` are passed via + this dictionary. + + """ + if connection: + dialect = connection.dialect + elif url: + url = sqla_url.make_url(url) + dialect = url.get_dialect()() + elif dialect_name: + url = sqla_url.make_url("%s://" % dialect_name) + dialect = url.get_dialect()() + else: + raise Exception("Connection, url, or dialect_name is required.") + + return MigrationContext(dialect, connection, opts) + + def _current_rev(self): if self.as_sql: return self._start_from_rev @@ -93,7 +130,8 @@ class MigrationContext(object): current_rev = rev = False self.impl.start_migrations() for change, prev_rev, rev in self._migrations_fn( - self._current_rev()): + self._current_rev(), + self): if current_rev is False: current_rev = prev_rev if self.as_sql and not current_rev: diff --git a/alembic/operations.py b/alembic/operations.py index cc2ef48b..a6530cf3 100644 --- a/alembic/operations.py +++ b/alembic/operations.py @@ -2,6 +2,8 @@ from alembic import util from alembic.ddl import impl from sqlalchemy.types import NULLTYPE, Integer from sqlalchemy import schema, sql +from contextlib import contextmanager +import alembic __all__ = sorted([ 'alter_column', @@ -34,6 +36,14 @@ class Operations(object): self.migration_context = migration_context self.impl = migration_context.impl + @classmethod + @contextmanager + def context(cls, migration_context): + op = Operations(migration_context) + alembic.op._proxy = op + yield op + del alembic.op._proxy + def _foreign_key_constraint(self, name, source, referent, local_cols, remote_cols): m = schema.MetaData() t1 = schema.Table(source, m, diff --git a/alembic/script.py b/alembic/script.py index 4b7eaf25..7c3bb0fc 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -88,14 +88,14 @@ class ScriptDirectory(object): if script is None and lower is not None: raise util.CommandError("Couldn't find revision %s" % downrev) - def upgrade_from(self, destination, current_rev): + def upgrade_from(self, destination, current_rev, context): revs = self._revs(destination, current_rev) return [ (script.module.upgrade, script.down_revision, script.revision) for script in reversed(list(revs)) ] - def downgrade_to(self, destination, current_rev): + def downgrade_to(self, destination, current_rev, context): revs = self._revs(current_rev, destination) return [ (script.module.downgrade, script.revision, script.down_revision) for script in diff --git a/tests/__init__.py b/tests/__init__.py index 328040a4..7e3e4b9f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,6 +7,7 @@ import itertools from sqlalchemy import create_engine, text, MetaData from alembic import util from alembic.migration import MigrationContext +from alembic.environment import EnvironmentContext import re import alembic from alembic.operations import Operations @@ -84,17 +85,16 @@ def capture_context_buffer(**kw): class capture(object): def __enter__(self): - context.configure( - dialect_name="sqlite", - output_buffer = buf, - **kw - ) + EnvironmentContext._default_opts = { + 'dialect_name':"sqlite", + 'output_buffer':buf + } + EnvironmentContext._default_opts.update(kw) return buf def __exit__(self, *arg, **kwarg): print buf.getvalue() - for k in kw: - context._context_opts.pop(k, None) + EnvironmentContext._default_opts = None return capture() diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index 0e913cf5..2b6cd3a5 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -2,7 +2,8 @@ from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \ Numeric, CHAR, ForeignKey, DATETIME, TypeDecorator from sqlalchemy.types import NULLTYPE from sqlalchemy.engine.reflection import Inspector -from alembic import autogenerate, context +from alembic import autogenerate +from alembic.migration import MigrationContext from unittest import TestCase from tests import staging_env, sqlite_db, clear_staging_env, eq_, \ eq_ignore_whitespace, requires_07 @@ -76,18 +77,26 @@ class AutogenerateDiffTest(TestCase): cls.m1 = _model_one() cls.m1.create_all(cls.bind) cls.m2 = _model_two() - context.configure( + + cls.context = context = MigrationContext.configure( connection = cls.bind.connect(), - compare_type = True, - compare_server_default = True, - target_metadata=cls.m2 + opts = { + 'compare_type':True, + 'compare_server_default':True, + 'target_metadata':cls.m2, + 'upgrade_token':"upgrades", + 'downgrade_token':"downgrades", + 'alembic_module_prefix':'op.', + 'sqlalchemy_module_prefix':'sa.' + } ) - connection = context.get_bind() + + connection = context.bind cls.autogen_context = { 'imports':set(), 'connection':connection, 'dialect':connection.dialect, - 'context':context.get_context() + 'context':context } @classmethod @@ -98,7 +107,7 @@ class AutogenerateDiffTest(TestCase): """test generation of diff rules""" metadata = self.m2 - connection = context.get_bind() + connection = self.context.bind diffs = [] autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context) @@ -140,14 +149,18 @@ class AutogenerateDiffTest(TestCase): def test_render_nothing(self): - context.configure( + context = MigrationContext.configure( connection = self.bind.connect(), - compare_type = True, - compare_server_default = True, - target_metadata=self.m1 + opts = { + 'compare_type' : True, + 'compare_server_default' : True, + 'target_metadata' : self.m1, + 'upgrade_token':"upgrades", + 'downgrade_token':"downgrades", + } ) template_args = {} - autogenerate.produce_migration_diffs(template_args, self.autogen_context) + autogenerate.produce_migration_diffs(context, template_args, set()) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### pass @@ -162,7 +175,7 @@ class AutogenerateDiffTest(TestCase): metadata = self.m2 template_args = {} - autogenerate.produce_migration_diffs(template_args, self.autogen_context) + autogenerate.produce_migration_diffs(self.context, template_args, set()) eq_(re.sub(r"u'", "'", template_args['upgrades']), """### commands auto generated by Alembic - please adjust! ### op.create_table('item', @@ -273,8 +286,12 @@ class AutogenRenderTest(TestCase): @classmethod @requires_07 def setup_class(cls): - context._context_opts['sqlalchemy_module_prefix'] = 'sa.' - context._context_opts['alembic_module_prefix'] = 'op.' + cls.autogen_context = { + 'opts':{ + 'sqlalchemy_module_prefix' : 'sa.', + 'alembic_module_prefix' : 'op.' + } + } def test_render_table_upgrade(self): m = MetaData() @@ -285,7 +302,7 @@ class AutogenRenderTest(TestCase): Column("amount", Numeric(5, 2)), ) eq_ignore_whitespace( - autogenerate._add_table(t, {}), + autogenerate._add_table(t, self.autogen_context), "op.create_table('test'," "sa.Column('id', sa.Integer(), nullable=False)," "sa.Column('address_id', sa.Integer(), nullable=True)," @@ -300,14 +317,14 @@ class AutogenRenderTest(TestCase): def test_render_drop_table(self): eq_( - autogenerate._drop_table(Table("sometable", MetaData()), {}), + autogenerate._drop_table(Table("sometable", MetaData()), self.autogen_context), "op.drop_table('sometable')" ) def test_render_add_column(self): eq_( autogenerate._add_column( - "foo", Column("x", Integer, server_default="5"), {}), + "foo", Column("x", Integer, server_default="5"), self.autogen_context), "op.add_column('foo', sa.Column('x', sa.Integer(), " "server_default='5', nullable=True))" ) @@ -315,7 +332,7 @@ class AutogenRenderTest(TestCase): def test_render_drop_column(self): eq_( autogenerate._drop_column( - "foo", Column("x", Integer, server_default="5"), {}), + "foo", Column("x", Integer, server_default="5"), self.autogen_context), "op.drop_column('foo', 'x')" ) @@ -324,7 +341,7 @@ class AutogenRenderTest(TestCase): eq_ignore_whitespace( autogenerate._modify_col( "sometable", "somecolumn", - {}, + self.autogen_context, type_=CHAR(10), existing_type=CHAR(20)), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))" @@ -334,7 +351,7 @@ class AutogenRenderTest(TestCase): eq_ignore_whitespace( autogenerate._modify_col( "sometable", "somecolumn", - {}, + self.autogen_context, existing_type=Integer(), nullable=True), "op.alter_column('sometable', 'somecolumn', " @@ -345,7 +362,7 @@ class AutogenRenderTest(TestCase): eq_ignore_whitespace( autogenerate._modify_col( "sometable", "somecolumn", - {}, + self.autogen_context, existing_type=Integer(), existing_server_default="5", nullable=True), diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 46cd81d6..feb867d8 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -5,7 +5,8 @@ from tests import op_fixture, db_for_dialect, eq_, staging_env, \ from unittest import TestCase from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, String from sqlalchemy.engine.reflection import Inspector -from alembic import context, command, util +from alembic import command, util +from alembic.migration import MigrationContext from alembic.script import ScriptDirectory class PGOfflineEnumTest(TestCase): @@ -27,37 +28,37 @@ class PGOfflineEnumTest(TestCase): self.script.write(self.rid, """ down_revision = None -from alembic.op import * +from alembic import op from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy import Column def upgrade(): - create_table("sometable", + op.create_table("sometable", Column("data", ENUM("one", "two", "three", name="pgenum")) ) def downgrade(): - drop_table("sometable") + op.drop_table("sometable") """) def _distinct_enum_script(self): self.script.write(self.rid, """ down_revision = None -from alembic.op import * +from alembic import op from sqlalchemy.dialects.postgresql import ENUM from sqlalchemy import Column def upgrade(): enum = ENUM("one", "two", "three", name="pgenum", create_type=False) - enum.create(get_bind(), checkfirst=False) - create_table("sometable", + enum.create(op.get_bind(), checkfirst=False) + op.create_table("sometable", Column("data", enum) ) def downgrade(): - drop_table("sometable") - ENUM(name="pgenum").drop(get_bind(), checkfirst=False) + op.drop_table("sometable") + ENUM(name="pgenum").drop(op.get_bind(), checkfirst=False) """) @@ -97,17 +98,19 @@ class PostgresqlDefaultCompareTest(TestCase): def setup_class(cls): cls.bind = db_for_dialect("postgresql") staging_env() - context.configure( + context = MigrationContext.configure( connection = cls.bind.connect(), - compare_type = True, - compare_server_default = True, + opts = { + 'compare_type':True, + 'compare_server_default':True + } ) - connection = context.get_bind() + connection = context.bind cls.autogen_context = { 'imports':set(), 'connection':connection, 'dialect':connection.dialect, - 'context':context.get_context() + 'context':context } @classmethod @@ -145,7 +148,7 @@ class PostgresqlDefaultCompareTest(TestCase): t1.create(self.bind) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) - ctx = context.get_context() + ctx = self.autogen_context['context'] return ctx.impl.compare_server_default( cols[0], col, diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py index b4bff6e0..fd09a85b 100644 --- a/tests/test_revision_paths.py +++ b/tests/test_revision_paths.py @@ -19,7 +19,7 @@ def teardown(): def test_upgrade_path(): eq_( - env.upgrade_from(e.revision, c.revision), + env.upgrade_from(e.revision, c.revision, None), [ (d.module.upgrade, c.revision, d.revision), (e.module.upgrade, d.revision, e.revision), @@ -27,7 +27,7 @@ def test_upgrade_path(): ) eq_( - env.upgrade_from(c.revision, None), + env.upgrade_from(c.revision, None, None), [ (a.module.upgrade, None, a.revision), (b.module.upgrade, a.revision, b.revision), @@ -38,7 +38,7 @@ def test_upgrade_path(): def test_downgrade_path(): eq_( - env.downgrade_to(c.revision, e.revision), + env.downgrade_to(c.revision, e.revision, None), [ (e.module.downgrade, e.revision, e.down_revision), (d.module.downgrade, d.revision, d.down_revision), @@ -46,7 +46,7 @@ def test_downgrade_path(): ) eq_( - env.downgrade_to(None, c.revision), + env.downgrade_to(None, c.revision, None), [ (c.module.downgrade, c.revision, c.down_revision), (b.module.downgrade, b.revision, b.down_revision), diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py index ab86f19b..af9b1361 100644 --- a/tests/test_sql_script.py +++ b/tests/test_sql_script.py @@ -9,7 +9,8 @@ def setup(): global cfg, env env = staging_env() cfg = _no_sql_testing_config() - + cfg.set_main_option('dialect_name', 'sqlite') + cfg.remove_main_option('url') global a, b, c a, b, c = three_rev_fixture(cfg)