]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
- clean up whitespace
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Feb 2011 18:00:54 +0000 (13:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 25 Feb 2011 18:00:54 +0000 (13:00 -0500)
- can't import sqlalchemy.test anymore

15 files changed:
alembic/command.py
alembic/config.py
alembic/context.py
alembic/ddl/base.py
alembic/ddl/mysql.py
alembic/ddl/postgresql.py
alembic/op.py
alembic/script.py
alembic/util.py
setup.py
tests/__init__.py
tests/test_revision_create.py
tests/test_revision_paths.py
tests/test_schema.py
tests/test_versioning.py

index 3a4bff423b7d269764635b1336fef0c678a2a9fb..c1147593fe71ad52c57946c3469fbef40e565b73 100644 (file)
@@ -5,7 +5,7 @@ import functools
 
 def list_templates(config):
     """List available templates"""
-    
+
     print "Available templates:\n"
     for tempname in os.listdir(config.get_template_directory()):
         readme = os.path.join(
@@ -14,13 +14,13 @@ def list_templates(config):
                         'README')
         synopsis = open(readme).next()
         print util.format_opt(tempname, synopsis)
-    
+
     print "\nTemplates are used via the 'init' command, e.g.:"
     print "\n  alembic init --template pylons ./scripts"
-    
+
 def init(config, directory, template='generic'):
     """Initialize a new scripts directory."""
-    
+
     if os.access(directory, os.F_OK):
         raise util.CommandError("Directory %s already exists" % directory)
 
@@ -31,7 +31,7 @@ def init(config, directory, template='generic'):
 
     util.status("Creating directory %s" % os.path.abspath(directory),
                 os.makedirs, directory)
-    
+
     versions = os.path.join(directory, 'versions')
     util.status("Creating directory %s" % os.path.abspath(versions),
                 os.makedirs, versions)
@@ -64,7 +64,7 @@ def revision(config, message=None):
 
     script = ScriptDirectory.from_config(config)
     script.generate_rev(util.rev_id(), message)
-    
+
 def upgrade(config, revision, sql=False):
     """Upgrade to a later version."""
 
@@ -75,10 +75,10 @@ def upgrade(config, revision, sql=False):
         as_sql = sql
     )
     script.run_env()
-    
+
 def downgrade(config, revision, sql=False):
     """Revert to a previous version."""
-    
+
     script = ScriptDirectory.from_config(config)
     context.opts(
         config,
@@ -102,10 +102,10 @@ def branches(config):
     for sc in script.walk_revisions():
         if sc.is_branch_point:
             print sc
-    
+
 def current(config):
     """Display the current revision for each database."""
-    
+
     script = ScriptDirectory.from_config(config)
     def display_version(rev):
         print "Current revision for %s: %s" % (
@@ -113,14 +113,14 @@ def current(config):
                                 context.get_context().connection.engine.url),
                             script._get_rev(rev))
         return []
-    
+
     context.opts(
         config,
         fn = display_version
-    )    
+    )
     script.run_env()
-    
+
 def splice(config, parent, child):
     """'splice' two branches, creating a new revision file."""
-    
-    
+
+
index f8673381d400d8a1f2ebe1f32b10e20db1bb7ae1..1cbe1f0a2e58dbb14ec18474a50b4174ffe0fd59 100644 (file)
@@ -4,17 +4,17 @@ import ConfigParser
 import inspect
 import os
 import sys
-    
+
 class Config(object):
     def __init__(self, file_):
         self.config_file_name = file_
-    
+
     @util.memoized_property
     def file_config(self):
         file_config = ConfigParser.ConfigParser()
         file_config.read([self.config_file_name])
         return file_config
-        
+
     def get_template_directory(self):
         # TODO: what's the official way to get at
         # setuptools-installed datafiles ?
