]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
just a pep8 pass of lib/sqlalchemy/testing/
authorDiana Clarke <diana.joan.clarke@gmail.com>
Tue, 20 Nov 2012 01:34:27 +0000 (20:34 -0500)
committerDiana Clarke <diana.joan.clarke@gmail.com>
Tue, 20 Nov 2012 01:34:27 +0000 (20:34 -0500)
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/entities.py
lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/testing/fixtures.py

index 15b3471aab79bcc37134a45de1b61292ccd60722..e571a50458f5e8f5b12ded06f7ee96b8abfb6cd6 100644 (file)
@@ -10,12 +10,11 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\
 
 from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
         eq_, ne_, is_, is_not_, startswith_, assert_raises, \
-        assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
+        assert_raises_message, AssertsCompiledSQL, ComparesTables, \
+        AssertsExecutionResults
 
 from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
 
 crashes = skip
 
 from .config import db, requirements as requires
-
-
index e74d13a9774011c6053545921b6c281947bb265b..ebd10b130890f4a25836d89db6f9eb2db29156aa 100644 (file)
@@ -16,6 +16,7 @@ import itertools
 from .util import fail
 import contextlib
 
+
 def emits_warning(*messages):
     """Mark a test as emitting a warning.
 
@@ -50,6 +51,7 @@ def emits_warning(*messages):
             resetwarnings()
     return decorate
 
+
 def emits_warning_on(db, *warnings):
     """Mark a test as emitting a warning on a specific dialect.
 
@@ -115,7 +117,6 @@ def uses_deprecated(*messages):
     return decorate
 
 
