]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- get the "stamp" command to work in as_sql
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2011 18:44:05 +0000 (10:44 -0800)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2011 18:44:05 +0000 (10:44 -0800)
- overhaul context + env so that --sql mode truly does
not make any SQL connections of any kind.   The env.py
scripts create the engine and use it as the source of
a "dialect" - the "dialect" is now passed straight to the context.
- more power to env - can set output buffer, transactional ddl flag,
execute SQL via context instead of needing to import op

alembic/command.py
alembic/context.py
alembic/templates/generic/env.py
alembic/templates/multidb/env.py
alembic/templates/pylons/env.py
docs/build/api.rst
tests/__init__.py
tests/test_sql_script.py

index 0c48908026d4437b400b1f438510d4f066d13059..0ea8bdbd209886173646fc8e51f67ed0cb5dbbce 100644 (file)
@@ -131,7 +131,10 @@ def stamp(config, revision, sql=False):
 
     script = ScriptDirectory.from_config(config)
     def do_stamp(rev):
-        current = context.get_context()._current_rev()
+        if sql:
+            current = False
+        else:
+            current = context.get_context()._current_rev()
         dest = script._get_rev(revision)
         if dest is not None:
             dest = dest.revision
index cef196e49c788db90b96934477bc03c15072e6a5..a7b66cf7ef66e5ef17326aae7ac55e65142baa41 100644 (file)
@@ -2,6 +2,7 @@ from alembic import util
 from sqlalchemy import MetaData, Table, Column, String, literal_column, \
     text
 from sqlalchemy import schema, create_engine
+from sqlalchemy.engine import url as sqla_url
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql.expression import _BindParamClause
 import sys
@@ -31,7 +32,10 @@ class DefaultContext(object):
     transactional_ddl = False
     as_sql = False
 
-    def __init__(self, connection, fn, as_sql=False, output_buffer=sys.stdout):
+    def __init__(self, dialect, connection, fn, as_sql=False, 
+                        output_buffer=None,
+                        transactional_ddl=None):
+        self.dialect = dialect
         if as_sql:
             self.connection = self._stdout_connection(connection)
             assert self.connection is not None
@@ -39,7 +43,12 @@ class DefaultContext(object):
             self.connection = connection
         self._migrations_fn = fn
         self.as_sql = as_sql
-        self.output_buffer = output_buffer
+        if output_buffer is None:
+            self.output_buffer = sys.stdout
+        else:
+            self.output_buffer = output_buffer
+        if transactional_ddl is not None:
+            self.transactional_ddl = transactional_ddl
 
     def _current_rev(self):
         if self.as_sql:
@@ -67,6 +76,8 @@ class DefaultContext(object):
 
     def run_migrations(self, **kw):
         log.info("Context class %s.", self.__class__.__name__)
+        if self.as_sql:
+            log.info("Generating static SQL")
         log.info("Will assume %s DDL.", 
                         "transactional" if self.transactional_ddl 
                         else "non-transactional")
@@ -111,10 +122,6 @@ class DefaultContext(object):
         else:
             self.connection.execute(construct, *args, **kw)
 
-    @property
-    def dialect(self):
-        return self.connection.dialect
-
     def static_output(self, text):
         self.output_buffer.write(text + "\n\n")
 
@@ -125,7 +132,7 @@ class DefaultContext(object):
         def dump(construct, *multiparams, **params):
             self._exec(construct)
 
