]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- A more efficient IdentitySet
authorJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 19:53:27 +0000 (19:53 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 31 Oct 2007 19:53:27 +0000 (19:53 +0000)
lib/sqlalchemy/util.py

index 9317d4b9bc701bab557b2a26d94766d64503ca8a..a4ccaac6ab88705cdde65baf16feb09f25bdc55e 100644 (file)
@@ -531,188 +531,198 @@ class IdentitySet(object):
     two 'foo' strings in one of these sets, for example.  Use sparingly.
     """
 
-    class _IdentityProxy(object):
-        """Proxies an object's id() as its hash and basis for equality."""
-
-        __slots__ = ('obj',)
-
-        def __init__(self, value):
-            self.obj = value
-        def __hash__(self):
-            return id(self.obj)
-        def __eq__(self, other):
-            if isinstance(other, type(self)):
-                return id(self.obj) == id(other.obj)
-            else:
-                return id(self.obj) == id(other)
-        def __ne__(self, other):
-            if isinstance(other, type(self)):
-                return id(self.obj) != id(other.obj)
-            else:
-                return id(self.obj) != id(other)
-
     def __init__(self, iterable=None):
-        self.set = Set()
+        self._members = {}
         if iterable:
             for o in iterable:
                 self.add(o)
 
     def add(self, value):
-        self.set.add(_id_proxy(value))
+        self._members[id(value)] = value
+
+    def __contains__(self, value):
+        return id(value) in self._members
 
     def remove(self, value):
-        value = _id_proxy(value)
-        if value not in self:
-            raise KeyError(value.obj)
-        self.set.remove(value)
+        del self._members[id(value)]
 
     def discard(self, value):
-        self.set.discard(_id_proxy(value))
+        try:
+            self.remove(value)
+        except KeyError:
+            pass
 
     def pop(self):
-        proxied = self.set.pop()
-        return proxied.obj
-
-    def issubset(self, iterable):
-        if not isinstance(iterable, type(self)):
-            iterable = type(self)(iterable)
-        return self.set.issubset(iterable)
-    __le__ = issubset
-
-    def __lt__(self, iterable):
-        if not isinstance(iterable, type(self)):
-            iterable = type(self)(iterable)
-        return len(self) < len(iterable) and self.issubset(iterable)
+        try:
+            pair = self._members.popitem()
+            return pair[1]
+        except KeyError:
+            raise KeyError('pop from an empty set')
 
-    def issuperset(self, iterable):
-        if not isinstance(iterable, type(self)):
-            iterable = type(self)(iterable)
-        return self.set.issuperset(iterable)
-    __ge__ = issuperset
+    def clear(self):
+        self._members.clear()
 
-    def __gt__(self, iterable):
-        if not isinstance(iterable, type(self)):
-            iterable = type(self)(iterable)
-        return len(self) > len(iterable) and self.issuperset(iterable)
+    def __cmp__(self, other):
+        raise TypeError('cannot compare sets using cmp()')
 
     def __eq__(self, other):
         if isinstance(other, IdentitySet):
-            return self.set == other.set
+            return self._members == other._members
         else:
             return False
 
     def __ne__(self, other):
         if isinstance(other, IdentitySet):
-            return self.set != other.set
+            return self._members != other._members
         else:
             return True
 
-    def __cmp__(self, other):
-        raise TypeError('cannot compare sets using cmp()')
+    def issubset(self, iterable):
+        other = type(self)(iterable)
 
-    def clear(self):
-        self.set.clear()
+        if len(self) > len(other):
+            return False
+        for m in itertools.ifilterfalse(other._members.has_key,
+                                        self._members.iterkeys()):
+            return False
+        return True
 
-    def copy(self):
-        return type(self)(self.set)
+    def __le__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.issubset(other)
+
+    def __lt__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
+            return NotImplemented
+        return len(self) < len(other) and self.issubset(other)
+
+    def issuperset(self, iterable):
+        other = type(self)(iterable)
+
+        if len(self) < len(other):
+            return False
+
+        for m in itertools.ifilterfalse(self._members.has_key,
+                                        other._members.iterkeys()):
+            return False
+        return True
+
+    def __ge__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
+            return NotImplemented
+        return self.issuperset(other)
+
+    def __gt__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
+            return NotImplemented
+        return len(self) > len(other) and self.issuperset(other)
 
     def union(self, iterable):
-        return type(self)(self.set.union(_proxyiter(iterable)))
+        result = type(self)()
+        result._members.update(
+            Set(self._members.iteritems()).union(_iter_id(iterable)))
+        return result
 
-    def __or__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __or__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        return self.union(iterable)
-    __ror__ = union
+        return self.union(other)
+    __ror__ = __or__
 
     def update(self, iterable):
-        self.set.update(_proxyiter(iterable))
+        self._members = self.union(iterable)._members
 
-    def __ior__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __ior__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        self.update(iterable)
+        self.update(other)
         return self
 
     def difference(self, iterable):
-        return type(self)(self.set.difference(_proxyiter(iterable)))
+        result = type(self)()
+        result._members.update(
+            Set(self._members.iteritems()).difference(_iter_id(iterable)))
+        return result
 
-    def __sub__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __sub__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        return self.difference(iterable)
+        return self.difference(other)
     __rsub__ = __sub__
 
     def difference_update(self, iterable):
-        self.set.difference_update(_proxyiter(iterable))
+        self._members = self.difference(iterable)._members
 
-    def __isub__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __isub__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        self.difference_update(iterable)
+        self.difference_update(other)
         return self
 
     def intersection(self, iterable):
-        return type(self)(self.set.intersection(_proxyiter(iterable)))
+        result = type(self)()
+        result._members.update(
+            Set(self._members.iteritems()).intersection(_iter_id(iterable)))
+        return result
 
-    def __and__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __and__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        return self.intersection(iterable)
+        return self.intersection(other)
     __rand__ = __and__
 
     def intersection_update(self, iterable):
-        self.set.intersection_update(_proxyiter(iterable))
+        self._members = self.intersection(iterable)._members
 
-    def __iand__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __iand__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        self.intersection_update(iterable)
+        self.intersection_update(other)
         return self
 
     def symmetric_difference(self, iterable):
-        return type(self)(self.set.symmetric_difference(_proxyiter(iterable)))
+        result = type(self)()
+        result._members.update(
+            Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable)))
+        return result
 
