From 4875c3702c3d59072a3d074f4c3c384ebd6ffa9b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 23 Oct 2005 06:25:43 +0000 Subject: [PATCH] --- lib/sqlalchemy/databases/postgres.py | 12 +++-- lib/sqlalchemy/mapper.py | 11 ++-- lib/sqlalchemy/objectstore.py | 5 +- lib/sqlalchemy/util.py | 76 ++++++++++------------------ test/dependency.py | 57 +++++++++++++++++++++ 5 files changed, 100 insertions(+), 61 deletions(-) create mode 100644 test/dependency.py diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 1d0349a07b..547a2c0b78 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -26,8 +26,11 @@ from sqlalchemy.ansisql import * try: import psycopg2 as psycopg except: - import psycopg - + try: + import psycopg + except: + psycopg = None + class PGNumeric(sqltypes.Numeric): def get_col_spec(self): return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} @@ -70,8 +73,9 @@ def engine(opts, **params): class PGSQLEngine(ansisql.ANSISQLEngine): def __init__(self, opts, module = None, **params): if module is None: - self.module = __import__('psycopg2') - #self.module = psycopg + if psycopg is None: + raise "Couldnt locate psycopg1 or psycopg2: specify postgres module argument" + self.module = psycopg else: self.module = module self.opts = opts or {} diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 5f65f02942..c90b50769d 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -443,11 +443,12 @@ class Mapper(object): (obj, params) = rec statement.execute(**params) primary_keys = table.engine.last_inserted_ids() - i = 0 - for col in self.primary_keys[table]: - if self._getattrbycolumn(obj, col) is None: - self._setattrbycolumn(obj, col, primary_keys[i]) - i+=1 + if primary_keys is not None: + i = 0 + for col in self.primary_keys[table]: + if self._getattrbycolumn(obj, col) is None: + self._setattrbycolumn(obj, col, primary_keys[i]) + i+=1 self.extension.after_insert(self, obj) def delete_obj(self, objects, uow): diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index 4fe35272b8..9b414f1818 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -324,7 +324,7 @@ class UOWTransaction(object): task.mapper.register_dependencies(self) head = self._sort_dependencies() - print "Task dump:\n" + head.dump() + #print "Task dump:\n" + head.dump() if head is not None: head.execute(self) @@ -434,15 +434,12 @@ class UOWTask(object): self.circular.execute(trans) return -# print "task " + str(self) + " tosave: " + repr(self.tosave_objects()) self.mapper.save_obj(self.tosave_objects(), trans) for dep in self.save_dependencies(): (processor, targettask, isdelete) = dep processor.process_dependencies(targettask, [elem.obj for elem in targettask.tosave_elements()], trans, delete = False) - # print "processed dependencies on " + repr([elem.obj for elem in targettask.tosave_elements()]) for element in self.tosave_elements(): if element.childtask is not None: -# print "execute elem childtask " + str(element.childtask) element.childtask.execute(trans) for dep in self.delete_dependencies(): (processor, targettask, isdelete) = dep diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index da5e7c7e82..063291dde7 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -337,55 +337,46 @@ class DependencySorter(object): self.children = HashSet() self.parent = None self.circular = False + def append(self, node): + if node.parent is not None: + del node.parent.children[node] + self.children.append(node) + node.parent = self def __str__(self): return self.safestr({}) def safestr(self, hash, indent = 0): if hash.has_key(self): return (' ' * indent) + "RECURSIVE:%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or 'None') hash[self] = True - return (' ' * indent) + "%s(%s, %s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '') + return (' ' * indent) + "%s (idself=%s, idparent=%s)" % (str(self.item), repr(id(self)), self.parent and repr(id(self.parent)) or "None") + "\n" + string.join([n.safestr(hash, indent + 1) for n in self.children], '') def __init__(self, tuples, allitems): self.tuples = tuples self.allitems = allitems def sort(self): (tuples, allitems) = (self.tuples, self.allitems) + nodes = {} - head = None + # make nodes for all the items and store in the hash + for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]: + if not nodes.has_key(item): + nodes[item] = DependencySorter.Node(item) + + # loop through tuples for tup in tuples: (parent, child) = (tup[0], tup[1]) - #print "tuple: " + str(parent) + " " + str(child) - # get parent node - try: - parentnode = nodes[parent] - except KeyError: - parentnode = DependencySorter.Node(parent) - nodes[parent] = parentnode + parentnode = nodes[parent] # if parent is child, mark "circular" attribute on the node if parent is child: parentnode.circular = True - # set head if its nothing - if head is None: - head = parentnode - # nothing more to do for this one + # and just continue continue # get child node - try: - childnode = nodes[child] - except KeyError: - childnode = DependencySorter.Node(child) - nodes[child] = childnode - - # set head if its nothing, move it up to the parent - # if its the child node - if head is None: - head = parentnode - elif head is childnode: - head = parentnode - + childnode = nodes[child] + # now see, if the parent is an ancestor of the child c = childnode while c is not None and c is not parentnode: @@ -394,28 +385,17 @@ class DependencySorter(object): # nope, so we have to move the child down from whereever # it currently is to a child of the parent if c is None: - if childnode.parent is not None: - del childnode.parent.children[childnode] - childnode.parent.children.append(parentnode) - parentnode.children.append(childnode) - childnode.parent = parentnode - - # go through the total list of items. for those - # that had no dependency tuples, and therefore are not - # in the tree, add them as head nodes in a line - newhead = None - for item in allitems: - if not nodes.has_key(item): - if newhead is None: - newhead = DependencySorter.Node(item) - if head is not None: - head.parent = newhead - newhead.children.append(head) - head = newhead + parentnode.append(childnode) + + # now we have a collection of subtrees which represent dependencies. + # go through the collection root nodes wire them together into one tree + head = None + for node in nodes.values(): + if node.parent is None: + if head is not None: + head.append(node) else: - n = DependencySorter.Node(item) - head.children.append(n) - n.parent = head - #print str(head) + head = node + return head \ No newline at end of file diff --git a/test/dependency.py b/test/dependency.py new file mode 100644 index 0000000000..7e04e7ec6c --- /dev/null +++ b/test/dependency.py @@ -0,0 +1,57 @@ +from testbase import PersistTest +import sqlalchemy.util as util +import unittest, sys, os + +class thingy(object): + def __init__(self, name): + self.name = name + def __repr__(self): + return "thingy(%d, %s)" % (id(self), self.name) + def __str__(self): + return repr(self) + +class DependencySortTest(PersistTest): + def testsort(self): + rootnode = thingy('root') + node2 = thingy('node2') + node3 = thingy('node3') + node4 = thingy('node4') + subnode1 = thingy('subnode1') + subnode2 = thingy('subnode2') + subnode3 = thingy('subnode3') + subnode4 = thingy('subnode4') + subsubnode1 = thingy('subsubnode1') + tuples = [ + (subnode3, subsubnode1), + (node2, subnode1), + (node2, subnode2), + (rootnode, node2), + (rootnode, node3), + (rootnode, node4), + (node4, subnode3), + (node4, subnode4) + ] + head = util.DependencySorter(tuples, []).sort() + print str(head) + + def testsort2(self): + node1 = thingy('node1') + node2 = thingy('node2') + node3 = thingy('node3') + node4 = thingy('node4') + node5 = thingy('node5') + node6 = thingy('node6') + node7 = thingy('node7') + tuples = [ + (node1, node2), + (node3, node4), + (node5, node6), + (node6, node2) + ] + head = util.DependencySorter(tuples, [node7]).sort() + print "\n" + str(head) + + + +if __name__ == "__main__": + unittest.main() -- 2.47.2