From: Jason Kirtland Date: Thu, 24 Jan 2008 01:00:41 +0000 (+0000) Subject: - IdentitySet binops no longer accept plain sets. X-Git-Tag: rel_0_4_3~73 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=e94c3ba27a5e07d1c77fed36bf6fcd3c44848118;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - IdentitySet binops no longer accept plain sets. --- diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 1f41c1179e..c0d0c7eed7 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -678,12 +678,12 @@ class IdentitySet(object): return True def __le__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.issubset(other) def __lt__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return len(self) < len(other) and self.issubset(other) @@ -699,12 +699,12 @@ class IdentitySet(object): return True def __ge__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.issuperset(other) def __gt__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return len(self) > len(other) and self.issuperset(other) @@ -716,16 +716,15 @@ class IdentitySet(object): return result def __or__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.union(other) - __ror__ = __or__ def update(self, iterable): self._members = self.union(iterable)._members def __ior__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented self.update(other) return self @@ -738,16 +737,15 @@ class IdentitySet(object): return result def __sub__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.difference(other) - __rsub__ = __sub__ def difference_update(self, iterable): self._members = self.difference(iterable)._members def __isub__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented self.difference_update(other) return self @@ -760,16 +758,15 @@ class IdentitySet(object): return result def __and__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.intersection(other) - __rand__ = __and__ def intersection_update(self, iterable): self._members = self.intersection(iterable)._members def __iand__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented self.intersection_update(other) return self @@ -782,16 +779,15 @@ class IdentitySet(object): return result def __xor__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented return self.symmetric_difference(other) - __rxor__ = __xor__ def symmetric_difference_update(self, iterable): self._members = self.symmetric_difference(iterable)._members def __ixor__(self, other): - if not isinstance(other, set_types + (IdentitySet,)): + if not isinstance(other, IdentitySet): return NotImplemented self.symmetric_difference(other) return self diff --git a/test/base/utils.py b/test/base/utils.py index 837eb058f0..9d65263793 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -226,13 +226,30 @@ class IdentitySetTest(unittest.TestCase): self.assert_(False) except TypeError: self.assert_(True) - s = set([o1,o2]) - s |= ids - self.assert_(isinstance(s, IdentitySet)) + + try: + s = set([o1,o2]) + s |= ids + self.assert_(False) + except TypeError: + self.assert_(True) self.assertRaises(TypeError, cmp, ids) self.assertRaises(TypeError, hash, ids) + def test_difference(self): + os1 = util.IdentitySet([1,2,3]) + os2 = util.IdentitySet([3,4,5]) + s1 = set([1,2,3]) + s2 = set([3,4,5]) + + self.assertEquals(os1 - os2, util.IdentitySet([1, 2])) + self.assertEquals(os2 - os1, util.IdentitySet([4, 5])) + self.assertRaises(TypeError, lambda: os1 - s2) + self.assertRaises(TypeError, lambda: os1 - [3, 4, 5]) + self.assertRaises(TypeError, lambda: s1 - os2) + self.assertRaises(TypeError, lambda: s1 - [3, 4, 5]) + class DictlikeIteritemsTest(unittest.TestCase): baseline = set([('a', 1), ('b', 2), ('c', 3)])