-
 def global_cleanup_assertions():
     """Check things that have to be finalized at the end of a test suite.
 
@@ -129,28 +130,32 @@ def global_cleanup_assertions():
     assert not pool._refs, str(pool._refs)
 
 
-
 def eq_(a, b, msg=None):
     """Assert a == b, with repr messaging on failure."""
     assert a == b, msg or "%r != %r" % (a, b)
 
+
 def ne_(a, b, msg=None):
     """Assert a != b, with repr messaging on failure."""
     assert a != b, msg or "%r == %r" % (a, b)
 
+
 def is_(a, b, msg=None):
     """Assert a is b, with repr messaging on failure."""
     assert a is b, msg or "%r is not %r" % (a, b)
 
+
 def is_not_(a, b, msg=None):
     """Assert a is not b, with repr messaging on failure."""
     assert a is not b, msg or "%r is %r" % (a, b)
 
+
 def startswith_(a, fragment, msg=None):
     """Assert a.startswith(fragment), with repr messaging on failure."""
     assert a.startswith(fragment), msg or "%r does not start with %r" % (
         a, fragment)
 
+
 def assert_raises(except_cls, callable_, *args, **kw):
     try:
         callable_(*args, **kw)
@@ -161,6 +166,7 @@ def assert_raises(except_cls, callable_, *args, **kw):
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
 
+
 def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
     try:
         callable_(*args, **kwargs)
@@ -214,7 +220,9 @@ class AssertsCompiledSQL(object):
             p = c.construct_params(params)
             eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
 
+
 class ComparesTables(object):
+
     def assert_tables_equal(self, table, reflected_table, strict_types=False):
         assert len(table.c) == len(reflected_table.c)
         for c, reflected_c in zip(table.c, reflected_table.c):
@@ -224,15 +232,19 @@ class ComparesTables(object):
             eq_(c.nullable, reflected_c.nullable)
 
             if strict_types:
+                msg = "Type '%s' doesn't correspond to type '%s'"
                 assert type(reflected_c.type) is type(c.type), \
-                    "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+                    msg % (reflected_c.type, c.type)
             else:
                 self.assert_types_base(reflected_c, c)
 
             if isinstance(c.type, sqltypes.String):
                 eq_(c.type.length, reflected_c.type.length)
 
-            eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+            eq_(
+                set([f.column.name for f in c.foreign_keys]),
+                set([f.column.name for f in reflected_c.foreign_keys])
+            )
             if c.server_default:
                 assert isinstance(reflected_c.server_default,
                                   schema.FetchedValue)
@@ -246,6 +258,7 @@ class ComparesTables(object):
                 "On column %r, type '%s' doesn't correspond to type '%s'" % \
                 (c1.name, c1.type, c2.type)
 
+
 class AssertsExecutionResults(object):
     def assert_result(self, result, class_, *objects):
         result = list(result)
@@ -296,6 +309,7 @@ class AssertsExecutionResults(object):
                 len(found), len(expected)))
 
         NOVALUE = object()
+
         def _compare_item(obj, spec):
             for key, value in spec.iteritems():
                 if isinstance(value, tuple):
@@ -347,7 +361,8 @@ class AssertsExecutionResults(object):
         self.assert_sql_execution(db, callable_, *newrules)
 
     def assert_sql_count(self, db, callable_, count):
-        self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+        self.assert_sql_execution(
+            db, callable_, assertsql.CountStatements(count))
 
     @contextlib.contextmanager
     def assert_execution(self, *rules):
@@ -359,4 +374,4 @@ class AssertsExecutionResults(object):
             assertsql.asserter.clear_rules()
 
     def assert_statement_count(self, count):
-        return self.assert_execution(assertsql.CountStatements(count))
\ No newline at end of file
+        return self.assert_execution(assertsql.CountStatements(count))
index 08ee55d57194f6192d3183fb8bb3ecd6786f9d1a..d955d15546135f7fa61f8c7d5d294e683831a5c0 100644 (file)
@@ -3,6 +3,7 @@ from ..engine.default import DefaultDialect
 from .. import util
 import re
 
+
 class AssertRule(object):
 
     def process_execute(self, clauseelement, *multiparams, **params):
@@ -40,6 +41,7 @@ class AssertRule(object):
             assert False, 'Rule has not been consumed'
         return self.is_consumed()
 
+
 class SQLMatchRule(AssertRule):
     def __init__(self):
         self._result = None
@@ -56,6 +58,7 @@ class SQLMatchRule(AssertRule):
 
         return True
 
+
 class ExactSQL(SQLMatchRule):
 
     def __init__(self, sql, params=None):
@@ -138,6 +141,7 @@ class RegexSQL(SQLMatchRule):
                                     _received_statement,
                                     _received_parameters)
 
+
 class CompiledSQL(SQLMatchRule):
 
     def __init__(self, statement, params):
@@ -217,6 +221,7 @@ class CountStatements(AssertRule):
             % (self.count, self._statement_count)
         return True
 
+
 class AllOf(AssertRule):
 
     def __init__(self, *rules):
@@ -244,6 +249,7 @@ class AllOf(AssertRule):
     def consume_final(self):
         return len(self.rules) == 0
 
+
 def _process_engine_statement(query, context):
     if util.jython:
 
@@ -256,6 +262,7 @@ def _process_engine_statement(query, context):
     query = re.sub(r'\n', '', query)
     return query
 
+
 def _process_assertion_statement(query, context):
     paramstyle = context.dialect.paramstyle
     if paramstyle == 'named':
@@ -275,6 +282,7 @@ def _process_assertion_statement(query, context):
 
     return query
 
+
 class SQLAssert(object):
 
     rules = None
@@ -311,4 +319,3 @@ class SQLAssert(object):
                     executemany)
 
 asserter = SQLAssert()
-
index 2945bd456eaaf9b268fca649db77cfb932609daf..ae4f585e1904e161d465daba26b9dd724db5004a 100644 (file)
@@ -1,3 +1,2 @@
 requirements = None
 db = None
-
index 9d15c50785157328450bdca4336276c9e85d9cea..20bcf031785d60ae41beddd99ecb1024dbb089e0 100644 (file)
@@ -9,7 +9,9 @@ from .. import event, pool
 import re
 import warnings
 
+
 class ConnectionKiller(object):
+
     def __init__(self):
         self.proxy_refs = weakref.WeakKeyDictionary()
         self.testing_engines = weakref.WeakKeyDictionary()
@@ -83,12 +85,14 @@ class ConnectionKiller(object):
 
 testing_reaper = ConnectionKiller()
 
+
 def drop_all_tables(metadata, bind):
     testing_reaper.close_all()
     if hasattr(bind, 'close'):
         bind.close()
     metadata.drop_all(bind)
 
+
 @decorator
 def assert_conns_closed(fn, *args, **kw):
     try:
@@ -96,6 +100,7 @@ def assert_conns_closed(fn, *args, **kw):
     finally:
         testing_reaper.assert_all_closed()
 
+
 @decorator
 def rollback_open_connections(fn, *args, **kw):
     """Decorator that rolls back all open connections after fn execution."""
@@ -105,6 +110,7 @@ def rollback_open_connections(fn, *args, **kw):
     finally:
         testing_reaper.rollback_all()
 
+
 @decorator
 def close_first(fn, *args, **kw):
     """Decorator that closes all connections before fn execution."""
@@ -121,6 +127,7 @@ def close_open_connections(fn, *args, **kw):
     finally:
         testing_reaper.close_all()
 
+
 def all_dialects(exclude=None):
     import sqlalchemy.databases as d
     for name in d.__all__:
@@ -129,10 +136,13 @@ def all_dialects(exclude=None):
             continue
         mod = getattr(d, name, None)
         if not mod:
-            mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+            mod = getattr(__import__(
+                'sqlalchemy.databases.%s' % name).databases, name)
         yield mod.dialect()
 
+
 class ReconnectFixture(object):
+
     def __init__(self, dbapi):
         self.dbapi = dbapi
         self.connections = []
@@ -165,6 +175,7 @@ class ReconnectFixture(object):
             self._safe(c.close)
         self.connections = []
 
+
 def reconnecting_engine(url=None, options=None):
     url = url or config.db_url
     dbapi = config.db.dialect.dbapi
@@ -173,9 +184,11 @@ def reconnecting_engine(url=None, options=None):
     options['module'] = ReconnectFixture(dbapi)
     engine = testing_engine(url, options)
     _dispose = engine.dispose
+
     def dispose():
         engine.dialect.dbapi.shutdown()
         _dispose()
+
     engine.test_shutdown = engine.dialect.dbapi.shutdown
     engine.dispose = dispose
     return engine
@@ -209,6 +222,7 @@ def testing_engine(url=None, options=None):
 
     return engine
 
+
 def utf8_engine(url=None, options=None):
     """Hook for dialects or drivers that don't handle utf8 by default."""
 
