From: Mike Bayer Date: Mon, 14 Apr 2008 18:23:59 +0000 (+0000) Subject: - simplified __create_lazy_clause to make better usage of the new local/remote pairs... X-Git-Tag: rel_0_5beta1~183 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f5126ab3a169b6f8a9171868fe32b2bd385f8b8f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - simplified __create_lazy_clause to make better usage of the new local/remote pairs collection - corrected the direction of local/remote pairs for manytoone - added new tests which demonstrate lazyloading working when the bind param is embedded inside of a SQL function, when _local_remote_pairs argument is used; fixes the viewonly version of [ticket:610] - removed needless kwargs check from visitors.traverse --- diff --git a/CHANGES b/CHANGES index 37861592d7..3bcfd89159 100644 --- a/CHANGES +++ b/CHANGES @@ -17,8 +17,11 @@ CHANGES across functions, etc., _local_remote_pairs=[tuples]. This complements a complex primaryjoin condition allowing you to provide the individual column pairs which comprise - the relation's local and remote sides. - + the relation's local and remote sides. Also improved + lazy load SQL generation to handle placing bind params + inside of functions and other expressions. + (partial progress towards [ticket:610]) + - removed ancient assertion that mapped selectables require "alias names" - the mapper creates its own alias now if none is present. Though in this case you need to use diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 140fffd892..55f6c98752 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -621,9 +621,10 @@ class PropertyLoader(StrategizedProperty): if self.direction is MANYTOONE: self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] + self.local_remote_pairs = [(r, l) for l, r in eq_pairs] else: self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)] - self.local_remote_pairs = zip(self.local_side, self.remote_side) + self.local_remote_pairs = eq_pairs if self.direction is ONETOMANY: for l in self.local_side: @@ -651,12 +652,15 @@ class PropertyLoader(StrategizedProperty): self.direction = ONETOMANY else: self.direction = MANYTOONE - + elif self._arg_local_remote_pairs: + remote = util.Set([r for l, r in self._arg_local_remote_pairs]) + if self.foreign_keys.intersection(remote): + self.direction = ONETOMANY + else: + self.direction = MANYTOONE elif self.remote_side: - for f in self.foreign_keys: - if f in self.remote_side: - self.direction = ONETOMANY - return + if self.foreign_keys.intersection(self.remote_side): + self.direction = ONETOMANY else: self.direction = MANYTOONE else: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 3bbb380d13..e758ac08b1 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -351,41 +351,39 @@ class LazyLoader(AbstractRelationLoader): def __create_lazy_clause(cls, prop, reverse_direction=False): binds = {} + lookup = {} equated_columns = {} - secondaryjoin = prop.secondaryjoin - local = prop.local_side - - def should_bind(targetcol, othercol): - if reverse_direction and not secondaryjoin: - return othercol in local - else: - return targetcol in local - - def visit_binary(binary): - leftcol = binary.left - rightcol = binary.right - - equated_columns[rightcol] = leftcol - equated_columns[leftcol] = rightcol - - if should_bind(leftcol, rightcol): - if leftcol not in binds: - binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type) - binary.left = binds[leftcol] - elif should_bind(rightcol, leftcol): - if rightcol not in binds: - binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type) - binary.right = binds[rightcol] - + if reverse_direction and not prop.secondaryjoin: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(r, []) + _list.append((r, l)) + equated_columns[l] = r + else: + for l, r in prop.local_remote_pairs: + _list = lookup.setdefault(l, []) + _list.append((l, r)) + equated_columns[r] = l + + def col_to_bind(col): + if col in lookup: + for tobind, equated in lookup[col]: + if equated in binds: + return None + if col not in binds: + binds[col] = sql.bindparam(None, None, type_=col.type) + return binds[col] + return None + lazywhere = prop.primaryjoin if not prop.secondaryjoin or not reverse_direction: - lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary) + lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True) if prop.secondaryjoin is not None: + secondaryjoin = prop.secondaryjoin if reverse_direction: - secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary) + secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True) lazywhere = sql.and_(lazywhere, secondaryjoin) bind_to_col = dict([(binds[col].key, col) for col in binds]) diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 792391929f..9888a228a3 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -177,7 +177,6 @@ def traverse(clause, **kwargs): __traverse_options__ = kwargs.pop('traverse_options', {}) vis = Vis() for key in kwargs: - if key.startswith('visit_'): - setattr(vis, key, kwargs[key]) + setattr(vis, key, kwargs[key]) return vis.traverse(clause, clone=clone) diff --git a/test/orm/relationships.py b/test/orm/relationships.py index 6e9a0f96e1..40773f8359 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -1066,7 +1066,7 @@ class ExplicitLocalRemoteTest(ORMTest): Column('t1id', String(50)), ) - def test_onetomany(self): + def test_onetomany_funcfk(self): class T1(fixtures.Base): pass class T2(fixtures.Base): @@ -1084,17 +1084,21 @@ class ExplicitLocalRemoteTest(ORMTest): sess = create_session() a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') b1 = T2(data='b1', t1id='NuMbEr1') b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') sess.save(a1) + sess.save(a2) sess.save(b1) sess.save(b2) + sess.save(b3) sess.flush() sess.clear() self.assertEquals(sess.query(T1).first(), T1(id='number1', data='a1', t2s=[T2(data='b1', t1id='NuMbEr1'), T2(data='b2', t1id='Number1')])) - def test_manytoone(self): + def test_manytoone_funcfk(self): class T1(fixtures.Base): pass class T2(fixtures.Base): @@ -1103,25 +1107,95 @@ class ExplicitLocalRemoteTest(ORMTest): mapper(T2, t2, properties={ 't1':relation(T1, primaryjoin=t1.c.id==func.lower(t2.c.t1id), _local_remote_pairs=[(t2.c.t1id, t1.c.id)], - foreign_keys=[t2.c.t1id] + foreign_keys=[t2.c.t1id], + uselist=True ) }) sess = create_session() a1 = T1(id='number1', data='a1') + a2 = T1(id='number2', data='a2') b1 = T2(data='b1', t1id='NuMbEr1') b2 = T2(data='b2', t1id='Number1') + b3 = T2(data='b3', t1id='Number2') sess.save(a1) + sess.save(a2) sess.save(b1) sess.save(b2) + sess.save(b3) sess.flush() sess.clear() - self.assertEquals(sess.query(T2).all(), + self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), [ - T2(data='b1', t1=T1(id='number1', data='a1')), - T2(data='b2', t1=T1(id='number1', data='a1')) + T2(data='b1', t1=[T1(id='number1', data='a1')]), + T2(data='b2', t1=[T1(id='number1', data='a1')]) ] ) + def test_onetomany_func_referent(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, primaryjoin=func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t1.c.id, t2.c.t1id)], + foreign_keys=[t2.c.t1id] + ) + }) + mapper(T2, t2) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b2', t1id='number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T1).first(), T1(id='NuMbeR1', data='a1', t2s=[T2(data='b1', t1id='number1'), T2(data='b2', t1id='number1')])) + + def test_manytoone_func_referent(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + + mapper(T1, t1) + mapper(T2, t2, properties={ + 't1':relation(T1, primaryjoin=func.lower(t1.c.id)==t2.c.t1id, + _local_remote_pairs=[(t2.c.t1id, t1.c.id)], + foreign_keys=[t2.c.t1id], uselist=True + ) + }) + + sess = create_session() + a1 = T1(id='NuMbeR1', data='a1') + a2 = T1(id='NuMbeR2', data='a2') + b1 = T2(data='b1', t1id='number1') + b2 = T2(data='b2', t1id='number1') + b3 = T2(data='b3', t1id='number2') + sess.save(a1) + sess.save(a2) + sess.save(b1) + sess.save(b2) + sess.save(b3) + sess.flush() + sess.clear() + + self.assertEquals(sess.query(T2).filter(T2.data.in_(['b1', 'b2'])).all(), + [ + T2(data='b1', t1=[T1(id='NuMbeR1', data='a1')]), + T2(data='b2', t1=[T1(id='NuMbeR1', data='a1')]) + ] + ) + def test_escalation(self): class T1(fixtures.Base): pass