]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- IdentitySet binops no longer accept plain sets.
authorJason Kirtland <jek@discorporate.us>
Thu, 24 Jan 2008 01:00:41 +0000 (01:00 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 24 Jan 2008 01:00:41 +0000 (01:00 +0000)
lib/sqlalchemy/util.py
test/base/utils.py

index 1f41c1179e57d79d17125df2ceecc6f9935ca4f4..c0d0c7eed7b58c363103e6edf0f4b8078eeca752 100644 (file)
@@ -678,12 +678,12 @@ class IdentitySet(object):
         return True
 
     def __le__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.issubset(other)
 
     def __lt__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return len(self) < len(other) and self.issubset(other)
 
@@ -699,12 +699,12 @@ class IdentitySet(object):
         return True
 
     def __ge__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.issuperset(other)
 
     def __gt__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return len(self) > len(other) and self.issuperset(other)
 
@@ -716,16 +716,15 @@ class IdentitySet(object):
         return result
 
     def __or__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.union(other)
-    __ror__ = __or__
 
     def update(self, iterable):
         self._members = self.union(iterable)._members
 
     def __ior__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         self.update(other)
         return self
@@ -738,16 +737,15 @@ class IdentitySet(object):
         return result
 
     def __sub__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.difference(other)
-    __rsub__ = __sub__
 
     def difference_update(self, iterable):
         self._members = self.difference(iterable)._members
 
     def __isub__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         self.difference_update(other)
         return self
@@ -760,16 +758,15 @@ class IdentitySet(object):
         return result
 
     def __and__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.intersection(other)
-    __rand__ = __and__
 
     def intersection_update(self, iterable):
         self._members = self.intersection(iterable)._members
 
     def __iand__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         self.intersection_update(other)
         return self
@@ -782,16 +779,15 @@ class IdentitySet(object):
         return result
 
     def __xor__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         return self.symmetric_difference(other)
-    __rxor__ = __xor__
 
     def symmetric_difference_update(self, iterable):
         self._members = self.symmetric_difference(iterable)._members
 
     def __ixor__(self, other):
-        if not isinstance(other, set_types + (IdentitySet,)):
+        if not isinstance(other, IdentitySet):
             return NotImplemented
         self.symmetric_difference(other)
         return self
index 837eb058f0623a801b7184297d5865741d101d7a..9d65263793fa820cf1ea657cc82ba911d7d4f63e 100644 (file)
@@ -226,13 +226,30 @@ class IdentitySetTest(unittest.TestCase):
             self.assert_(False)
         except TypeError:
             self.assert_(True)
-        s = set([o1,o2])
-        s |= ids
-        self.assert_(isinstance(s, IdentitySet))
+
+        try:
+            s = set([o1,o2])
+            s |= ids
+            self.assert_(False)
+        except TypeError:
+            self.assert_(True)
 
         self.assertRaises(TypeError, cmp, ids)
         self.assertRaises(TypeError, hash, ids)
 
+    def test_difference(self):
+        os1 = util.IdentitySet([1,2,3])
+        os2 = util.IdentitySet([3,4,5])
+        s1 = set([1,2,3])
+        s2 = set([3,4,5])
+
+        self.assertEquals(os1 - os2, util.IdentitySet([1, 2]))
+        self.assertEquals(os2 - os1, util.IdentitySet([4, 5]))
+        self.assertRaises(TypeError, lambda: os1 - s2)
+        self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
+        self.assertRaises(TypeError, lambda: s1 - os2)
+        self.assertRaises(TypeError, lambda: s1 - [3, 4, 5])
+
 
 class DictlikeIteritemsTest(unittest.TestCase):
     baseline = set([('a', 1), ('b', 2), ('c', 3)])