]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 00:47:45 +0000 (00:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 00:47:45 +0000 (00:47 +0000)
examples/adjacencytree/byroot_tree.py
lib/sqlalchemy/objectstore.py
test/objectstore.py
test/testbase.py

index cf8b7cbbba544cd7065a186618fd78fc07803561..d352db4c711756a86abd4e1936fc72594a3f5815 100644 (file)
@@ -133,7 +133,7 @@ print "Committing:"
 print "----------------------------"
 
 objectstore.commit()
-
+#sys.exit()
 print "\n\n\n----------------------------"
 print "Tree After Save:"
 print "----------------------------"
index 59e639e615d62f1a07ae6edf505c3680f29337ed..8f7509f5e5269e0ce299f506ee91daafa9219567 100644 (file)
@@ -324,7 +324,7 @@ class UOWTransaction(object):
             task.mapper.register_dependencies(self)
 
         head = self._sort_dependencies()
-        #print "Task dump:\n" + head.dump()
+        print "Task dump:\n" + head.dump()
         if head is not None:
             head.execute(self)
             
@@ -395,7 +395,8 @@ class UOWTaskElement(object):
         
 class UOWTask(object):
     def __init__(self, uowtransaction, mapper):
-        uowtransaction.tasks[mapper] = self
+        if uowtransaction is not None:
+            uowtransaction.tasks[mapper] = self
         self.uowtransaction = uowtransaction
         self.mapper = mapper
         self.objects = util.OrderedDict()
@@ -430,15 +431,20 @@ class UOWTask(object):
         """executes this UOWTask.  saves objects to be saved, processes all dependencies
         that have been registered, and deletes objects to be deleted. """
         if self.circular is not None:
+            print "CIRCULAR !"
             self.circular.execute(trans)
+            print "CIRCULAR DONE !"
             return
 
+        print "task " + str(self) + " tosave: " + repr(self.tosave_objects())
         self.mapper.save_obj(self.tosave_objects(), trans)
         for dep in self.save_dependencies():
             (processor, targettask, isdelete) = dep
             processor.process_dependencies(targettask, [elem.obj for elem in targettask.tosave_elements()], trans, delete = False)
+            print "processed dependencies on " + repr([elem.obj for elem in targettask.tosave_elements()])
         for element in self.tosave_elements():
             if element.childtask is not None:
+                print "execute elem childtask " + str(element.childtask)
                 element.childtask.execute(trans)
         for dep in self.delete_dependencies():
             (processor, targettask, isdelete) = dep
@@ -477,7 +483,7 @@ class UOWTask(object):
             try:
                 return objecttotask[obj]
             except KeyError:
-                t = UOWTask(trans, self.mapper)
+                t = UOWTask(None, self.mapper)
                 objecttotask[obj] = t
                 return t
 
@@ -491,7 +497,7 @@ class UOWTask(object):
             try:
                 l = dp[(processor, isdelete)]
             except KeyError:
-                l = UOWTask(trans, None)
+                l = UOWTask(None, None)
                 dp[(processor, isdelete)] = l
             return l
 
@@ -538,7 +544,7 @@ class UOWTask(object):
                 t2 = make_task_tree(n, t)
             return t
             
-        t = UOWTask(trans, self.mapper)
+        t = UOWTask(None, self.mapper)
         make_task_tree(head, t)
         return t
 
@@ -550,7 +556,7 @@ class UOWTask(object):
             return s
         saveobj = self.tosave_elements()
         if len(saveobj) > 0:
-            s += "\n" + indent + "  Save Elements:"
+            s += "\n" + indent + "  Save Elements:(%d)" % len(saveobj)
             for o in saveobj:
                 if not o.listonly:
                     s += "\n     " + indent + repr(o)
index 41c3dea8ecee2846079790615a8b80b8acb10904..5b5889695d48013108160086823734713262e364 100644 (file)
@@ -378,16 +378,22 @@ class SaveTest(AssertMixin):
             dict(user_id = 8, user_name = 'ed'),
             dict(user_id = 9, user_name = 'fred')
         )
-        db.connection().commit()
+        db.commit()
 