@@ -226,6 +240,7 @@ def utf8_engine(url=None, options=None):
 
     return testing_engine(url, options)
 
+
 def mock_engine(dialect_name=None):
     """Provides a mocking engine based on the current testing.db.
 
@@ -244,17 +259,21 @@ def mock_engine(dialect_name=None):
         dialect_name = config.db.name
 
     buffer = []
+
     def executor(sql, *a, **kw):
         buffer.append(sql)
+
     def assert_sql(stmts):
         recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
         assert  recv == stmts, recv
+
     def print_sql():
         d = engine.dialect
         return "\n".join(
             str(s.compile(dialect=d))
             for s in engine.mock
         )
+
     engine = create_engine(dialect_name + '://',
                            strategy='mock', executor=executor)
     assert not hasattr(engine, 'mock')
@@ -263,6 +282,7 @@ def mock_engine(dialect_name=None):
     engine.print_sql = print_sql
     return engine
 
+
 class DBAPIProxyCursor(object):
     """Proxy a DBAPI cursor.
 
@@ -287,6 +307,7 @@ class DBAPIProxyCursor(object):
     def __getattr__(self, key):
         return getattr(self.cursor, key)
 
+
 class DBAPIProxyConnection(object):
     """Proxy a DBAPI connection.
 
@@ -308,14 +329,17 @@ class DBAPIProxyConnection(object):
     def __getattr__(self, key):
         return getattr(self.conn, key)
 
-def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
+
+def proxying_engine(conn_cls=DBAPIProxyConnection,
+                    cursor_cls=DBAPIProxyCursor):
     """Produce an engine that provides proxy hooks for
     common methods.
 
     """
     def mock_conn():
         return conn_cls(config.db, cursor_cls)
-    return testing_engine(options={'creator':mock_conn})
+    return testing_engine(options={'creator': mock_conn})
+
 
 class ReplayableSession(object):
     """A simple record/playback tool.