-    def __xor__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __xor__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        return self.symmetric_difference(iterable)
+        return self.symmetric_difference(other)
     __rxor__ = __xor__
 
     def symmetric_difference_update(self, iterable):
-        self.set.symmetric_difference_update(_proxyiter(iterable))
+        self._members = self.symmetric_difference(iterable)._members
 
-    def __ixor__(self, iterable):
-        if not isinstance(iterable, set_types + (IdentitySet,)):
+    def __ixor__(self, other):
+        if not isinstance(other, set_types + (IdentitySet,)):
             return NotImplemented
-        self.symmetric_difference_update(iterable)
+        self.symmetric_difference(other)
         return self
 
-    def __iter__(self):
-        for proxy in self.set:
-            assert isinstance(proxy, self._IdentityProxy)
-            yield proxy.obj
+    def copy(self):
+        return type(self)(self._members.itervalues())
+
+    __copy__ = copy
 
     def __len__(self):
-        return len(self.set)
+        return len(self._members)
 
-    def __contains__(self, value):
-        return _id_proxy(value) in self.set
+    def __iter__(self):
+        return self._members.itervalues()
 
     def __hash__(self):
         raise TypeError('set objects are unhashable')
 
     def __repr__(self):
-        return '%s(%r)' % (type(self).__name__, list(self))
-
-def _proxyiter(iterable):
-    return itertools.imap(_id_proxy, iterable)
+        return '%s(%r)' % (type(self).__name__, self._members.values())
 
-def _id_proxy(item):
-    if isinstance(item, IdentitySet._IdentityProxy):
-        return item
-    else:
-        return IdentitySet._IdentityProxy(item)
+def _iter_id(iterable):
+    """Generator: ((id(o), o) for o in iterable)."""
+    for item in iterable:
+        yield id(item), item
 
 
 class UniqueAppender(object):