+        # mapper with just users table
         User.mapper = assignmapper(users)
         User.mapper.select()
+        oldmapper = User.mapper
+        # now a mapper with the users table plus a relation to the addresses
         User.mapper = assignmapper(users, properties = dict(
             addresses = relation(Address, addresses, lazy = False)
         ))
+        self.assert_(oldmapper is not User.mapper)
         u = User.mapper.select()
         u[0].addresses.append(Address())
         u[0].addresses[0].email_address='hi'
+        
+        # insure that upon commit, the new mapper with the address relation is used
         self.assert_sql(db, lambda: objectstore.commit(), 
                 [
                     (
index 77688f7da6e68ea58bf23d54cb605cf7ec5c0119..b4fbb3142c885e31e09041aea63004de978f93fb 100644 (file)
@@ -7,6 +7,7 @@ import sqlalchemy.databases.postgres as postgres
 echo = True
 
 class PersistTest(unittest.TestCase):
+    """persist base class, provides default setUpAll, tearDownAll and echo functionality"""
     def __init__(self, *args, **params):
         unittest.TestCase.__init__(self, *args, **params)
     def echo(self, text):
@@ -18,6 +19,8 @@ class PersistTest(unittest.TestCase):
         pass
 
 class AssertMixin(PersistTest):
+    """given a list-based structure of keys/properties which represent information within an object structure, and
+    a list of actual objects, asserts that the list of objects corresponds to the structure."""
     def assert_result(self, result, class_, *objects):
         if echo:
             print repr(result)
@@ -44,6 +47,7 @@ class AssertMixin(PersistTest):
             db.set_assert_list(None, None)
         
 class EngineAssert(object):
+    """decorates a SQLEngine object to match the incoming queries against a set of assertions."""
     def __init__(self, engine):
         self.engine = engine
         self.realexec = engine.execute
@@ -76,40 +80,41 @@ class EngineAssert(object):
 
 
 class TTestSuite(unittest.TestSuite):
-        def __init__(self, tests=()):
-            if len(tests) >0 and isinstance(tests[0], PersistTest):
-                self._initTest = tests[0]
-            else:
-                self._initTest = None
-            unittest.TestSuite.__init__(self, tests)
+    """override unittest.TestSuite to provide per-TestCase class setUpAll() and tearDownAll() functionality"""
+    def __init__(self, tests=()):
+        if len(tests) >0 and isinstance(tests[0], PersistTest):
+            self._initTest = tests[0]
+        else:
+            self._initTest = None
+        unittest.TestSuite.__init__(self, tests)
 
-        def run(self, result):
+    def run(self, result):
+        try:
+            if self._initTest is not None:
+                self._initTest.setUpAll()
+        except:
+            result.addError(self._initTest, self.__exc_info())
+            pass
+        try:
+            return unittest.TestSuite.run(self, result)
+        finally:
             try:
                 if self._initTest is not None:
-                    self._initTest.setUpAll()
+                    self._initTest.tearDownAll()
             except:
                 result.addError(self._initTest, self.__exc_info())
                 pass
-            try:
-                return unittest.TestSuite.run(self, result)
-            finally:
-                try:
-                    if self._initTest is not None:
-                        self._initTest.tearDownAll()
-                except:
-                    result.addError(self._initTest, self.__exc_info())
-                    pass
 
-        def __exc_info(self):
-            """Return a version of sys.exc_info() with the traceback frame
-               minimised; usually the top level of the traceback frame is not
-               needed.
-               ripped off out of unittest module since its double __
-            """
-            exctype, excvalue, tb = sys.exc_info()
-            if sys.platform[:4] == 'java': ## tracebacks look different in Jython
-                return (exctype, excvalue, tb)
+    def __exc_info(self):
+        """Return a version of sys.exc_info() with the traceback frame
+           minimised; usually the top level of the traceback frame is not
+           needed.
+           ripped off out of unittest module since its double __
+        """
+        exctype, excvalue, tb = sys.exc_info()
+        if sys.platform[:4] == 'java': ## tracebacks look different in Jython
             return (exctype, excvalue, tb)
+        return (exctype, excvalue, tb)
 
 
 unittest.TestLoader.suiteClass = TTestSuite