From: Mike Bayer Date: Sun, 2 Oct 2005 19:50:11 +0000 (+0000) Subject: (no commit message) X-Git-Tag: rel_0_1_0~575 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=24fe959df62d960ab71674c19cd473cd128e9a5a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git --- diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 331e1dc285..59f68203bf 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -54,6 +54,10 @@ class PropHistory(object): self.obj = obj self.key = key self.orig = PropHistory.NONE + def gethistory(self, *args, **kwargs): + return self + def __call__(self, *args, **kwargs): + return self.obj.__dict__[self.key] def history_contains(self, obj): return self.orig is obj or self.obj.__dict__[self.key] is obj def setattr_clean(self, value): @@ -90,15 +94,30 @@ class PropHistory(object): class ListElement(util.HistoryArraySet): """manages the value of a particular list-based attribute on a particular object instance.""" - def __init__(self, obj, key, items = None): + def __init__(self, obj, key, data=None): self.obj = obj self.key = key - util.HistoryArraySet.__init__(self, items) - obj.__dict__[key] = self.data + try: + list_ = obj.__dict__[key] + if data is not None: + list_.clear() + for d in data: + list_.append(d) + except KeyError: + if data is not None: + list_ = data + else: + list_ = [] + obj.__dict__[key] = [] + + util.HistoryArraySet.__init__(self, list_) + def gethistory(self, *args, **kwargs): + return self + def __call__(self, *args, **kwargs): + return self def list_value_changed(self, obj, key, listval): pass - def setattr(self, value): self.obj.__dict__[self.key] = value self.set_data(value) @@ -115,9 +134,37 @@ class ListElement(util.HistoryArraySet): self.list_value_changed(self.obj, self.key, self) return res - +class CallableProp(object): + """allows the attaching of a callable item, representing the future value + of a particular attribute on a particular object instance, to + the AttributeManager. When the attributemanager + accesses the object attribute, either to get its history or its real value, the __call__ method + is invoked which runs the underlying callable_ and sets the new value to the object attribute + via the manager.""" + def __init__(self, callable_, obj, key, uselist = False): + self.callable_ = callable_ + self.obj = obj + self.key = key + self.uselist = uselist + def gethistory(self, manager, *args, **kwargs): + self.__call__(manager, *args, **kwargs) + return manager.attribute_history[self.obj][self.key] + def __call__(self, manager, passive=False): + if passive: + return None + value = self.callable_() + if self.uselist: + p = manager.create_list(self.obj, self.key, value) + manager.attribute_history[self.obj][self.key] = p + return p + else: + self.obj.__dict__[self.key] = value + p = PropHistory(self.obj, self.key) + manager.attribute_history[self.obj][self.key] = p + return p + class AttributeManager(object): - """maintains a set of per-attribute history objects for a set of objects.""" + """maintains a set of per-attribute callable/history manager objects for a set of objects.""" def __init__(self): self.attribute_history = {} @@ -130,13 +177,13 @@ class AttributeManager(object): def get_attribute(self, obj, key): try: - v = obj.__dict__[key] + return self.get_history(obj, key)(self) + except KeyError: + pass + try: + return obj.__dict__[key] except KeyError: raise AttributeError(key) - if (callable(v)): - v = v() - obj.__dict__[key] = v - return v def get_list_attribute(self, obj, key): return self.get_list_history(obj, key) @@ -152,6 +199,13 @@ class AttributeManager(object): self.get_history(obj, key).delattr() self.value_changed(obj, key, value) + def set_callable(self, obj, key, func, uselist): + try: + d = self.attribute_history[obj] + except KeyError, e: + d = {} + self.attribute_history[obj] = d + d[key] = CallableProp(func, obj, key, uselist) def delete_list_attribute(self, obj, key): pass @@ -190,7 +244,7 @@ class AttributeManager(object): def get_history(self, obj, key): try: - return self.attribute_history[obj][key] + return self.attribute_history[obj][key].gethistory(self) except KeyError, e: if e.args[0] is obj: d = {} @@ -205,14 +259,10 @@ class AttributeManager(object): def get_list_history(self, obj, key, passive = False): try: - return self.attribute_history[obj][key] + return self.attribute_history[obj][key].gethistory(self, passive) except KeyError, e: # TODO: when an callable is re-set on an existing list element list_ = obj.__dict__.get(key, None) - if callable(list_): - if passive: - return None - list_ = list_() if e.args[0] is obj: d = {} self.attribute_history[obj] = d diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 87216773da..2f3e94de25 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -22,7 +22,7 @@ import sqlalchemy.schema as schema import sqlalchemy.pool import sqlalchemy.util as util import sqlalchemy.sql as sql -import StringIO +import StringIO, sys import sqlalchemy.types as types def create_engine(name, *args ,**kwargs): @@ -61,6 +61,7 @@ class SQLEngine(schema.SchemaEngine): self.context = util.ThreadLocal() self.tables = {} self.notes = {} + self.logger = sys.stdout def type_descriptor(self, typeobj): @@ -206,7 +207,7 @@ class SQLEngine(schema.SchemaEngine): return ResultProxy(c, self.echo, typemap = typemap) def log(self, msg): - print msg + self.logger.write(msg + "\n") class ResultProxy: diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 173e35c37b..b7e87eb049 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -694,12 +694,15 @@ class PropertyLoader(MapperProperty): return (obj2, obj1) def process_dependencies(self, deplist, uowcommit, delete = False): - #print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete) + print self.mapper.table.name + " " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete) # fucntion to set properties across a parent/child object plus an "association row", # based on a join condition def sync_foreign_keys(binary): - self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys) + if self.direction == PropertyLoader.RIGHT: + self._sync_foreign_keys(binary, child, obj, associationrow, clearkeys) + else: + self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys) setter = BinaryVisitor(sync_foreign_keys) def getlist(obj, passive=True): @@ -744,8 +747,8 @@ class PropertyLoader(MapperProperty): if len(secondary_insert): statement = self.secondary.insert() statement.execute(*secondary_insert) - elif self.direction == PropertyLoader.LEFT: - if delete and not self.private: + elif self.direction == PropertyLoader.LEFT and delete: + if not self.private: updates = [] clearkeys = True for obj in deplist: @@ -763,33 +766,18 @@ class PropertyLoader(MapperProperty): values[bind.shortname] = None statement = self.target.update(self.lazywhere, values = values) statement.execute(*updates) - else: - for obj in deplist: - childlist = getlist(obj) - if childlist is None: return - uowcommit.register_saved_list(childlist) - clearkeys = False - for child in childlist.added_items(): - self.primaryjoin.accept_visitor(setter) - clearkeys = True - for child in childlist.deleted_items(): - self.primaryjoin.accept_visitor(setter) - elif self.direction == PropertyLoader.RIGHT: - for child in deplist: - childlist = getlist(child) + else: + for obj in deplist: + childlist = getlist(obj) if childlist is None: return uowcommit.register_saved_list(childlist) clearkeys = False - added = childlist.added_items() - if len(added): - for obj in added: - self.primaryjoin.accept_visitor(setter) - else: + for child in childlist.added_items(): + self.primaryjoin.accept_visitor(setter) + if self.direction != PropertyLoader.RIGHT or len(childlist.added_items()) == 0: clearkeys = True - for obj in childlist.deleted_items(): + for child in childlist.deleted_items(): self.primaryjoin.accept_visitor(setter) - else: - raise " no foreign key ?" #print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete) @@ -797,6 +785,8 @@ class PropertyLoader(MapperProperty): """given a binary clause with an = operator joining two table columns, synchronizes the values of the corresponding attributes within a parent object and a child object, or the attributes within an an "association row" that represents an association link between the 'parent' and 'child' object.""" + if obj is child: + raise "wha?" if binary.operator == '=': if binary.left.table == binary.right.table: if binary.right is self.foreignkey: @@ -805,8 +795,9 @@ class PropertyLoader(MapperProperty): source = binary.right else: raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname) - #print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source)) + print "set " + repr(id(child)) + child.__dict__['name'] + ":" + self.foreignkey.key + " to " + repr(id(obj)) + obj.__dict__['name'] + ":" + source.key + #+ "\n" + repr(child.__dict__) else: colmap = {binary.left.table : binary.left, binary.right.table : binary.right} if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target): @@ -820,18 +811,10 @@ class PropertyLoader(MapperProperty): elif colmap.has_key(self.target) and colmap.has_key(self.secondary): associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target]) - -# TODO: break out the lazywhere capability so that the main PropertyLoader can use it -# to do child deletes class LazyLoader(PropertyLoader): - def execute(self, instance, row, identitykey, imap, isnew): if isnew: - # TODO: get lazy callables to be stored within the unit of work? - # allows serializable ? still need lazyload state to exist in the application - # when u deserialize tho - objectstore.uow().attribute_set_callable(instance, self.key, LazyLoadInstance(self, row)) - + objectstore.uow().register_callable(instance, self.key, LazyLoadInstance(self, row), uselist=self.uselist) def create_lazy_clause(table, primaryjoin, secondaryjoin, thiscol): binds = {} diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index 997db3e6be..0b5311f4fc 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -107,7 +107,9 @@ class UnitOfWork(object): self.new = util.HashSet(ordered = True) self.dirty = util.HashSet() self.modified_lists = util.HashSet() - self.deleted = util.HashSet() + # the delete list is ordered mostly so the unit tests can predict the argument list ordering. + # TODO: need stronger unit test fixtures.... + self.deleted = util.HashSet(ordered = True) self.parent = parent def get(self, class_, *id): @@ -136,17 +138,10 @@ class UnitOfWork(object): def register_attribute(self, class_, key, uselist): self.attributes.register_attribute(class_, key, uselist) - - def attribute_set_callable(self, obj, key, func): - # TODO: gotta work this out when a list element is already there, - # etc. - obj.__dict__[key] = func - try: - del self.attributes.attribute_history[obj][key] - except KeyError: - pass - + def register_callable(self, obj, key, func, uselist): + self.attributes.set_callable(obj, key, func, uselist) + def register_clean(self, obj): try: del self.dirty[obj] @@ -405,7 +400,32 @@ class UOWTask(object): def sort_circular_dependencies(self, trans): allobjects = self.objects tuples = [] + d = {} + def get_task(obj): + try: + return d[obj] + except KeyError: + t = UOWTask(self.mapper, self.isdelete, self.listonly) + t.taskhash = d + d[obj] = t + return t + + dependencies = {} + def get_dependency_task(obj, processor): + try: + dp = dependencies[obj] + except KeyError: + dp = {} + dependencies[obj] = dp + try: + l = dp[processor] + except KeyError: + l = UOWTask(None, None, None) + dp[processor] = l + return l + for obj in self.objects: + parenttask = get_task(obj) for dep in self.dependencies: (processor, targettask) = dep if targettask is self: @@ -414,29 +434,40 @@ class UOWTask(object): whosdep = processor.whose_dependent_on_who(obj, o, trans) if whosdep is not None: tuples.append(whosdep) + if whosdep[0] is obj: + get_dependency_task(whosdep[0], processor).objects.append(whosdep[0]) + else: + get_dependency_task(whosdep[0], processor).objects.append(whosdep[1]) + head = TupleSorter(tuples, allobjects).sort() if head is None: return None - - d = {} - def make_task(): - t = UOWTask(self.mapper, self.isdelete, self.listonly) - t.dependencies = self.dependencies - t.taskhash = d - return t def make_task_tree(node, parenttask): if node is None: return parenttask.objects.append(node.item) - t = make_task() - d[node.item] = t + if dependencies.has_key(node.item): + for processor, deptask in dependencies[node.item].iteritems(): + parenttask.dependencies.append((processor, deptask)) + t = d[node.item] for n in node.children: - make_task_tree(n, t) - - t = make_task() + t2 = make_task_tree(n, t) + return t + + t = UOWTask(self.mapper, self.isdelete, self.listonly) + t.taskhash = d make_task_tree(head, t) + + t._print_circular() return t + + def _print_circular(t): + print "-----------------------------" + print "task objects: " + repr([str(v) for v in t.objects]) + print "task depends: " + repr([(dt[0].key, [str(o) for o in dt[1].objects]) for dt in t.dependencies]) + for o in t.objects: + t.taskhash[o]._print_circular() def __str__(self): if self.isdelete: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index dfa3cdbf47..52ce5fdc9d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -220,12 +220,9 @@ class ClauseElement(object): def compile(self, engine = None, bindparams = None): """compiles this SQL expression using its underlying SQLEngine to produce - a Compiled object. The actual SQL statement is the Compiled object's string representation. - bindparams is an optional dictionary representing the bind parameters to be used with - the statement. Currently, only the compilations of INSERT and UPDATE statements - use the bind parameters, in order to determine which - table columns should be used in the statement.""" - + a Compiled object. If no engine can be found, an ansisql engine is used. + bindparams is a dictionary representing the default bind parameters to be used with + the statement. """ if engine is None: for f in self._get_from_objects(): engine = f.engine @@ -237,6 +234,9 @@ class ClauseElement(object): return engine.compile(self, bindparams = bindparams) + def __str__(self): + return str(self.compile()) + def execute(self, *multiparams, **params): """compiles and executes this SQL expression using its underlying SQLEngine. the given **params are used as bind parameters when compiling and executing the expression. diff --git a/test/objectstore.py b/test/objectstore.py index c64b477377..8736869aa1 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -1,6 +1,7 @@ from testbase import PersistTest, AssertMixin import unittest, sys, os from sqlalchemy.mapper import * +import StringIO import sqlalchemy.objectstore as objectstore from tables import * @@ -207,7 +208,18 @@ class SaveTest(AssertMixin): objectstore.uow().register_deleted(l[0]) objectstore.uow().register_deleted(l[2]) - objectstore.uow().commit() + res = self.capture_exec(db, lambda: objectstore.uow().commit()) + state = None + for line in res.split('\n'): + if line == "DELETE FROM items WHERE items.item_id = :item_id": + self.assert_(state is None or state == 'addresses') + elif line == "DELETE FROM orders WHERE orders.order_id = :order_id": + state = 'orders' + elif line == "DELETE FROM email_addresses WHERE email_addresses.address_id = :address_id": + if state is None: + state = 'addresses' + elif line == "DELETE FROM users WHERE users.user_id = :user_id": + self.assert_(state is not None) def testbackwardsonetoone(self): # test 'backwards' @@ -238,8 +250,12 @@ class SaveTest(AssertMixin): objects[3].user = User() objects[3].user.user_name = 'imnewlyadded' - objectstore.uow().commit() - return + self.assert_enginesql(db, lambda: objectstore.uow().commit(), +"""INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name) +{'user_id': None, 'user_name': 'imnewlyadded'} +UPDATE email_addresses SET address_id=:address_id, user_id=:user_id, email_address=:email_address WHERE email_addresses.address_id = :address_id +[{'email_address': 'imnew@foo.bar', 'address_id': 3, 'user_id': 3}, {'email_address': 'adsd5@llala.net', 'address_id': 4, 'user_id': None}] +""") l = sql.select([users, addresses], sql.and_(users.c.user_id==addresses.c.address_id, addresses.c.address_id==a.address_id)).execute() self.echo( repr(l.fetchone().row)) diff --git a/test/rundocs.py b/test/rundocs.py index 227247aae7..8ed8144452 100644 --- a/test/rundocs.py +++ b/test/rundocs.py @@ -57,7 +57,6 @@ User.mapper = assignmapper(users, properties = dict( # select user = User.mapper.select(User.c.user_name == 'fred jones')[0] -print repr(user.__dict__['addresses']) address = user.addresses[0] # modify @@ -129,4 +128,4 @@ user.preferences.stylename = 'bluesteel' user.addresses.append(Address('freddy@hi.org')) # commit -objectstore.commit() \ No newline at end of file +objectstore.commit() diff --git a/test/testbase.py b/test/testbase.py index 6bdad4953a..4d4d1e408a 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,4 +1,5 @@ import unittest +import StringIO echo = True @@ -8,7 +9,20 @@ class PersistTest(unittest.TestCase): def echo(self, text): if echo: print text - + def capture_exec(self, db, callable_): + e = db.echo + b = db.logger + buffer = StringIO.StringIO() + db.logger = buffer + db.echo = True + try: + callable_() + if echo: + print buffer.getvalue() + return buffer.getvalue() + finally: + db.logger = b + db.echo = e class AssertMixin(PersistTest): def assert_result(self, result, class_, *objects): @@ -29,7 +43,9 @@ class AssertMixin(PersistTest): self.assert_row(value[0], getattr(rowobj, key), value[1]) else: self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value)) - + def assert_enginesql(self, db, callable_, result): + self.assert_(self.capture_exec(db, callable_) == result, result) + def runTests(suite): runner = unittest.TextTestRunner(verbosity = 2, descriptions =1) runner.run(suite)