From: Mike Bayer Date: Tue, 15 Nov 2011 00:19:11 +0000 (-0500) Subject: - refactor the migration operations out of context, which X-Git-Tag: rel_0_1_0~57 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7faa084be96698729f43cff531e18605fe8ec9fe;p=thirdparty%2Fsqlalchemy%2Falembic.git - refactor the migration operations out of context, which mediates at a high level, into ddl/impl, which deals with DB stuff - fix MSSQL add column, #2 --- diff --git a/.hgignore b/.hgignore index 89d44e2e..b8ad7f29 100644 --- a/.hgignore +++ b/.hgignore @@ -5,4 +5,5 @@ syntax:regexp .pyc$ .orig$ .egg-info - +.coverage +alembic.ini diff --git a/alembic/context.py b/alembic/context.py index a917c158..b5f0e22a 100644 --- a/alembic/context.py +++ b/alembic/context.py @@ -1,37 +1,26 @@ from alembic import util from sqlalchemy import MetaData, Table, Column, String, literal_column, \ text -from sqlalchemy import schema, create_engine +from sqlalchemy import create_engine from sqlalchemy.engine import url as sqla_url -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.expression import _BindParamClause import sys +from alembic import ddl import logging -base = util.importlater("alembic.ddl", "base") log = logging.getLogger(__name__) -class ContextMeta(type): - def __init__(cls, classname, bases, dict_): - newtype = type.__init__(cls, classname, bases, dict_) - if '__dialect__' in dict_: - _context_impls[dict_['__dialect__']] = cls - return newtype - -_context_impls = {} - _meta = MetaData() _version = Table('alembic_version', _meta, Column('version_num', String(32), nullable=False) ) -class DefaultContext(object): - __metaclass__ = ContextMeta - __dialect__ = 'default' - - transactional_ddl = False - as_sql = False - +class Context(object): + """Maintains state throughout the migration running process. + + Mediates the relationship between an ``env.py`` environment script, + a :class:`.ScriptDirectory` instance, and a :class:`.DDLImpl` instance. + + """ def __init__(self, dialect, script, connection, fn, as_sql=False, output_buffer=None, @@ -46,13 +35,14 @@ class DefaultContext(object): self.connection = connection self._migrations_fn = fn self.as_sql = as_sql - if output_buffer is None: - self.output_buffer = sys.stdout - else: - self.output_buffer = output_buffer - if transactional_ddl is not None: - self.transactional_ddl = transactional_ddl + self.output_buffer = output_buffer if output_buffer else sys.stdout + self._start_from_rev = starting_rev + self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( + dialect, connection, self.as_sql, + transactional_ddl, + self.output_buffer + ) def _current_rev(self): if self.as_sql: @@ -69,13 +59,13 @@ class DefaultContext(object): if old == new: return if new is None: - self._exec(_version.delete()) + self.impl._exec(_version.delete()) elif old is None: - self._exec(_version.insert(). + self.impl._exec(_version.insert(). values(version_num=literal_column("'%s'" % new)) ) else: - self._exec(_version.update(). + self.impl._exec(_version.update(). values(version_num=literal_column("'%s'" % new)) ) @@ -84,11 +74,11 @@ class DefaultContext(object): if self.as_sql: log.info("Generating static SQL") log.info("Will assume %s DDL.", - "transactional" if self.transactional_ddl + "transactional" if self.impl.transactional_ddl else "non-transactional") - if self.as_sql and self.transactional_ddl: - self.static_output("BEGIN;") + if self.as_sql and self.impl.transactional_ddl: + self.impl.static_output("BEGIN;") current_rev = rev = False for change, prev_rev, rev in self._migrations_fn( @@ -99,42 +89,26 @@ class DefaultContext(object): _version.create(self.connection) log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) change(**kw) - if not self.transactional_ddl: + if not self.impl.transactional_ddl: self._update_current_rev(prev_rev, rev) prev_rev = rev if rev is not False: - if self.transactional_ddl: + if self.impl.transactional_ddl: self._update_current_rev(current_rev, rev) if self.as_sql and not rev: _version.drop(self.connection) - if self.as_sql and self.transactional_ddl: - self.static_output("COMMIT;") - - def _exec(self, construct, *args, **kw): - if isinstance(construct, basestring): - construct = text(construct) - if self.as_sql: - if args or kw: - # TODO: coverage - raise Exception("Execution arguments not allowed with as_sql") - self.static_output(unicode( - construct.compile(dialect=self.dialect) - ).replace("\t", " ").strip() + ";") - else: - self.connection.execute(construct, *args, **kw) - - def static_output(self, text): - self.output_buffer.write(text + "\n\n") + if self.as_sql and self.impl.transactional_ddl: + self.impl.static_output("COMMIT;") def execute(self, sql): - self._exec(sql) + self.impl._exec(sql) def _stdout_connection(self, connection): def dump(construct, *multiparams, **params): - self._exec(construct) + self.impl._exec(construct) return create_engine("%s://" % self.dialect.name, strategy="mock", executor=dump) @@ -151,60 +125,6 @@ class DefaultContext(object): """ return self.connection - def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - ): - - if nullable is not None: - self._exec(base.ColumnNullable(table_name, column_name, - nullable, schema=schema)) - if server_default is not False: - self._exec(base.ColumnDefault( - table_name, column_name, server_default, - schema=schema - )) - if type_ is not None: - self._exec(base.ColumnType( - table_name, column_name, type_, schema=schema - )) - - def add_column(self, table_name, column): - self._exec(base.AddColumn(table_name, column)) - - def drop_column(self, table_name, column): - self._exec(base.DropColumn(table_name, column)) - - def add_constraint(self, const): - self._exec(schema.AddConstraint(const)) - - def create_table(self, table): - self._exec(schema.CreateTable(table)) - for index in table.indexes: - self._exec(schema.CreateIndex(index)) - - def drop_table(self, table): - self._exec(schema.DropTable(table)) - - def bulk_insert(self, table, rows): - if self.as_sql: - for row in rows: - self._exec(table.insert().values(**dict( - (k, _literal_bindparam(k, v, type_=table.c[k].type)) - for k, v in row.items() - ))) - else: - self._exec(table.insert(), *rows) - -class _literal_bindparam(_BindParamClause): - pass - -@compiles(_literal_bindparam) -def _render_literal_bindparam(element, compiler, **kw): - return compiler.render_literal_bindparam(element, **kw) _context_opts = {} _context = None @@ -323,7 +243,6 @@ def configure( raise Exception("Connection, url, or dialect_name is required.") global _context - from alembic.ddl import base opts = _context_opts if transactional_ddl is not None: opts["transactional_ddl"] = transactional_ddl @@ -333,9 +252,7 @@ def configure( opts['starting_rev'] = starting_rev if tag: opts['tag'] = tag - _context = _context_impls.get( - dialect.name, - DefaultContext)( + _context = Context( dialect, _script, connection, opts['fn'], as_sql=opts.get('as_sql', False), @@ -363,7 +280,7 @@ def run_migrations(**kw): to the migration functions. """ - _context.run_migrations(**kw) + get_context().run_migrations(**kw) def execute(sql): """Execute the given SQL using the current change context. @@ -385,4 +302,7 @@ def get_context(): """ if _context is None: raise Exception("No context has been configured yet.") - return _context \ No newline at end of file + return _context + +def get_impl(): + return get_context().impl \ No newline at end of file diff --git a/alembic/ddl/__init__.py b/alembic/ddl/__init__.py index 7efc90cb..128b14cd 100644 --- a/alembic/ddl/__init__.py +++ b/alembic/ddl/__init__.py @@ -1 +1,2 @@ -import postgresql, mysql, sqlite, mssql \ No newline at end of file +import postgresql, mysql, sqlite, mssql +from impl import DefaultImpl \ No newline at end of file diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py new file mode 100644 index 00000000..f1dd4a70 --- /dev/null +++ b/alembic/ddl/impl.py @@ -0,0 +1,119 @@ +from sqlalchemy import text +from sqlalchemy.sql.expression import _BindParamClause +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import schema +from alembic.ddl import base + +class ImplMeta(type): + def __init__(cls, classname, bases, dict_): + newtype = type.__init__(cls, classname, bases, dict_) + if '__dialect__' in dict_: + _impls[dict_['__dialect__']] = cls + return newtype + +_impls = {} + +class DefaultImpl(object): + """Provide the entrypoint for major migration operations, + including database-specific behavioral variances. + + While individual SQL/DDL constructs already provide + for database-specific implementations, variances here + allow for entirely different sequences of operations + to take place for a particular migration, such as + SQL Server's special 'IDENTITY INSERT' step for + bulk inserts. + + """ + __metaclass__ = ImplMeta + __dialect__ = 'default' + + transactional_ddl = False + + def __init__(self, dialect, connection, as_sql, transactional_ddl, output_buffer): + self.dialect = dialect + self.connection = connection + self.as_sql = as_sql + self.output_buffer = output_buffer + if transactional_ddl is not None: + self.transactional_ddl = transactional_ddl + + @classmethod + def get_by_dialect(cls, dialect): + return _impls[dialect.name] + + def static_output(self, text): + self.output_buffer.write(text + "\n\n") + + def _exec(self, construct, *args, **kw): + if isinstance(construct, basestring): + construct = text(construct) + if self.as_sql: + if args or kw: + # TODO: coverage + raise Exception("Execution arguments not allowed with as_sql") + self.static_output(unicode( + construct.compile(dialect=self.dialect) + ).replace("\t", " ").strip() + ";") + else: + self.connection.execute(construct, *args, **kw) + + def execute(self, sql): + self._exec(sql) + + def alter_column(self, table_name, column_name, + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + ): + + if nullable is not None: + self._exec(base.ColumnNullable(table_name, column_name, + nullable, schema=schema)) + if server_default is not False: + self._exec(base.ColumnDefault( + table_name, column_name, server_default, + schema=schema + )) + if type_ is not None: + self._exec(base.ColumnType( + table_name, column_name, type_, schema=schema + )) + + def add_column(self, table_name, column): + self._exec(base.AddColumn(table_name, column)) + + def drop_column(self, table_name, column): + self._exec(base.DropColumn(table_name, column)) + + def add_constraint(self, const): + self._exec(schema.AddConstraint(const)) + + def create_table(self, table): + self._exec(schema.CreateTable(table)) + for index in table.indexes: + self._exec(schema.CreateIndex(index)) + + def drop_table(self, table): + self._exec(schema.DropTable(table)) + + def bulk_insert(self, table, rows): + if self.as_sql: + for row in rows: + self._exec(table.insert().values(**dict( + (k, _literal_bindparam(k, v, type_=table.c[k].type)) + for k, v in row.items() + ))) + else: + self._exec(table.insert(), *rows) + + +class _literal_bindparam(_BindParamClause): + pass + +@compiles(_literal_bindparam) +def _render_literal_bindparam(element, compiler, **kw): + return compiler.render_literal_bindparam(element, **kw) + diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index d79e6193..3c489e19 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -1,6 +1,8 @@ -from alembic.context import DefaultContext +from alembic.ddl.impl import DefaultImpl +from alembic.ddl.base import alter_table, AddColumn +from sqlalchemy.ext.compiler import compiles -class MSSQLContext(DefaultContext): +class MSSQLImpl(DefaultImpl): __dialect__ = 'mssql' transactional_ddl = True @@ -10,10 +12,22 @@ class MSSQLContext(DefaultContext): "SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(table) ) - super(MSSQLContext, self).bulk_insert(table, rows) + super(MSSQLImpl, self).bulk_insert(table, rows) self._exec( "SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(table) ) else: - super(MSSQLContext, self).bulk_insert(table, rows) \ No newline at end of file + super(MSSQLImpl, self).bulk_insert(table, rows) + + +@compiles(AddColumn, 'mssql') +def visit_add_column(element, compiler, **kw): + return "%s %s" % ( + alter_table(compiler, element.table_name, element.schema), + mysql_add_column(compiler, element.column, **kw) + ) + +def mysql_add_column(compiler, column, **kw): + return "ADD %s" % compiler.get_column_specification(column, **kw) + diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index f7b7b30d..14abf261 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -1,5 +1,5 @@ -from alembic.context import DefaultContext +from alembic.ddl.impl import DefaultImpl -class MySQLContext(DefaultContext): +class MySQLImpl(DefaultImpl): __dialect__ = 'mysql' diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 79d6f1a0..f6268424 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -1,5 +1,5 @@ -from alembic.context import DefaultContext +from alembic.ddl.impl import DefaultImpl -class PostgresqlContext(DefaultContext): +class PostgresqlImpl(DefaultImpl): __dialect__ = 'postgresql' transactional_ddl = True diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index 20ec1eba..09437137 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -1,5 +1,5 @@ -from alembic.context import DefaultContext +from alembic.ddl.impl import DefaultImpl -class SQLiteContext(DefaultContext): +class SQLiteImpl(DefaultImpl): __dialect__ = 'sqlite' transactional_ddl = True diff --git a/alembic/op.py b/alembic/op.py index 2e5e74f5..49631221 100644 --- a/alembic/op.py +++ b/alembic/op.py @@ -1,5 +1,5 @@ from alembic import util -from alembic.context import get_context +from alembic.context import get_impl, get_context from sqlalchemy.types import NULLTYPE from sqlalchemy import schema, sql @@ -91,7 +91,7 @@ def alter_column(table_name, column_name, ): """Issue an "alter column" instruction using the current change context.""" - get_context().alter_column(table_name, column_name, + get_impl().alter_column(table_name, column_name, nullable=nullable, server_default=server_default, name=name, @@ -110,12 +110,12 @@ def add_column(table_name, column): """ t = _table(table_name, column) - get_context().add_column( + get_impl().add_column( table_name, column ) for constraint in [f.constraint for f in t.foreign_keys]: - get_context().add_constraint(constraint) + get_impl().add_constraint(constraint) def drop_column(table_name, column_name): """Issue a "drop column" instruction using the current change context. @@ -126,7 +126,7 @@ def drop_column(table_name, column_name): """ - get_context().drop_column( + get_impl().drop_column( table_name, _column(column_name, NULLTYPE) ) @@ -135,14 +135,14 @@ def add_constraint(table_name, constraint): """Issue an "add constraint" instruction using the current change context.""" _ensure_table_for_constraint(table_name, constraint) - get_context().add_constraint( + get_impl().add_constraint( constraint ) def create_foreign_key(name, source, referent, local_cols, remote_cols): """Issue a "create foreign key" instruction using the current change context.""" - get_context().add_constraint( + get_impl().add_constraint( _foreign_key_constraint(name, source, referent, local_cols, remote_cols) ) @@ -150,7 +150,7 @@ def create_foreign_key(name, source, referent, local_cols, remote_cols): def create_unique_constraint(name, source, local_cols): """Issue a "create unique constraint" instruction using the current change context.""" - get_context().add_constraint( + get_impl().add_constraint( _unique_constraint(name, source, local_cols) ) @@ -173,7 +173,7 @@ def create_table(name, *columns, **kw): """ - get_context().create_table( + get_impl().create_table( _table(name, *columns, **kw) ) @@ -186,7 +186,7 @@ def drop_table(name, *columns, **kw): drop_table("accounts") """ - get_context().drop_table( + get_impl().drop_table( _table(name, *columns, **kw) ) @@ -212,7 +212,7 @@ def bulk_insert(table, rows): ] ) """ - get_context().bulk_insert(table, rows) + get_impl().bulk_insert(table, rows) def execute(sql): """Execute the given SQL using the current change context. @@ -221,7 +221,7 @@ def execute(sql): output stream. """ - get_context().execute(sql) + get_impl().execute(sql) def get_bind(): """Return the current 'bind'. @@ -233,4 +233,4 @@ def get_bind(): In a SQL script context, this value is ``None``. [TODO: verify this] """ - return get_context().bind \ No newline at end of file + return get_impl().bind \ No newline at end of file diff --git a/docs/build/api.rst b/docs/build/api.rst index ef442342..411600a1 100644 --- a/docs/build/api.rst +++ b/docs/build/api.rst @@ -54,6 +54,10 @@ DDL Internals :members: :undoc-members: +.. automodule:: alembic.ddl.impl + :members: + :undoc-members: + MySQL ^^^^^ diff --git a/docs/build/tutorial.rst b/docs/build/tutorial.rst index 4b64ea11..652ed7b4 100644 --- a/docs/build/tutorial.rst +++ b/docs/build/tutorial.rst @@ -517,12 +517,16 @@ the local environment, such as from a local file. A scheme like this would bas treat a local file in the same way ``alembic_version`` works:: if not context.requires_connection(): - version_file = os.path.join(os.path.dirname(config.config_file_name), "version.txt")) - current_version = file_(version_file).read() + version_file = os.path.join(os.path.dirname(config.config_file_name), "version.txt") + if os.path.exists(version_file): + current_version = file_(version_file).read() + else: + current_version = None context.configure(dialect_name=engine.name, starting_version=current_version) - end_version = context.get_revision_argument() context.run_migrations() - file_(version_file, 'w').write(end) + end_version = context.get_revision_argument() + if end_version and end_version != current_version: + file_(version_file, 'w').write(end_version) Writing Migration Scripts to Support Script Generation ------------------------------------------------------ diff --git a/tests/__init__.py b/tests/__init__.py index 68c72220..462e3d54 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,9 +6,10 @@ from sqlalchemy import create_engine, text from alembic import context, util import re from alembic.script import ScriptDirectory -from alembic.context import _context_impls +from alembic.context import Context from alembic import ddl import StringIO +from alembic.ddl.impl import _impls staging_directory = os.path.join(os.path.dirname(__file__), 'scratch') files_directory = os.path.join(os.path.dirname(__file__), 'files') @@ -70,18 +71,12 @@ def _testing_config(): return Config(os.path.join(staging_directory, 'test_alembic.ini')) def _op_fixture(dialect='default', as_sql=False): - _base = _context_impls[dialect] - class ctx(_base): - def __init__(self, dialect='default', as_sql=False): - self._dialect = _get_dialect(dialect) - - context._context = self - self.as_sql = as_sql + impl = _impls[dialect] + class Impl(impl): + def __init__(self, dialect, as_sql): self.assertion = [] - - @property - def dialect(self): - return self._dialect + self.dialect = dialect + self.as_sql = as_sql def _exec(self, construct, *args, **kw): if isinstance(construct, basestring): @@ -92,11 +87,28 @@ def _op_fixture(dialect='default', as_sql=False): sql ) + + class ctx(Context): + def __init__(self, dialect='default', as_sql=False): + self.dialect = _get_dialect(dialect) + self.impl = Impl(self.dialect, as_sql) +# super(ctx, self).__init__(_get_dialect(dialect), None, None, None, as_sql=as_sql) + +# def __init__(self, dialect, script, connection, fn, +# as_sql=False, +# output_buffer=None, +# transactional_ddl=None, +# starting_rev=None): + + + context._context = self + self.as_sql = as_sql + def assert_(self, *sql): # TODO: make this more flexible about # whitespace and such - eq_(self.assertion, list(sql)) - _context_impls[dialect] = _base + eq_(self.impl.assertion, list(sql)) + return ctx(dialect, as_sql) def _sqlite_testing_config(): diff --git a/tests/test_mssql.py b/tests/test_mssql.py new file mode 100644 index 00000000..1cb74651 --- /dev/null +++ b/tests/test_mssql.py @@ -0,0 +1,17 @@ +"""Test op functions against MSSQL.""" + +from tests import _op_fixture +from alembic import op +from sqlalchemy import Integer, Column, ForeignKey, \ + UniqueConstraint, Table, MetaData, String +from sqlalchemy.sql import table + +def test_add_column(): + context = _op_fixture('mssql') + op.add_column('t1', Column('c1', Integer, nullable=False)) + context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL") + +def test_add_column_with_default(): + context = _op_fixture("mssql") + op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12")) + context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")