]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added AssertMixin.assert_unordered_result
authorJason Kirtland <jek@discorporate.us>
Thu, 1 Nov 2007 20:59:30 +0000 (20:59 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 1 Nov 2007 20:59:30 +0000 (20:59 +0000)
test/testlib/testing.py

index 43bbb92ff2a5a844c79c62982c05068687221260..60b0173d8f81844c952865ca0892e24c7f3b1f27 100644 (file)
@@ -3,10 +3,10 @@
 # monkeypatches unittest.TestLoader.suiteClass at import time
 
 import testbase
-import unittest, re, sys, os, operator
+import itertools, unittest, re, sys, os, operator
 from cStringIO import StringIO
 import testlib.config as config
-sql, MetaData, clear_mappers, Session = None, None, None, None
+sql, MetaData, clear_mappers, Session, util = None, None, None, None, None
 
 
 __all__ = ('PersistTest', 'AssertMixin', 'ORMTest', 'SQLCompileTest')
@@ -293,7 +293,52 @@ class AssertMixin(PersistTest):
                 self.assert_(getattr(rowobj, key) == value,
                              "attribute %s value %s does not match %s" % (
                              key, getattr(rowobj, key), value))
-                
+
+    def assert_unordered_result(self, result, cls, *expected):
+        """As assert_result, but the order of objects is not considered.
+
+        The algorithm is very expensive but not a big deal for the small
+        numbers of rows that the test suite manipulates.
+        """
+
+        global util
+        if util is None:
+            from sqlalchemy import util
+
+        class frozendict(dict):
+            def __hash__(self):
+                return id(self)
+
+        found = util.IdentitySet(result)
+        expected = set([frozendict(e) for e in expected])
+
+        for wrong in itertools.ifilterfalse(lambda o: type(o) == cls, found):
+            self.fail('Unexpected type "%s", expected "%s"' % (
+                type(wrong).__name__, cls.__name__))
+
+        NOVALUE = object()
+        def _compare_item(obj, spec):
+            for key, value in spec.iteritems():
+                if isinstance(value, tuple):
+                    if (not self.assert_unordered_result(
+                          getattr(obj, key), value[0], *value[1])):
+                        return False
+                else:
+                    if getattr(obj, key, NOVALUE) != value:
+                        return False
+            return True
+
+        for expected_item in expected:
+            for found_item in found:
+                if _compare_item(found_item, expected_item):
+                    found.remove(found_item)
+                    break
+            else:
+                self.fail(
+                    "Expected %s instance with attributes %s not found." % (
+                    cls.__name__, repr(expected_item)))
+        return True
+
     def assert_sql(self, db, callable_, list, with_sequences=None):
         global testdata
         testdata = TestData()