From 42c53ec821b18883d6499120fc14ce8a7a1b8b46 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 24 Apr 2006 19:21:12 +0000 Subject: [PATCH] lazyload clause calculation uses anonymous keynames for the bind parameters, to avoid compilation name conflicts --- lib/sqlalchemy/mapping/properties.py | 22 +++++-- test/alltests.py | 1 + test/lazytest1.py | 92 ++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 test/lazytest1.py diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 133ad9fe77..8175ee2641 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -18,6 +18,7 @@ import sync import mapper import objectstore from sqlalchemy.exceptions import * +import random class ColumnProperty(MapperProperty): """describes an object attribute that corresponds to a table column.""" @@ -579,7 +580,7 @@ class PropertyLoader(MapperProperty): class LazyLoader(PropertyLoader): def do_init_subclass(self, key, parent): - (self.lazywhere, self.lazybinds) = create_lazy_clause(self.parent.noninherited_table, self.primaryjoin, self.secondaryjoin, self.foreignkey) + (self.lazywhere, self.lazybinds, self.lazyreverse) = create_lazy_clause(self.parent.noninherited_table, self.primaryjoin, self.secondaryjoin, self.foreignkey) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() self.use_get = not self.uselist and self.mapper.query._get_clause.compare(self.lazywhere) @@ -608,7 +609,8 @@ class LazyLoader(PropertyLoader): if self.use_get: ident = [] for primary_key in self.mapper.pks_by_table[self.mapper.table]: - ident.append(params[primary_key._label]) + bind = self.lazyreverse[primary_key] + ident.append(params[bind.key]) return self.mapper.using(session).get(*ident) elif self.order_by is not False: order_by = self.order_by @@ -646,23 +648,33 @@ class LazyLoader(PropertyLoader): def create_lazy_clause(table, primaryjoin, secondaryjoin, foreignkey): binds = {} + reverselookup = {} + + def bind_label(): + return "lazy_" + hex(random.randint(0, 65535))[2:] + def visit_binary(binary): circular = isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and binary.left.table is binary.right.table if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and ((not circular and binary.left.table is table) or (circular and binary.right is foreignkey)): + col = binary.left binary.left = binds.setdefault(binary.left, - sql.BindParamClause(binary.right._label, None, shortname = binary.left.name)) + sql.BindParamClause(bind_label(), None, shortname = binary.left.name)) + reverselookup[binary.right] = binds[col] binary.swap() if isinstance(binary.right, schema.Column) and isinstance(binary.left, schema.Column) and ((not circular and binary.right.table is table) or (circular and binary.left is foreignkey)): + col = binary.right binary.right = binds.setdefault(binary.right, - sql.BindParamClause(binary.left._label, None, shortname = binary.right.name)) + sql.BindParamClause(bind_label(), None, shortname = binary.right.name)) + reverselookup[binary.left] = binds[col] lazywhere = primaryjoin.copy_container() li = BinaryVisitor(visit_binary) lazywhere.accept_visitor(li) + #print "PRIMARYJOIN", str(lazywhere), [b.key for b in binds.values()] if secondaryjoin is not None: lazywhere = sql.and_(lazywhere, secondaryjoin) - return (lazywhere, binds) + return (lazywhere, binds, reverselookup) class EagerLoader(PropertyLoader): diff --git a/test/alltests.py b/test/alltests.py index f0e17b39c7..4e9c73c2c2 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -34,6 +34,7 @@ def suite(): # ORM selecting 'mapper', 'selectresults', + 'lazytest1', 'eagertest1', 'eagertest2', diff --git a/test/lazytest1.py b/test/lazytest1.py new file mode 100644 index 0000000000..a94940816f --- /dev/null +++ b/test/lazytest1.py @@ -0,0 +1,92 @@ +from testbase import PersistTest, AssertMixin +import testbase +import unittest, sys, os +from sqlalchemy import * +import datetime + +class LazyTest(AssertMixin): + def setUpAll(self): + global info_table, data_table, rel_table + engine = testbase.db + info_table = Table('infos', engine, + Column('pk', Integer, primary_key=True), + Column('info', String)) + + data_table = Table('data', engine, + Column('data_pk', Integer, primary_key=True), + Column('info_pk', Integer, ForeignKey(info_table.c.pk)), + Column('timeval', Integer), + Column('data_val', String)) + + rel_table = Table('rels', engine, + Column('rel_pk', Integer, primary_key=True), + Column('info_pk', Integer, ForeignKey(info_table.c.pk)), + Column('start', Integer), + Column('finish', Integer)) + + + info_table.create() + rel_table.create() + data_table.create() + info_table.insert().execute( + {'pk':1, 'info':'pk_1_info'}, + {'pk':2, 'info':'pk_2_info'}, + {'pk':3, 'info':'pk_3_info'}, + {'pk':4, 'info':'pk_4_info'}, + {'pk':5, 'info':'pk_5_info'}) + + rel_table.insert().execute( + {'rel_pk':1, 'info_pk':1, 'start':10, 'finish':19}, + {'rel_pk':2, 'info_pk':1, 'start':100, 'finish':199}, + {'rel_pk':3, 'info_pk':2, 'start':20, 'finish':29}, + {'rel_pk':4, 'info_pk':3, 'start':13, 'finish':23}, + {'rel_pk':5, 'info_pk':5, 'start':15, 'finish':25}) + + data_table.insert().execute( + {'data_pk':1, 'info_pk':1, 'timeval':11, 'data_val':'11_data'}, + {'data_pk':2, 'info_pk':1, 'timeval':9, 'data_val':'9_data'}, + {'data_pk':3, 'info_pk':1, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':4, 'info_pk':2, 'timeval':23, 'data_val':'23_data'}, + {'data_pk':5, 'info_pk':2, 'timeval':13, 'data_val':'13_data'}, + {'data_pk':6, 'info_pk':1, 'timeval':15, 'data_val':'15_data'}) + + + def tearDownAll(self): + data_table.drop() + rel_table.drop() + info_table.drop() + + def testone(self): + """tests a lazy load which has multiple join conditions, including two that are against + the same column in the child table""" + class Information(object): + pass + + class Relation(object): + pass + + class Data(object): + pass + + # Create the basic mappers, with no frills or modifications + Information.mapper = mapper(Information, info_table) + Data.mapper = mapper(Data, data_table) + Relation.mapper = mapper(Relation, rel_table) + + Relation.mapper.add_property('datas', relation(Data.mapper, + primaryjoin=and_(Relation.c.info_pk==Data.c.info_pk, + Data.c.timeval >= Relation.c.start, + Data.c.timeval <= Relation.c.finish), + foreignkey=Data.c.info_pk)) + + Information.mapper.add_property('rels', relation(Relation.mapper)) + + info = Information.mapper.get(1) + assert info + assert len(info.rels) == 2 + assert len(info.rels[0].datas) == 3 + +if __name__ == "__main__": + testbase.main() + + -- 2.47.2