From 7d3c4e27d4bda591d5907c425af1735a41ecd686 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 2 Oct 2005 01:46:42 +0000 Subject: [PATCH] --- examples/adjacencytree/tables.py | 13 ++ lib/sqlalchemy/attributes.py | 2 + lib/sqlalchemy/mapper.py | 53 +++++--- lib/sqlalchemy/objectstore.py | 211 ++++++++++++++++++++++--------- lib/sqlalchemy/util.py | 11 +- 5 files changed, 202 insertions(+), 88 deletions(-) create mode 100644 examples/adjacencytree/tables.py diff --git a/examples/adjacencytree/tables.py b/examples/adjacencytree/tables.py new file mode 100644 index 0000000000..3cfe1e9314 --- /dev/null +++ b/examples/adjacencytree/tables.py @@ -0,0 +1,13 @@ +from sqlalchemy.schema import * +import sqlalchemy.engine + +engine = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = True) + +trees = Table('treenodes', engine, + Column('node_id', Integer, primary_key=True), + Column('parent_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), + Column('root_node_id', Integer, ForeignKey('treenodes.node_id'), nullable=True), + Column('node_name', String(50), nullable=False) + ) + +trees.create() \ No newline at end of file diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index a12b57accb..30d9a296d6 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -54,6 +54,8 @@ class PropHistory(object): self.obj = obj self.key = key self.orig = PropHistory.NONE + def history_contains(self, obj): + return self.orig is obj or self.obj.__dict__[self.key] is obj def setattr_clean(self, value): self.obj.__dict__[self.key] = value def setattr(self, value): diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 218c4fb660..7585dc0cd4 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -38,8 +38,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin else: return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options) -def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, **options): - return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, **options) +def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, thiscol = None, **options): + return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, thiscol = thiscol, **options) class assignmapper(object): def __init__(self, table, **kwargs): @@ -528,7 +528,7 @@ class ColumnProperty(MapperProperty): class PropertyLoader(MapperProperty): """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None): self.uselist = uselist self.argument = argument self.secondary = secondary @@ -536,6 +536,7 @@ class PropertyLoader(MapperProperty): self.secondaryjoin = secondaryjoin self.foreignkey = foreignkey self.private = private + self.thiscol = thiscol self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist)) def _copy(self): @@ -553,10 +554,12 @@ class PropertyLoader(MapperProperty): self.mapper = self.argument self.target = self.mapper.table - self.key = key self.parent = parent - + + if self.parent.table is self.target and self.thiscol is None: + raise "Circular relationship requires 'thiscol' parameter" + # if join conditions were not specified, figure them out based on foreign keys if self.secondary is not None: if self.secondaryjoin is None: @@ -660,15 +663,33 @@ class PropertyLoader(MapperProperty): raise " no foreign key ?" def get_object_dependencies(self, obj, uowcommit, passive = True): - """function to retreive the child list off of an object. "passive" means, if its - a lazy loaded list that is not loaded yet, dont load it.""" if self.uselist: return uowcommit.uow.attributes.get_list_history(obj, self.key, passive = passive) else: return uowcommit.uow.attributes.get_history(obj, self.key) - + + def whose_dependent_on_who(self, obj1, obj2, uowcommit): + if obj1 is obj2: + return None + hist = self.get_object_dependencies(obj1, uowcommit) + if hist.history_contains(obj2): + if self.thiscol.primary_key: + return (obj1, obj2) + else: + return (obj2, obj1) + else: + hist = self.get_object_dependencies(obj2, uowcommit) + if hist.history_contains(obj1): + if self.thiscol.primary_key: + return (obj2, obj1) + else: + return (obj1, obj2) + else: + return None + + def process_dependencies(self, deplist, uowcommit, delete = False): - print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete) + #print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete) # fucntion to set properties across a parent/child object plus an "association row", # based on a join condition @@ -677,7 +698,10 @@ class PropertyLoader(MapperProperty): setter = BinaryVisitor(sync_foreign_keys) def getlist(obj, passive=True): - return self.get_object_dependencies(obj, uowcommit, passive) + if self.uselist: + return uowcommit.uow.attributes.get_list_history(obj, self.key, passive = passive) + else: + return uowcommit.uow.attributes.get_history(obj, self.key) associationrow = {} @@ -719,7 +743,6 @@ class PropertyLoader(MapperProperty): statement = self.secondary.insert() statement.execute(*secondary_insert) elif self.foreignkey.table == self.target: - print "HI" if delete and not self.private: updates = [] clearkeys = True @@ -739,11 +762,9 @@ class PropertyLoader(MapperProperty): statement = self.target.update(self.lazywhere, values = values) statement.execute(*updates) else: - print str(self.primaryjoin.compile()) for obj in deplist: childlist = getlist(obj) if childlist is None: return - print "DEP: " +str(obj) + " LIST: " + repr([str(v) for v in childlist.added_items()]) uowcommit.register_saved_list(childlist) clearkeys = False for child in childlist.added_items(): @@ -768,7 +789,7 @@ class PropertyLoader(MapperProperty): else: raise " no foreign key ?" - print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete) + #print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete) def _sync_foreign_keys(self, binary, obj, child, associationrow, clearkeys): """given a binary clause with an = operator joining two table columns, synchronizes the values @@ -782,12 +803,12 @@ class PropertyLoader(MapperProperty): source = binary.right else: raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname) - print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key + #print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source)) else: colmap = {binary.left.table : binary.left, binary.right.table : binary.right} if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target): - print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key + #print "set " + repr(child) + ":" + colmap[self.target].key + " to " + repr(obj) + ":" + colmap[self.parent.primarytable].key if clearkeys: self.mapper._setattrbycolumn(child, colmap[self.target], None) else: diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index f4639cb47f..eb601e3cb3 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -24,7 +24,7 @@ import thread import sqlalchemy.util as util import sqlalchemy.attributes as attributes import weakref - +import string def get_id_key(ident, class_, table): """returns an identity-map key for use in storing/retrieving an item from the identity map, given @@ -302,6 +302,7 @@ class UOWTransaction(object): task.mapper.register_dependencies(self) for task in self._sort_dependencies(): + print "exec task: " + str(task) task.execute(self) def post_exec(self): @@ -326,54 +327,10 @@ class UOWTransaction(object): pass def _sort_dependencies(self): - - class Node: - def __init__(self, mapper): - #print "new node on " + str(mapper) - self.mapper = mapper - self.children = util.HashSet() - self.parent = None - - def maketree(tuples, allitems): - nodes = {} - head = None - for tup in tuples: - (parent, child) = (tup[0], tup[1]) - #print "tuple: " + str(parent) + " " + str(child) - try: - parentnode = nodes[parent] - except KeyError: - parentnode = Node(parent) - nodes[parent] = parentnode - try: - childnode = nodes[child] - except KeyError: - childnode = Node(child) - nodes[child] = childnode - - if head is None: - head = parentnode - elif head is childnode: - head = parentnode - if childnode.parent is not None: - del childnode.parent.children[childnode] - childnode.parent.children.append(parentnode) - parentnode.children.append(childnode) - childnode.parent = parentnode - - for item in allitems: - if not nodes.has_key(item): - node = Node(item) - if head is not None: - head.parent = node - node.children.append(head) - head = node - return head - bymapper = {} def sort(node, isdel, res): - #print "Sort: " + (node and str(node.mapper) or 'None') + print "Sort: " + (node and str(node.mapper) or 'None') if node is None: return res task = bymapper.get((node.mapper, isdel), None) @@ -381,6 +338,8 @@ class UOWTransaction(object): res.append(task) for child in node.children: if child is node: + print "setting circular: " + str(task) + task.iscircular = True continue sort(child, isdel, res) return res @@ -391,7 +350,7 @@ class UOWTransaction(object): mappers.append(task.mapper) bymapper[(task.mapper, task.isdelete)] = task - head = maketree(self.dependencies, mappers) + head = TupleSorter(self.dependencies, mappers).sort() res = [] tasklist = sort(head, False, res) @@ -415,9 +374,16 @@ class UOWTask(object): self.objects = util.HashSet(ordered = True) self.dependencies = [] self.listonly = listonly + self.iscircular = False #print "new task " + str(self) def execute(self, trans): + print "exec " + str(self) + " circualr=" + repr(self.iscircular) + if self.iscircular: + task = self.sort_circular_dependencies(trans) + task.execute_circular(trans) + return + obj_list = self.objects if not self.listonly and not self.isdelete: self.mapper.save_obj(obj_list, trans) @@ -427,32 +393,151 @@ class UOWTask(object): if not self.listonly and self.isdelete: self.mapper.delete_obj(obj_list, trans) - def sort_circular_dependencies(self): + def execute_circular(self, trans): + print "execcircular " + str(self) +# obj_list = self.objects + # if not self.listonly and not self.isdelete: + # self.mapper.save_obj(obj_list, trans) + # raise "hi" + self.execute(trans) + for obj in self.objects: + childtask = self.taskhash[obj] + childtask.execute_circular(trans) + + + def sort_circular_dependencies(self, trans): + allobjects = self.objects + tuples = [] + for obj in self.objects: + for dep in self.dependencies: + (processor, targettask) = dep + if targettask is self: + childlist = processor.get_object_dependencies(obj, trans, passive = True) + for o in childlist.added_items() + childlist.deleted_items(): + whosdep = processor.whose_dependent_on_who(obj, o, trans) + if whosdep is not None: + tuples.append(whosdep) + head = TupleSorter(tuples, allobjects).sort() + print "---------" + print str(head) + raise "hi" + + def old_sort_circular_dependencies(self, trans): + dependents = {} d = {} - head = None - for obj in obj_list: - d[obj] = UOWTask(self.mapper, self.isdelete, self.listonly) - d[obj].dependencies = self.dependencies - if head is None: - head = obj + + def make_task(): + t = UOWTask(self.mapper, self.isdelete, self.listonly) + t.dependencies = self.dependencies + t.taskhash = d + return t + + head = make_task() + for obj in self.objects: + print "obj: " + str(obj) + task = make_task() + d[obj] = task + if not dependents.has_key(obj): + head.objects.append(obj) for dep in self.dependencies: (processor, targettask) = dep - if targetttask is self: - for o in processor.get_object_dependencies(obj, self, passive = True): - if o is head: - head = obj - d[obj].objects.append(o) - if head is None: - return self - else: - return d[head] + if targettask is self: + childlist = processor.get_object_dependencies(obj, trans, passive = True) + for o in childlist.added_items() + childlist.deleted_items(): + whosdep = processor.whose_dependent_on_who(obj, o, trans) + if whosdep is not None: + (child, parent) = whosdep + if not d.has_key(parent): + d[parent] = make_task() + if dependents.has_key(child): + p2 = dependents[child] + wd2 = processor.whose_dependent_on_who(parent, p2, trans) + + d[parent].objects.append(child) + dependents[child] = parent + print "dependent obj: " + str(child) + " is dependent in relation " + str(obj) + " " + str(o) + if head.objects.contains(child): + del head.objects[child] + + def printtask(t): + print "l1" + print repr([str(v) for v in t.objects]) + for v in t.objects: + t2 = t.taskhash[v] + print "l2" + print repr([str(v2) for v2 in t2.objects]) + for v3 in t2.objects: + t3 = t.taskhash[v3] + print "l3" + print repr([str(v4) for v4 in t3.objects]) +# printtask(t2) + print "sorted hierarchical tasks: " + printtask(head) + raise "hi" + return head def __str__(self): if self.isdelete: return self.mapper.primarytable.name + " deletes " + repr(self.listonly) else: return self.mapper.primarytable.name + " saves " + repr(self.listonly) + +class TupleSorter(object): + + class Node: + def __init__(self, mapper): + #print "new node on " + str(mapper) + self.mapper = mapper + self.children = util.HashSet() + self.parent = None + def __str__(self): + return self.safestr({}) + def safestr(self, hash): + if hash.has_key(self): + return "[RECURSIVE:%s(%s, %s)]" % (str(self.mapper), repr(id(self)), repr(id(self.parent))) + hash[self] = True + return "%s(%s, %s)" % (str(self.mapper), repr(id(self)), repr(id(self.parent))) + "\n" + string.join([n.safestr(hash) for n in self.children]) + + def __init__(self, tuples, allitems): + self.tuples = tuples + self.allitems = allitems + def sort(self): + (tuples, allitems) = (self.tuples, self.allitems) + nodes = {} + head = None + for tup in tuples: + (parent, child) = (tup[0], tup[1]) + print "tuple: " + str(parent) + " " + str(child) + try: + parentnode = nodes[parent] + except KeyError: + parentnode = TupleSorter.Node(parent) + nodes[parent] = parentnode + try: + childnode = nodes[child] + except KeyError: + childnode = TupleSorter.Node(child) + nodes[child] = childnode + + if head is None: + head = parentnode + elif head is childnode: + head = parentnode + if childnode.parent is not None: + del childnode.parent.children[childnode] + childnode.parent.children.append(parentnode) + parentnode.children.append(childnode) + childnode.parent = parentnode + for item in allitems: + if not nodes.has_key(item): + node = TupleSorter.Node(item) + if head is not None: + head.parent = node + node.children.append(head) + head = node + return head + uow = util.ScopedRegistry(lambda: UnitOfWork(), "thread") diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index bc13f5dbde..446b87cb95 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -121,28 +121,20 @@ class HashSet(object): if iter is not None: for i in iter: self.append(i) - def __iter__(self): return iter(self.map.values()) - def contains(self, item): return self.map.has_key(item) - def clear(self): self.map.clear() - def append(self, item): self.map[item] = item - def __add__(self, other): return HashSet(self.map.values() + [i for i in other]) - def __len__(self): return len(self.map) - def __delitem__(self, key): del self.map[key] - def __getitem__(self, key): return self.map[key] @@ -179,7 +171,8 @@ class HistoryArraySet(UserList.UserList): if not self._setrecord(self.data[i]): del self.data[i] i -= 1 - + def history_contains(self, obj): + return self.records.has_key(obj) def __hash__(self): return id(self) def _setrecord(self, item): -- 2.47.2