From: Diana Clarke Date: Tue, 20 Nov 2012 01:34:27 +0000 (-0500) Subject: just a pep8 pass of lib/sqlalchemy/testing/ X-Git-Tag: rel_0_8_0b2~33^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cc42604ba453d9affa7e3feda985dcdb25cd8f57;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git just a pep8 pass of lib/sqlalchemy/testing/ --- diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 15b3471aab..e571a50458 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -10,12 +10,11 @@ from .exclusions import db_spec, _is_excluded, fails_if, skip_if, future,\ from .assertions import emits_warning, emits_warning_on, uses_deprecated, \ eq_, ne_, is_, is_not_, startswith_, assert_raises, \ - assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults + assert_raises_message, AssertsCompiledSQL, ComparesTables, \ + AssertsExecutionResults from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict crashes = skip from .config import db, requirements as requires - - diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index e74d13a977..ebd10b1308 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -16,6 +16,7 @@ import itertools from .util import fail import contextlib + def emits_warning(*messages): """Mark a test as emitting a warning. @@ -50,6 +51,7 @@ def emits_warning(*messages): resetwarnings() return decorate + def emits_warning_on(db, *warnings): """Mark a test as emitting a warning on a specific dialect. @@ -115,7 +117,6 @@ def uses_deprecated(*messages): return decorate - def global_cleanup_assertions(): """Check things that have to be finalized at the end of a test suite. @@ -129,28 +130,32 @@ def global_cleanup_assertions(): assert not pool._refs, str(pool._refs) - def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) + def ne_(a, b, msg=None): """Assert a != b, with repr messaging on failure.""" assert a != b, msg or "%r == %r" % (a, b) + def is_(a, b, msg=None): """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) + def is_not_(a, b, msg=None): """Assert a is not b, with repr messaging on failure.""" assert a is not b, msg or "%r is %r" % (a, b) + def startswith_(a, fragment, msg=None): """Assert a.startswith(fragment), with repr messaging on failure.""" assert a.startswith(fragment), msg or "%r does not start with %r" % ( a, fragment) + def assert_raises(except_cls, callable_, *args, **kw): try: callable_(*args, **kw) @@ -161,6 +166,7 @@ def assert_raises(except_cls, callable_, *args, **kw): # assert outside the block so it works for AssertionError too ! assert success, "Callable did not raise an exception" + def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): try: callable_(*args, **kwargs) @@ -214,7 +220,9 @@ class AssertsCompiledSQL(object): p = c.construct_params(params) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): @@ -224,15 +232,19 @@ class ComparesTables(object): eq_(c.nullable, reflected_c.nullable) if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" assert type(reflected_c.type) is type(c.type), \ - "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type) + msg % (reflected_c.type, c.type) else: self.assert_types_base(reflected_c, c) if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) - eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys])) + eq_( + set([f.column.name for f in c.foreign_keys]), + set([f.column.name for f in reflected_c.foreign_keys]) + ) if c.server_default: assert isinstance(reflected_c.server_default, schema.FetchedValue) @@ -246,6 +258,7 @@ class ComparesTables(object): "On column %r, type '%s' doesn't correspond to type '%s'" % \ (c1.name, c1.type, c2.type) + class AssertsExecutionResults(object): def assert_result(self, result, class_, *objects): result = list(result) @@ -296,6 +309,7 @@ class AssertsExecutionResults(object): len(found), len(expected))) NOVALUE = object() + def _compare_item(obj, spec): for key, value in spec.iteritems(): if isinstance(value, tuple): @@ -347,7 +361,8 @@ class AssertsExecutionResults(object): self.assert_sql_execution(db, callable_, *newrules) def assert_sql_count(self, db, callable_, count): - self.assert_sql_execution(db, callable_, assertsql.CountStatements(count)) + self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count)) @contextlib.contextmanager def assert_execution(self, *rules): @@ -359,4 +374,4 @@ class AssertsExecutionResults(object): assertsql.asserter.clear_rules() def assert_statement_count(self, count): - return self.assert_execution(assertsql.CountStatements(count)) \ No newline at end of file + return self.assert_execution(assertsql.CountStatements(count)) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 08ee55d571..d955d15546 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -3,6 +3,7 @@ from ..engine.default import DefaultDialect from .. import util import re + class AssertRule(object): def process_execute(self, clauseelement, *multiparams, **params): @@ -40,6 +41,7 @@ class AssertRule(object): assert False, 'Rule has not been consumed' return self.is_consumed() + class SQLMatchRule(AssertRule): def __init__(self): self._result = None @@ -56,6 +58,7 @@ class SQLMatchRule(AssertRule): return True + class ExactSQL(SQLMatchRule): def __init__(self, sql, params=None): @@ -138,6 +141,7 @@ class RegexSQL(SQLMatchRule): _received_statement, _received_parameters) + class CompiledSQL(SQLMatchRule): def __init__(self, statement, params): @@ -217,6 +221,7 @@ class CountStatements(AssertRule): % (self.count, self._statement_count) return True + class AllOf(AssertRule): def __init__(self, *rules): @@ -244,6 +249,7 @@ class AllOf(AssertRule): def consume_final(self): return len(self.rules) == 0 + def _process_engine_statement(query, context): if util.jython: @@ -256,6 +262,7 @@ def _process_engine_statement(query, context): query = re.sub(r'\n', '', query) return query + def _process_assertion_statement(query, context): paramstyle = context.dialect.paramstyle if paramstyle == 'named': @@ -275,6 +282,7 @@ def _process_assertion_statement(query, context): return query + class SQLAssert(object): rules = None @@ -311,4 +319,3 @@ class SQLAssert(object): executemany) asserter = SQLAssert() - diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 2945bd456e..ae4f585e19 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -1,3 +1,2 @@ requirements = None db = None - diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 9d15c50785..20bcf03178 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -9,7 +9,9 @@ from .. import event, pool import re import warnings + class ConnectionKiller(object): + def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() self.testing_engines = weakref.WeakKeyDictionary() @@ -83,12 +85,14 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() + def drop_all_tables(metadata, bind): testing_reaper.close_all() if hasattr(bind, 'close'): bind.close() metadata.drop_all(bind) + @decorator def assert_conns_closed(fn, *args, **kw): try: @@ -96,6 +100,7 @@ def assert_conns_closed(fn, *args, **kw): finally: testing_reaper.assert_all_closed() + @decorator def rollback_open_connections(fn, *args, **kw): """Decorator that rolls back all open connections after fn execution.""" @@ -105,6 +110,7 @@ def rollback_open_connections(fn, *args, **kw): finally: testing_reaper.rollback_all() + @decorator def close_first(fn, *args, **kw): """Decorator that closes all connections before fn execution.""" @@ -121,6 +127,7 @@ def close_open_connections(fn, *args, **kw): finally: testing_reaper.close_all() + def all_dialects(exclude=None): import sqlalchemy.databases as d for name in d.__all__: @@ -129,10 +136,13 @@ def all_dialects(exclude=None): continue mod = getattr(d, name, None) if not mod: - mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name) + mod = getattr(__import__( + 'sqlalchemy.databases.%s' % name).databases, name) yield mod.dialect() + class ReconnectFixture(object): + def __init__(self, dbapi): self.dbapi = dbapi self.connections = [] @@ -165,6 +175,7 @@ class ReconnectFixture(object): self._safe(c.close) self.connections = [] + def reconnecting_engine(url=None, options=None): url = url or config.db_url dbapi = config.db.dialect.dbapi @@ -173,9 +184,11 @@ def reconnecting_engine(url=None, options=None): options['module'] = ReconnectFixture(dbapi) engine = testing_engine(url, options) _dispose = engine.dispose + def dispose(): engine.dialect.dbapi.shutdown() _dispose() + engine.test_shutdown = engine.dialect.dbapi.shutdown engine.dispose = dispose return engine @@ -209,6 +222,7 @@ def testing_engine(url=None, options=None): return engine + def utf8_engine(url=None, options=None): """Hook for dialects or drivers that don't handle utf8 by default.""" @@ -226,6 +240,7 @@ def utf8_engine(url=None, options=None): return testing_engine(url, options) + def mock_engine(dialect_name=None): """Provides a mocking engine based on the current testing.db. @@ -244,17 +259,21 @@ def mock_engine(dialect_name=None): dialect_name = config.db.name buffer = [] + def executor(sql, *a, **kw): buffer.append(sql) + def assert_sql(stmts): recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer] assert recv == stmts, recv + def print_sql(): d = engine.dialect return "\n".join( str(s.compile(dialect=d)) for s in engine.mock ) + engine = create_engine(dialect_name + '://', strategy='mock', executor=executor) assert not hasattr(engine, 'mock') @@ -263,6 +282,7 @@ def mock_engine(dialect_name=None): engine.print_sql = print_sql return engine + class DBAPIProxyCursor(object): """Proxy a DBAPI cursor. @@ -287,6 +307,7 @@ class DBAPIProxyCursor(object): def __getattr__(self, key): return getattr(self.cursor, key) + class DBAPIProxyConnection(object): """Proxy a DBAPI connection. @@ -308,14 +329,17 @@ class DBAPIProxyConnection(object): def __getattr__(self, key): return getattr(self.conn, key) -def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor): + +def proxying_engine(conn_cls=DBAPIProxyConnection, + cursor_cls=DBAPIProxyCursor): """Produce an engine that provides proxy hooks for common methods. """ def mock_conn(): return conn_cls(config.db, cursor_cls) - return testing_engine(options={'creator':mock_conn}) + return testing_engine(options={'creator': mock_conn}) + class ReplayableSession(object): """A simple record/playback tool. @@ -427,4 +451,3 @@ class ReplayableSession(object): raise AttributeError(key) else: return result - diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py index 1b24e73b7c..5c5e691545 100644 --- a/lib/sqlalchemy/testing/entities.py +++ b/lib/sqlalchemy/testing/entities.py @@ -2,7 +2,10 @@ import sqlalchemy as sa from sqlalchemy import exc as sa_exc _repr_stack = set() + + class BasicEntity(object): + def __init__(self, **kw): for key, value in kw.iteritems(): setattr(self, key, value) @@ -21,7 +24,10 @@ class BasicEntity(object): _repr_stack.remove(id(self)) _recursion_stack = set() + + class ComparableEntity(BasicEntity): + def __hash__(self): return hash(self.__class__) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 3c70ec8d9a..f105c8b6a2 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -61,6 +61,7 @@ class skip_if(object): self._fails_on = skip_if(other, reason) return self + class fails_if(skip_if): def __call__(self, fn): @decorator @@ -69,14 +70,17 @@ class fails_if(skip_if): return fn(*args, **kw) return decorate(fn) + def only_if(predicate, reason=None): predicate = _as_predicate(predicate) return skip_if(NotPredicate(predicate), reason) + def succeeds_if(predicate, reason=None): predicate = _as_predicate(predicate) return fails_if(NotPredicate(predicate), reason) + class Predicate(object): @classmethod def as_predicate(cls, predicate): @@ -93,6 +97,7 @@ class Predicate(object): else: assert False, "unknown predicate type: %s" % predicate + class BooleanPredicate(Predicate): def __init__(self, value, description=None): self.value = value @@ -110,6 +115,7 @@ class BooleanPredicate(Predicate): def __str__(self): return self._as_string() + class SpecPredicate(Predicate): def __init__(self, db, op=None, spec=None, description=None): self.db = db @@ -177,6 +183,7 @@ class SpecPredicate(Predicate): def __str__(self): return self._as_string() + class LambdaPredicate(Predicate): def __init__(self, lambda_, description=None, args=None, kw=None): self.lambda_ = lambda_ @@ -201,6 +208,7 @@ class LambdaPredicate(Predicate): def __str__(self): return self._as_string() + class NotPredicate(Predicate): def __init__(self, predicate): self.predicate = predicate @@ -211,6 +219,7 @@ class NotPredicate(Predicate): def __str__(self): return self.predicate._as_string(True) + class OrPredicate(Predicate): def __init__(self, predicates, description=None): self.predicates = predicates @@ -256,9 +265,11 @@ class OrPredicate(Predicate): _as_predicate = Predicate.as_predicate + def _is_excluded(db, op, spec): return SpecPredicate(db, op, spec)() + def _server_version(engine): """Return a server_version_info tuple.""" @@ -268,24 +279,30 @@ def _server_version(engine): conn.close() return version + def db_spec(*dbs): return OrPredicate( Predicate.as_predicate(db) for db in dbs ) + def open(): return skip_if(BooleanPredicate(False, "mark as execute")) + def closed(): return skip_if(BooleanPredicate(True, "marked as skip")) + @decorator def future(fn, *args, **kw): return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature") + def fails_on(db, reason=None): return fails_if(SpecPredicate(db), reason) + def fails_on_everything_except(*dbs): return succeeds_if( OrPredicate([ @@ -293,9 +310,11 @@ def fails_on_everything_except(*dbs): ]) ) + def skip(db, reason=None): return skip_if(SpecPredicate(db), reason) + def only_on(dbs, reason=None): return only_if( OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)]) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 1a1204898c..5c587cb2f2 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -7,6 +7,7 @@ import sys import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta + class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -29,6 +30,7 @@ class TestBase(object): def assert_(self, val, msg=None): assert val, msg + class TablesTest(TestBase): # 'once', None @@ -208,9 +210,11 @@ class _ORMTest(object): sa.orm.session.Session.close_all() sa.orm.clear_mappers() + class ORMTest(_ORMTest, TestBase): pass + class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None run_setup_classes = 'once' @@ -252,7 +256,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): cls.classes.clear() _ORMTest.teardown_class() - @classmethod def _setup_once_classes(cls): if cls.run_setup_classes == 'once': @@ -275,18 +278,21 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes + class FindFixture(type): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls return type.__init__(cls, classname, bases, dict_) - class _Base(object): __metaclass__ = FindFixture + class Basic(BasicEntity, _Base): pass + class Comparable(ComparableEntity, _Base): pass + cls.Basic = Basic cls.Comparable = Comparable fn() @@ -306,6 +312,7 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): def setup_mappers(cls): pass + class DeclarativeMappedTest(MappedTest): run_setup_classes = 'once' run_setup_mappers = 'once' @@ -317,17 +324,21 @@ class DeclarativeMappedTest(MappedTest): @classmethod def _with_register_classes(cls, fn): cls_registry = cls.classes + class FindFixtureDeclarative(DeclarativeMeta): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls return DeclarativeMeta.__init__( cls, classname, bases, dict_) + class DeclarativeBasic(object): __table_cls__ = schema.Table + _DeclBase = declarative_base(metadata=cls.metadata, metaclass=FindFixtureDeclarative, cls=DeclarativeBasic) cls.DeclarativeBasic = _DeclBase fn() + if cls.metadata.tables: cls.metadata.create_all(config.db)