import inspect
import sys
-import sets
#
# the "proxy" to the database engine... this can be swapped out at runtime
self.uselist = uselist
self.secondary = secondary
self.order_by = order_by
-
+ def process(self, klass, propname, relations):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if isinstance(self.order_by, str):
+ self.order_by = [ self.order_by ]
+ if isinstance(self.order_by, list):
+ for itemno in range(len(self.order_by)):
+ if isinstance(self.order_by[itemno], str):
+ self.order_by[itemno] = \
+ getattr(relclass.c, self.order_by[itemno])
+ backref = self.create_backref(klass)
+ relations[propname] = relation(relclass.mapper,
+ secondary=self.secondary,
+ backref=backref,
+ private=self.private,
+ lazy=self.lazy,
+ uselist=self.uselist,
+ order_by=self.order_by)
+ def create_backref(self, klass):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if klass.__name__ == self.classname:
+ br_fkey = getattr(relclass.c, self.colname)
+ else:
+ br_fkey = None
+ return create_backref(self.backref, foreignkey=br_fkey)
+
class one_to_many(relationship):
def __init__(self, classname, colname=None, backref=None, private=False,
lazy=True, order_by=False):
class one_to_one(relationship):
def __init__(self, classname, colname=None, backref=None, private=False,
lazy=True, order_by=False):
- if backref is not None:
- backref = create_backref(backref, uselist=False)
relationship.__init__(self, classname, colname, backref, private,
lazy, uselist=False, order_by=order_by)
+ def create_backref(self, klass):
+ relclass = ActiveMapperMeta.classes[self.classname]
+ if klass.__name__ == self.classname:
+ br_fkey = getattr(relclass.c, self.colname)
+ else:
+ br_fkey = None
+ return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
class many_to_many(relationship):
def __init__(self, classname, secondary, backref=None, lazy=True,
uselist=True, secondary=secondary,
order_by=order_by)
-
#
# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy
# mapping in a declarative way, along with a function to process the
# up if the classes aren't specified in a proper order
#
-__deferred_classes__ = set()
-__processed_classes__ = set()
+__deferred_classes__ = {}
+__processed_classes__ = {}
def process_relationships(klass, was_deferred=False):
# first, we loop through all of the relationships defined on the
# class, and make sure that the related class already has been
# completely processed and defer processing if it has not
defer = False
for propname, reldesc in klass.relations.items():
- found = False
- for other_klass in __processed_classes__:
- if reldesc.classname == other_klass.__name__:
- found = True
- break
-
+ found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
if not found:
- if not was_deferred: __deferred_classes__.add(klass)
defer = True
break
# and make sure that we can find the related tables (they do not
# have to be processed yet, just defined), and we defer if we are
# not able to find any of the related tables
- for col in klass.columns:
- if col.foreign_key is not None:
- found = False
- for other_klass in ActiveMapperMeta.classes.values():
+ if not defer:
+ for col in klass.columns:
+ if col.foreign_key is not None:
+ found = False
table_name = col.foreign_key._colspec.rsplit('.', 1)[0]
- if other_klass.table.fullname.lower() == table_name.lower():
- found = True
+ for other_klass in ActiveMapperMeta.classes.values():
+ if other_klass.table.fullname.lower() == table_name.lower():
+ found = True
- if not found:
- if not was_deferred: __deferred_classes__.add(klass)
- defer = True
- break
-
+ if not found:
+ defer = True
+ break
+
+ if defer and not was_deferred:
+ __deferred_classes__[klass.__name__] = klass
+
# if we are able to find all related and referred to tables, then
# we can go ahead and assign the relationships to the class
if not defer:
relations = {}
for propname, reldesc in klass.relations.items():
- relclass = ActiveMapperMeta.classes[reldesc.classname]
- if isinstance(reldesc.order_by, str):
- reldesc.order_by = [ reldesc.order_by ]
- if isinstance(reldesc.order_by, list):
- for itemno in range(len(reldesc.order_by)):
- if isinstance(reldesc.order_by[itemno], str):
- reldesc.order_by[itemno] = \
- getattr(relclass.c, reldesc.order_by[itemno])
- relations[propname] = relation(relclass.mapper,
- secondary=reldesc.secondary,
- backref=reldesc.backref,
- private=reldesc.private,
- lazy=reldesc.lazy,
- uselist=reldesc.uselist,
- order_by=reldesc.order_by)
+ reldesc.process(klass, propname, relations)
class_mapper(klass).add_properties(relations)
- if klass in __deferred_classes__:
- __deferred_classes__.remove(klass)
- __processed_classes__.add(klass)
+ if klass.__name__ in __deferred_classes__:
+ del __deferred_classes__[klass.__name__]
+ __processed_classes__[klass.__name__] = klass
# finally, loop through the deferred classes and attempt to process
# relationships for them
while last_count > len(__deferred_classes__):
last_count = len(__deferred_classes__)
deferred = __deferred_classes__.copy()
- for deferred_class in deferred:
+ for deferred_class in deferred.values():
process_relationships(deferred_class, was_deferred=True)
import testbase
from sqlalchemy.ext.activemapper import ActiveMapper, column, one_to_many, one_to_one, objectstore
-from sqlalchemy import and_, or_, clear_mappers
+from sqlalchemy import and_, or_, clear_mappers, backref
from sqlalchemy import ForeignKey, String, Integer, DateTime
from datetime import datetime
)
self.assertEquals(len(results), 1)
-
+class testselfreferential(testbase.PersistTest):
+ def setUpAll(self):
+ global TreeNode
+ class TreeNode(activemapper.ActiveMapper):
+ class mapping:
+ id = column(Integer, primary_key=True)
+ name = column(String(30))
+ parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
+ children = one_to_many('TreeNode', colname='id', backref='parent')
+
+ activemapper.metadata.connect(testbase.db)
+ activemapper.create_tables()
+ def tearDownAll(self):
+ clear_mappers()
+ activemapper.drop_tables()
+
+ def testbasic(self):
+ t = TreeNode(name='node1')
+ t.children.append(TreeNode(name='node2'))
+ t.children.append(TreeNode(name='node3'))
+ objectstore.flush()
+ objectstore.clear()
+
+ t = TreeNode.get_by(name='node1')
+ assert (t.name == 'node1')
+ assert (t.children[0].name == 'node2')
+ assert (t.children[1].name == 'node3')
+ assert (t.children[1].parent is t)
+
+ objectstore.clear()
+ t = TreeNode.get_by(name='node3')
+ assert (t.parent is TreeNode.get_by(name='node1'))
+
if __name__ == '__main__':
- unittest.main()
+ testbase.main()