From 0dde728591cb083e351cf1ff1998aaf1883ab7a1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 29 Sep 2007 20:17:40 +0000 Subject: [PATCH] - fixed three- and multi-level select and deferred inheritance loading (i.e. abc inheritance with no select_table), [ticket:795] --- CHANGES | 3 + lib/sqlalchemy/orm/mapper.py | 20 ++++-- lib/sqlalchemy/orm/strategies.py | 18 ++++- test/orm/inheritance/abc_polymorphic.py | 90 +++++++++++++++++++++++++ test/orm/inheritance/alltests.py | 1 + test/orm/inheritance/basic.py | 1 + test/testlib/fixtures.py | 10 +-- 7 files changed, 128 insertions(+), 15 deletions(-) create mode 100644 test/orm/inheritance/abc_polymorphic.py diff --git a/CHANGES b/CHANGES index 095e582b49..d39c17e9a0 100644 --- a/CHANGES +++ b/CHANGES @@ -16,6 +16,9 @@ CHANGES - firebird has supports_sane_rowcount and supports_sane_multi_rowcount set to False due to ticket #370 (right way). +- fixed three- and multi-level select and deferred inheritance + loading (i.e. abc inheritance with no select_table), [ticket:795] + 0.4.0beta6 ---------- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 9764a0ae63..b2bffd6ea5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1442,10 +1442,7 @@ class Mapper(object): return instance - def _deferred_inheritance_condition(self, needs_tables): - cond = self.inherit_condition - - param_names = [] + def _deferred_inheritance_condition(self, base_mapper, needs_tables): def visit_binary(binary): leftcol = binary.left rightcol = binary.right @@ -1457,8 +1454,17 @@ class Mapper(object): elif rightcol not in needs_tables: binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True) param_names.append(rightcol) - cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True) - return cond, param_names + + allconds = [] + param_names = [] + + visitor = mapperutil.BinaryVisitor(visit_binary) + for mapper in self.iterate_to_root(): + if mapper is base_mapper: + break + allconds.append(visitor.traverse(mapper.inherit_condition, clone=True)) + + return sql.and_(*allconds), param_names def translate_row(self, tomapper, row): """Translate the column keys of a row into a new or proxied @@ -1532,7 +1538,7 @@ class Mapper(object): if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred': return - cond, param_names = self._deferred_inheritance_condition(needs_tables) + cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables) statement = sql.select(needs_tables, cond, use_labels=True) def post_execute(instance, **flags): self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance)) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b93993af8d..09b51c203f 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -89,7 +89,7 @@ class ColumnLoader(LoaderStrategy): # 'deferred' polymorphic row fetcher, put a callable on the property. def new_execute(instance, row, isnew, **flags): if isnew: - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, needs_tables)) + sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, hosted_mapper, needs_tables)) if self._should_log_debug: self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key)) return (new_execute, None, None) @@ -99,18 +99,30 @@ class ColumnLoader(LoaderStrategy): self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key)) return (None, None, None) - def _get_deferred_inheritance_loader(self, instance, mapper, needs_tables): + def _get_deferred_inheritance_loader(self, instance, mapper, hosted_mapper, needs_tables): + # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load. + # the mapper for the object creates the WHERE criterion using the mapper who originally + # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper + # and this mapper. (i.e. A->B->C, the query used mapper A. therefore will need B's and C's tables + # in the query). def create_statement(): - cond, param_names = mapper._deferred_inheritance_condition(needs_tables) + # TODO: the SELECT statement here should be cached in the selectcontext. we are somewhat duplicating + # efforts from mapper._get_poly_select_loader as well and should look + # for ways to simplify. + cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables) statement = sql.select(needs_tables, cond, use_labels=True) params = {} for c in param_names: params[c.name] = mapper.get_attr_by_column(instance, c) return (statement, params) + # install the create_statement() callable using the deferred loading strategy strategy = self.parent_property._get_strategy(DeferredColumnLoader) + # assemble list of all ColumnProperties which will need to be loaded props = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables] + + # set the deferred loader on the instance attribute return strategy.setup_loader(instance, props=props, create_statement=create_statement) diff --git a/test/orm/inheritance/abc_polymorphic.py b/test/orm/inheritance/abc_polymorphic.py new file mode 100644 index 0000000000..da9097637f --- /dev/null +++ b/test/orm/inheritance/abc_polymorphic.py @@ -0,0 +1,90 @@ +import testbase +from sqlalchemy import * +from sqlalchemy import exceptions, util +from sqlalchemy.orm import * +from testlib import * +from testlib import fixtures + +class ABCTest(ORMTest): + def define_tables(self, metadata): + global a, b, c + a = Table('a', metadata, + Column('id', Integer, primary_key=True), + Column('adata', String(30)), + Column('type', String(30)), + ) + b = Table('b', metadata, + Column('id', Integer, ForeignKey('a.id'), primary_key=True), + Column('bdata', String(30))) + c = Table('c', metadata, + Column('id', Integer, ForeignKey('b.id'), primary_key=True), + Column('cdata', String(30))) + + def make_test(fetchtype): + def test_roundtrip(self): + class A(fixtures.Base):pass + class B(A):pass + class C(B):pass + + if fetchtype == 'union': + abc = a.outerjoin(b).outerjoin(c) + bc = a.join(b).outerjoin(c) + else: + abc = bc = None + + mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype) + mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype) + mapper(C, c, inherits=B, polymorphic_identity='c') + + a1 = A(adata='a1') + b1 = B(bdata='b1', adata='b1') + b2 = B(bdata='b2', adata='b2') + b3 = B(bdata='b3', adata='b3') + c1 = C(cdata='c1', bdata='c1', adata='c1') + c2 = C(cdata='c2', bdata='c2', adata='c2') + c3 = C(cdata='c2', bdata='c2', adata='c2') + + sess = create_session() + for x in (a1, b1, b2, b3, c1, c2, c3): + sess.save(x) + sess.flush() + sess.clear() + + #for obj in sess.query(A).all(): + # print obj + assert [ + A(adata='a1'), + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(A).all() + + assert [ + B(bdata='b1', adata='b1'), + B(bdata='b2', adata='b2'), + B(bdata='b3', adata='b3'), + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(B).all() + + assert [ + C(cdata='c1', bdata='c1', adata='c1'), + C(cdata='c2', bdata='c2', adata='c2'), + C(cdata='c2', bdata='c2', adata='c2'), + ] == sess.query(C).all() + + test_roundtrip.__name__ = 'test_%s' % fetchtype + return test_roundtrip + + test_union = make_test('union') + test_select = make_test('select') + test_deferred = make_test('deferred') + + +if __name__ == '__main__': + testbase.main() + \ No newline at end of file diff --git a/test/orm/inheritance/alltests.py b/test/orm/inheritance/alltests.py index da59dd8fb7..dc93ed9b38 100644 --- a/test/orm/inheritance/alltests.py +++ b/test/orm/inheritance/alltests.py @@ -10,6 +10,7 @@ def suite(): 'orm.inheritance.polymorph', 'orm.inheritance.polymorph2', 'orm.inheritance.poly_linked_list', + 'orm.inheritance.abc_polymorphic', 'orm.inheritance.abc_inheritance', 'orm.inheritance.productspec', 'orm.inheritance.magazine', diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index fbdb4019e1..a033d61eac 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -63,6 +63,7 @@ class O2MTest(ORMTest): self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') + class GetTest(ORMTest): def define_tables(self, metadata): global foo, bar, blub diff --git a/test/testlib/fixtures.py b/test/testlib/fixtures.py index 1b05b366da..ada254c375 100644 --- a/test/testlib/fixtures.py +++ b/test/testlib/fixtures.py @@ -10,11 +10,11 @@ class Base(object): setattr(self, k, kwargs[k]) # TODO: add recursion checks to this - #def __repr__(self): - # return "%s(%s)" % ( - # (self.__class__.__name__), - # ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')]) - # ) + def __repr__(self): + return "%s(%s)" % ( + (self.__class__.__name__), + ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')]) + ) def __ne__(self, other): return not self.__eq__(other) -- 2.47.3