]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- add begin_transaction() env.py helper. Emits the appropriate
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2011 20:50:35 +0000 (15:50 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2011 20:50:35 +0000 (15:50 -0500)
begin/commit pair regardless of context.
- add dialect support for BEGIN/COMMIT working corresponding
to backend.  Add implementation for SQL server.
- add tests for BEGIN/COMMIT , #11
- rework SQL server test suite for more classes of test
- fix test suite to clean up after a prior failed suite

alembic/context.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/templates/generic/env.py
alembic/templates/multidb/env.py
alembic/templates/pylons/env.py
tests/__init__.py
tests/test_mssql.py
tests/test_sql_script.py

index d9500e7bf895ddd105c04f69ba9ee9a172b9b55c..e737ed96a33baf1afc2c28fd19d14a676aa0cb5b 100644 (file)
@@ -3,8 +3,9 @@ from sqlalchemy import MetaData, Table, Column, String, literal_column, \
     text
 from sqlalchemy import create_engine
 from sqlalchemy.engine import url as sqla_url
-import sys
 from alembic import ddl
+import sys
+from contextlib import contextmanager
 
 import logging
 log = logging.getLogger(__name__)
@@ -95,7 +96,10 @@ class Context(object):
                     _version.create(self.connection)
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
             if self.as_sql:
-                self.impl.static_output("-- Running %s %s -> %s" %(change.__name__, prev_rev, rev))
+                self.impl.static_output(
+                        "-- Running %s %s -> %s" %
+                        (change.__name__, prev_rev, rev)
+                    )
             change(**kw)
             if not self.impl.transactional_ddl:
                 self._update_current_rev(prev_rev, rev)
@@ -500,6 +504,65 @@ def execute(sql):
     """
     get_context().execute(sql)
 
+def begin_transaction():
+    """Return a context manager that will 
+    enclose an operation within a "transaction",
+    as defined by the environment's offline
+    and transactional DDL settings.
+
+    e.g.::
+    
+        with context.begin_transaction():
+            context.run_migrations()
+    
+    :func:`.begin_transaction` is intended to
+    "do the right thing" regardless of 
+    calling context:
+    
+    * If :func:`.is_transactional_ddl` is ``False``,
+      returns a "do nothing" context manager
+      which otherwise produces no transactional
+      state or directives.
+    * If :func:`.is_offline_mode` is ``True``,
+      returns a context manager that will
+      invoke the :meth:`.DefaultImpl.emit_begin`
+      and :meth:`.DefaultImpl.emit_commit`
+      methods, which will produce the string
+      directives ``BEGIN`` and ``COMMIT`` on
+      the output stream, as rendered by the
+      target backend (e.g. SQL Server would
+      emit ``BEGIN TRANSACTION``).
+    * Otherwise, calls :meth:`sqlalchemy.engine.base.Connection.begin`
+      on the current online connection, which
+      returns a :class:`sqlalchemy.engine.base.Transaction`
+      object.  This object demarcates a real
+      transaction and is itself a context manager,
+      which will roll back if an exception
+      is raised.
+    
+    Note that a custom ``env.py`` script which 
+    has more specific transactional needs can of course
+    manipulate the :class:`~sqlalchemy.engine.base.Connection`
+    directly to produce transactional state in "online"
+    mode.
+
+    """
+    if not is_transactional_ddl():
+        @contextmanager
+        def do_nothing():
+            yield
+        return do_nothing()
+    elif is_offline_mode():
+        @contextmanager
+        def begin_commit():
+            get_context().impl.emit_begin()
+            yield
+            get_context().impl.emit_commit()
+        return begin_commit()
+    else:
+        return get_bind().begin()
+
+
 def get_context():
     """Return the current :class:`.Context` object.
 
index 4159d526363aa456bfc9537c8951f35a385adbb7..8306fcb6c8b419f257b98ad01dec39f72198cf81 100644 (file)
@@ -187,6 +187,25 @@ class DefaultImpl(object):
         conn_col_default = inspector_column['default']
         return conn_col_default != rendered_metadata_default
 
+    def emit_begin(self):
+        """Emit the string ``BEGIN``, or the backend-specific
+        equivalent, on the current connection context.
+        
+        This is used in offline mode and typically
+        via :func:`.context.begin_transaction`.
+        
+        """
+        self._exec("BEGIN")
+
+    def emit_commit(self):
+        """Emit the string ``COMMIT``, or the backend-specific
+        equivalent, on the current connection context.
+        
+        This is used in offline mode and typically
+        via :func:`.context.begin_transaction`.
+        
+        """
+        self._exec("COMMIT")
 
 class _literal_bindparam(_BindParamClause):
     pass
index 47ffef7acbcd9180fb79cac339b7dd391f6c5f64..400f614751df009c98ce8219eed5c6307827fb99 100644 (file)
@@ -7,6 +7,9 @@ class MSSQLImpl(DefaultImpl):
     __dialect__ = 'mssql'
     transactional_ddl = True
 
+    def emit_begin(self):
+        self._exec("BEGIN TRANSACTION")
+
     def bulk_insert(self, table, rows):
         if self.as_sql:
             self._exec(
index 4bc6065a1d2bc27439d3f8753b3bd06e68b71011..fbd0ada002137ed45389fa18375c650c6d534de1 100644 (file)
@@ -36,11 +36,8 @@ def run_migrations_offline():
     url = config.get_main_option("sqlalchemy.url")
     context.configure(url=url)
 
-    if context.is_transactional_ddl():
-        context.execute("BEGIN")
-    context.run_migrations()
-    if context.is_transactional_ddl():
-        context.execute("COMMIT")
+    with context.begin_transaction():
+        context.run_migrations()
 
 def run_migrations_online():
     """Run migrations in 'online' mode.
@@ -58,13 +55,8 @@ def run_migrations_online():
                 target_metadata=target_metadata
                 )
 
-    trans = connection.begin()
-    try:
+    with context.begin_transaction():
         context.run_migrations()
-        trans.commit()
-    except:
-        trans.rollback()
-        raise
 
 if context.is_offline_mode():
     run_migrations_offline()
index ee16f064e6e851e13ec6f6ae73dea4026ae6ae5d..4df6ae3666baa1d1619c62419e66b34b1d0fb3a8 100644 (file)
@@ -53,7 +53,8 @@ def run_migrations_offline():
                     url=rec['url'],
                     output_buffer=open(file_, 'w')
                 )
-        context.run_migrations(engine=name)
+        with context.begin_transaction():
+            context.run_migrations(engine=name)
 
 def run_migrations_online():
     """Run migrations in 'online' mode.
index e9ea402efbf66269439680d191dd5e04fe7f78b1..25832bfaac2b9054e5ef9ac5dd4b2f5ef3c026c8 100644 (file)
@@ -42,8 +42,9 @@ def run_migrations_offline():
     
     """
     context.configure(
-                dialect_name=meta.engine.name)
-    context.run_migrations()
+                url=meta.engine.url)
+    with context.begin_transaction():
+        context.run_migrations()
 
 def run_migrations_online():
     """Run migrations in 'online' mode.
@@ -57,13 +58,9 @@ def run_migrations_online():
                 connection=connection,
                 target_metadata=target_metadata
                 )
-    trans = connection.begin()
-    try:
+
+    with context.begin_transaction():
         context.run_migrations()
-        trans.commit()
-    except:
-        trans.rollback()
-        raise
 
 if context.is_offline_mode():
     run_migrations_offline()
index e44e23b41aa647412feea0df7ade4805fe6b765f..0c35747bdca7dab609347e3e8ba25fbb98cacae0 100644 (file)
@@ -14,6 +14,7 @@ import ConfigParser
 from nose import SkipTest
 from sqlalchemy.exc import SQLAlchemyError
 from sqlalchemy.util import decorator
+import shutil
 
 staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
 files_directory = os.path.join(os.path.dirname(__file__), 'files')
@@ -74,9 +75,13 @@ def assert_compiled(element, assert_string, dialect=None):
         assert_string.replace("\n", "").replace("\t", "")
     )
 
