]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- sqlite dialect
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Apr 2010 19:47:18 +0000 (15:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Apr 2010 19:47:18 +0000 (15:47 -0400)
- SQL text mode
- some methods to help with upcoming tests

alembic/command.py
alembic/config.py
alembic/context.py
alembic/ddl/sqlite.py [new file with mode: 0644]
alembic/script.py

index 050111a571c2672eb6628b8274c69dddec03fac4..3a4bff423b7d269764635b1336fef0c678a2a9fb 100644 (file)
@@ -22,12 +22,12 @@ def init(config, directory, template='generic'):
     """Initialize a new scripts directory."""
     
     if os.access(directory, os.F_OK):
-        raise util.CommandException("Directory %s already exists" % directory)
+        raise util.CommandError("Directory %s already exists" % directory)
 
     template_dir = os.path.join(config.get_template_directory(),
                                     template)
     if not os.access(template_dir, os.F_OK):
-        raise util.CommandException("No such template %r" % template)
+        raise util.CommandError("No such template %r" % template)
 
     util.status("Creating directory %s" % os.path.abspath(directory),
                 os.makedirs, directory)
@@ -65,20 +65,26 @@ def revision(config, message=None):
     script = ScriptDirectory.from_config(config)
     script.generate_rev(util.rev_id(), message)
     
-def upgrade(config, revision):
+def upgrade(config, revision, sql=False):
     """Upgrade to a later version."""
 
     script = ScriptDirectory.from_config(config)
-    context._migration_fn = functools.partial(script.upgrade_from, revision)
-    context.config = config
+    context.opts(
+        config,
+        fn = functools.partial(script.upgrade_from, revision),
+        as_sql = sql
+    )
     script.run_env()
     
-def downgrade(config, revision):
+def downgrade(config, revision, sql=False):
     """Revert to a previous version."""
     
     script = ScriptDirectory.from_config(config)
-    context._migration_fn = functools.partial(script.downgrade_to, revision)
-    context.config = config
+    context.opts(
+        config,
+        fn = functools.partial(script.downgrade_to, revision),
+        as_sql = sql,
+    )
     script.run_env()
 
 def history(config):
@@ -107,9 +113,11 @@ def current(config):
                                 context.get_context().connection.engine.url),
                             script._get_rev(rev))
         return []
-        
-    context._migration_fn = display_version
-    context.config = config
+    
+    context.opts(
+        config,
+        fn = display_version
+    )    
     script.run_env()
     
 def splice(config, parent, child):
index d244193a44b4f73b6c89f5216dff64ec88f43d39..f8673381d400d8a1f2ebe1f32b10e20db1bb7ae1 100644 (file)
@@ -79,7 +79,7 @@ def main(argv):
                     format_opt(cmd)
                     for cmd in commands.values()
                 ])) +