@@ -427,4 +451,3 @@ class ReplayableSession(object):
                 raise AttributeError(key)
             else:
                 return result
-
index 1b24e73b7cdcdc3e330600b2d62a87c2efa6d969..5c5e6915451b823bc433fc0fd6f5bde5444f12d3 100644 (file)
@@ -2,7 +2,10 @@ import sqlalchemy as sa
 from sqlalchemy import exc as sa_exc
 
 _repr_stack = set()
+
+
 class BasicEntity(object):
+
     def __init__(self, **kw):
         for key, value in kw.iteritems():
             setattr(self, key, value)
@@ -21,7 +24,10 @@ class BasicEntity(object):
             _repr_stack.remove(id(self))
 
 _recursion_stack = set()
+
+
 class ComparableEntity(BasicEntity):
+
     def __hash__(self):
         return hash(self.__class__)
 
index 3c70ec8d9a198acb5948db71ee8541212f8ae60b..f105c8b6a2267780e202ac9c04eafb6760f28505 100644 (file)
@@ -61,6 +61,7 @@ class skip_if(object):
         self._fails_on = skip_if(other, reason)
         return self
 
+
 class fails_if(skip_if):
     def __call__(self, fn):
         @decorator
@@ -69,14 +70,17 @@ class fails_if(skip_if):
                 return fn(*args, **kw)
         return decorate(fn)
 
+
 def only_if(predicate, reason=None):
     predicate = _as_predicate(predicate)
     return skip_if(NotPredicate(predicate), reason)
 
+
 def succeeds_if(predicate, reason=None):
     predicate = _as_predicate(predicate)
     return fails_if(NotPredicate(predicate), reason)
 
+
 class Predicate(object):
     @classmethod
     def as_predicate(cls, predicate):
@@ -93,6 +97,7 @@ class Predicate(object):
         else:
             assert False, "unknown predicate type: %s" % predicate
 
+
 class BooleanPredicate(Predicate):
     def __init__(self, value, description=None):
         self.value = value
@@ -110,6 +115,7 @@ class BooleanPredicate(Predicate):
     def __str__(self):
         return self._as_string()
 
+
 class SpecPredicate(Predicate):
     def __init__(self, db, op=None, spec=None, description=None):
         self.db = db
@@ -177,6 +183,7 @@ class SpecPredicate(Predicate):
     def __str__(self):
         return self._as_string()
 
+
 class LambdaPredicate(Predicate):
     def __init__(self, lambda_, description=None, args=None, kw=None):
         self.lambda_ = lambda_
@@ -201,6 +208,7 @@ class LambdaPredicate(Predicate):
     def __str__(self):
         return self._as_string()
 
+
 class NotPredicate(Predicate):
     def __init__(self, predicate):
         self.predicate = predicate
@@ -211,6 +219,7 @@ class NotPredicate(Predicate):
     def __str__(self):
         return self.predicate._as_string(True)
 
+
 class OrPredicate(Predicate):
     def __init__(self, predicates, description=None):
         self.predicates = predicates