-def capture_context_buffer():
+def capture_context_buffer(transactional_ddl=None):
     buf = StringIO.StringIO()
 
+    if transactional_ddl is not None:
+        context._context_opts['transactional_ddl'] = \
+            transactional_ddl
+
     class capture(object):
         def __enter__(self):
             context._context_opts['output_buffer'] = buf
@@ -210,13 +215,13 @@ datefmt = %%H:%%M:%%S
     """ % (dir_, dir_))
 
 
-def no_sql_testing_config():
+def no_sql_testing_config(dialect="postgresql"):
     """use a postgresql url with no host so that connections guaranteed to fail"""
     dir_ = os.path.join(staging_directory, 'scripts')
     return _write_config_file("""
 [alembic]
 script_location = %s
-sqlalchemy.url = postgresql://
+sqlalchemy.url = %s://
 
 [loggers]
 keys = root
@@ -242,7 +247,7 @@ keys = generic
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
 
-""" % (dir_))
+""" % (dir_, dialect))
 
 def _write_config_file(text):
     cfg = _testing_config()
@@ -254,7 +259,10 @@ def staging_env(create=True, template="generic"):
     from alembic import command, script
     cfg = _testing_config()
     if create:
-        command.init(cfg, os.path.join(staging_directory, 'scripts'))
+        path = os.path.join(staging_directory, 'scripts')
+        if os.path.exists(path):
+            shutil.rmtree(path)
+        command.init(cfg, path)
     sc = script.ScriptDirectory.from_config(cfg)
     context._opts(cfg,sc, fn=lambda:None)
     return sc
index ab546e6238b35239dd75954c6606e6a6d6491760..9392c6c9c54dcf9002566775f26e161586e177e8 100644 (file)
@@ -1,60 +1,83 @@
 """Test op functions against MSSQL."""
 
-from tests import op_fixture
-from alembic import op
+from tests import op_fixture, capture_context_buffer, no_sql_testing_config, staging_env, three_rev_fixture, clear_staging_env
+from alembic import op, command
 from sqlalchemy import Integer, Column, ForeignKey, \
             UniqueConstraint, Table, MetaData, String
 from sqlalchemy.sql import table
