]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
alternate OrderedSet implementation courtesy sdobrev
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Feb 2007 23:15:26 +0000 (23:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Feb 2007 23:15:26 +0000 (23:15 +0000)
lib/sqlalchemy/util.py

index c2e0dbc45ff1655ca676c5aab354a3dbfee7b0af..0e888da370d8af473d20c2272098f7038578c8a9 100644 (file)
@@ -18,7 +18,7 @@ try:
     Set = set
 except:
     Set = sets.Set
-    
+
 def to_list(x):
     if x is None:
         return None
@@ -221,12 +221,81 @@ class DictDecorator(dict):
     def __repr__(self):
         return dict.__repr__(self) + repr(self.decorate)
 
-class OrderedSet(sets.Set):
-    def __init__(self, iterable=None):
-        """Construct a set from an optional iterable."""
-        self._data = OrderedDict()
-        if iterable is not None: 
-          self._update(iterable)
+class OrderedSet(Set):
+    def __init__(self, d=None, **kwargs):
+      super(OrderedSet, self).__init__(**kwargs)
+      self._list = []
+      if d: self.update( d, **kwargs)
+
+    def add(self, key):
+      if key not in self:
+          self._list.append(key)
+      Set.add( self, key)
+
+    def remove( self, element):
+      Set.remove( self, element)
+      self._list.remove( element)
+
+    def discard( self, element):
+      try:
+          Set.remove( self, element)
+      except KeyError: pass
+      else:
+          self._list.remove( element)
+
+    def clear(self):
+      Set.clear( self)
+      self._list=[]
+
+    def __iter__(self): return iter(self._list)
+
+    def update(self, iterable):
+      add = self.add
+      for i in iterable: add(i)
+      return self
+
+    def __repr__( self):
+      return '%s(%r)' % (self.__class__.__name__, self._list)
+    __str__ = __repr__
+
+    def union(self, other):
+      result = self.__class__(self)
+      result.update(other)
+      return result
+    __or__ = union
+    def intersection(self, other):
+      return self.__class__( [a for a in self if a in other])
+    __and__ = intersection
+    def symmetric_difference(self, other):
+      result = self.__class__( [a for a in self if a not in other])
+      result.update( [a for a in other if a not in self])
+      return result
+    __xor__ = symmetric_difference
+
+    def difference(self, other):
+      return self.__class__( [a for a in self if a not in other])
+    __sub__ = difference
+
+    __ior__ = update
+
+    def intersection_update(self, other):
+      Set.intersection_update( self, other)
+      self._list = [ a for a in self._list if a in other]
+      return self
+    __iand__ = intersection_update
+
+    def symmetric_difference_update(self, other):
+      Set.symmetric_difference_update( self, other)
+      self._list =  [ a for a in self._list if a in self]
+      self._list += [ a for a in other._list if a in self]
+      return self
+    __ixor__ = symmetric_difference_update
+
+    def difference_update(self, other):
+      Set.difference_update( self, other)
+      self._list = [ a for a in self._list if a in self]
+      return self
+    __isub__ = difference_update
 
 class UniqueAppender(object):
     def __init__(self, data):