]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- refactor the migration operations out of context, which
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Nov 2011 00:19:11 +0000 (19:19 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Nov 2011 00:19:11 +0000 (19:19 -0500)
mediates at a high level, into ddl/impl, which deals with DB stuff
- fix MSSQL add column, #2

13 files changed:
.hgignore
alembic/context.py
alembic/ddl/__init__.py
alembic/ddl/impl.py [new file with mode: 0644]
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/postgresql.py
alembic/ddl/sqlite.py
alembic/op.py
docs/build/api.rst
docs/build/tutorial.rst
tests/__init__.py
tests/test_mssql.py [new file with mode: 0644]

index 89d44e2eee45bbdb68aea5dcda7fb9524f37682a..b8ad7f29976ec32d27a504c68447293b0640391b 100644 (file)
--- a/.hgignore
+++ b/.hgignore
@@ -5,4 +5,5 @@ syntax:regexp
 .pyc$
 .orig$
 .egg-info
-
+.coverage
+alembic.ini
index a917c158a241a86b8661f2af2a3a5730d45ea281..b5f0e22aea801610e0dc752b8193b5f0164444a9 100644 (file)
@@ -1,37 +1,26 @@
 from alembic import util
 from sqlalchemy import MetaData, Table, Column, String, literal_column, \
     text
-from sqlalchemy import schema, create_engine
+from sqlalchemy import create_engine
 from sqlalchemy.engine import url as sqla_url
-from sqlalchemy.ext.compiler import compiles
-from sqlalchemy.sql.expression import _BindParamClause
 import sys
+from alembic import ddl
 
 import logging
-base = util.importlater("alembic.ddl", "base")
 log = logging.getLogger(__name__)
 
-class ContextMeta(type):
-    def __init__(cls, classname, bases, dict_):
-        newtype = type.__init__(cls, classname, bases, dict_)
-        if '__dialect__' in dict_:
-            _context_impls[dict_['__dialect__']] = cls
-        return newtype
-
-_context_impls = {}
-
 _meta = MetaData()
 _version = Table('alembic_version', _meta, 
                 Column('version_num', String(32), nullable=False)
             )
 
-class DefaultContext(object):
-    __metaclass__ = ContextMeta
-    __dialect__ = 'default'
-
-    transactional_ddl = False
-    as_sql = False
-
+class Context(object):
+    """Maintains state throughout the migration running process.
+    
+    Mediates the relationship between an ``env.py`` environment script, 
+    a :class:`.ScriptDirectory` instance, and a :class:`.DDLImpl` instance.
+    
+    """
     def __init__(self, dialect, script, connection, fn, 
                         as_sql=False, 
                         output_buffer=None,
@@ -46,13 +35,14 @@ class DefaultContext(object):
             self.connection = connection
         self._migrations_fn = fn
         self.as_sql = as_sql
-        if output_buffer is None:
-            self.output_buffer = sys.stdout
-        else:
-            self.output_buffer = output_buffer
-        if transactional_ddl is not None:
-            self.transactional_ddl = transactional_ddl
+        self.output_buffer = output_buffer if output_buffer else sys.stdout
+
         self._start_from_rev = starting_rev
+        self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
+                            dialect, connection, self.as_sql,
+                            transactional_ddl,
+                            self.output_buffer
+                            )
 
     def _current_rev(self):
         if self.as_sql:
@@ -69,13 +59,13 @@ class DefaultContext(object):
         if old == new:
             return
         if new is None:
-            self._exec(_version.delete())
+            self.impl._exec(_version.delete())
         elif old is None:
-            self._exec(_version.insert().
+            self.impl._exec(_version.insert().
                         values(version_num=literal_column("'%s'" % new))
                     )
         else:
-            self._exec(_version.update().
+            self.impl._exec(_version.update().
                         values(version_num=literal_column("'%s'" % new))
                     )
 
@@ -84,11 +74,11 @@ class DefaultContext(object):
         if self.as_sql:
             log.info("Generating static SQL")
         log.info("Will assume %s DDL.", 
-                        "transactional" if self.transactional_ddl 
+                        "transactional" if self.impl.transactional_ddl 
                         else "non-transactional")
 
-        if self.as_sql and self.transactional_ddl:
-            self.static_output("BEGIN;")
+        if self.as_sql and self.impl.transactional_ddl:
+            self.impl.static_output("BEGIN;")
 
         current_rev = rev = False
         for change, prev_rev, rev in self._migrations_fn(
@@ -99,42 +89,26 @@ class DefaultContext(object):
                     _version.create(self.connection)
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
             change(**kw)
-            if not self.transactional_ddl:
+            if not self.impl.transactional_ddl:
                 self._update_current_rev(prev_rev, rev)
             prev_rev = rev
 
         if rev is not False:
-            if self.transactional_ddl:
+            if self.impl.transactional_ddl:
                 self._update_current_rev(current_rev, rev)
 
             if self.as_sql and not rev:
                 _version.drop(self.connection)
 
-        if self.as_sql and self.transactional_ddl:
-            self.static_output("COMMIT;")
-
-    def _exec(self, construct, *args, **kw):
-        if isinstance(construct, basestring):
-            construct = text(construct)
-        if self.as_sql:
-            if args or kw:
-                # TODO: coverage
-                raise Exception("Execution arguments not allowed with as_sql")
-            self.static_output(unicode(
-                    construct.compile(dialect=self.dialect)
-                    ).replace("\t", "    ").strip() + ";")
-        else:
-            self.connection.execute(construct, *args, **kw)
-
-    def static_output(self, text):
-        self.output_buffer.write(text + "\n\n")
+        if self.as_sql and self.impl.transactional_ddl:
+            self.impl.static_output("COMMIT;")
 
     def execute(self, sql):
-        self._exec(sql)
+        self.impl._exec(sql)
 
     def _stdout_connection(self, connection):
         def dump(construct, *multiparams, **params):
-            self._exec(construct)
+            self.impl._exec(construct)
 
         return create_engine("%s://" % self.dialect.name, 
                         strategy="mock", executor=dump)
@@ -151,60 +125,6 @@ class DefaultContext(object):
         """
         return self.connection
 
-    def alter_column(self, table_name, column_name, 
-                        nullable=None,
-                        server_default=False,
-                        name=None,
-                        type_=None,
-                        schema=None,
-    ):
-
-        if nullable is not None:
-            self._exec(base.ColumnNullable(table_name, column_name, 
-                                nullable, schema=schema))
-        if server_default is not False:
-            self._exec(base.ColumnDefault(
-                                table_name, column_name, server_default,
-                                schema=schema
-                            ))
-        if type_ is not None:
-            self._exec(base.ColumnType(
-                                table_name, column_name, type_, schema=schema
-                            ))
-
-    def add_column(self, table_name, column):
-        self._exec(base.AddColumn(table_name, column))
-
-    def drop_column(self, table_name, column):
-        self._exec(base.DropColumn(table_name, column))
-
-    def add_constraint(self, const):
-        self._exec(schema.AddConstraint(const))
-
-    def create_table(self, table):
-        self._exec(schema.CreateTable(table))
-        for index in table.indexes:
-            self._exec(schema.CreateIndex(index))
-
-    def drop_table(self, table):
-        self._exec(schema.DropTable(table))
-
-    def bulk_insert(self, table, rows):
-        if self.as_sql:
-            for row in rows:
-                self._exec(table.insert().values(**dict(
-                    (k, _literal_bindparam(k, v, type_=table.c[k].type))
-                    for k, v in row.items()
-                )))
-        else:
-            self._exec(table.insert(), *rows)
-
-class _literal_bindparam(_BindParamClause):
-    pass
-
-@compiles(_literal_bindparam)
-def _render_literal_bindparam(element, compiler, **kw):
-    return compiler.render_literal_bindparam(element, **kw)
 
 _context_opts = {}
 _context = None
@@ -323,7 +243,6 @@ def configure(
         raise Exception("Connection, url, or dialect_name is required.")
 
     global _context
-    from alembic.ddl import base
     opts = _context_opts
     if transactional_ddl is not None:
         opts["transactional_ddl"] =  transactional_ddl
@@ -333,9 +252,7 @@ def configure(
         opts['starting_rev'] = starting_rev
     if tag:
         opts['tag'] = tag
-    _context = _context_impls.get(
-                    dialect.name, 
-                    DefaultContext)(
+    _context = Context(
                         dialect, _script, connection, 
                         opts['fn'],
                         as_sql=opts.get('as_sql', False), 
@@ -363,7 +280,7 @@ def run_migrations(**kw):
     to the migration functions.
     
     """
-    _context.run_migrations(**kw)
+    get_context().run_migrations(**kw)
 
 def execute(sql):
     """Execute the given SQL using the current change context.
@@ -385,4 +302,7 @@ def get_context():
     """
     if _context is None:
         raise Exception("No context has been configured yet.")
-    return _context
\ No newline at end of file
+    return _context
+
+def get_impl():
+    return get_context().impl
\ No newline at end of file
index 7efc90cb5259702f32439dc5dca53fb66fade080..128b14cd33c20473c671794ffacfeaaef5ab801c 100644 (file)
@@ -1 +1,2 @@
-import postgresql, mysql, sqlite, mssql
\ No newline at end of file
+import postgresql, mysql, sqlite, mssql
+from impl import DefaultImpl
\ No newline at end of file
diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py
new file mode 100644 (file)
index 0000000..f1dd4a7
--- /dev/null
@@ -0,0 +1,119 @@
+from sqlalchemy import text
+from sqlalchemy.sql.expression import _BindParamClause
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy import schema
+from alembic.ddl import base
+
+class ImplMeta(type):
+    def __init__(cls, classname, bases, dict_):
+        newtype = type.__init__(cls, classname, bases, dict_)
+        if '__dialect__' in dict_:
+            _impls[dict_['__dialect__']] = cls
+        return newtype
+
+_impls = {}
+
+class DefaultImpl(object):
+    """Provide the entrypoint for major migration operations,
+    including database-specific behavioral variances.
+    
+    While individual SQL/DDL constructs already provide
+    for database-specific implementations, variances here
+    allow for entirely different sequences of operations
+    to take place for a particular migration, such as
+    SQL Server's special 'IDENTITY INSERT' step for 
+    bulk inserts.
+
+    """
+    __metaclass__ = ImplMeta
+    __dialect__ = 'default'
+
+    transactional_ddl = False
+
+    def __init__(self, dialect, connection, as_sql, transactional_ddl, output_buffer):
+        self.dialect = dialect
+        self.connection = connection
+        self.as_sql = as_sql
+        self.output_buffer = output_buffer
+        if transactional_ddl is not None:
+            self.transactional_ddl = transactional_ddl
+
+    @classmethod
+    def get_by_dialect(cls, dialect):
+        return _impls[dialect.name]
+
+    def static_output(self, text):
+        self.output_buffer.write(text + "\n\n")
+
+    def _exec(self, construct, *args, **kw):
+        if isinstance(construct, basestring):
+            construct = text(construct)
+        if self.as_sql:
+            if args or kw:
+                # TODO: coverage
+                raise Exception("Execution arguments not allowed with as_sql")
+            self.static_output(unicode(
+                    construct.compile(dialect=self.dialect)
+                    ).replace("\t", "    ").strip() + ";")
+        else:
+            self.connection.execute(construct, *args, **kw)
+
+    def execute(self, sql):
+        self._exec(sql)
+
+    def alter_column(self, table_name, column_name, 
+                        nullable=None,
+                        server_default=False,
+                        name=None,
+                        type_=None,
+                        schema=None,
+    ):
+
+        if nullable is not None:
+            self._exec(base.ColumnNullable(table_name, column_name, 
+                                nullable, schema=schema))
+        if server_default is not False:
+            self._exec(base.ColumnDefault(
+                                table_name, column_name, server_default,
+                                schema=schema
+                            ))
+        if type_ is not None:
+            self._exec(base.ColumnType(
+                                table_name, column_name, type_, schema=schema
+                            ))
+
+    def add_column(self, table_name, column):
+        self._exec(base.AddColumn(table_name, column))
+
+    def drop_column(self, table_name, column):
+        self._exec(base.DropColumn(table_name, column))
+
+    def add_constraint(self, const):
+        self._exec(schema.AddConstraint(const))
+
+    def create_table(self, table):
+        self._exec(schema.CreateTable(table))
+        for index in table.indexes:
+            self._exec(schema.CreateIndex(index))
+
+    def drop_table(self, table):
+        self._exec(schema.DropTable(table))
+
+    def bulk_insert(self, table, rows):
+        if self.as_sql:
+            for row in rows:
+                self._exec(table.insert().values(**dict(
+                    (k, _literal_bindparam(k, v, type_=table.c[k].type))
+                    for k, v in row.items()
+                )))
+        else:
+            self._exec(table.insert(), *rows)
+
+
+class _literal_bindparam(_BindParamClause):
+    pass
+
+@compiles(_literal_bindparam)
+def _render_literal_bindparam(element, compiler, **kw):
+    return compiler.render_literal_bindparam(element, **kw)
+
index d79e61936384cdc35cb0686b8baa68ecf6554946..3c489e19b3b087b45a7b845efa2a6cc09f51f6ad 100644 (file)
@@ -1,6 +1,8 @@
-from alembic.context import DefaultContext
+from alembic.ddl.impl import DefaultImpl
+from alembic.ddl.base import alter_table, AddColumn
+from sqlalchemy.ext.compiler import compiles
 
-class MSSQLContext(DefaultContext):
+class MSSQLImpl(DefaultImpl):
     __dialect__ = 'mssql'
     transactional_ddl = True
 
@@ -10,10 +12,22 @@ class MSSQLContext(DefaultContext):
                 "SET IDENTITY_INSERT %s ON" % 
                     self.dialect.identifier_preparer.format_table(table)
             )
-            super(MSSQLContext, self).bulk_insert(table, rows)
+            super(MSSQLImpl, self).bulk_insert(table, rows)
             self._exec(
                 "SET IDENTITY_INSERT %s OFF" % 
                     self.dialect.identifier_preparer.format_table(table)
             )
         else:
-            super(MSSQLContext, self).bulk_insert(table, rows)
\ No newline at end of file
+            super(MSSQLImpl, self).bulk_insert(table, rows)
+
+
+@compiles(AddColumn, 'mssql')
+def visit_add_column(element, compiler, **kw):
+    return "%s %s" % (
+        alter_table(compiler, element.table_name, element.schema),
+        mysql_add_column(compiler, element.column, **kw)
+    )
+
+def mysql_add_column(compiler, column, **kw):
+    return "ADD %s" % compiler.get_column_specification(column, **kw)
+
index f7b7b30d62c030daf881ba8240c70d34050e4ee5..14abf261c65e46cf40735ed86320da8b277be070 100644 (file)
@@ -1,5 +1,5 @@
-from alembic.context import DefaultContext
+from alembic.ddl.impl import DefaultImpl
 
-class MySQLContext(DefaultContext):
+class MySQLImpl(DefaultImpl):
     __dialect__ = 'mysql'
 
index 79d6f1a042f9a6032156e9435716297a4ac2df1b..f6268424ec8c077952672ac981fa7e0d5e9b088e 100644 (file)
@@ -1,5 +1,5 @@
-from alembic.context import DefaultContext
+from alembic.ddl.impl import DefaultImpl
 
-class PostgresqlContext(DefaultContext):
+class PostgresqlImpl(DefaultImpl):
     __dialect__ = 'postgresql'
     transactional_ddl = True
index 20ec1ebac60a3f930111599a5e6eeb3b9c3218ed..094371370630976c449685fb95d9cc917ff6c50a 100644 (file)
@@ -1,5 +1,5 @@
-from alembic.context import DefaultContext
+from alembic.ddl.impl import DefaultImpl
 
-class SQLiteContext(DefaultContext):
+class SQLiteImpl(DefaultImpl):
     __dialect__ = 'sqlite'
     transactional_ddl = True
index 2e5e74f531dc5e4faabd564d989a6b35f3586e93..49631221b1d332fe61ccb5ef0a16d01ed019b4b1 100644 (file)
@@ -1,5 +1,5 @@
 from alembic import util
-from alembic.context import get_context
+from alembic.context import get_impl, get_context
 from sqlalchemy.types import NULLTYPE
 from sqlalchemy import schema, sql
 
@@ -91,7 +91,7 @@ def alter_column(table_name, column_name,
 ):
     """Issue an "alter column" instruction using the current change context."""
 
-    get_context().alter_column(table_name, column_name, 
+    get_impl().alter_column(table_name, column_name, 
         nullable=nullable,
         server_default=server_default,
         name=name,
@@ -110,12 +110,12 @@ def add_column(table_name, column):
     """
 
     t = _table(table_name, column)
-    get_context().add_column(
+    get_impl().add_column(
         table_name,
         column
     )
     for constraint in [f.constraint for f in t.foreign_keys]:
-        get_context().add_constraint(constraint)
+        get_impl().add_constraint(constraint)
 
 def drop_column(table_name, column_name):
     """Issue a "drop column" instruction using the current change context.
@@ -126,7 +126,7 @@ def drop_column(table_name, column_name):
     
     """
 
-    get_context().drop_column(
+    get_impl().drop_column(
         table_name,
         _column(column_name, NULLTYPE)
     )
@@ -135,14 +135,14 @@ def add_constraint(table_name, constraint):
     """Issue an "add constraint" instruction using the current change context."""
 
     _ensure_table_for_constraint(table_name, constraint)
-    get_context().add_constraint(
+    get_impl().add_constraint(
         constraint
     )
 
 def create_foreign_key(name, source, referent, local_cols, remote_cols):
     """Issue a "create foreign key" instruction using the current change context."""
 
-    get_context().add_constraint(
+    get_impl().add_constraint(
                 _foreign_key_constraint(name, source, referent, 
                         local_cols, remote_cols)
             )
@@ -150,7 +150,7 @@ def create_foreign_key(name, source, referent, local_cols, remote_cols):
 def create_unique_constraint(name, source, local_cols):
     """Issue a "create unique constraint" instruction using the current change context."""
 
-    get_context().add_constraint(
+    get_impl().add_constraint(
                 _unique_constraint(name, source, local_cols)
             )
 
@@ -173,7 +173,7 @@ def create_table(name, *columns, **kw):
 
     """
 
-    get_context().create_table(
+    get_impl().create_table(
         _table(name, *columns, **kw)
     )
 
@@ -186,7 +186,7 @@ def drop_table(name, *columns, **kw):
         drop_table("accounts")
         
     """
-    get_context().drop_table(
+    get_impl().drop_table(
         _table(name, *columns, **kw)
     )
 
@@ -212,7 +212,7 @@ def bulk_insert(table, rows):
             ]
         )
       """
-    get_context().bulk_insert(table, rows)
+    get_impl().bulk_insert(table, rows)
 
 def execute(sql):
     """Execute the given SQL using the current change context.
@@ -221,7 +221,7 @@ def execute(sql):
     output stream.
     
     """
-    get_context().execute(sql)
+    get_impl().execute(sql)
 
 def get_bind():
     """Return the current 'bind'.
@@ -233,4 +233,4 @@ def get_bind():
     In a SQL script context, this value is ``None``. [TODO: verify this]
     
     """
-    return get_context().bind
\ No newline at end of file
+    return get_impl().bind
\ No newline at end of file
index ef442342caab79ce2a9c6f4b492bc5124fb10ea1..411600a1d6d091643746a3c6ec8516b0de72bb18 100644 (file)
@@ -54,6 +54,10 @@ DDL Internals
     :members:
     :undoc-members:
 
+.. automodule:: alembic.ddl.impl
+    :members:
+    :undoc-members:
+
 MySQL
 ^^^^^
 
index 4b64ea11d1a584a015d0b2037bc0f9dc92e50d78..652ed7b4d9ceec85a0d86a12f25248ca43f92721 100644 (file)
@@ -517,12 +517,16 @@ the local environment, such as from a local file.   A scheme like this would bas
 treat a local file in the same way ``alembic_version`` works::
 
     if not context.requires_connection():
-        version_file = os.path.join(os.path.dirname(config.config_file_name), "version.txt"))
-        current_version = file_(version_file).read()
+        version_file = os.path.join(os.path.dirname(config.config_file_name), "version.txt")
+        if os.path.exists(version_file):
+            current_version = file_(version_file).read()
+        else:
+            current_version = None
         context.configure(dialect_name=engine.name, starting_version=current_version)
-        end_version = context.get_revision_argument()
         context.run_migrations()
-        file_(version_file, 'w').write(end)
+        end_version = context.get_revision_argument()
+        if end_version and end_version != current_version:
+            file_(version_file, 'w').write(end_version)
 
 Writing Migration Scripts to Support Script Generation
 ------------------------------------------------------
index 68c7222056e38007088c777620caa8ff6a0e2abe..462e3d54efe8b82341e2516fadf40c3589be3009 100644 (file)
@@ -6,9 +6,10 @@ from sqlalchemy import create_engine, text
 from alembic import context, util
 import re
 from alembic.script import ScriptDirectory
-from alembic.context import _context_impls
+from alembic.context import Context
 from alembic import ddl
 import StringIO
+from alembic.ddl.impl import _impls
 
 staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
 files_directory = os.path.join(os.path.dirname(__file__), 'files')
@@ -70,18 +71,12 @@ def _testing_config():
     return Config(os.path.join(staging_directory, 'test_alembic.ini'))
 
 def _op_fixture(dialect='default', as_sql=False):
-    _base = _context_impls[dialect]
-    class ctx(_base):
-        def __init__(self, dialect='default', as_sql=False):
-            self._dialect = _get_dialect(dialect)
-
-            context._context = self
-            self.as_sql = as_sql
+    impl = _impls[dialect]
+    class Impl(impl):
+        def __init__(self, dialect, as_sql):
             self.assertion = []
-
-        @property
-        def dialect(self):
-            return self._dialect
+            self.dialect = dialect
+            self.as_sql = as_sql
 
         def _exec(self, construct, *args, **kw):
             if isinstance(construct, basestring):
@@ -92,11 +87,28 @@ def _op_fixture(dialect='default', as_sql=False):
                 sql
             )
 
+
+    class ctx(Context):
+        def __init__(self, dialect='default', as_sql=False):
+            self.dialect = _get_dialect(dialect)
+            self.impl = Impl(self.dialect, as_sql)
+#            super(ctx, self).__init__(_get_dialect(dialect), None, None, None, as_sql=as_sql)
+
+#    def __init__(self, dialect, script, connection, fn, 
+#                        as_sql=False, 
+#                       output_buffer=None,
+#                        transactional_ddl=None,
+#                        starting_rev=None):
+
+
+            context._context = self
+            self.as_sql = as_sql
+
         def assert_(self, *sql):
             # TODO: make this more flexible about 
             # whitespace and such
-            eq_(self.assertion, list(sql))
-    _context_impls[dialect] = _base
+            eq_(self.impl.assertion, list(sql))
+
     return ctx(dialect, as_sql)
 
 def _sqlite_testing_config():
diff --git a/tests/test_mssql.py b/tests/test_mssql.py
new file mode 100644 (file)
index 0000000..1cb7465
--- /dev/null
@@ -0,0 +1,17 @@
+"""Test op functions against MSSQL."""
+
+from tests import _op_fixture
+from alembic import op
+from sqlalchemy import Integer, Column, ForeignKey, \
+            UniqueConstraint, Table, MetaData, String
+from sqlalchemy.sql import table
+
+def test_add_column():
+    context = _op_fixture('mssql')
+    op.add_column('t1', Column('c1', Integer, nullable=False))
+    context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
+
+def test_add_column_with_default():
+    context = _op_fixture("mssql")
+    op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
+    context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")