-        return create_engine(connection.engine.url
+        return create_engine("%s://" % self.dialect.name
                         strategy="mock", executor=dump)
 
     @property
@@ -196,6 +203,7 @@ def _render_literal_bindparam(element, compiler, **kw):
     return compiler.render_literal_bindparam(element, **kw)
 
 _context_opts = {}
+_context = None
 
 def opts(cfg, **kw):
     """Set up options that will be used by the :func:`.configure_connection`
@@ -208,9 +216,29 @@ def opts(cfg, **kw):
     _context_opts.update(kw)
     config = cfg
 
-def configure_connection(connection):
-    """Configure the migration environment against a specific
-    database connection, an instance of :class:`sqlalchemy.engine.Connection`.
+def requires_connection():
+    """Return True if the current migrations environment should have
+    an active database connection.
+    
+    """
+    return not _context_opts.get('as_sql', False)
+
+def configure(
+        connection=None,
+        url=None,
+        dialect_name=None,
+        transactional_ddl=None,
+        output_buffer=None
+    ):
+    """Configure the migration environment.
+    
+    The important thing needed here is first a way to figure out
+    what kind of "dialect" is in use.   The second is to pass
+    an actual database connection, if one is required.
+    
+    If the :func:`requires_connection` function returns False,
+    then no connection is needed here.  Otherwise, the
+    object should be an instance of :class:`sqlalchemy.engine.Connection`.
     
     This function is typically called from the ``env.py``
     script within a migration environment.  It can be called
@@ -218,12 +246,44 @@ def configure_connection(connection):
     for which it was called is the one that will be operated upon
     by the next call to :func:`.run_migrations`.
     
+    :param connection: a :class:`sqlalchemy.engine.Connection`.  The type of dialect
+     to be used will be derived from this.
+    :param url: a string database url, or a :class:`sqlalchemy.engine.url.URL` object.
+     The type of dialect to be used will be derived from this if ``connection`` is
+     not passed.
+    :param dialect_name: string name of a dialect, such as "postgresql", "mssql", etc.
+     The type of dialect to be used will be derived from this if ``connection``
+     and ``url`` are not passed.
+    :param transactional_ddl: Force the usage of "transactional" DDL on or off;
+     this otherwise defaults to whether or not the dialect in use supports it.
+    :param output_buffer: a file-like object that will be used for textual output
+     when the ``--sql`` option is used to generate SQL scripts.  Defaults to
+     ``sys.stdout`` it not passed here.
     """
+
+    if connection:
+        dialect = connection.dialect
+    elif url:
+        url = sqla_url.make_url(url)
+        dialect = url.get_dialect()()
+    elif dialect_name:
+        url = sqla_url.make_url("%s://" % dialect_name)
+        dialect = url.get_dialect()()
+    else:
+        raise Exception("Connection, url, or dialect_name is required.")
+
     global _context
     from alembic.ddl import base
+    opts = _context_opts.copy()
+    opts.setdefault("transactional_ddl", transactional_ddl)
+    opts.setdefault("output_buffer", output_buffer)
     _context = _context_impls.get(
-                    connection.dialect.name, 
-                    DefaultContext)(connection, **_context_opts)
+                    dialect.name, 
+                    DefaultContext)(dialect, connection, **opts)
+
+def configure_connection(connection):
+    """Deprecated; use :func:`alembic.context.configure`."""
+    configure(connection=connection)
 
 def run_migrations(**kw):
     """Run migrations as determined by the current command line configuration
@@ -232,5 +292,16 @@ def run_migrations(**kw):
     """
     _context.run_migrations(**kw)
 
+def execute(sql):
+    """Execute the given SQL using the current change context.
+    
+    In a SQL script context, the statement is emitted directly to the 
+    output stream.
+    
+    """
+    get_context().execute(sql)
+
 def get_context():
+    if _context is None:
+        raise Exception("No context has been configured yet.")
     return _context
\ No newline at end of file
index 347356e791e5925f47be6abefafd7643d992d10c..e10e682c8ab9bb515e66a86ff9d42d3ada70dc02 100644 (file)
@@ -8,12 +8,17 @@ fileConfig(config.config_file_name)
 engine = engine_from_config(
             config.get_section('alembic'), prefix='sqlalchemy.')
 
-connection = engine.connect()
-context.configure_connection(connection)
-trans = connection.begin()
-try:
+if not context.requires_connection():
+    context.configure(dialect_name=engine.name)
     context.run_migrations()
-    trans.commit()
-except:
-    trans.rollback()
-    raise
\ No newline at end of file
+else:
+    connection = engine.connect()
+    context.configure(connection=connection, dialect_name=engine.name)
+
+    trans = connection.begin()
+    try:
+        context.run_migrations()
+        trans.commit()
+    except:
+        trans.rollback()
+        raise
\ No newline at end of file
index 892b580670909072640c4337991e81b7d3983993..a46038ff928163ca3c702f576d8097d60f83ddcc 100644 (file)
@@ -8,32 +8,46 @@ import logging
 logging.fileConfig(options.config_file)
 
 db_names = options.get_main_option('databases')
-
 engines = {}
 for name in re.split(r',\s*', db_names):
     engines[name] = rec = {}
     rec['engine'] = engine = \
                 engine_from_config(options.get_section(name),
                                 prefix='sqlalchemy.')
-    rec['connection'] = conn = engine.connect()
 
-    if USE_TWOPHASE:
-        rec['transaction'] = conn.begin_twophase()
-    else:
-        rec['transaction'] = conn.begin()
 
-try:
+if not context.requires_connection():
     for name, rec in engines.items():
-        context.configure_connection(rec['connection'])
+        context.configure(
+                    dialect_name=rec['engine'].name
+                )
         context.run_migrations(engine=name)
