]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- Added new feature :paramref:`.EnvironmentContext.configure.transaction_per_migration`,
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 May 2014 19:46:00 +0000 (15:46 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 May 2014 19:46:00 +0000 (15:46 -0400)
which when True causes the BEGIN/COMMIT pair to incur for each migration
individually, rather than for the whole series of migrations.  This is
to assist with some database directives that need to be within individual
transactions, without the need to disable transactional DDL entirely.
fixes #201

alembic/environment.py
alembic/migration.py
docs/build/changelog.rst
tests/__init__.py
tests/test_sql_script.py

index cba3beb590ceb96e1fe8de54269aa2ec7a7997fa..f8875a2ad69864a04781d4547b0a494fe0cc4e24 100644 (file)
@@ -1,5 +1,3 @@
-from contextlib import contextmanager
-
 from .operations import Operations
 from .migration import MigrationContext
 from . import util
@@ -60,7 +58,6 @@ class EnvironmentContext(object):
     """
 
     _migration_context = None
-    _default_opts = None
 
     config = None
     """An instance of :class:`.Config` representing the
@@ -87,8 +84,6 @@ class EnvironmentContext(object):
         self.config = config
         self.script = script
         self.context_opts = kw
-        if self._default_opts:
-            self.context_opts.update(self._default_opts)
 
     def __enter__(self):
         """Establish a context which provides a
@@ -261,6 +256,7 @@ class EnvironmentContext(object):
             url=None,
             dialect_name=None,
             transactional_ddl=None,
+            transaction_per_migration=False,
             output_buffer=None,
             starting_rev=None,
             tag=None,
@@ -325,6 +321,12 @@ class EnvironmentContext(object):
          DDL on or off;
          this otherwise defaults to whether or not the dialect in
          use supports it.
+        :param transaction_per_migration: if True, nest each migration script
+         in a transaction rather than the full series of migrations to
+         run.
+
+         .. versionadded:: 0.6.5
+
         :param output_buffer: a file-like object that will be used
          for textual output
          when the ``--sql`` option is used to generate SQL scripts.
@@ -635,6 +637,7 @@ class EnvironmentContext(object):
             opts['tag'] = tag
         if template_args and 'template_args' in opts:
             opts['template_args'].update(template_args)
+        opts["transaction_per_migration"] = transaction_per_migration
         opts['target_metadata'] = target_metadata
         opts['include_symbol'] = include_symbol
         opts['include_object'] = include_object
@@ -651,6 +654,7 @@ class EnvironmentContext(object):
         if compare_server_default is not None:
             opts['compare_server_default'] = compare_server_default
         opts['script'] = self.script
+
         opts.update(kw)
 
         self._migration_context = MigrationContext.configure(
@@ -709,6 +713,7 @@ class EnvironmentContext(object):
         """
         self.get_context().impl.static_output(text)
 
+
     def begin_transaction(self):
         """Return a context manager that will
         enclose an operation within a "transaction",
@@ -752,20 +757,9 @@ class EnvironmentContext(object):
         mode.
 
         """
-        if not self.is_transactional_ddl():
-            @contextmanager
-            def do_nothing():
-                yield
-            return do_nothing()
-        elif self.is_offline_mode():
-            @contextmanager
-            def begin_commit():
-                self.get_context().impl.emit_begin()
-                yield
-                self.get_context().impl.emit_commit()
-            return begin_commit()
-        else:
-            return self.get_bind().begin()
+
+        return self.get_context().begin_transaction()
+
 
     def get_context(self):
         """Return the current :class:`.MigrationContext` object.
index 9a01b0161db05e8cfd83366c9434916d5f075754..e554515b7a6be9fb549ac181a1d880100e670dba 100644 (file)
@@ -1,6 +1,8 @@
 import io
 import logging
 import sys
+from contextlib import contextmanager
+
 
 from sqlalchemy import MetaData, Table, Column, String, literal_column
 from sqlalchemy import create_engine
@@ -64,6 +66,9 @@ class MigrationContext(object):
         as_sql = opts.get('as_sql', False)
         transactional_ddl = opts.get("transactional_ddl")
 
+        self._transaction_per_migration = opts.get(
+                                            "transaction_per_migration", False)
+
         if as_sql:
             self.connection = self._stdout_connection(connection)
             assert self.connection is not None
@@ -146,6 +151,30 @@ class MigrationContext(object):
         return MigrationContext(dialect, connection, opts)
 
 
+    def begin_transaction(self, _per_migration=False):
+        transaction_now = _per_migration == self._transaction_per_migration
+
+        if not transaction_now:
+            @contextmanager
+            def do_nothing():
+                yield
+            return do_nothing()
+
+        elif not self.impl.transactional_ddl:
+            @contextmanager
+            def do_nothing():
+                yield
+            return do_nothing()
+        elif self.as_sql:
+            @contextmanager
+            def begin_commit():
+                self.impl.emit_begin()
+                yield
+                self.impl.emit_commit()
+            return begin_commit()
+        else:
+            return self.bind.begin()
+
     def get_current_revision(self):
         """Return the current revision, usually that which is present
         in the ``alembic_version`` table in the database.
@@ -204,31 +233,35 @@ class MigrationContext(object):
 
         """
         current_rev = rev = False
+        stamp_per_migration = not self.impl.transactional_ddl or \
+                                    self._transaction_per_migration
+
         self.impl.start_migrations()
         for change, prev_rev, rev, doc in self._migrations_fn(
                                             self.get_current_revision(),
                                             self):
-            if current_rev is False:
-                current_rev = prev_rev
-                if self.as_sql and not current_rev:
-                    self._version.create(self.connection)
-            if doc:
-                log.info("Running %s %s -> %s, %s", change.__name__, prev_rev,
-                    rev, doc)
-            else:
-                log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
-            if self.as_sql:
-                self.impl.static_output(
-                        "-- Running %s %s -> %s" %
-                        (change.__name__, prev_rev, rev)
-                    )
-            change(**kw)
-            if not self.impl.transactional_ddl:
-                self._update_current_rev(prev_rev, rev)
-            prev_rev = rev
+            with self.begin_transaction(_per_migration=True):
+                if current_rev is False:
+                    current_rev = prev_rev
+                    if self.as_sql and not current_rev:
+                        self._version.create(self.connection)
+                if doc:
+                    log.info("Running %s %s -> %s, %s", change.__name__, prev_rev,
+                        rev, doc)
+                else:
+                    log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
+                if self.as_sql:
+                    self.impl.static_output(
+                            "-- Running %s %s -> %s" %
+                            (change.__name__, prev_rev, rev)
+                        )
+                change(**kw)
+                if stamp_per_migration:
+                    self._update_current_rev(prev_rev, rev)
+                prev_rev = rev
 
         if rev is not False:
-            if self.impl.transactional_ddl:
+            if not stamp_per_migration:
                 self._update_current_rev(current_rev, rev)
 
             if self.as_sql and not rev:
index d4ccb33b758a531efa85aa3ec76603c67a97f870..e312603edbfdee7fd07bcb148db2cc208a542e81 100644 (file)
@@ -5,6 +5,16 @@ Changelog
 .. changelog::
     :version: 0.6.5
 
+    .. change::
+      :tags: feature, environment
+      :tickets: 201
+
+      Added new feature :paramref:`.EnvironmentContext.configure.transaction_per_migration`,
+      which when True causes the BEGIN/COMMIT pair to incur for each migration
+      individually, rather than for the whole series of migrations.  This is
+      to assist with some database directives that need to be within individual
+      transactions, without the need to disable transactional DDL entirely.
+
     .. change::
       :tags: bug, autogenerate
       :tickets: 200
index 9df275564fb1022cdf821940f7fc0419a82e4ed4..cfb90983713645c61206f02827e30b718c51d4fb 100644 (file)
@@ -20,6 +20,7 @@ from alembic.environment import EnvironmentContext
 from alembic.operations import Operations
 from alembic.script import ScriptDirectory, Script
 from alembic.ddl.impl import _impls
+from contextlib import contextmanager
 
 staging_directory = os.path.join(os.path.dirname(__file__), 'scratch')
 files_directory = os.path.join(os.path.dirname(__file__), 'files')
@@ -121,25 +122,24 @@ def assert_compiled(element, assert_string, dialect=None):
         assert_string.replace("\n", "").replace("\t", "")
     )
 
+@contextmanager
 def capture_context_buffer(**kw):
     if kw.pop('bytes_io', False):
         buf = io.BytesIO()
     else:
         buf = io.StringIO()
 
-    class capture(object):
-        def __enter__(self):
-            EnvironmentContext._default_opts = {
+    kw.update({
                 'dialect_name': "sqlite",
                 'output_buffer': buf
-            }
-            EnvironmentContext._default_opts.update(kw)
-            return buf
-
-        def __exit__(self, *arg, **kwarg):
-            EnvironmentContext._default_opts = None
-
-    return capture()
+    })
+    conf = EnvironmentContext.configure
+    def configure(*arg, **opt):
+        opt.update(**kw)
+        return conf(*arg, **opt)
+
+    with mock.patch.object(EnvironmentContext, "configure", configure):
+        yield buf
 
 def eq_ignore_whitespace(a, b, msg=None):
     a = re.sub(r'^\s+?|\n', "", a)
index 3d45998a3709435bd764149e60569ecb0b348e60..7aae797a2e63984ce0592035b3947303606d68ca 100644 (file)
@@ -9,6 +9,7 @@ from . import clear_staging_env, staging_env, \
     three_rev_fixture, write_script
 from alembic import command, util
 from alembic.script import ScriptDirectory
+import re
 
 cfg = None
 a, b, c = None, None, None
@@ -27,17 +28,32 @@ class ThreeRevTest(unittest.TestCase):
     def tearDown(self):
         clear_staging_env()
 
-    def test_begin_comit(self):
+    def test_begin_commit_transactional_ddl(self):
         with capture_context_buffer(transactional_ddl=True) as buf:
-            command.upgrade(cfg, a, sql=True)
-        assert "BEGIN;" in buf.getvalue()
-        assert "COMMIT;" in buf.getvalue()
+            command.upgrade(cfg, c, sql=True)
+        assert re.match(
+                    (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % a) +
+                    (r".*%s" % b) +
+                    (r".*%s.*?COMMIT;.*$" % c),
+
+                buf.getvalue(), re.S)
 
+    def test_begin_commit_nontransactional_ddl(self):
         with capture_context_buffer(transactional_ddl=False) as buf:
             command.upgrade(cfg, a, sql=True)
-        assert "BEGIN;" not in buf.getvalue()
+        assert re.match(r"^CREATE TABLE.*?\n+$", buf.getvalue(), re.S)
         assert "COMMIT;" not in buf.getvalue()
 
+    def test_begin_commit_per_rev_ddl(self):
+        with capture_context_buffer(transaction_per_migration=True) as buf:
+            command.upgrade(cfg, c, sql=True)
+        assert re.match(
+                    (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % a) +
+                    (r"BEGIN;.*?%s.*?COMMIT;.*" % b) +
+                    (r"BEGIN;.*?%s.*?COMMIT;.*$" % c),
+
+                buf.getvalue(), re.S)
+
     def test_version_from_none_insert(self):
         with capture_context_buffer() as buf:
             command.upgrade(cfg, a, sql=True)