]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 19:50:11 +0000 (19:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Oct 2005 19:50:11 +0000 (19:50 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/sql.py
test/objectstore.py
test/rundocs.py
test/testbase.py

index 331e1dc285c5d2cd14133f31b3bb4e216a455613..59f68203bfb921da23916f328f0fc22d325a5446 100644 (file)
@@ -54,6 +54,10 @@ class PropHistory(object):
         self.obj = obj
         self.key = key
         self.orig = PropHistory.NONE
+    def gethistory(self, *args, **kwargs):
+        return self
+    def __call__(self, *args, **kwargs):
+        return self.obj.__dict__[self.key]
     def history_contains(self, obj):
         return self.orig is obj or self.obj.__dict__[self.key] is obj
     def setattr_clean(self, value):
@@ -90,15 +94,30 @@ class PropHistory(object):
 
 class ListElement(util.HistoryArraySet):
     """manages the value of a particular list-based attribute on a particular object instance."""
-    def __init__(self, obj, key, items = None):
+    def __init__(self, obj, key, data=None):
         self.obj = obj
         self.key = key
-        util.HistoryArraySet.__init__(self, items)
-        obj.__dict__[key] = self.data
+        try:
+            list_ = obj.__dict__[key]
+            if data is not None:
+                list_.clear()
+                for d in data:
+                    list_.append(d)
+        except KeyError:
+            if data is not None:
+                list_ = data
+            else:
+                list_ = []
+            obj.__dict__[key] = []
+            
+        util.HistoryArraySet.__init__(self, list_)
 
+    def gethistory(self, *args, **kwargs):
+        return self
+    def __call__(self, *args, **kwargs):
+        return self
     def list_value_changed(self, obj, key, listval):
         pass    
-
     def setattr(self, value):
         self.obj.__dict__[self.key] = value
         self.set_data(value)
@@ -115,9 +134,37 @@ class ListElement(util.HistoryArraySet):
             self.list_value_changed(self.obj, self.key, self)
         return res
 
-
+class CallableProp(object):
+    """allows the attaching of a callable item, representing the future value
+    of a particular attribute on a particular object instance, to 
+    the AttributeManager.  When the attributemanager
+    accesses the object attribute, either to get its history or its real value, the __call__ method
+    is invoked which runs the underlying callable_ and sets the new value to the object attribute
+    via the manager."""
+    def __init__(self, callable_, obj, key, uselist = False):
+        self.callable_ = callable_
+        self.obj = obj
+        self.key = key
+        self.uselist = uselist
+    def gethistory(self, manager, *args, **kwargs):
+        self.__call__(manager, *args, **kwargs)
+        return manager.attribute_history[self.obj][self.key]
+    def __call__(self, manager, passive=False):
+        if passive:
+            return None
+        value = self.callable_()
+        if self.uselist:
+            p = manager.create_list(self.obj, self.key, value)
+            manager.attribute_history[self.obj][self.key] = p
+            return p
+        else:
+            self.obj.__dict__[self.key] = value
+            p = PropHistory(self.obj, self.key)
+            manager.attribute_history[self.obj][self.key] = p
+            return p
+            
 class AttributeManager(object):
-    """maintains a set of per-attribute history objects for a set of objects."""
+    """maintains a set of per-attribute callable/history manager objects for a set of objects."""
     def __init__(self):
         self.attribute_history = {}
 
@@ -130,13 +177,13 @@ class AttributeManager(object):
         
     def get_attribute(self, obj, key):
         try:
-            v = obj.__dict__[key]
+            return self.get_history(obj, key)(self)
+        except KeyError:
+            pass
+        try:
+            return obj.__dict__[key]
         except KeyError:
             raise AttributeError(key)
-        if (callable(v)):
-            v = v()
-            obj.__dict__[key] = v
-        return v
 
     def get_list_attribute(self, obj, key):
         return self.get_list_history(obj, key)
@@ -152,6 +199,13 @@ class AttributeManager(object):
         self.get_history(obj, key).delattr()
         self.value_changed(obj, key, value)
 
+    def set_callable(self, obj, key, func, uselist):
+        try:
+            d = self.attribute_history[obj]
+        except KeyError, e:
+            d = {}
+            self.attribute_history[obj] = d
+        d[key] = CallableProp(func, obj, key, uselist)
         
     def delete_list_attribute(self, obj, key):
         pass
@@ -190,7 +244,7 @@ class AttributeManager(object):
             
     def get_history(self, obj, key):
         try:
-            return self.attribute_history[obj][key]
+            return self.attribute_history[obj][key].gethistory(self)
         except KeyError, e:
             if e.args[0] is obj:
                 d = {}
@@ -205,14 +259,10 @@ class AttributeManager(object):
 
     def get_list_history(self, obj, key, passive = False):
         try:
-            return self.attribute_history[obj][key]
+            return self.attribute_history[obj][key].gethistory(self, passive)
         except KeyError, e:
             # TODO: when an callable is re-set on an existing list element
             list_ = obj.__dict__.get(key, None)
-            if callable(list_):
-                if passive:
-                    return None
-                list_ = list_()
             if e.args[0] is obj:
                 d = {}
                 self.attribute_history[obj] = d
index 87216773dacedd100484323a3b8c63b4164822ff..2f3e94de25e4063cc789fa97b1a4fca86081c99f 100644 (file)
@@ -22,7 +22,7 @@ import sqlalchemy.schema as schema
 import sqlalchemy.pool
 import sqlalchemy.util as util
 import sqlalchemy.sql as sql
-import StringIO
+import StringIO, sys
 import sqlalchemy.types as types
 
 def create_engine(name, *args ,**kwargs):
@@ -61,6 +61,7 @@ class SQLEngine(schema.SchemaEngine):
         self.context = util.ThreadLocal()
         self.tables = {}
         self.notes = {}
+        self.logger = sys.stdout
 
         
     def type_descriptor(self, typeobj):
@@ -206,7 +207,7 @@ class SQLEngine(schema.SchemaEngine):
         return ResultProxy(c, self.echo, typemap = typemap)
 
     def log(self, msg):
-        print msg
+        self.logger.write(msg + "\n")
 
 
 class ResultProxy:
index 173e35c37bb09813bb5afeb9807beae33d6d6b04..b7e87eb04917e36c75c8a9d27480a20f704d1e31 100644 (file)
@@ -694,12 +694,15 @@ class PropertyLoader(MapperProperty):
             return (obj2, obj1)
             
     def process_dependencies(self, deplist, uowcommit, delete = False):
-        #print self.mapper.table.name + " " + repr(deplist.map.values()) + " process_dep isdelete " + repr(delete)
+        print self.mapper.table.name + " " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
 
         # fucntion to set properties across a parent/child object plus an "association row",
         # based on a join condition
         def sync_foreign_keys(binary):
-            self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys)
+            if self.direction == PropertyLoader.RIGHT:
+                self._sync_foreign_keys(binary, child, obj, associationrow, clearkeys)
+            else:
+                self._sync_foreign_keys(binary, obj, child, associationrow, clearkeys)
         setter = BinaryVisitor(sync_foreign_keys)
 
         def getlist(obj, passive=True):
@@ -744,8 +747,8 @@ class PropertyLoader(MapperProperty):
                 if len(secondary_insert):
                     statement = self.secondary.insert()
                     statement.execute(*secondary_insert)
-        elif self.direction == PropertyLoader.LEFT:
-            if delete and not self.private:
+        elif self.direction == PropertyLoader.LEFT and delete:
+            if not self.private:
                 updates = []
                 clearkeys = True
                 for obj in deplist:
@@ -763,33 +766,18 @@ class PropertyLoader(MapperProperty):
                         values[bind.shortname] = None
                     statement = self.target.update(self.lazywhere, values = values)
                     statement.execute(*updates)
-            else:
-                for obj in deplist:
-                    childlist = getlist(obj)
-                    if childlist is None: return
-                    uowcommit.register_saved_list(childlist)
-                    clearkeys = False
-                    for child in childlist.added_items():
-                        self.primaryjoin.accept_visitor(setter)
-                    clearkeys = True
-                    for child in childlist.deleted_items():
-                         self.primaryjoin.accept_visitor(setter)
-        elif self.direction == PropertyLoader.RIGHT:
-            for child in deplist:
-                childlist = getlist(child)
+        else:
+            for obj in deplist:
+                childlist = getlist(obj)
                 if childlist is None: return
                 uowcommit.register_saved_list(childlist)
                 clearkeys = False
-                added = childlist.added_items()
-                if len(added):
-                    for obj in added:
-                        self.primaryjoin.accept_visitor(setter)
-                else:
+                for child in childlist.added_items():
+                    self.primaryjoin.accept_visitor(setter)
+                if self.direction != PropertyLoader.RIGHT or len(childlist.added_items()) == 0:
                     clearkeys = True
-                    for obj in childlist.deleted_items():
+                    for child in childlist.deleted_items():
                         self.primaryjoin.accept_visitor(setter)
-        else:
-            raise " no foreign key ?"
     
         #print self.mapper.table.name + " postdep " + repr([str(v) for v in deplist.map.values()]) + " process_dep isdelete " + repr(delete)
 
@@ -797,6 +785,8 @@ class PropertyLoader(MapperProperty):
         """given a binary clause with an = operator joining two table columns, synchronizes the values 
         of the corresponding attributes within a parent object and a child object, or the attributes within an 
         an "association row" that represents an association link between the 'parent' and 'child' object."""
+        if obj is child:
+            raise "wha?"
         if binary.operator == '=':
             if binary.left.table == binary.right.table:
                 if binary.right is self.foreignkey:
@@ -805,8 +795,9 @@ class PropertyLoader(MapperProperty):
                     source = binary.right
                 else:
                     raise "Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname)
-                #print "set " + repr(child) + ":" + self.foreignkey.key + " to " + repr(obj) + ":" + source.key
                 self.mapper._setattrbycolumn(child, self.foreignkey, self.parent._getattrbycolumn(obj, source))
+                print "set " + repr(id(child)) + child.__dict__['name'] + ":" + self.foreignkey.key + " to " + repr(id(obj)) + obj.__dict__['name'] + ":" + source.key 
+                #+ "\n" + repr(child.__dict__)
             else:
                 colmap = {binary.left.table : binary.left, binary.right.table : binary.right}
                 if colmap.has_key(self.parent.primarytable) and colmap.has_key(self.target):
@@ -820,18 +811,10 @@ class PropertyLoader(MapperProperty):
                 elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
                     associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
             
-
-# TODO: break out the lazywhere capability so that the main PropertyLoader can use it
-# to do child deletes
 class LazyLoader(PropertyLoader):
-
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
-            # TODO: get lazy callables to be stored within the unit of work?
-            # allows serializable ?  still need lazyload state to exist in the application
-            # when u deserialize tho
-            objectstore.uow().attribute_set_callable(instance, self.key, LazyLoadInstance(self, row))
-
+            objectstore.uow().register_callable(instance, self.key, LazyLoadInstance(self, row), uselist=self.uselist)
 
 def create_lazy_clause(table, primaryjoin, secondaryjoin, thiscol):
     binds = {}
index 997db3e6beb17859fc241247cdbd540860334f40..0b5311f4fc580c9b1901c88520f409c18af3743a 100644 (file)
@@ -107,7 +107,9 @@ class UnitOfWork(object):
         self.new = util.HashSet(ordered = True)
         self.dirty = util.HashSet()
         self.modified_lists = util.HashSet()
-        self.deleted = util.HashSet()
+        # the delete list is ordered mostly so the unit tests can predict the argument list ordering.
+        # TODO: need stronger unit test fixtures....
+        self.deleted = util.HashSet(ordered = True)
         self.parent = parent
 
     def get(self, class_, *id):
@@ -136,17 +138,10 @@ class UnitOfWork(object):
         
     def register_attribute(self, class_, key, uselist):
         self.attributes.register_attribute(class_, key, uselist)
-        
-    def attribute_set_callable(self, obj, key, func):
-        # TODO: gotta work this out when a list element is already there,
-        # etc.
-        obj.__dict__[key] = func
-        try:
-            del self.attributes.attribute_history[obj][key]
-        except KeyError:
-            pass
 
-    
+    def register_callable(self, obj, key, func, uselist):
+        self.attributes.set_callable(obj, key, func, uselist)
+        
     def register_clean(self, obj):
         try:
             del self.dirty[obj]
@@ -405,7 +400,32 @@ class UOWTask(object):
     def sort_circular_dependencies(self, trans):
         allobjects = self.objects
         tuples = []
+        d = {}
+        def get_task(obj):
+            try:
+                return d[obj]
+            except KeyError:
+                t = UOWTask(self.mapper, self.isdelete, self.listonly)
+                t.taskhash = d
+                d[obj] = t
+                return t
+
+        dependencies = {}
+        def get_dependency_task(obj, processor):
+            try:
+                dp = dependencies[obj]
+            except KeyError:
+                dp = {}
+                dependencies[obj] = dp
+            try:
+                l = dp[processor]
+            except KeyError:
+                l = UOWTask(None, None, None)
+                dp[processor] = l
+            return l
+            
         for obj in self.objects:
+            parenttask = get_task(obj)
             for dep in self.dependencies:
                 (processor, targettask) = dep
                 if targettask is self:
@@ -414,29 +434,40 @@ class UOWTask(object):
                         whosdep = processor.whose_dependent_on_who(obj, o, trans)
                         if whosdep is not None:
                             tuples.append(whosdep)
+                            if whosdep[0] is obj:
+                                get_dependency_task(whosdep[0], processor).objects.append(whosdep[0])
+                            else:
+                                get_dependency_task(whosdep[0], processor).objects.append(whosdep[1])
+        
         head = TupleSorter(tuples, allobjects).sort()
         if head is None:
             return None
-
-        d = {}
-        def make_task():
-            t = UOWTask(self.mapper, self.isdelete, self.listonly)
-            t.dependencies = self.dependencies
-            t.taskhash = d
-            return t
         
         def make_task_tree(node, parenttask):
             if node is None:
                 return
             parenttask.objects.append(node.item)
-            t = make_task()
-            d[node.item] = t
+            if dependencies.has_key(node.item):
+                for processor, deptask in dependencies[node.item].iteritems():
+                    parenttask.dependencies.append((processor, deptask))
+            t = d[node.item]
             for n in node.children:
-                make_task_tree(n, t)
-        
-        t = make_task()
+                t2 = make_task_tree(n, t)
+            return t
+            
+        t = UOWTask(self.mapper, self.isdelete, self.listonly)
+        t.taskhash = d
         make_task_tree(head, t)
+        
+        t._print_circular()        
         return t
+
+    def _print_circular(t):
+        print "-----------------------------"
+        print "task objects: " + repr([str(v) for v in t.objects])
+        print "task depends: " + repr([(dt[0].key, [str(o) for o in dt[1].objects]) for dt in t.dependencies])
+        for o in t.objects:
+            t.taskhash[o]._print_circular()
         
     def __str__(self):
         if self.isdelete:
index dfa3cdbf47ca0d13422034c150ea484a141e4433..52ce5fdc9dbc34c8f8b550b30c9c544f982e6a95 100644 (file)
@@ -220,12 +220,9 @@ class ClauseElement(object):
 
     def compile(self, engine = None, bindparams = None):
         """compiles this SQL expression using its underlying SQLEngine to produce
-        a Compiled object.  The actual SQL statement is the Compiled object's string representation.   
-        bindparams is an optional dictionary representing the bind parameters to be used with 
-        the statement.  Currently, only the compilations of INSERT and UPDATE statements
-        use the bind parameters, in order to determine which
-        table columns should be used in the statement."""
-
+        a Compiled object.  If no engine can be found, an ansisql engine is used.
+        bindparams is a dictionary representing the default bind parameters to be used with 
+        the statement.  """
         if engine is None:
             for f in self._get_from_objects():
                 engine = f.engine
@@ -237,6 +234,9 @@ class ClauseElement(object):
 
         return engine.compile(self, bindparams = bindparams)
 
+    def __str__(self):
+        return str(self.compile())
+        
     def execute(self, *multiparams, **params):
         """compiles and executes this SQL expression using its underlying SQLEngine.
         the given **params are used as bind parameters when compiling and executing the expression. 
index c64b477377f5e8fa0e0197a3863d0eda1eda1495..8736869aa1d50443fadbcb63fe04947384f47d24 100644 (file)
@@ -1,6 +1,7 @@
 from testbase import PersistTest, AssertMixin
 import unittest, sys, os
 from sqlalchemy.mapper import *
+import StringIO
 import sqlalchemy.objectstore as objectstore
 
 from tables import *
@@ -207,7 +208,18 @@ class SaveTest(AssertMixin):
         
         objectstore.uow().register_deleted(l[0])
         objectstore.uow().register_deleted(l[2])
-        objectstore.uow().commit()
+        res = self.capture_exec(db, lambda: objectstore.uow().commit())
+        state = None
+        for line in res.split('\n'):
+            if line == "DELETE FROM items WHERE items.item_id = :item_id":
+                self.assert_(state is None or state == 'addresses')
+            elif line == "DELETE FROM orders WHERE orders.order_id = :order_id":
+                state = 'orders'
+            elif line == "DELETE FROM email_addresses WHERE email_addresses.address_id = :address_id":
+                if state is None:
+                    state = 'addresses'
+            elif line == "DELETE FROM users WHERE users.user_id = :user_id":
+                self.assert_(state is not None)
         
     def testbackwardsonetoone(self):
         # test 'backwards'
@@ -238,8 +250,12 @@ class SaveTest(AssertMixin):
         objects[3].user = User()
         objects[3].user.user_name = 'imnewlyadded'
         
-        objectstore.uow().commit()
-        return
+        self.assert_enginesql(db, lambda: objectstore.uow().commit(), 
+"""INSERT INTO users (user_id, user_name) VALUES (:user_id, :user_name)
+{'user_id': None, 'user_name': 'imnewlyadded'}
+UPDATE email_addresses SET address_id=:address_id, user_id=:user_id, email_address=:email_address WHERE email_addresses.address_id = :address_id
+[{'email_address': 'imnew@foo.bar', 'address_id': 3, 'user_id': 3}, {'email_address': 'adsd5@llala.net', 'address_id': 4, 'user_id': None}]
+""")
         l = sql.select([users, addresses], sql.and_(users.c.user_id==addresses.c.address_id, addresses.c.address_id==a.address_id)).execute()
         self.echo( repr(l.fetchone().row))
         
index 227247aae77084bf28154724a90840f318b39489..8ed8144452ed50f90ec174782725cc3cfcb18b0a 100644 (file)
@@ -57,7 +57,6 @@ User.mapper = assignmapper(users, properties = dict(
 
 # select
 user = User.mapper.select(User.c.user_name == 'fred jones')[0]
-print repr(user.__dict__['addresses'])
 address = user.addresses[0]
 
 # modify
@@ -129,4 +128,4 @@ user.preferences.stylename = 'bluesteel'
 user.addresses.append(Address('freddy@hi.org'))
 
 # commit
-objectstore.commit()
\ No newline at end of file
+objectstore.commit()
index 6bdad4953ae2627d8c8e7cf5ce7f6b1f589e51bd..4d4d1e408af961bf54460864cfaf6acee3df8acb 100644 (file)
@@ -1,4 +1,5 @@
 import unittest
+import StringIO
 
 echo = True
 
@@ -8,7 +9,20 @@ class PersistTest(unittest.TestCase):
     def echo(self, text):
         if echo:
             print text
-        
+    def capture_exec(self, db, callable_):
+        e = db.echo
+        b = db.logger
+        buffer = StringIO.StringIO()
+        db.logger = buffer
+        db.echo = True
+        try:
+            callable_()
+            if echo:
+                print buffer.getvalue()
+            return buffer.getvalue()
+        finally:
+            db.logger = b
+            db.echo = e
 
 class AssertMixin(PersistTest):
     def assert_result(self, result, class_, *objects):
@@ -29,7 +43,9 @@ class AssertMixin(PersistTest):
                     self.assert_row(value[0], getattr(rowobj, key), value[1])
             else:
                 self.assert_(getattr(rowobj, key) == value, "attribute %s value %s does not match %s" % (key, getattr(rowobj, key), value))
-
+    def assert_enginesql(self, db, callable_, result):
+        self.assert_(self.capture_exec(db, callable_) == result, result)
+        
 def runTests(suite):
     runner = unittest.TextTestRunner(verbosity = 2, descriptions =1)
     runner.run(suite)