+from unittest import TestCase
 
-def test_add_column():
-    context = op_fixture('mssql')
-    op.add_column('t1', Column('c1', Integer, nullable=False))
-    context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
-
-def test_add_column_with_default():
-    context = op_fixture("mssql")
-    op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
-    context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")
-
-def test_alter_column_rename_mssql():
-    context = op_fixture('mssql')
-    op.alter_column("t", "c", name="x")
-    context.assert_(
-        "EXEC sp_rename 't.c', 'x', 'COLUMN'"
-    )
-
-def test_drop_column_w_default():
-    context = op_fixture('mssql')
-    op.drop_column('t1', 'c1', mssql_drop_default=True)
-    context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)")
-    context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
-
-
-def test_drop_column_w_check():
-    context = op_fixture('mssql')
-    op.drop_column('t1', 'c1', mssql_drop_check=True)
-    context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)")
-    context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
-
-def test_alter_column_nullable():
-    context = op_fixture('mssql')
-    op.alter_column("t", "c", nullable=True)
-    context.assert_(
-        "ALTER TABLE t ALTER COLUMN c NULL"
-    )
-
-def test_alter_column_not_nullable():
-    context = op_fixture('mssql')
-    op.alter_column("t", "c", nullable=False)
-    context.assert_(
-        "ALTER TABLE t ALTER COLUMN c SET NOT NULL"
-    )
-
-# TODO: when we add schema support
-#def test_alter_column_rename_mssql_schema():
-#    context = op_fixture('mssql')
-#    op.alter_column("t", "c", name="x", schema="y")
-#    context.assert_(
-#        "EXEC sp_rename 'y.t.c', 'x', 'COLUMN'"
-#    )
+
+class FullEnvironmentTests(TestCase):
+    @classmethod
+    def setup_class(cls):
+        env = staging_env()
+        cls.cfg = cfg = no_sql_testing_config("mssql")
+
+        cls.a, cls.b, cls.c = \
+            three_rev_fixture(cfg)
+
+    @classmethod
+    def teardown_class(cls):
+        clear_staging_env()
+
+    def test_begin_comit(self):
+        with capture_context_buffer(transactional_ddl=True) as buf:
+            command.upgrade(self.cfg, self.a, sql=True)
+        assert "BEGIN TRANSACTION" in buf.getvalue()
+        assert "COMMIT" in buf.getvalue()
+
+class OpTest(TestCase):
+    def test_add_column(self):
+        context = op_fixture('mssql')
+        op.add_column('t1', Column('c1', Integer, nullable=False))
+        context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL")
+
+
+    def test_add_column_with_default(self):
+        context = op_fixture("mssql")
+        op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12"))
+        context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL DEFAULT '12'")
+
+    def test_alter_column_rename_mssql(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", name="x")
+        context.assert_(
+            "EXEC sp_rename 't.c', 'x', 'COLUMN'"
+        )
+
+    def test_drop_column_w_default(self):
+        context = op_fixture('mssql')
+        op.drop_column('t1', 'c1', mssql_drop_default=True)
+        context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)")
+        context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
+
+
+    def test_drop_column_w_check(self):
+        context = op_fixture('mssql')
+        op.drop_column('t1', 'c1', mssql_drop_check=True)
+        context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)")
+        context.assert_contains("ALTER TABLE t1 DROP COLUMN c1")
+
+    def test_alter_column_nullable(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", nullable=True)
+        context.assert_(
+            "ALTER TABLE t ALTER COLUMN c NULL"
+        )
+
+    def test_alter_column_not_nullable(self):
+        context = op_fixture('mssql')
+        op.alter_column("t", "c", nullable=False)
+        context.assert_(
+            "ALTER TABLE t ALTER COLUMN c SET NOT NULL"
+        )
+
+    # TODO: when we add schema support
+    #def test_alter_column_rename_mssql_schema(self):
+    #    context = op_fixture('mssql')
+    #    op.alter_column("t", "c", name="x", schema="y")
+    #    context.assert_(
+    #        "EXEC sp_rename 'y.t.c', 'x', 'COLUMN'"
+    #    )
 
index 1df94cdcb1be99d44b795fd61e3df95f7755cc9a..a615bdf1364320f4456a9b56fb563520e310d3b7 100644 (file)
@@ -14,6 +14,17 @@ def setup():
 def teardown():
     clear_staging_env()
 
+def test_begin_comit():
+    with capture_context_buffer(transactional_ddl=True) as buf:
+        command.upgrade(cfg, a, sql=True)
+    assert "BEGIN" in buf.getvalue()
+    assert "COMMIT" in buf.getvalue()
+
+    with capture_context_buffer(transactional_ddl=False) as buf:
+        command.upgrade(cfg, a, sql=True)
+    assert "BEGIN" not in buf.getvalue()
+    assert "COMMIT" not in buf.getvalue()
+
 def test_version_from_none_insert():
     with capture_context_buffer() as buf:
         command.upgrade(cfg, a, sql=True)