+else:
+    for name, rec in engines.items():
+        engine = rec['engine']
+        rec['connection'] = conn = engine.connect()
+
+        if USE_TWOPHASE:
+            rec['transaction'] = conn.begin_twophase()
+        else:
+            rec['transaction'] = conn.begin()
+
+    try:
+        for name, rec in engines.items():
+            context.configure(
+                        connection=rec['connection'],
+                        dialect_name=rec['engine'].name
+                    )
+            context.execute("--running migrations for engine %s" % name)
+            context.run_migrations(engine=name)
+
+        if USE_TWOPHASE:
+            for rec in engines.values():
+                rec['transaction'].prepare()
 
-    if USE_TWOPHASE:
         for rec in engines.values():
-            rec['transaction'].prepare()
-
-    for rec in engines.values():
-        rec['transaction'].commit()
-except:
-    for rec in engines.values():
-        rec['transaction'].rollback()
-    raise
\ No newline at end of file
+            rec['transaction'].commit()
+    except:
+        for rec in engines.values():
+            rec['transaction'].rollback()
+        raise
\ No newline at end of file
index 8f868c36ca5c3dda96e9745a0d8f74b71b8fdb81..4e9212cbbf14d6b7820d34318ad387712f263a03 100644 (file)
@@ -23,12 +23,16 @@ except:
 # customize this section for non-standard engine configurations.
 meta = __import__("%s.model.meta" % config['pylons.package']).model.meta
 
-connection = meta.engine.connect()
-context.configure_connection(connection)
-trans = connection.begin()
-try:
+if not context.requires_connection():
+    context.configure(dialect_name=meta.engine.name)
     context.run_migrations()
-    trans.commit()
-except:
-    trans.rollback()
-    raise
\ No newline at end of file
+else:
+    connection = meta.engine.connect()
+    context.configure_connection(connection)
+    trans = connection.begin()
+    try:
+        context.run_migrations()
+        trans.commit()
+    except:
+        trans.rollback()
+        raise
\ No newline at end of file
index 238137c171b8cf7daa577d7caf47dd7dfd061bb8..4747ae60d01a90b29f25f889717d7bf65176301f 100644 (file)
@@ -9,7 +9,10 @@ env.py Directives
 =================
 
 .. autofunction:: sqlalchemy.engine.engine_from_config
-.. autofunction:: alembic.context.configure_connection
+.. autofunction:: alembic.context.configure
+.. autofunction:: alembic.context.get_context
+.. autofunction:: alembic.context.execute
+.. autofunction:: alembic.context.requires_connection
 .. autofunction:: alembic.context.run_migrations
 
 Internals
index 1fc72721caeb3c322c37df0279a6083320731fce..1fec30c9cd0638cc361e2e1734c2a77a848818a4 100644 (file)
@@ -122,6 +122,42 @@ datefmt = %%H:%%M:%%S
     """ % (dir_, dir_))
     return cfg
 
+def _no_sql_testing_config():
+    """use a postgresql url with no host so that connections guaranteed to fail"""
+    cfg = _testing_config()
+    dir_ = os.path.join(staging_directory, 'scripts')
+    open(cfg.config_file_name, 'w').write("""
+[alembic]
+script_location = %s
+sqlalchemy.url = postgresql://
+
+[loggers]
+keys = root
+
+[handlers]
+keys = console
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatters]
+keys = generic
+
+[formatter_generic]
+format = %%(levelname)-5.5s [%%(name)s] %%(message)s
+datefmt = %%H:%%M:%%S
+
+""" % (dir_))
+    return cfg
+
 def sqlite_db():
     # sqlite caches table pragma info 
     # per connection, so create a new
index 1971b6e696cc1574b5d5dba72e2a54f833fd4165..5fa7fe514b8622cc74ed1d256a198912006bb034 100644 (file)
@@ -1,11 +1,11 @@
-from tests import clear_staging_env, staging_env, _sqlite_testing_config, sqlite_db, eq_, ne_, capture_context_buffer
+from tests import clear_staging_env, staging_env, _no_sql_testing_config, sqlite_db, eq_, ne_, capture_context_buffer
 from alembic import command, util
 from alembic.script import ScriptDirectory
 
 def setup():
     global cfg, env
     env = staging_env()
-    cfg = _sqlite_testing_config()
+    cfg = _no_sql_testing_config()
 
     global a, b, c
     a = util.rev_id()
@@ -96,3 +96,8 @@ def test_version_to_middle():
     assert "DROP STEP 2" in buf.getvalue()
     assert "DROP STEP 1" not in buf.getvalue()
 
+
+def test_stamp():
+    with capture_context_buffer() as buf:
+        command.stamp(cfg, "head", sql=True)
+    assert "UPDATE alembic_version SET version_num='%s';" % c in buf.getvalue()