From: Mike Bayer Date: Sat, 2 Dec 2006 23:36:59 +0000 (+0000) Subject: added 'remote_side' functionality to lazy clause generation X-Git-Tag: rel_0_3_2~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=18ff5f572fb83baa6d8f74aa4c978e2c1f09558e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added 'remote_side' functionality to lazy clause generation --- diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 28fbdcfd31..a3a92841de 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -129,7 +129,8 @@ class AbstractRelationLoader(LoaderStrategy): self.cascade = self.parent_property.cascade self.attributeext = self.parent_property.attributeext self.order_by = self.parent_property.order_by - + self.remote_side = self.parent_property.remote_side + def _init_instance_attribute(self, instance, callable_=None): return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_) @@ -151,7 +152,7 @@ NoLoader.logger = logging.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() - (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey) + (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self.parent.unjoined_table, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.remote_side) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere) @@ -235,12 +236,15 @@ class LazyLoader(AbstractRelationLoader): # to load data into it. sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) - def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey): + def _create_lazy_clause(self, table, primaryjoin, secondaryjoin, foreignkey, remote_side): binds = {} reverse = {} def column_in_table(table, column): return table.corresponding_column(column, raiseerr=False, keys_ok=False) is not None + if remote_side is None or len(remote_side) == 0: + remote_side = foreignkey + def find_column_in_expr(expr): if not isinstance(expr, sql.ColumnElement): return None @@ -259,13 +263,13 @@ class LazyLoader(AbstractRelationLoader): if leftcol is None or rightcol is None: return circular = leftcol.table is rightcol.table - if ((not circular and column_in_table(table, leftcol)) or (circular and rightcol in foreignkey)): + if ((not circular and column_in_table(table, leftcol)) or (circular and rightcol in remote_side)): col = leftcol binary.left = binds.setdefault(leftcol, sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type)) reverse[rightcol] = binds[col] - if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and leftcol in foreignkey)): + if (leftcol is not rightcol) and ((not circular and column_in_table(table, rightcol)) or (circular and leftcol in remote_side)): col = rightcol binary.right = binds.setdefault(rightcol, sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type)) diff --git a/test/orm/cycles.py b/test/orm/cycles.py index b5ad6ce332..5ec0426489 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -147,6 +147,22 @@ class SelfReferentialNoPKTest(AssertMixin): s.clear() t = s.query(TT).get_by(id=t1.id) assert t.children[0].parent_uuid == t1.uuid + def testlazyclause(self): + class TT(object): + def __init__(self): + self.uuid = hex(id(self)) + mapper(TT, table, properties={'children':relation(TT, remote_side=[table.c.parent_uuid], backref=backref('parent', remote_side=[table.c.uuid]))}) + s = create_session() + t1 = TT() + t2 = TT() + t1.children.append(t2) + s.save(t1) + s.flush() + s.clear() + + t = s.query(TT).get_by(id=t2.id) + assert t.uuid == t2.uuid + assert t.parent.uuid == t1.uuid class InheritTestOne(AssertMixin): def setUpAll(self):