]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
HashSet is gone, uses set() for most sets in py2.4 or sets.Set.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 23:28:44 +0000 (23:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 23:28:44 +0000 (23:28 +0000)
ordered set functionality supplied by a subclass of sets.Set

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/ext/activemapper.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/topological.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py

diff --git a/CHANGES b/CHANGES
index 68007a4e6545ecfbe34f26b11d52e66d8ec089d1..9d200b8a0599febaf0fbb40efc01f802e698b3da 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,9 @@
+0.2.3
+- py2.4 "set" construct used internally, falls back to sets.Set when
+"set" not available/ordering is needed.
+- "foreignkey" argument to relation() can also be a list.  fixed
+auto-foreignkey detection [ticket:151]
+
 0.2.2
 - big improvements to polymorphic inheritance behavior, enabling it
 to work with adjacency list table structures [ticket:190]
index 6956c5379d34d6ef7557bd0ba054034d37823f5f..f1030b835458adc802fc99a498f281b08ffd3dcd 100644 (file)
@@ -9,9 +9,9 @@ in the sql module."""
 
 from sqlalchemy import schema, sql, engine, util
 import sqlalchemy.engine.default as default
-import string, re
+import string, re, sets
 
-ANSI_FUNCS = util.HashSet([
+ANSI_FUNCS = sets.ImmutableSet([
 'CURRENT_TIME',
 'CURRENT_TIMESTAMP',
 'CURRENT_DATE',
index 74f4df349b1a76f7b3d78df02b67eb1b8da9d7de..aa71471d2117f5235a7a840493f0676fd00bf073 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy             import create_session, relation, mapper, join, DynamicMetaData, class_mapper
+from sqlalchemy             import create_session, relation, mapper, join, DynamicMetaData, class_mapper, util
 from sqlalchemy             import and_, or_
 from sqlalchemy             import Table, Column, ForeignKey
 from sqlalchemy.ext.sessioncontext import SessionContext
@@ -108,7 +108,7 @@ def process_relationships(klass, was_deferred=False):
 
 class ActiveMapperMeta(type):
     classes = {}
-    metadatas = sets.Set()
+    metadatas = util.Set()
     def __init__(cls, clsname, bases, dict):
         table_name = clsname.lower()
         columns    = []
index eba2203843d74ef052e7d96fd690500274ab2be0..64889b9a667cfd8c040104b1772c736f6d9038cf 100644 (file)
@@ -56,16 +56,16 @@ class Mapper(object):
         # uber-pendantic style of making mapper chain, as various testbase/
         # threadlocal/assignmapper combinations keep putting dupes etc. in the list
         # TODO: do something that isnt 21 lines....
-        extlist = util.HashSet()
+        extlist = util.Set()
         for ext_class in global_extensions:
             if isinstance(ext_class, MapperExtension):
-                extlist.append(ext_class)
+                extlist.add(ext_class)
             else:
-                extlist.append(ext_class())
+                extlist.add(ext_class())
 
         if extension is not None:
             for ext_obj in util.to_list(extension):
-                extlist.append(ext_obj)
+                extlist.add(ext_obj)
         
         self.extension = None
         previous = None
@@ -87,7 +87,7 @@ class Mapper(object):
         self._options = {}
         self.always_refresh = always_refresh
         self.version_id_col = version_id_col
-        self._inheriting_mappers = sets.Set()
+        self._inheriting_mappers = util.Set()
         self.polymorphic_on = polymorphic_on
         if polymorphic_map is None:
             self.polymorphic_map = {}
@@ -146,7 +146,7 @@ class Mapper(object):
                     # stricter set of tables to create "sync rules" by,based on the immediate
                     # inherited table, rather than all inherited tables
                     self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-                    self._synchronizer.compile(self.mapped_table.onclause, util.HashSet([inherits.local_table]), sqlutil.TableFinder(self.local_table))
+                    self._synchronizer.compile(self.mapped_table.onclause, util.Set([inherits.local_table]), sqlutil.TableFinder(self.local_table))
             else:
                 self._synchronizer = None
                 self.mapped_table = self.local_table
@@ -182,19 +182,19 @@ class Mapper(object):
         self.pks_by_table = {}
         if primary_key is not None:
             for k in primary_key:
-                self.pks_by_table.setdefault(k.table, util.HashSet(ordered=True)).append(k)
+                self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
                 if k.table != self.mapped_table:
                     # associate pk cols from subtables to the "main" table
-                    self.pks_by_table.setdefault(self.mapped_table, util.HashSet(ordered=True)).append(k)
+                    self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(k)
                 # TODO: need local_table properly accounted for when custom primary key is sent
         else:
             for t in self.tables + [self.mapped_table]:
                 try:
                     l = self.pks_by_table[t]
                 except KeyError:
-                    l = self.pks_by_table.setdefault(t, util.HashSet(ordered=True))
+                    l = self.pks_by_table.setdefault(t, util.OrderedSet())
                 for k in t.primary_key:
-                    l.append(k)
+                    l.add(k)
                     
         if len(self.pks_by_table[self.mapped_table]) == 0:
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
@@ -582,7 +582,7 @@ class Mapper(object):
                             params[col.key] = params[col._label] + 1
                         else:
                             params[col.key] = 1
-                    elif self.pks_by_table[table].contains(col):
+                    elif col in self.pks_by_table[table]:
                         # column is a primary key ?
                         if not isinsert:
                             # doing an UPDATE?  put primary key values as "WHERE" parameters
@@ -756,14 +756,14 @@ class Mapper(object):
     
     def cascade_iterator(self, type, object, callable_=None, recursive=None):
         if recursive is None:
-            recursive=sets.Set()
+            recursive=util.Set()
         for prop in self.props.values():
             for c in prop.cascade_iterator(type, object, recursive):
                 yield c
 
     def cascade_callable(self, type, object, callable_, recursive=None):
         if recursive is None:
-            recursive=sets.Set()
+            recursive=util.Set()
         for prop in self.props.values():
             prop.cascade_callable(type, object, callable_, recursive)
             
index 34529a1366d0f5f22cb3d45ceb8ca4228d71e435..0b609e3008c802c28576927524e001841cf96aa5 100644 (file)
@@ -271,7 +271,7 @@ class PropertyLoader(mapper.MapperProperty):
         """searches through the primary join condition to determine which side
         has the foreign key - from this we return 
         the "foreign key" for this property which helps determine one-to-many/many-to-one."""
-        foreignkeys = sets.Set()
+        foreignkeys = util.Set()
         def foo(binary):
             if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
@@ -305,8 +305,8 @@ class PropertyLoader(mapper.MapperProperty):
         
         The list of rules is used within commits by the _synchronize() method when dependent 
         objects are processed."""
-        parent_tables = util.HashSet(self.parent.tables + [self.parent.mapped_table])
-        target_tables = util.HashSet(self.mapper.tables + [self.mapper.mapped_table])
+        parent_tables = util.Set(self.parent.tables + [self.parent.mapped_table])
+        target_tables = util.Set(self.mapper.tables + [self.mapper.mapped_table])
 
         self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
         if self.direction == sync.MANYTOMANY:
index bd17501653c663975144f1de6a574b499cc2e5c8..1ba6d1a35e23408adfef71245a9ac2a0cd4f20e6 100644 (file)
@@ -393,7 +393,7 @@ class Session(object):
     def __contains__(self, obj):
         return self._is_attached(obj) and (obj in self.uow.new or self.uow.has_key(obj._instance_key))
     def __iter__(self):
-        return iter(self.uow.new + self.uow.identity_map.values())
+        return iter(list(self.uow.new) + self.uow.identity_map.values())
     def _get(self, key):
         return self.uow._get(key)
     def has_key(self, key):
index 89e760039146e694b0f08afffd9002666bac609a..d9ec5cde98bb83223b7881ac9c2d134a71a69a93 100644 (file)
@@ -231,119 +231,3 @@ class QueueDependencySorter(object):
         else:
             return cycled_edges
 
-class TreeDependencySorter(object):
-    """
-    this is my first topological sorting algorithm.  its crazy, but matched my thinking
-    at the time.  it also creates the kind of structure I want.  but, I am not 100% sure
-    it works in all cases since I always did really poorly in linear algebra.  anyway,
-    I got the other one above to produce a tree structure too so we should be OK.
-    """
-    class Node:
-        """represents a node in a tree.  stores an 'item' which represents the 
-        dependent thing we are talking about.  if node 'a' is an ancestor node of 
-        node 'b', it means 'a's item is *not* dependent on that of 'b'."""
-        def __init__(self, item):
-            #print "new node on " + str(item)
-            self.item = item
-            self.children = HashSet()
-            self.parent = None
-        def append(self, node):
-            """appends the given node as a child on this node.  removes the node from 
-            its preexisting parent."""
-            if node.parent is not None:
-                del node.parent.children[node]
-            self.children.append(node)
-            node.parent = self
-        def is_descendant_of(self, node):
-            """returns true if this node is a descendant of the given node"""
-            n = self
-            while n is not None:
-                if n is node:
-                    return True
-                else:
-                    n = n.parent
-            return False
-        def get_root(self):
-            """returns the highest ancestor node of this node, i.e. which has no parent"""
-            n = self
-            while n.parent is not None:
-                n = n.parent
-            return n
-        def get_sibling_ancestor(self, node):
-            """returns the node which is:
-            - an ancestor of this node 
-            - is a sibling of the given node 
-            - not an ancestor of the given node
-            
-            - else returns this node's root node."""
-            n = self
-            while n.parent is not None and n.parent is not node.parent and not node.is_descendant_of(n.parent):
-                n = n.parent
-            return n
-        def __str__(self):
-            return self.safestr({})
-        def safestr(self, hash, indent = 0):
-            if hash.has_key(self):
-                return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None')
-            hash[self] = True
-            return (' ' * indent) + "%s  (idself=%s, idparent=%s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '')
-        def describe(self):
-            return "%s  (idself=%s)" % (str(self.item), repr(id(self)))
-            
-    def __init__(self, tuples, allitems):
-        self.tuples = tuples
-        self.allitems = allitems
-        
-    def sort(self):
-        (tuples, allitems) = (self.tuples, self.allitems)
-        
-        nodes = {}
-        # make nodes for all the items and store in the hash
-        for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
-            if not nodes.has_key(item):
-                nodes[item] = TreeDependencySorter.Node(item)
-
-        # loop through tuples
-        for tup in tuples:
-            (parent, child) = (tup[0], tup[1])
-            # get parent node
-            parentnode = nodes[parent]
-
-            # if parent is child, mark "circular" attribute on the node
-            if parent is child:
-                parentnode.circular = True
-                # and just continue
-                continue
-
-            # get child node
-            childnode = nodes[child]
-            
-            if parentnode.parent is childnode:
-                # check for "a switch"
-                t = parentnode.item
-                parentnode.item = childnode.item
-                childnode.item = t
-                nodes[parentnode.item] = parentnode
-                nodes[childnode.item] = childnode
-            elif parentnode.is_descendant_of(childnode):
-                # check for a line thats backwards with nodes in between, this is a 
-                # circular dependency (although confirmation on this would be helpful)
-                raise FlushError("Circular dependency detected")
-            elif not childnode.is_descendant_of(parentnode):
-                # if relationship doesnt exist, connect nodes together
-                root = childnode.get_sibling_ancestor(parentnode)
-                parentnode.append(root)
-
-            
-        # now we have a collection of subtrees which represent dependencies.
-        # go through the collection root nodes wire them together into one tree        
-        head = None
-        for node in nodes.values():
-            if node.parent is None:
-                if head is not None:
-                    head.append(node)
-                else:
-                    head = node
-        #print str(head)
-        return head
-                        
\ No newline at end of file
index 9e9778cad98c898697e5d2b6c0405db483763d32..b8c6939f76774f1e9409d76f64294c1cb77af895 100644 (file)
@@ -104,10 +104,10 @@ class UnitOfWork(object):
             self.identity_map = weakref.WeakValueDictionary()
             
         self.attributes = global_attributes
-        self.new = util.HashSet(ordered = True)
-        self.dirty = util.HashSet()
+        self.new = util.OrderedSet()
+        self.dirty = util.Set()
         
-        self.deleted = util.HashSet()
+        self.deleted = util.Set()
 
     def get(self, class_, *id):
         """given a class and a list of primary key values in their table-order, locates the mapper 
@@ -149,15 +149,15 @@ class UnitOfWork(object):
         if hasattr(obj, "_instance_key"):
             del self.identity_map[obj._instance_key]
         try:            
-            del self.deleted[obj]
+            self.deleted.remove(obj)
         except KeyError:
             pass
         try:
-            del self.dirty[obj]
+            self.dirty.remove(obj)
         except KeyError:
             pass
         try:
-            del self.new[obj]
+            self.new.remove(obj)
         except KeyError:
             pass
         #self.attributes.commit(obj)
@@ -183,11 +183,11 @@ class UnitOfWork(object):
     
     def register_clean(self, obj):
         try:
-            del self.dirty[obj]
+            self.dirty.remove(obj)
         except KeyError:
             pass
         try:
-            del self.new[obj]
+            self.new.remove(obj)
         except KeyError:
             pass
         if not hasattr(obj, '_instance_key'):
@@ -199,26 +199,26 @@ class UnitOfWork(object):
     def register_new(self, obj):
         if hasattr(obj, '_instance_key'):
             raise InvalidRequestError("Object '%s' already has an identity - it cant be registered as new" % repr(obj))
-        if not self.new.contains(obj):
-            self.new.append(obj)
+        if obj not in self.new:
+            self.new.add(obj)
         self.unregister_deleted(obj)
         
     def register_dirty(self, obj):
-        if not self.dirty.contains(obj):
+        if obj not in self.dirty:
             self._validate_obj(obj)
-            self.dirty.append(obj)
+            self.dirty.add(obj)
         self.unregister_deleted(obj)
         
     def is_dirty(self, obj):
-        if not self.dirty.contains(obj):
+        if obj not in self.dirty:
             return False
         else:
             return True
         
     def register_deleted(self, obj):
-        if not self.deleted.contains(obj):
+        if obj not in self.deleted:
             self._validate_obj(obj)
-            self.deleted.append(obj)  
+            self.deleted.add(obj)  
 
     def unregister_deleted(self, obj):
         try:
@@ -230,14 +230,14 @@ class UnitOfWork(object):
         flush_context = UOWTransaction(self, session)
 
         if objects is not None:
-            objset = sets.Set(objects)
+            objset = util.Set(objects)
         else:
             objset = None
 
         for obj in [n for n in self.new] + [d for d in self.dirty]:
             if objset is not None and not obj in objset:
                 continue
-            if self.deleted.contains(obj):
+            if obj in self.deleted:
                 continue
             flush_context.register_object(obj)
             
@@ -262,11 +262,11 @@ class UnitOfWork(object):
         """'rolls back' the attributes that have been changed on an object instance."""
         self.attributes.rollback(obj)
         try:
-            del self.dirty[obj]
+            self.dirty.remove(obj)
         except KeyError:
             pass
         try:
-            del self.deleted[obj]
+            self.deleted.remove(obj)
         except KeyError:
             pass
             
@@ -277,7 +277,7 @@ class UOWTransaction(object):
         self.uow = uow
         self.session = session
         #  unique list of all the mappers we come across
-        self.mappers = sets.Set()
+        self.mappers = util.Set()
         self.dependencies = {}
         self.tasks = {}
         self.__modified = False
@@ -463,7 +463,7 @@ class UOWTransaction(object):
     def _get_noninheriting_mappers(self):
         """returns a list of UOWTasks whose mappers are not inheriting from the mapper of another UOWTask.
         i.e., this returns the root UOWTasks for all the inheritance hierarchies represented in this UOWTransaction."""
-        mappers = sets.Set()
+        mappers = util.Set()
         for task in self.tasks.values():
             base = task.mapper.base_mapper()
             mappers.add(base)
@@ -580,7 +580,7 @@ class UOWTask(object):
         
         # a list of UOWDependencyProcessors which are executed after saves and
         # before deletes, to synchronize data to dependent objects
-        self.dependencies = sets.Set()
+        self.dependencies = util.Set()
 
         # a list of UOWTasks that are dependent on this UOWTask, which 
         # are to be executed after this UOWTask performs saves and post-save
@@ -589,7 +589,7 @@ class UOWTask(object):
         
         # a list of UOWTasks that correspond to Mappers which are inheriting
         # mappers of this UOWTask's Mapper
-        #self.inheriting_tasks = sets.Set()
+        #self.inheriting_tasks = util.Set()
 
         # whether this UOWTask is circular, meaning it holds a second
         # UOWTask that contains a special row-based dependency structure.
@@ -603,7 +603,7 @@ class UOWTask(object):
         # set of dependencies, referencing sub-UOWTasks attached to this
         # one which represent portions of the total list of objects.
         # this is used for the row-based "circular sort"
-        self.cyclical_dependencies = sets.Set()
+        self.cyclical_dependencies = util.Set()
         
     def is_empty(self):
         return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0
@@ -773,7 +773,7 @@ class UOWTask(object):
             allobjects += [e.obj for e in task.get_elements(polymorphic=True)]
         tuples = []
         
-        cycles = sets.Set(cycles)
+        cycles = util.Set(cycles)
         
         #print "BEGIN CIRC SORT-------"
         #print "PRE-CIRC:"
index 19cb213673c96926321cdf4ce9b8804c92843d31..86799b31166e25d67dab1a4091cc14ee9698b5f5 100644 (file)
@@ -4,13 +4,13 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import sets
+import sqlalchemy.util as util
 import sqlalchemy.sql as sql
 
 class CascadeOptions(object):
     """keeps track of the options sent to relation().cascade"""
     def __init__(self, arg=""):
-        values = sets.Set([c.strip() for c in arg.split(',')])
+        values = util.Set([c.strip() for c in arg.split(',')])
         self.delete_orphan = "delete-orphan" in values
         self.delete = "delete" in values or self.delete_orphan or "all" in values
         self.save_update = "save-update" in values or "all" in values
@@ -22,7 +22,7 @@ class CascadeOptions(object):
     
 
 def polymorphic_union(table_map, typecolname, aliasname='p_union'):
-    colnames = sets.Set()
+    colnames = util.Set()
     colnamemaps = {}
     
     for key in table_map.keys():
index c0769c5e65237828a3e394c11aa12d20b33ac77e..2af45c6da6cc5b74eb518bcfe72fa2dc3d2bd50c 100644 (file)
@@ -624,7 +624,7 @@ class ColumnElement(Selectable, CompareMixin):
         try:
             return self.__orig_set
         except AttributeError:
-            self.__orig_set = sets.Set([self])
+            self.__orig_set = util.Set([self])
             return self.__orig_set
     def _set_orig_set(self, s):
         if len(s) == 0:
@@ -1334,7 +1334,7 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         try:
             colset = self._col_map[col.name]
         except KeyError:
-            colset = sets.Set()
+            colset = util.Set()
             self._col_map[col.name] = colset
         [colset.add(c) for c in col.orig_set]
         col.orig_set = colset
index bb1b8bbeea110ad572fc5b30b86a30eac9f48058..c3017e40c2b8fb4dbda35b6c19b68d69503ab330 100644 (file)
@@ -4,10 +4,15 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import thread, threading, weakref, UserList, time, string, inspect, sys
+import thread, threading, weakref, UserList, time, string, inspect, sys, sets
 from exceptions import *
 import __builtin__
 
+try:
+    Set = set
+except:
+    Set = sets.Set
+    
 def to_list(x):
     if x is None:
         return None
@@ -18,9 +23,9 @@ def to_list(x):
 
 def to_set(x):
     if x is None:
-        return HashSet()
-    if not isinstance(x, HashSet):
-        return HashSet(to_list(x))
+        return Set()
+    if not isinstance(x, Set):
+        return Set(to_list(x))
     else:
         return x
 
@@ -189,43 +194,13 @@ class DictDecorator(dict):
             return self.decorate[key]
     def __repr__(self):
         return dict.__repr__(self) + repr(self.decorate)
-        
-class HashSet(object):
-    """implements a Set, including ordering capability"""
-    def __init__(self, iter=None, ordered=False):
-        if ordered:
-            self.map = OrderedDict()
-        else:
-            self.map  = {}
-        if iter is not None:
-            for i in iter:
-                self.append(i)
-    def __iter__(self):
-        return iter(self.map.values())
-    def contains(self, item):
-        return self.map.has_key(item)
-    def __contains__(self, item):
-        return self.map.has_key(item)
-    def clear(self):
-        self.map.clear()
-    def intersection(self, l):
-        return HashSet([x for x in l if self.contains(x)])
-    def empty(self):
-        return len(self.map) == 0
-    def append(self, item):
-        self.map[item] = item
-    def remove(self, item):
-        del self.map[item]
-    def __add__(self, other):
-        return HashSet(self.map.values() + [i for i in other])
-    def __len__(self):
-        return len(self.map)
-    def __delitem__(self, key):
-        del self.map[key]
-    def __getitem__(self, key):
-        return self.map[key]
-    def __repr__(self):
-        return repr(self.map.values())
+
+class OrderedSet(sets.Set):
+    def __init__(self, iterable=None):
+        """Construct a set from an optional iterable."""
+        self._data = OrderedDict()
+        if iterable is not None: 
+          self._update(iterable)
         
 class HistoryArraySet(UserList.UserList):
     """extends a UserList to provide unique-set functionality as well as history-aware