from .assertions import emits_warning, emits_warning_on, uses_deprecated, \
eq_, ne_, is_, is_not_, startswith_, assert_raises, \
- assert_raises_message, AssertsCompiledSQL, ComparesTables, AssertsExecutionResults
+ assert_raises_message, AssertsCompiledSQL, ComparesTables, \
+ AssertsExecutionResults
from .util import run_as_contextmanager, rowset, fail, provide_metadata, adict
crashes = skip
from .config import db, requirements as requires
-
-
from .util import fail
import contextlib
+
def emits_warning(*messages):
"""Mark a test as emitting a warning.
resetwarnings()
return decorate
+
def emits_warning_on(db, *warnings):
"""Mark a test as emitting a warning on a specific dialect.
return decorate
-
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
assert not pool._refs, str(pool._refs)
-
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
+
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
+
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
+
def is_not_(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
+
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a, fragment)
+
def assert_raises(except_cls, callable_, *args, **kw):
try:
callable_(*args, **kw)
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
+
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+
class ComparesTables(object):
+
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.nullable, reflected_c.nullable)
if strict_types:
+ msg = "Type '%s' doesn't correspond to type '%s'"
assert type(reflected_c.type) is type(c.type), \
- "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
+ msg % (reflected_c.type, c.type)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
- eq_(set([f.column.name for f in c.foreign_keys]), set([f.column.name for f in reflected_c.foreign_keys]))
+ eq_(
+ set([f.column.name for f in c.foreign_keys]),
+ set([f.column.name for f in reflected_c.foreign_keys])
+ )
if c.server_default:
assert isinstance(reflected_c.server_default,
schema.FetchedValue)
"On column %r, type '%s' doesn't correspond to type '%s'" % \
(c1.name, c1.type, c2.type)
+
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
len(found), len(expected)))
NOVALUE = object()
+
def _compare_item(obj, spec):
for key, value in spec.iteritems():
if isinstance(value, tuple):
self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
- self.assert_sql_execution(db, callable_, assertsql.CountStatements(count))
+ self.assert_sql_execution(
+ db, callable_, assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_execution(self, *rules):
assertsql.asserter.clear_rules()
def assert_statement_count(self, count):
- return self.assert_execution(assertsql.CountStatements(count))
\ No newline at end of file
+ return self.assert_execution(assertsql.CountStatements(count))
from .. import util
import re
+
class AssertRule(object):
def process_execute(self, clauseelement, *multiparams, **params):
assert False, 'Rule has not been consumed'
return self.is_consumed()
+
class SQLMatchRule(AssertRule):
def __init__(self):
self._result = None
return True
+
class ExactSQL(SQLMatchRule):
def __init__(self, sql, params=None):
_received_statement,
_received_parameters)
+
class CompiledSQL(SQLMatchRule):
def __init__(self, statement, params):
% (self.count, self._statement_count)
return True
+
class AllOf(AssertRule):
def __init__(self, *rules):
def consume_final(self):
return len(self.rules) == 0
+
def _process_engine_statement(query, context):
if util.jython:
query = re.sub(r'\n', '', query)
return query
+
def _process_assertion_statement(query, context):
paramstyle = context.dialect.paramstyle
if paramstyle == 'named':
return query
+
class SQLAssert(object):
rules = None
executemany)
asserter = SQLAssert()
-
requirements = None
db = None
-
import re
import warnings
+
class ConnectionKiller(object):
+
def __init__(self):
self.proxy_refs = weakref.WeakKeyDictionary()
self.testing_engines = weakref.WeakKeyDictionary()
testing_reaper = ConnectionKiller()
+
def drop_all_tables(metadata, bind):
testing_reaper.close_all()
if hasattr(bind, 'close'):
bind.close()
metadata.drop_all(bind)
+
@decorator
def assert_conns_closed(fn, *args, **kw):
try:
finally:
testing_reaper.assert_all_closed()
+
@decorator
def rollback_open_connections(fn, *args, **kw):
"""Decorator that rolls back all open connections after fn execution."""
finally:
testing_reaper.rollback_all()
+
@decorator
def close_first(fn, *args, **kw):
"""Decorator that closes all connections before fn execution."""
finally:
testing_reaper.close_all()
+
def all_dialects(exclude=None):
import sqlalchemy.databases as d
for name in d.__all__:
continue
mod = getattr(d, name, None)
if not mod:
- mod = getattr(__import__('sqlalchemy.databases.%s' % name).databases, name)
+ mod = getattr(__import__(
+ 'sqlalchemy.databases.%s' % name).databases, name)
yield mod.dialect()
+
class ReconnectFixture(object):
+
def __init__(self, dbapi):
self.dbapi = dbapi
self.connections = []
self._safe(c.close)
self.connections = []
+
def reconnecting_engine(url=None, options=None):
url = url or config.db_url
dbapi = config.db.dialect.dbapi
options['module'] = ReconnectFixture(dbapi)
engine = testing_engine(url, options)
_dispose = engine.dispose
+
def dispose():
engine.dialect.dbapi.shutdown()
_dispose()
+
engine.test_shutdown = engine.dialect.dbapi.shutdown
engine.dispose = dispose
return engine
return engine
+
def utf8_engine(url=None, options=None):
"""Hook for dialects or drivers that don't handle utf8 by default."""
return testing_engine(url, options)
+
def mock_engine(dialect_name=None):
"""Provides a mocking engine based on the current testing.db.
dialect_name = config.db.name
buffer = []
+
def executor(sql, *a, **kw):
buffer.append(sql)
+
def assert_sql(stmts):
recv = [re.sub(r'[\n\t]', '', str(s)) for s in buffer]
assert recv == stmts, recv
+
def print_sql():
d = engine.dialect
return "\n".join(
str(s.compile(dialect=d))
for s in engine.mock
)
+
engine = create_engine(dialect_name + '://',
strategy='mock', executor=executor)
assert not hasattr(engine, 'mock')
engine.print_sql = print_sql
return engine
+
class DBAPIProxyCursor(object):
"""Proxy a DBAPI cursor.
def __getattr__(self, key):
return getattr(self.cursor, key)
+
class DBAPIProxyConnection(object):
"""Proxy a DBAPI connection.
def __getattr__(self, key):
return getattr(self.conn, key)
-def proxying_engine(conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor):
+
+def proxying_engine(conn_cls=DBAPIProxyConnection,
+ cursor_cls=DBAPIProxyCursor):
"""Produce an engine that provides proxy hooks for
common methods.
"""
def mock_conn():
return conn_cls(config.db, cursor_cls)
- return testing_engine(options={'creator':mock_conn})
+ return testing_engine(options={'creator': mock_conn})
+
class ReplayableSession(object):
"""A simple record/playback tool.
raise AttributeError(key)
else:
return result
-
from sqlalchemy import exc as sa_exc
_repr_stack = set()
+
+
class BasicEntity(object):
+
def __init__(self, **kw):
for key, value in kw.iteritems():
setattr(self, key, value)
_repr_stack.remove(id(self))
_recursion_stack = set()
+
+
class ComparableEntity(BasicEntity):
+
def __hash__(self):
return hash(self.__class__)
self._fails_on = skip_if(other, reason)
return self
+
class fails_if(skip_if):
def __call__(self, fn):
@decorator
return fn(*args, **kw)
return decorate(fn)
+
def only_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return skip_if(NotPredicate(predicate), reason)
+
def succeeds_if(predicate, reason=None):
predicate = _as_predicate(predicate)
return fails_if(NotPredicate(predicate), reason)
+
class Predicate(object):
@classmethod
def as_predicate(cls, predicate):
else:
assert False, "unknown predicate type: %s" % predicate
+
class BooleanPredicate(Predicate):
def __init__(self, value, description=None):
self.value = value
def __str__(self):
return self._as_string()
+
class SpecPredicate(Predicate):
def __init__(self, db, op=None, spec=None, description=None):
self.db = db
def __str__(self):
return self._as_string()
+
class LambdaPredicate(Predicate):
def __init__(self, lambda_, description=None, args=None, kw=None):
self.lambda_ = lambda_
def __str__(self):
return self._as_string()
+
class NotPredicate(Predicate):
def __init__(self, predicate):
self.predicate = predicate
def __str__(self):
return self.predicate._as_string(True)
+
class OrPredicate(Predicate):
def __init__(self, predicates, description=None):
self.predicates = predicates
_as_predicate = Predicate.as_predicate
+
def _is_excluded(db, op, spec):
return SpecPredicate(db, op, spec)()
+
def _server_version(engine):
"""Return a server_version_info tuple."""
conn.close()
return version
+
def db_spec(*dbs):
return OrPredicate(
Predicate.as_predicate(db) for db in dbs
)
+
def open():
return skip_if(BooleanPredicate(False, "mark as execute"))
+
def closed():
return skip_if(BooleanPredicate(True, "marked as skip"))
+
@decorator
def future(fn, *args, **kw):
return fails_if(LambdaPredicate(fn, *args, **kw), "Future feature")
+
def fails_on(db, reason=None):
return fails_if(SpecPredicate(db), reason)
+
def fails_on_everything_except(*dbs):
return succeeds_if(
OrPredicate([
])
)
+
def skip(db, reason=None):
return skip_if(SpecPredicate(db), reason)
+
def only_on(dbs, reason=None):
return only_if(
OrPredicate([SpecPredicate(db) for db in util.to_list(dbs)])
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
+
class TestBase(object):
# A sequence of database names to always run, regardless of the
# constraints below.
def assert_(self, val, msg=None):
assert val, msg
+
class TablesTest(TestBase):
# 'once', None
sa.orm.session.Session.close_all()
sa.orm.clear_mappers()
+
class ORMTest(_ORMTest, TestBase):
pass
+
class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
# 'once', 'each', None
run_setup_classes = 'once'
cls.classes.clear()
_ORMTest.teardown_class()
-
@classmethod
def _setup_once_classes(cls):
if cls.run_setup_classes == 'once':
"""
cls_registry = cls.classes
+
class FindFixture(type):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return type.__init__(cls, classname, bases, dict_)
-
class _Base(object):
__metaclass__ = FindFixture
+
class Basic(BasicEntity, _Base):
pass
+
class Comparable(ComparableEntity, _Base):
pass
+
cls.Basic = Basic
cls.Comparable = Comparable
fn()
def setup_mappers(cls):
pass
+
class DeclarativeMappedTest(MappedTest):
run_setup_classes = 'once'
run_setup_mappers = 'once'
@classmethod
def _with_register_classes(cls, fn):
cls_registry = cls.classes
+
class FindFixtureDeclarative(DeclarativeMeta):
def __init__(cls, classname, bases, dict_):
cls_registry[classname] = cls
return DeclarativeMeta.__init__(
cls, classname, bases, dict_)
+
class DeclarativeBasic(object):
__table_cls__ = schema.Table
+
_DeclBase = declarative_base(metadata=cls.metadata,
metaclass=FindFixtureDeclarative,
cls=DeclarativeBasic)
cls.DeclarativeBasic = _DeclBase
fn()
+
if cls.metadata.tables:
cls.metadata.create_all(config.db)