import testlib.config
from testlib.schema import Table, Column
+from testlib.orm import mapper
import testlib.testing as testing
from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
import testlib.profiling as profiling
__all__ = ('testing',
+ 'mapper',
'Table', 'Column',
'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
'profiling', 'engines')
atexit.register(_stop)
coverage.erase()
coverage.start()
-
+
def _list_dbs(*args):
print "Available --db options (use --dburi to override)"
for macro in sorted(file_config.options('db')):
opt("--reversetop", action="store_true", dest="reversetop", default=False,
help="Reverse the collection ordering for topological sorts (helps "
"reveal dependency issues)")
+opt("--unhashable", action="store_true", dest="unhashable", default=False,
+ help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+ help="Disallow SQLAlchemy from performing == on mapped test objects.")
+opt("--truthless", action="store_true", dest="truthless", default=False,
+ help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
opt("--serverside", action="callback", callback=_server_side_cursors,
help="Turn on server side cursors for PG")
opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
--- /dev/null
+import testbase
+from testlib import config
+import inspect
+orm = None
+
+__all__ = 'mapper',
+
+
+def _make_blocker(method_name, fallback):
+ def method(self, *args, **kw):
+ frame_r = None
+ try:
+ frame_r = inspect.stack()[1]
+ module = frame_r[0].f_globals.get('__name__', '')
+
+ type_ = type(self)
+
+ if not module.startswith('sqlalchemy'):
+ supermeth = getattr(super(type_, self), method_name, None)
+ if supermeth is None or supermeth.im_func is method:
+ return fallback(self, *args, **kw)
+ else:
+ return supermeth(*args, **kw)
+ else:
+ raise AssertionError(
+ "%s.%s called in %s, line %s in %s" % (
+ type_.__name__, method_name, module, frame_r[2], frame_r[3]))
+ finally:
+ del frame_r
+ method.__name__ = method_name
+ return method
+
+def mapper(type_, *args, **kw):
+ global orm
+ if orm is None:
+ from sqlalchemy import orm
+
+ forbidden = [
+ ('__hash__', 'unhashable', None),
+ ('__eq__', 'noncomparable', lambda s, x, y: x is y),
+ ('__nonzero__', 'truthless', lambda s: 1), ]
+
+ if type_.__bases__ == (object,):
+ for method_name, option, fallback in forbidden:
+ if (getattr(config.options, option, False) and
+ method_name not in type_.__dict__):
+ setattr(type_, method_name, _make_blocker(method_name, fallback))
+
+ return orm.mapper(type_, *args, **kw)