-                "\n\n<revision> is a hex revision id or 'head'"
+                "\n\n<revision> is a hex revision id, 'head' or 'base'."
                 )
 
     parser.add_option("-c", "--config", 
@@ -93,6 +93,9 @@ def main(argv):
     parser.add_option("-m", "--message",
                         type="string",
                         help="Message string to use with 'revision'")
+    parser.add_option("--sql",
+                        action="store_true",
+                        help="Dump output to a SQL file")
 
     cmd_line_options, cmd_line_args = parser.parse_args(argv[1:])
 
index 96941902aa00ec0aea14771e62cced2c440f2eb7..176388bd1e318ecd184c2aca95243c1dc018070e 100644 (file)
@@ -1,6 +1,7 @@
 from alembic.ddl import base
 from alembic import util
-from sqlalchemy import MetaData, Table, Column, String
+from sqlalchemy import MetaData, Table, Column, String, literal_column, text
+from sqlalchemy.schema import CreateTable
 import logging
 
 log = logging.getLogger(__name__)
@@ -24,13 +25,20 @@ class DefaultContext(object):
     __dialect__ = 'default'
     
     transactional_ddl = False
+    as_sql = False
     
-    def __init__(self, connection, fn):
+    def __init__(self, connection, fn, as_sql=False):
         self.connection = connection
         self._migrations_fn = fn
+        self.as_sql = as_sql
         
     def _current_rev(self):
-        _version.create(self.connection, checkfirst=True)
+        if self.as_sql:
+            if not self.connection.dialect.has_table(self.connection, 'alembic_version'):
+                self._exec(CreateTable(_version))
+                return None
+        else:
+            _version.create(self.connection, checkfirst=True)
         return self.connection.scalar(_version.select())
     
     def _update_current_rev(self, old, new):
@@ -38,17 +46,21 @@ class DefaultContext(object):
             return
             
         if new is None:
-            self.connection.execute(_version.delete())
+            self._exec(_version.delete())
         elif old is None:
-            self.connection.execute(_version.insert(), {'version_num':new})
+            self._exec(_version.insert().values(version_num=literal_column("'%s'" % new)))
         else:
-            self.connection.execute(_version.update(), {'version_num':new})
+            self._exec(_version.update().values(version_num=literal_column("'%s'" % new)))
             
     def run_migrations(self, **kw):
         log.info("Context class %s.", self.__class__.__name__)
         log.info("Will assume %s DDL.", 
                         "transactional" if self.transactional_ddl 
                         else "non-transactional")
+
+        if self.as_sql and self.transactional_ddl:
+            print "BEGIN;\n"
+
         current_rev = prev_rev = rev = self._current_rev()
         for change, rev in self._migrations_fn(current_rev):
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
@@ -60,8 +72,16 @@ class DefaultContext(object):
         if self.transactional_ddl:
             self._update_current_rev(current_rev, rev)
         
+        if self.as_sql and self.transactional_ddl:
+            print "COMMIT;\n"
+            
     def _exec(self, construct):
-        self.connection.execute(construct)
+        if isinstance(construct, basestring):
+            construct = text(construct)
+        if self.as_sql:
+            print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", "    ") + ";"
+        else:
+            self.connection.execute(construct)
     
     def execute(self, sql):
         self._exec(sql)
@@ -83,10 +103,14 @@ class DefaultContext(object):
     def add_constraint(self, const):
         self._exec(schema.AddConstraint(const))
 
-
+def opts(cfg, **kw):
+    global _context_opts, config
+    _context_opts = kw
+    config = cfg
+    
 def configure_connection(connection):
     global _context
-    _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, _migration_fn)
+    _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts)
     
 def run_migrations(**kw):
     _context.run_migrations(**kw)
diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py
new file mode 100644 (file)
index 0000000..20ec1eb
--- /dev/null
@@ -0,0 +1,5 @@
+from alembic.context import DefaultContext
+
+class SQLiteContext(DefaultContext):
+    __dialect__ = 'sqlite'
+    transactional_ddl = True
index d0de004c9faf2f5d4abf085d57b6ef3c4a0647dd..541b32f3af627c0301e3a48c858a978919b15eb1 100644 (file)
@@ -82,7 +82,7 @@ class ScriptDirectory(object):
     def _revision_map(self):
         map_ = {}
         for file_ in os.listdir(self.versions):
-            script = Script.from_path(self.versions, file_)
+            script = Script.from_filename(self.versions, file_)
             if script is None:
                 continue
             if script.revision in map_:
@@ -100,6 +100,18 @@ class ScriptDirectory(object):
         map_[None] = None
         return map_
     
+    def rev_path(self, rev_id):
+        filename = "%s.py" % rev_id
+        return os.path.join(self.versions, filename)
+    
+    def refresh(self, rev_id):
+        script = Script.from_path(self.rev_path(rev_id))
+        old = self._revision_map[script.revision]
+        if old.down_revision != script.down_revision:
+            raise Exception("Can't change down_revision on a refresh operation.")
+        self._revision_map[script.revision] = script
+        script.nextrev = old.nextrev
+        
     def _current_head(self):
         current_heads = self._get_heads()
         if len(current_heads) > 1:
@@ -139,16 +151,16 @@ class ScriptDirectory(object):
     
     def generate_rev(self, revid, message):
         current_head = self._current_head()
-        filename = "%s.py" % revid
+        path = self.rev_path(revid)
         self.generate_template(
             os.path.join(self.dir, "script.py.mako"),
-            os.path.join(self.versions, filename), 
+            path,
             up_revision=str(revid),
             down_revision=current_head,
             create_date=datetime.datetime.now(),
             message=message if message is not None else ("empty message")
         )
-        script = Script.from_path(self.versions, filename)
+        script = Script.from_path(path)
         self._revision_map[script.revision] = script
         if script.down_revision:
             self._revision_map[script.down_revision].add_nextrev(script.revision)
@@ -186,11 +198,15 @@ class Script(object):
                         self.doc)
     
     @classmethod
-    def from_path(cls, dir_, filename):
+    def from_path(cls, path):
+        dir_, filename = os.path.split(path)
+        return cls.from_filename(dir_, filename)
+        
+    @classmethod
+    def from_filename(cls, dir_, filename):
         m = _rev_file.match(filename)
         if not m:
             return None
-        
         module = util.load_python_file(dir_, filename)
         return Script(module, m.group(1))
         
\ No newline at end of file