index 6978bbaad6d0d57f6e219e72c31526064fd9039d..711d1cd40d9615268bd8518c86d524e56432f2d6 100644 (file)
@@ -22,15 +22,15 @@ _version = Table('alembic_version', _meta,
 class DefaultContext(object):
     __metaclass__ = ContextMeta
     __dialect__ = 'default'
-    
+
     transactional_ddl = False
     as_sql = False
-    
+
     def __init__(self, connection, fn, as_sql=False):
         self.connection = connection
         self._migrations_fn = fn
         self.as_sql = as_sql
-        
+
     def _current_rev(self):
         if self.as_sql:
             if not self.connection.dialect.has_table(self.connection, 'alembic_version'):
@@ -39,18 +39,18 @@ class DefaultContext(object):
         else:
             _version.create(self.connection, checkfirst=True)
         return self.connection.scalar(_version.select())
-    
+
     def _update_current_rev(self, old, new):
         if old == new:
             return
-            
+
         if new is None:
             self._exec(_version.delete())
         elif old is None:
             self._exec(_version.insert().values(version_num=literal_column("'%s'" % new)))
         else:
             self._exec(_version.update().values(version_num=literal_column("'%s'" % new)))
-            
+
     def run_migrations(self, **kw):
         log.info("Context class %s.", self.__class__.__name__)
         log.info("Will assume %s DDL.", 
@@ -67,13 +67,13 @@ class DefaultContext(object):
             if not self.transactional_ddl:
                 self._update_current_rev(prev_rev, rev)
             prev_rev = rev
-            
+
         if self.transactional_ddl:
             self._update_current_rev(current_rev, rev)
-        
+
         if self.as_sql and self.transactional_ddl:
             print "COMMIT;\n"
-            
+
     def _exec(self, construct):
         if isinstance(construct, basestring):
             construct = text(construct)
@@ -81,24 +81,24 @@ class DefaultContext(object):
             print unicode(construct.compile(dialect=self.connection.dialect)).replace("\t", "    ") + ";"
         else:
             self.connection.execute(construct)
-    
+
     def execute(self, sql):
         self._exec(sql)
-        
+
     def alter_column(self, table_name, column_name, 
                         nullable=util.NO_VALUE,
                         server_default=util.NO_VALUE,
                         name=util.NO_VALUE,
                         type=util.NO_VALUE
     ):
-    
+
         if nullable is not util.NO_VALUE:
             self._exec(base.ColumnNullable(table_name, column_name, nullable))
         if server_default is not util.NO_VALUE:
             self._exec(base.ColumnDefault(table_name, column_name, server_default))
-    
+
         # ... etc
-        
+
     def add_constraint(self, const):
         self._exec(schema.AddConstraint(const))
 
@@ -106,12 +106,12 @@ def opts(cfg, **kw):
     global _context_opts, config
     _context_opts = kw
     config = cfg
-    
+
 def configure_connection(connection):
     global _context
     from alembic.ddl import base
     _context = _context_impls.get(connection.dialect.name, DefaultContext)(connection, **_context_opts)
-    
+
 def run_migrations(**kw):
     _context.run_migrations(**kw)
 
index ff86f1f6dbdfd19688c9d53f2cc06589445e3791..3c50bf12f0c6b3cc488bb183fa29a0e432ee5545 100644 (file)
@@ -4,10 +4,10 @@ from sqlalchemy.schema import DDLElement
 
 class AlterTable(DDLElement):
     """Represent an ALTER TABLE statement.
-    
+
     Only the string name and optional schema name of the table
     is required, not a full Table object.
-    
+
     """
     def __init__(self, table_name, schema=None):
         self.table_name = table_name
@@ -37,7 +37,7 @@ class ColumnDefault(AlterColumn):
     def __init__(self, name, column_name, default, schema=None):
         super(ColumnDefault, self).__init__(name, column_name, schema=schema)
         self.default = default
-    
+
 class AddColumn(AlterTable):
     def __init__(self, name, column, schema=None):
         super(AddColumn, self).__init__(name, schema=schema)
@@ -60,7 +60,7 @@ def visit_column_nullable(element, compiler, **kw):
 
 def quote_dotted(name, quote):
     """quote the elements of a dotted name"""
-    
+
     result = '.'.join([quote(x) for x in name.split('.')])
     return result
 
index 0b60bf2a5f86ee5c7ac0503f5b5f5a41b98684fa..f7b7b30d62c030daf881ba8240c70d34050e4ee5 100644 (file)
@@ -2,4 +2,4 @@ from alembic.context import DefaultContext
 
 class MySQLContext(DefaultContext):
     __dialect__ = 'mysql'
-    
+
index ebd2f00cc0b1fbbff5a32ce68bd872dc8f4a376a..79d6f1a042f9a6032156e9435716297a4ac2df1b 100644 (file)
@@ -3,4 +3,3 @@ from alembic.context import DefaultContext
 class PostgresqlContext(DefaultContext):
     __dialect__ = 'postgresql'
     transactional_ddl = True
-    
\ No newline at end of file
index d2b1083db2f317f03d1e94c87449a9b3e1678c9f..74acb26a1749c543a9f826ba6548bef7b361c63f 100644 (file)
@@ -16,7 +16,7 @@ def alter_column(table_name, column_name,
                     type_=util.NO_VALUE
 ):
     """Issue ALTER COLUMN using the current change context."""
-    
+
     context.alter_column(table_name, column_name, 
         nullable=nullable,
         server_default=server_default,
@@ -38,7 +38,7 @@ def _foreign_key_constraint(name, source, referent, local_cols, remote_cols):
                                         name=name
                                         )
     t1.append_constraint(f)
-    
+
     return f
 
 def _unique_constraint(name, source, local_cols):
@@ -56,7 +56,7 @@ def _table(name, *columns, **kw):
 def _ensure_table_for_fk(metadata, fk):
     """create a placeholder Table object for the referent of a
     ForeignKey.
-    
+
     """
     if isinstance(fk._colspec, basestring):
         table_key, cname = fk._colspec.split('.')
index ede4c05075f3d36f97eb5113f07104eeef7f09f6..888d4caecd3790d8accc0a58f0b78f67fbb8670d 100644 (file)
@@ -12,23 +12,23 @@ class ScriptDirectory(object):
     def __init__(self, dir):
         self.dir = dir
         self.versions = os.path.join(self.dir, 'versions')
-        
+
         if not os.access(dir, os.F_OK):
             raise util.CommandError("Path doesn't exist: %r.  Please use "
                         "the 'init' command to create a new "
                         "scripts folder." % dir)
-        
+
     @classmethod
     def from_config(cls, config):
         return ScriptDirectory(
                     config.get_main_option('script_location'))
-    
+
     def walk_revisions(self):
         """Iterate through all revisions.
-        
+
         This is actually a breadth-first tree traversal,
         with leaf nodes being heads.
-        
+
         """
         heads = set(self._get_heads())
         base = self._get_rev("base")
@@ -44,7 +44,7 @@ class ScriptDirectory(object):
                         break
                     else:
                         yield sc
-        
+
     def _get_rev(self, id_):
         if id_ == 'head':
             id_ = self._current_head()
@@ -54,7 +54,7 @@ class ScriptDirectory(object):
             return self._revision_map[id_]
         except KeyError:
             raise util.CommandError("No such revision %s" % id_)
-            
+
     def _revs(self, upper, lower):
         lower = self._get_rev(lower)
         upper = self._get_rev(upper)
@@ -62,19 +62,19 @@ class ScriptDirectory(object):
         while script != lower:
             yield script
             script = self._revision_map[script.down_revision]
-        
+
     def upgrade_from(self, destination, current_rev):
         return [
             (script.module.upgrade, script.revision) for script in 
             reversed(list(self._revs(destination, current_rev)))
             ]
-        
+
     def downgrade_to(self, destination, current_rev):
         return [
             (script.module.downgrade, script.down_revision) for script in 
             self._revs(current_rev, destination)
             ]
-        
+
     def run_env(self):
         util.load_python_file(self.dir, 'env.py')
 
@@ -99,11 +99,11 @@ class ScriptDirectory(object):
                 map_[rev.down_revision].add_nextrev(rev.revision)
         map_[None] = None
         return map_
-    
+
     def _rev_path(self, rev_id):
         filename = "%s.py" % rev_id
         return os.path.join(self.versions, filename)
-    
+
     def write(self, rev_id, content):
         path = self._rev_path(rev_id)
         file(path, 'w').write(content)
@@ -115,7 +115,7 @@ class ScriptDirectory(object):
             raise Exception("Can't change down_revision on a refresh operation.")
         self._revision_map[script.revision] = script
         script.nextrev = old.nextrev
-        
+
     def _current_head(self):
         current_heads = self._get_heads()
         if len(current_heads) > 1:
@@ -124,14 +124,14 @@ class ScriptDirectory(object):
             return current_heads[0]
         else:
             return None
-        
+
     def _get_heads(self):
         heads = []
         for script in self._revision_map.values():
             if script and script.is_head:
                 heads.append(script.revision)
         return heads
-    
+
     def _get_origin(self):
         for script in self._revision_map.values():
             if script.down_revision is None \
@@ -139,7 +139,7 @@ class ScriptDirectory(object):
                 return script
         else:
             return None
-        
+
     def generate_template(self, src, dest, **kw):
         util.status("Generating %s" % os.path.abspath(dest),
             util.template_to_file,
@@ -147,12 +147,12 @@ class ScriptDirectory(object):
             dest,
             **kw
         )
-        
+
     def copy_file(self, src, dest):
         util.status("Generating %s" % os.path.abspath(dest), 
                     shutil.copy, 
                     src, dest)
-    
+
     def generate_rev(self, revid, message):
         current_head = self._current_head()
         path = self._rev_path(revid)
@@ -169,30 +169,30 @@ class ScriptDirectory(object):
         if script.down_revision:
             self._revision_map[script.down_revision].add_nextrev(script.revision)
         return script
-        
+
 class Script(object):
     nextrev = frozenset()
-    
+
     def __init__(self, module, rev_id):
         self.module = module
         self.revision = rev_id
         self.down_revision = getattr(module, 'down_revision', None)
-    
+
     @property
     def doc(self):
         return re.split(r"\n\n", self.module.__doc__)[0]
 
     def add_nextrev(self, rev):
         self.nextrev = self.nextrev.union([rev])
-        
+
     @property
     def is_head(self):
         return not bool(self.nextrev)
-    
+
     @property
     def is_branch_point(self):
         return len(self.nextrev) > 1
-        
+
     def __str__(self):
         return "%s -> %s%s%s, %s" % (
                         self.down_revision, 
@@ -200,12 +200,12 @@ class Script(object):
                         " (head)" if self.is_head else "", 
                         " (branchpoint)" if self.is_branch_point else "",
                         self.doc)
-    
+
     @classmethod
     def from_path(cls, path):
         dir_, filename = os.path.split(path)
         return cls.from_filename(dir_, filename)
-        
+
     @classmethod
     def from_filename(cls, dir_, filename):
         m = _rev_file.match(filename)
@@ -213,4 +213,3 @@ class Script(object):
             return None
         module = util.load_python_file(dir_, filename)
         return Script(module, m.group(1))
-        
\ No newline at end of file
index 087e189eb9f610a7e9b9d8e3113598a18613e666..f333bf1608ee1562fdadd4a0884c243fed17393d 100644 (file)
@@ -15,7 +15,7 @@ NO_VALUE = util.symbol("NO_VALUE")
 
 class CommandError(Exception):
     pass
-    
+
 try:
     width = int(os.environ['COLUMNS'])
 except (KeyError, ValueError):
@@ -48,10 +48,10 @@ def obfuscate_url_pw(u):
     if u.password:
         u.password = 'XXXXX'
     return str(u)
-    
+
 def warn(msg):
     warnings.warn(msg)
-    
+
 def msg(msg, newline=True):
     lines = textwrap.wrap(msg, width)
     if len(lines) > 1:
@@ -61,7 +61,7 @@ def msg(msg, newline=True):
 
 def load_python_file(dir_, filename):
     """Load a file from the given path as a Python module."""
-    
+
     module_id = re.sub(r'\W', "_", filename)
     path = os.path.join(dir_, filename)
     module = imp.load_source(module_id, path, open(path, 'rb'))
@@ -71,7 +71,7 @@ def load_python_file(dir_, filename):
 def rev_id():
     val = int(uuid.uuid4()) % 100000000000000
     return hex(val)[2:-1]
-    
+
 class memoized_property(object):
     """A read-only @property that is only evaluated once."""
     def __init__(self, fget, doc=None):
index e2813621d6feb7d069810a97e4982a25f7966213..2462ba39551af97a19f0bfeb82036e3c6b7e8f7d 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -13,14 +13,14 @@ def datafiles():
         if files:
             out.append((root, [os.path.join(root, f) for f in files]))
     return out
-    
+
 setup(name='alembic',
       version=VERSION,
       description="A database migration tool for SQLAlchemy.",
       long_description="""\
 Alembic is an open ended migrations tool.
 Basic operation involves the creation of script files, 
-each representing a version transition for one or more databases.  
+each representing a version transition for one or more databases.
 The scripts execute within the context of a particular connection 
 and transactional configuration that is explicitly constructed.
 
@@ -49,7 +49,7 @@ Key goals of Alembic are:
  * The ability to integrate configuration with other frameworks.
    A Pylons template is included which pulls all configuration
    from the Pylons project environment.
-    
+
 """,
       classifiers=[
       'Development Status :: 3 - Alpha',
index 64ebcb5eb1f8586b2bbd9bf1a04df71d2f8d168c..60869318528b0d94aad163d4bd5287af90b36168 100644 (file)
@@ -1,4 +1,3 @@
-from sqlalchemy.test.testing import eq_, ne_
 from sqlalchemy.util import defaultdict
 from sqlalchemy.engine import url, default
 import shutil
@@ -14,8 +13,8 @@ def _get_dialect(name):
         return default.DefaultDialect()
     else:
         return _dialects[name]
-    
-    
+
+
 def assert_compiled(element, assert_string, dialect=None):
     dialect = _get_dialect(dialect)
     eq_(
@@ -23,6 +22,14 @@ def assert_compiled(element, assert_string, dialect=None):
         assert_string.replace("\n", "").replace("\t", "")
     )
 
+def eq_(a, b, msg=None):
+    """Assert a == b, with repr messaging on failure."""
+    assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+    """Assert a != b, with repr messaging on failure."""
+    assert a != b, msg or "%r == %r" % (a, b)
+
 def _testing_config():
     from alembic.config import Config
     if not os.access(staging_directory, os.F_OK):
@@ -62,21 +69,20 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
     """ % (dir_, dir_))
     return cfg
-    
+
 def sqlite_db():
     # sqlite caches table pragma info 
     # per connection, so create a new
     # engine for each assertion
     dir_ = os.path.join(staging_directory, 'scripts')
     return create_engine('sqlite:///%s/foo.db' % dir_)
-    
+
 def staging_env(create=True):
     from alembic import command, script
     cfg = _testing_config()
     if create:
         command.init(cfg, os.path.join(staging_directory, 'scripts'))
     return script.ScriptDirectory.from_config(cfg)
-    
+
 def clear_staging_env():
     shutil.rmtree(staging_directory, True)
-    
\ No newline at end of file
index 9ad6e03b18bb0a71d5d0b19b3a14b09896b72505..6c8163b351ac9f40168238eca5bc441c3decb642 100644 (file)
@@ -14,10 +14,10 @@ def test_002_rev_ids():
     abc = util.rev_id()
     def_ = util.rev_id()
     ne_(abc, def_)
-    
+
 def test_003_heads():
     eq_(env._get_heads(), [])
-    
+
 def test_004_rev():
     script = env.generate_rev(abc, "this is a message")
     eq_(script.doc, "this is a message")
@@ -26,7 +26,7 @@ def test_004_rev():
     assert os.access(os.path.join(env.dir, 'versions', '%s.py' % abc), os.F_OK)
     assert callable(script.module.upgrade)
     eq_(env._get_heads(), [abc])
-    
+
 def test_005_nextrev():
     script = env.generate_rev(def_, "this is the next rev")
     eq_(script.revision, def_)
@@ -40,7 +40,7 @@ def test_005_nextrev():
 def test_006_from_clean_env():
     # test the environment so far with a 
     # new ScriptDirectory instance.
-    
+
     env = staging_env(create=False)
     abc_rev = env._revision_map[abc]
     def_rev = env._revision_map[def_]
@@ -48,10 +48,10 @@ def test_006_from_clean_env():
     eq_(abc_rev.revision, abc)
     eq_(def_rev.down_revision, abc)
     eq_(env._get_heads(), [def_])
-    
+
 def setup():
     global env
     env = staging_env()
-    
+
 def teardown():
     clear_staging_env()
\ No newline at end of file
index c891dab04c47e0d97ee6ac7a973e376b9c627271..c477c0b06e8aa7e5e561e2f56880d79b04f6f612 100644 (file)
@@ -11,13 +11,13 @@ def setup():
     c = env.generate_rev(util.rev_id(), None)
     d = env.generate_rev(util.rev_id(), None)
     e = env.generate_rev(util.rev_id(), None)
-    
+
 def teardown():
     clear_staging_env()
 
 
 def test_upgrade_path():
-    
+
     eq_(
         env.upgrade_from(e.revision, c.revision),
         [
@@ -34,7 +34,7 @@ def test_upgrade_path():
             (c.module.upgrade, c.revision),
         ]
     )
-    
+
 def test_downgrade_path():
 
     eq_(
index 84bff84dc8d7753bf79da95528ba2833edb1f18b..f5bcb8e322a923a01265b701370d322967b785b3 100644 (file)
@@ -11,14 +11,14 @@ def test_foreign_key():
         AddConstraint(fk),
         "ALTER TABLE t1 ADD CONSTRAINT hoho FOREIGN KEY(foo, bar) REFERENCES t2 (bat, hoho)"
     )
-    
+
 def test_unique_constraint():
     uc = op._unique_constraint('uk_test', 't1', ['foo', 'bar'])
     assert_compiled(
         AddConstraint(uc),
         "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)"
     )
-    
+
 
 def test_table():
     tb = op._table("some_table", 
@@ -50,7 +50,7 @@ def test_table():
             "FOREIGN KEY(foo_id) REFERENCES foo (id), "
             "FOREIGN KEY(foo_bar) REFERENCES foo (bar))"
     )
-    
+
     m = MetaData()
     foo = Table('foo', m, Column('id', Integer, primary_key=True))
     tb = op._table("some_table", 
index bb2f29f727157c28056d0ad67eff49ea4966127f..94065686a1d6b4ad45c3d4749fe8a79cff80ccc5 100644 (file)
@@ -8,7 +8,7 @@ def test_001_revisions():
     a = util.rev_id()
     b = util.rev_id()
     c = util.rev_id()
-    
+
     script = ScriptDirectory.from_config(cfg)
     script.generate_rev(a, None)
     script.write(a, """
@@ -51,8 +51,8 @@ def downgrade():
     execute("DROP TABLE bat")
 
 """ % b)
-    
-    
+
+
 def test_002_upgrade():
     command.upgrade(cfg, c)
     db = sqlite_db()
@@ -88,7 +88,7 @@ def setup():
     global cfg, env
     env = staging_env()
     cfg = _sqlite_testing_config()
-    
-    
+
+
 def teardown():
     clear_staging_env()
\ No newline at end of file