]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some refactorings to activemapper, made relationship() class have some polymorphic...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Jul 2006 19:48:02 +0000 (19:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Jul 2006 19:48:02 +0000 (19:48 +0000)
CHANGES
lib/sqlalchemy/ext/activemapper.py
test/ext/activemapper.py

diff --git a/CHANGES b/CHANGES
index 63f9052368d2b1e27c4a476776c018b55fb9c378..f6e4e63abb77f1635a2ff2d19e666b8aca4d7828 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,6 @@
+0.2.6
+- tweaks to ActiveMapper, supports self-referential relationships
+
 0.2.5
 - fixed endless loop bug in select_by(), if the traversal hit
 two mappers that referenced each other
index 32efc5a05ccc97f10923ab80f2a4fee78ab1fa52..a0984f46c103a1d24dd0f7655e0c470ff2b92706 100644 (file)
@@ -9,7 +9,6 @@ from sqlalchemy import backref as create_backref
 
 import inspect
 import sys
-import sets
 
 #
 # the "proxy" to the database engine... this can be swapped out at runtime
@@ -59,7 +58,31 @@ class relationship(object):
         self.uselist   = uselist
         self.secondary = secondary
         self.order_by  = order_by
-
+    def process(self, klass, propname, relations):
+        relclass = ActiveMapperMeta.classes[self.classname]
+        if isinstance(self.order_by, str):
+            self.order_by = [ self.order_by ]
+        if isinstance(self.order_by, list):
+            for itemno in range(len(self.order_by)):
+                if isinstance(self.order_by[itemno], str):
+                    self.order_by[itemno] = \
+                        getattr(relclass.c, self.order_by[itemno])
+        backref = self.create_backref(klass)
+        relations[propname] = relation(relclass.mapper,
+                                       secondary=self.secondary,
+                                       backref=backref, 
+                                       private=self.private, 
+                                       lazy=self.lazy, 
+                                       uselist=self.uselist,
+                                       order_by=self.order_by)
+    def create_backref(self, klass):
+        relclass = ActiveMapperMeta.classes[self.classname]
+        if klass.__name__ == self.classname:
+            br_fkey = getattr(relclass.c, self.colname)
+        else:
+            br_fkey = None
+        return create_backref(self.backref, foreignkey=br_fkey)
+        
 class one_to_many(relationship):
     def __init__(self, classname, colname=None, backref=None, private=False,
                  lazy=True, order_by=False):
@@ -69,10 +92,15 @@ class one_to_many(relationship):
 class one_to_one(relationship):
     def __init__(self, classname, colname=None, backref=None, private=False,
                  lazy=True, order_by=False):
-        if backref is not None:
-            backref = create_backref(backref, uselist=False)
         relationship.__init__(self, classname, colname, backref, private, 
                               lazy, uselist=False, order_by=order_by)
+    def create_backref(self, klass):
+        relclass = ActiveMapperMeta.classes[self.classname]
+        if klass.__name__ == self.classname:
+            br_fkey = getattr(relclass.c, self.colname)
+        else:
+            br_fkey = None
+        return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
 
 class many_to_many(relationship):
     def __init__(self, classname, secondary, backref=None, lazy=True,
@@ -81,7 +109,6 @@ class many_to_many(relationship):
                               uselist=True, secondary=secondary,
                               order_by=order_by)
 
-
 # 
 # SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy 
 # mapping in a declarative way, along with a function to process the 
@@ -89,22 +116,16 @@ class many_to_many(relationship):
 # up if the classes aren't specified in a proper order
 # 
 
-__deferred_classes__ = set()
-__processed_classes__ = set()
+__deferred_classes__ = {}
+__processed_classes__ = {}
 def process_relationships(klass, was_deferred=False):
     # first, we loop through all of the relationships defined on the
     # class, and make sure that the related class already has been
     # completely processed and defer processing if it has not
     defer = False
     for propname, reldesc in klass.relations.items():
-        found = False
-        for other_klass in __processed_classes__:
-            if reldesc.classname == other_klass.__name__:
-                found = True
-                break
-        
+        found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
         if not found:
-            if not was_deferred: __deferred_classes__.add(klass)
             defer = True
             break
     
@@ -112,44 +133,33 @@ def process_relationships(klass, was_deferred=False):
     # and make sure that we can find the related tables (they do not 
     # have to be processed yet, just defined), and we defer if we are 
     # not able to find any of the related tables
-    for col in klass.columns:
-        if col.foreign_key is not None:
-            found = False
-            for other_klass in ActiveMapperMeta.classes.values():
+    if not defer:
+        for col in klass.columns:
+            if col.foreign_key is not None:
+                found = False
                 table_name = col.foreign_key._colspec.rsplit('.', 1)[0]
-                if other_klass.table.fullname.lower() == table_name.lower():
-                    found = True
+                for other_klass in ActiveMapperMeta.classes.values():
+                    if other_klass.table.fullname.lower() == table_name.lower():
+                        found = True
                         
-            if not found:
-                if not was_deferred: __deferred_classes__.add(klass)
-                defer = True
-                break
-    
+                if not found:
+                    defer = True
+                    break
+
+    if defer and not was_deferred:
+        __deferred_classes__[klass.__name__] = klass
+        
     # if we are able to find all related and referred to tables, then
     # we can go ahead and assign the relationships to the class
     if not defer:
         relations = {}
         for propname, reldesc in klass.relations.items():
-            relclass = ActiveMapperMeta.classes[reldesc.classname]
-            if isinstance(reldesc.order_by, str):
-                reldesc.order_by = [ reldesc.order_by ]
-            if isinstance(reldesc.order_by, list):
-                for itemno in range(len(reldesc.order_by)):
-                    if isinstance(reldesc.order_by[itemno], str):
-                        reldesc.order_by[itemno] = \
-                            getattr(relclass.c, reldesc.order_by[itemno])
-            relations[propname] = relation(relclass.mapper,
-                                           secondary=reldesc.secondary,
-                                           backref=reldesc.backref, 
-                                           private=reldesc.private, 
-                                           lazy=reldesc.lazy, 
-                                           uselist=reldesc.uselist,
-                                           order_by=reldesc.order_by)
+            reldesc.process(klass, propname, relations)
         
         class_mapper(klass).add_properties(relations)
-        if klass in __deferred_classes__: 
-            __deferred_classes__.remove(klass)
-        __processed_classes__.add(klass)
+        if klass.__name__ in __deferred_classes__: 
+            del __deferred_classes__[klass.__name__]
+        __processed_classes__[klass.__name__] = klass
     
     # finally, loop through the deferred classes and attempt to process
     # relationships for them
@@ -160,7 +170,7 @@ def process_relationships(klass, was_deferred=False):
         while last_count > len(__deferred_classes__):
             last_count = len(__deferred_classes__)
             deferred = __deferred_classes__.copy()
-            for deferred_class in deferred:
+            for deferred_class in deferred.values():
                 process_relationships(deferred_class, was_deferred=True)
 
 
index 1bb93dd634e1fa64d5e0b69659e7ed06faa1c653..2a44f8e5bf8a1533097dcf5ffd0eba2500dbbe40 100644 (file)
@@ -1,6 +1,6 @@
 import testbase
 from sqlalchemy.ext.activemapper           import ActiveMapper, column, one_to_many, one_to_one, objectstore
-from sqlalchemy             import and_, or_, clear_mappers
+from sqlalchemy             import and_, or_, clear_mappers, backref
 from sqlalchemy             import ForeignKey, String, Integer, DateTime
 from datetime               import datetime
 
@@ -218,6 +218,38 @@ class testcase(testbase.PersistTest):
         )
         self.assertEquals(len(results), 1)
 
-    
+class testselfreferential(testbase.PersistTest):
+    def setUpAll(self):
+        global TreeNode
+        class TreeNode(activemapper.ActiveMapper):
+            class mapping:
+                id = column(Integer, primary_key=True)
+                name = column(String(30))
+                parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
+                children = one_to_many('TreeNode', colname='id', backref='parent')
+                
+        activemapper.metadata.connect(testbase.db)
+        activemapper.create_tables()
+    def tearDownAll(self):
+        clear_mappers()
+        activemapper.drop_tables()
+
+    def testbasic(self):
+        t = TreeNode(name='node1')
+        t.children.append(TreeNode(name='node2'))
+        t.children.append(TreeNode(name='node3'))
+        objectstore.flush()
+        objectstore.clear()
+        
+        t = TreeNode.get_by(name='node1')
+        assert (t.name == 'node1')
+        assert (t.children[0].name == 'node2')
+        assert (t.children[1].name == 'node3')
+        assert (t.children[1].parent is t)
+
+        objectstore.clear()
+        t = TreeNode.get_by(name='node3')
+        assert (t.parent is TreeNode.get_by(name='node1'))
+        
 if __name__ == '__main__':
-    unittest.main()
+    testbase.main()