From a4252a12b0e1411cea7a636025ef9b97cb824f17 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 5 Jun 2006 23:28:44 +0000 Subject: [PATCH] HashSet is gone, uses set() for most sets in py2.4 or sets.Set. ordered set functionality supplied by a subclass of sets.Set --- CHANGES | 6 ++ lib/sqlalchemy/ansisql.py | 4 +- lib/sqlalchemy/ext/activemapper.py | 4 +- lib/sqlalchemy/orm/mapper.py | 26 +++---- lib/sqlalchemy/orm/properties.py | 6 +- lib/sqlalchemy/orm/session.py | 2 +- lib/sqlalchemy/orm/topological.py | 116 ----------------------------- lib/sqlalchemy/orm/unitofwork.py | 50 ++++++------- lib/sqlalchemy/orm/util.py | 6 +- lib/sqlalchemy/sql.py | 4 +- lib/sqlalchemy/util.py | 57 ++++---------- 11 files changed, 73 insertions(+), 208 deletions(-) diff --git a/CHANGES b/CHANGES index 68007a4e65..9d200b8a05 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,9 @@ +0.2.3 +- py2.4 "set" construct used internally, falls back to sets.Set when +"set" not available/ordering is needed. +- "foreignkey" argument to relation() can also be a list. fixed +auto-foreignkey detection [ticket:151] + 0.2.2 - big improvements to polymorphic inheritance behavior, enabling it to work with adjacency list table structures [ticket:190] diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 6956c5379d..f1030b8354 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -9,9 +9,9 @@ in the sql module.""" from sqlalchemy import schema, sql, engine, util import sqlalchemy.engine.default as default -import string, re +import string, re, sets -ANSI_FUNCS = util.HashSet([ +ANSI_FUNCS = sets.ImmutableSet([ 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'CURRENT_DATE', diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py index 74f4df349b..aa71471d21 100644 --- a/lib/sqlalchemy/ext/activemapper.py +++ b/lib/sqlalchemy/ext/activemapper.py @@ -1,4 +1,4 @@ -from sqlalchemy import create_session, relation, mapper, join, DynamicMetaData, class_mapper +from sqlalchemy import create_session, relation, mapper, join, DynamicMetaData, class_mapper, util from sqlalchemy import and_, or_ from sqlalchemy import Table, Column, ForeignKey from sqlalchemy.ext.sessioncontext import SessionContext @@ -108,7 +108,7 @@ def process_relationships(klass, was_deferred=False): class ActiveMapperMeta(type): classes = {} - metadatas = sets.Set() + metadatas = util.Set() def __init__(cls, clsname, bases, dict): table_name = clsname.lower() columns = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index eba2203843..64889b9a66 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -56,16 +56,16 @@ class Mapper(object): # uber-pendantic style of making mapper chain, as various testbase/ # threadlocal/assignmapper combinations keep putting dupes etc. in the list # TODO: do something that isnt 21 lines.... - extlist = util.HashSet() + extlist = util.Set() for ext_class in global_extensions: if isinstance(ext_class, MapperExtension): - extlist.append(ext_class) + extlist.add(ext_class) else: - extlist.append(ext_class()) + extlist.add(ext_class()) if extension is not None: for ext_obj in util.to_list(extension): - extlist.append(ext_obj) + extlist.add(ext_obj) self.extension = None previous = None @@ -87,7 +87,7 @@ class Mapper(object): self._options = {} self.always_refresh = always_refresh self.version_id_col = version_id_col - self._inheriting_mappers = sets.Set() + self._inheriting_mappers = util.Set() self.polymorphic_on = polymorphic_on if polymorphic_map is None: self.polymorphic_map = {} @@ -146,7 +146,7 @@ class Mapper(object): # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.mapped_table.onclause, util.HashSet([inherits.local_table]), sqlutil.TableFinder(self.local_table)) + self._synchronizer.compile(self.mapped_table.onclause, util.Set([inherits.local_table]), sqlutil.TableFinder(self.local_table)) else: self._synchronizer = None self.mapped_table = self.local_table @@ -182,19 +182,19 @@ class Mapper(object): self.pks_by_table = {} if primary_key is not None: for k in primary_key: - self.pks_by_table.setdefault(k.table, util.HashSet(ordered=True)).append(k) + self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k) if k.table != self.mapped_table: # associate pk cols from subtables to the "main" table - self.pks_by_table.setdefault(self.mapped_table, util.HashSet(ordered=True)).append(k) + self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(k) # TODO: need local_table properly accounted for when custom primary key is sent else: for t in self.tables + [self.mapped_table]: try: l = self.pks_by_table[t] except KeyError: - l = self.pks_by_table.setdefault(t, util.HashSet(ordered=True)) + l = self.pks_by_table.setdefault(t, util.OrderedSet()) for k in t.primary_key: - l.append(k) + l.add(k) if len(self.pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name)) @@ -582,7 +582,7 @@ class Mapper(object): params[col.key] = params[col._label] + 1 else: params[col.key] = 1 - elif self.pks_by_table[table].contains(col): + elif col in self.pks_by_table[table]: # column is a primary key ? if not isinsert: # doing an UPDATE? put primary key values as "WHERE" parameters @@ -756,14 +756,14 @@ class Mapper(object): def cascade_iterator(self, type, object, callable_=None, recursive=None): if recursive is None: - recursive=sets.Set() + recursive=util.Set() for prop in self.props.values(): for c in prop.cascade_iterator(type, object, recursive): yield c def cascade_callable(self, type, object, callable_, recursive=None): if recursive is None: - recursive=sets.Set() + recursive=util.Set() for prop in self.props.values(): prop.cascade_callable(type, object, callable_, recursive) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 34529a1366..0b609e3008 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -271,7 +271,7 @@ class PropertyLoader(mapper.MapperProperty): """searches through the primary join condition to determine which side has the foreign key - from this we return the "foreign key" for this property which helps determine one-to-many/many-to-one.""" - foreignkeys = sets.Set() + foreignkeys = util.Set() def foo(binary): if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return @@ -305,8 +305,8 @@ class PropertyLoader(mapper.MapperProperty): The list of rules is used within commits by the _synchronize() method when dependent objects are processed.""" - parent_tables = util.HashSet(self.parent.tables + [self.parent.mapped_table]) - target_tables = util.HashSet(self.mapper.tables + [self.mapper.mapped_table]) + parent_tables = util.Set(self.parent.tables + [self.parent.mapped_table]) + target_tables = util.Set(self.mapper.tables + [self.mapper.mapped_table]) self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction) if self.direction == sync.MANYTOMANY: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index bd17501653..1ba6d1a35e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -393,7 +393,7 @@ class Session(object): def __contains__(self, obj): return self._is_attached(obj) and (obj in self.uow.new or self.uow.has_key(obj._instance_key)) def __iter__(self): - return iter(self.uow.new + self.uow.identity_map.values()) + return iter(list(self.uow.new) + self.uow.identity_map.values()) def _get(self, key): return self.uow._get(key) def has_key(self, key): diff --git a/lib/sqlalchemy/orm/topological.py b/lib/sqlalchemy/orm/topological.py index 89e7600391..d9ec5cde98 100644 --- a/lib/sqlalchemy/orm/topological.py +++ b/lib/sqlalchemy/orm/topological.py @@ -231,119 +231,3 @@ class QueueDependencySorter(object): else: return cycled_edges -class TreeDependencySorter(object): - """ - this is my first topological sorting algorithm. its crazy, but matched my thinking - at the time. it also creates the kind of structure I want. but, I am not 100% sure - it works in all cases since I always did really poorly in linear algebra. anyway, - I got the other one above to produce a tree structure too so we should be OK. - """ - class Node: - """represents a node in a tree. stores an 'item' which represents the - dependent thing we are talking about. if node 'a' is an ancestor node of - node 'b', it means 'a's item is *not* dependent on that of 'b'.""" - def __init__(self, item): - #print "new node on " + str(item) - self.item = item - self.children = HashSet() - self.parent = None - def append(self, node): - """appends the given node as a child on this node. removes the node from - its preexisting parent.""" - if node.parent is not None: - del node.parent.children[node] - self.children.append(node) - node.parent = self - def is_descendant_of(self, node): - """returns true if this node is a descendant of the given node""" - n = self - while n is not None: - if n is node: - return True - else: - n = n.parent - return False - def get_root(self): - """returns the highest ancestor node of this node, i.e. which has no parent""" - n = self - while n.parent is not None: - n = n.parent - return n - def get_sibling_ancestor(self, node): - """returns the node which is: - - an ancestor of this node - - is a sibling of the given node - - not an ancestor of the given node - - - else returns this node's root node.""" - n = self - while n.parent is not None and n.parent is not node.parent and not node.is_descendant_of(n.parent): - n = n.parent - return n - def __str__(self): - return self.safestr({}) - def safestr(self, hash, indent = 0): - if hash.has_key(self): - return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None') - hash[self] = True - return (' ' * indent) + "%s (idself=%s, idparent=%s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '') - def describe(self): - return "%s (idself=%s)" % (str(self.item), repr(id(self))) - - def __init__(self, tuples, allitems): - self.tuples = tuples - self.allitems = allitems - - def sort(self): - (tuples, allitems) = (self.tuples, self.allitems) - - nodes = {} - # make nodes for all the items and store in the hash - for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: - if not nodes.has_key(item): - nodes[item] = TreeDependencySorter.Node(item) - - # loop through tuples - for tup in tuples: - (parent, child) = (tup[0], tup[1]) - # get parent node - parentnode = nodes[parent] - - # if parent is child, mark "circular" attribute on the node - if parent is child: - parentnode.circular = True - # and just continue - continue - - # get child node - childnode = nodes[child] - - if parentnode.parent is childnode: - # check for "a switch" - t = parentnode.item - parentnode.item = childnode.item - childnode.item = t - nodes[parentnode.item] = parentnode - nodes[childnode.item] = childnode - elif parentnode.is_descendant_of(childnode): - # check for a line thats backwards with nodes in between, this is a - # circular dependency (although confirmation on this would be helpful) - raise FlushError("Circular dependency detected") - elif not childnode.is_descendant_of(parentnode): - # if relationship doesnt exist, connect nodes together - root = childnode.get_sibling_ancestor(parentnode) - parentnode.append(root) - - - # now we have a collection of subtrees which represent dependencies. - # go through the collection root nodes wire them together into one tree - head = None - for node in nodes.values(): - if node.parent is None: - if head is not None: - head.append(node) - else: - head = node - #print str(head) - return head - \ No newline at end of file diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 9e9778cad9..b8c6939f76 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -104,10 +104,10 @@ class UnitOfWork(object): self.identity_map = weakref.WeakValueDictionary() self.attributes = global_attributes - self.new = util.HashSet(ordered = True) - self.dirty = util.HashSet() + self.new = util.OrderedSet() + self.dirty = util.Set() - self.deleted = util.HashSet() + self.deleted = util.Set() def get(self, class_, *id): """given a class and a list of primary key values in their table-order, locates the mapper @@ -149,15 +149,15 @@ class UnitOfWork(object): if hasattr(obj, "_instance_key"): del self.identity_map[obj._instance_key] try: - del self.deleted[obj] + self.deleted.remove(obj) except KeyError: pass try: - del self.dirty[obj] + self.dirty.remove(obj) except KeyError: pass try: - del self.new[obj] + self.new.remove(obj) except KeyError: pass #self.attributes.commit(obj) @@ -183,11 +183,11 @@ class UnitOfWork(object): def register_clean(self, obj): try: - del self.dirty[obj] + self.dirty.remove(obj) except KeyError: pass try: - del self.new[obj] + self.new.remove(obj) except KeyError: pass if not hasattr(obj, '_instance_key'): @@ -199,26 +199,26 @@ class UnitOfWork(object): def register_new(self, obj): if hasattr(obj, '_instance_key'): raise InvalidRequestError("Object '%s' already has an identity - it cant be registered as new" % repr(obj)) - if not self.new.contains(obj): - self.new.append(obj) + if obj not in self.new: + self.new.add(obj) self.unregister_deleted(obj) def register_dirty(self, obj): - if not self.dirty.contains(obj): + if obj not in self.dirty: self._validate_obj(obj) - self.dirty.append(obj) + self.dirty.add(obj) self.unregister_deleted(obj) def is_dirty(self, obj): - if not self.dirty.contains(obj): + if obj not in self.dirty: return False else: return True def register_deleted(self, obj): - if not self.deleted.contains(obj): + if obj not in self.deleted: self._validate_obj(obj) - self.deleted.append(obj) + self.deleted.add(obj) def unregister_deleted(self, obj): try: @@ -230,14 +230,14 @@ class UnitOfWork(object): flush_context = UOWTransaction(self, session) if objects is not None: - objset = sets.Set(objects) + objset = util.Set(objects) else: objset = None for obj in [n for n in self.new] + [d for d in self.dirty]: if objset is not None and not obj in objset: continue - if self.deleted.contains(obj): + if obj in self.deleted: continue flush_context.register_object(obj) @@ -262,11 +262,11 @@ class UnitOfWork(object): """'rolls back' the attributes that have been changed on an object instance.""" self.attributes.rollback(obj) try: - del self.dirty[obj] + self.dirty.remove(obj) except KeyError: pass try: - del self.deleted[obj] + self.deleted.remove(obj) except KeyError: pass @@ -277,7 +277,7 @@ class UOWTransaction(object): self.uow = uow self.session = session # unique list of all the mappers we come across - self.mappers = sets.Set() + self.mappers = util.Set() self.dependencies = {} self.tasks = {} self.__modified = False @@ -463,7 +463,7 @@ class UOWTransaction(object): def _get_noninheriting_mappers(self): """returns a list of UOWTasks whose mappers are not inheriting from the mapper of another UOWTask. i.e., this returns the root UOWTasks for all the inheritance hierarchies represented in this UOWTransaction.""" - mappers = sets.Set() + mappers = util.Set() for task in self.tasks.values(): base = task.mapper.base_mapper() mappers.add(base) @@ -580,7 +580,7 @@ class UOWTask(object): # a list of UOWDependencyProcessors which are executed after saves and # before deletes, to synchronize data to dependent objects - self.dependencies = sets.Set() + self.dependencies = util.Set() # a list of UOWTasks that are dependent on this UOWTask, which # are to be executed after this UOWTask performs saves and post-save @@ -589,7 +589,7 @@ class UOWTask(object): # a list of UOWTasks that correspond to Mappers which are inheriting # mappers of this UOWTask's Mapper - #self.inheriting_tasks = sets.Set() + #self.inheriting_tasks = util.Set() # whether this UOWTask is circular, meaning it holds a second # UOWTask that contains a special row-based dependency structure. @@ -603,7 +603,7 @@ class UOWTask(object): # set of dependencies, referencing sub-UOWTasks attached to this # one which represent portions of the total list of objects. # this is used for the row-based "circular sort" - self.cyclical_dependencies = sets.Set() + self.cyclical_dependencies = util.Set() def is_empty(self): return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0 @@ -773,7 +773,7 @@ class UOWTask(object): allobjects += [e.obj for e in task.get_elements(polymorphic=True)] tuples = [] - cycles = sets.Set(cycles) + cycles = util.Set(cycles) #print "BEGIN CIRC SORT-------" #print "PRE-CIRC:" diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 19cb213673..86799b3116 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,13 +4,13 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sets +import sqlalchemy.util as util import sqlalchemy.sql as sql class CascadeOptions(object): """keeps track of the options sent to relation().cascade""" def __init__(self, arg=""): - values = sets.Set([c.strip() for c in arg.split(',')]) + values = util.Set([c.strip() for c in arg.split(',')]) self.delete_orphan = "delete-orphan" in values self.delete = "delete" in values or self.delete_orphan or "all" in values self.save_update = "save-update" in values or "all" in values @@ -22,7 +22,7 @@ class CascadeOptions(object): def polymorphic_union(table_map, typecolname, aliasname='p_union'): - colnames = sets.Set() + colnames = util.Set() colnamemaps = {} for key in table_map.keys(): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c0769c5e65..2af45c6da6 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -624,7 +624,7 @@ class ColumnElement(Selectable, CompareMixin): try: return self.__orig_set except AttributeError: - self.__orig_set = sets.Set([self]) + self.__orig_set = util.Set([self]) return self.__orig_set def _set_orig_set(self, s): if len(s) == 0: @@ -1334,7 +1334,7 @@ class CompoundSelect(SelectBaseMixin, FromClause): try: colset = self._col_map[col.name] except KeyError: - colset = sets.Set() + colset = util.Set() self._col_map[col.name] = colset [colset.add(c) for c in col.orig_set] col.orig_set = colset diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index bb1b8bbeea..c3017e40c2 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -4,10 +4,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import thread, threading, weakref, UserList, time, string, inspect, sys +import thread, threading, weakref, UserList, time, string, inspect, sys, sets from exceptions import * import __builtin__ +try: + Set = set +except: + Set = sets.Set + def to_list(x): if x is None: return None @@ -18,9 +23,9 @@ def to_list(x): def to_set(x): if x is None: - return HashSet() - if not isinstance(x, HashSet): - return HashSet(to_list(x)) + return Set() + if not isinstance(x, Set): + return Set(to_list(x)) else: return x @@ -189,43 +194,13 @@ class DictDecorator(dict): return self.decorate[key] def __repr__(self): return dict.__repr__(self) + repr(self.decorate) - -class HashSet(object): - """implements a Set, including ordering capability""" - def __init__(self, iter=None, ordered=False): - if ordered: - self.map = OrderedDict() - else: - self.map = {} - if iter is not None: - for i in iter: - self.append(i) - def __iter__(self): - return iter(self.map.values()) - def contains(self, item): - return self.map.has_key(item) - def __contains__(self, item): - return self.map.has_key(item) - def clear(self): - self.map.clear() - def intersection(self, l): - return HashSet([x for x in l if self.contains(x)]) - def empty(self): - return len(self.map) == 0 - def append(self, item): - self.map[item] = item - def remove(self, item): - del self.map[item] - def __add__(self, other): - return HashSet(self.map.values() + [i for i in other]) - def __len__(self): - return len(self.map) - def __delitem__(self, key): - del self.map[key] - def __getitem__(self, key): - return self.map[key] - def __repr__(self): - return repr(self.map.values()) + +class OrderedSet(sets.Set): + def __init__(self, iterable=None): + """Construct a set from an optional iterable.""" + self._data = OrderedDict() + if iterable is not None: + self._update(iterable) class HistoryArraySet(UserList.UserList): """extends a UserList to provide unique-set functionality as well as history-aware -- 2.47.2