@@ -256,9 +265,11 @@ class OrPredicate(Predicate):
 
 _as_predicate = Predicate.as_predicate
 
+
 def _is_excluded(db, op, spec):
     return SpecPredicate(db, op, spec)()
 
+
 def _server_version(engine):
     """Return a server_version_info tuple."""
 
@@ -268,24 +279,30 @@ def _server_version(engine):
     conn.close()
     return version
 
+
 def db_spec(*dbs):
     return OrPredicate(
             Predicate.as_predicate(db) for db in dbs
         )
 
+
 def open():
     return skip_if(BooleanPredicate(False, "mark as execute"))
 
+
 def closed():
     return skip_if(BooleanPredicate(True, "marked as skip"))
 
+
 @decorator
 def future(fn, *args, **kw):
     return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
 
+
 def fails_on(db, reason=None):
     return fails_if(SpecPredicate(db), reason)
 
+
 def fails_on_everything_except(*dbs):
     return succeeds_if(
                 OrPredicate([
@@ -293,9 +310,11 @@ def fails_on_everything_except(*dbs):
                     ])
             )
 
+
 def skip(db, reason=None):
     return skip_if(SpecPredicate(db), reason)
 
+
 def only_on(dbs, reason=None):
     return only_if(
             OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
index 1a1204898cc48f1724d7b3c3ba8c51711a2acbeb..5c587cb2f26f3157229f1678364c6539af9c6fdf 100644 (file)
@@ -7,6 +7,7 @@ import sys
 import sqlalchemy as sa
 from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
 
+
 class TestBase(object):
     # A sequence of database names to always run, regardless of the
     # constraints below.
@@ -29,6 +30,7 @@ class TestBase(object):
     def assert_(self, val, msg=None):
         assert val, msg
 
+
 class TablesTest(TestBase):
 
     # 'once', None
@@ -208,9 +210,11 @@ class _ORMTest(object):
         sa.orm.session.Session.close_all()
         sa.orm.clear_mappers()
 
+
 class ORMTest(_ORMTest, TestBase):
     pass
 
+
 class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
     # 'once', 'each', None
     run_setup_classes = 'once'
@@ -252,7 +256,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
         cls.classes.clear()
         _ORMTest.teardown_class()
 
-
     @classmethod
     def _setup_once_classes(cls):
         if cls.run_setup_classes == 'once':
@@ -275,18 +278,21 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
 
         """
         cls_registry = cls.classes
+
         class FindFixture(type):
             def __init__(cls, classname, bases, dict_):
                 cls_registry[classname] = cls
                 return type.__init__(cls, classname, bases, dict_)
 
-
         class _Base(object):
             __metaclass__ = FindFixture
+
         class Basic(BasicEntity, _Base):
             pass
+
         class Comparable(ComparableEntity, _Base):
             pass
+
         cls.Basic = Basic
         cls.Comparable = Comparable
         fn()
@@ -306,6 +312,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
     def setup_mappers(cls):
         pass
 
+
 class DeclarativeMappedTest(MappedTest):
     run_setup_classes = 'once'
     run_setup_mappers = 'once'
@@ -317,17 +324,21 @@ class DeclarativeMappedTest(MappedTest):
     @classmethod
     def _with_register_classes(cls, fn):
         cls_registry = cls.classes
+
         class FindFixtureDeclarative(DeclarativeMeta):
             def __init__(cls, classname, bases, dict_):
                 cls_registry[classname] = cls
                 return DeclarativeMeta.__init__(
                         cls, classname, bases, dict_)
+
         class DeclarativeBasic(object):
             __table_cls__ = schema.Table
+
         _DeclBase = declarative_base(metadata=cls.metadata,
                             metaclass=FindFixtureDeclarative,
                             cls=DeclarativeBasic)
         cls.DeclarativeBasic = _DeclBase
         fn()
+
         if cls.metadata.tables:
             cls.metadata.create_all(config.db)