]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added command line options to add tripwires for __hash__, __eq__ and __nonzero__...
authorJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 23:49:09 +0000 (23:49 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 23:49:09 +0000 (23:49 +0000)
test/testlib/__init__.py
test/testlib/config.py
test/testlib/orm.py [new file with mode: 0644]

index d30e4fc28738c9e0f65988fdd117271fc689fb09..29b258c9f810f407336c830b2687887ca94148f7 100644 (file)
@@ -5,6 +5,7 @@ Load after sqlalchemy imports to use instrumented stand-ins like Table.
 
 import testlib.config
 from testlib.schema import Table, Column
+from testlib.orm import mapper
 import testlib.testing as testing
 from testlib.testing import PersistTest, AssertMixin, ORMTest, SQLCompileTest
 import testlib.profiling as profiling
@@ -12,6 +13,7 @@ import testlib.engines as engines
 
 
 __all__ = ('testing',
+           'mapper',
            'Table', 'Column',
            'PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest',
            'profiling', 'engines')
index a4306e9b4b1abf29914f5f4f0ad51049b278d380..36db872eb43e042eaf2564c8dd90ee5358a0c4c0 100644 (file)
@@ -74,7 +74,7 @@ def _start_coverage(option, opt_str, value, parser):
     atexit.register(_stop)
     coverage.erase()
     coverage.start()
-    
+
 def _list_dbs(*args):
     print "Available --db options (use --dburi to override)"
     for macro in sorted(file_config.options('db')):
@@ -114,6 +114,12 @@ opt("--enginestrategy", action="callback", type="string",
 opt("--reversetop", action="store_true", dest="reversetop", default=False,
     help="Reverse the collection ordering for topological sorts (helps "
           "reveal dependency issues)")
+opt("--unhashable", action="store_true", dest="unhashable", default=False,
+    help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
+opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
+    help="Disallow SQLAlchemy from performing == on mapped test objects.")
+opt("--truthless", action="store_true", dest="truthless", default=False,
+    help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
 opt("--serverside", action="callback", callback=_server_side_cursors,
     help="Turn on server side cursors for PG")
 opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
diff --git a/test/testlib/orm.py b/test/testlib/orm.py
new file mode 100644 (file)
index 0000000..3a10b08
--- /dev/null
@@ -0,0 +1,49 @@
+import testbase
+from testlib import config
+import inspect
+orm = None
+
+__all__ = 'mapper',
+
+
+def _make_blocker(method_name, fallback):
+    def method(self, *args, **kw):
+        frame_r = None
+        try:
+            frame_r = inspect.stack()[1]
+            module = frame_r[0].f_globals.get('__name__', '')
+
+            type_ = type(self)
+
+            if not module.startswith('sqlalchemy'):
+                supermeth = getattr(super(type_, self), method_name, None)
+                if supermeth is None or supermeth.im_func is method:
+                    return fallback(self, *args, **kw)
+                else:
+                    return supermeth(*args, **kw)
+            else:
+                raise AssertionError(
+                    "%s.%s called in %s, line %s in %s" % (
+                    type_.__name__, method_name, module, frame_r[2], frame_r[3]))
+        finally:
+            del frame_r
+    method.__name__ = method_name
+    return method
+
+def mapper(type_, *args, **kw):
+    global orm
+    if orm is None:
+        from sqlalchemy import orm
+
+    forbidden = [
+        ('__hash__', 'unhashable', None),
+        ('__eq__', 'noncomparable', lambda s, x, y: x is y),
+        ('__nonzero__', 'truthless', lambda s: 1), ]
+
+    if type_.__bases__ == (object,):
+        for method_name, option, fallback in forbidden:
+            if (getattr(config.options, option, False) and
+                method_name not in type_.__dict__):
+                setattr(type_, method_name, _make_blocker(method_name, fallback))
+
+    return orm.mapper(type_, *args, **kw)