- firebird has supports_sane_rowcount and supports_sane_multi_rowcount set
to False due to ticket #370 (right way).
+- fixed three- and multi-level select and deferred inheritance
+ loading (i.e. abc inheritance with no select_table), [ticket:795]
+
0.4.0beta6
----------
return instance
- def _deferred_inheritance_condition(self, needs_tables):
- cond = self.inherit_condition
-
- param_names = []
+ def _deferred_inheritance_condition(self, base_mapper, needs_tables):
def visit_binary(binary):
leftcol = binary.left
rightcol = binary.right
elif rightcol not in needs_tables:
binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True)
param_names.append(rightcol)
- cond = mapperutil.BinaryVisitor(visit_binary).traverse(cond, clone=True)
- return cond, param_names
+
+ allconds = []
+ param_names = []
+
+ visitor = mapperutil.BinaryVisitor(visit_binary)
+ for mapper in self.iterate_to_root():
+ if mapper is base_mapper:
+ break
+ allconds.append(visitor.traverse(mapper.inherit_condition, clone=True))
+
+ return sql.and_(*allconds), param_names
def translate_row(self, tomapper, row):
"""Translate the column keys of a row into a new or proxied
if hosted_mapper is None or len(needs_tables)==0 or hosted_mapper.polymorphic_fetch == 'deferred':
return
- cond, param_names = self._deferred_inheritance_condition(needs_tables)
+ cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
statement = sql.select(needs_tables, cond, use_labels=True)
def post_execute(instance, **flags):
self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
# 'deferred' polymorphic row fetcher, put a callable on the property.
def new_execute(instance, row, isnew, **flags):
if isnew:
- sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, needs_tables))
+ sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, hosted_mapper, needs_tables))
if self._should_log_debug:
self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
return (new_execute, None, None)
self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
return (None, None, None)
- def _get_deferred_inheritance_loader(self, instance, mapper, needs_tables):
+ def _get_deferred_inheritance_loader(self, instance, mapper, hosted_mapper, needs_tables):
+ # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
+ # the mapper for the object creates the WHERE criterion using the mapper who originally
+ # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
+ # and this mapper. (i.e. A->B->C, the query used mapper A. therefore will need B's and C's tables
+ # in the query).
def create_statement():
- cond, param_names = mapper._deferred_inheritance_condition(needs_tables)
+ # TODO: the SELECT statement here should be cached in the selectcontext. we are somewhat duplicating
+ # efforts from mapper._get_poly_select_loader as well and should look
+ # for ways to simplify.
+ cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
statement = sql.select(needs_tables, cond, use_labels=True)
params = {}
for c in param_names:
params[c.name] = mapper.get_attr_by_column(instance, c)
return (statement, params)
+ # install the create_statement() callable using the deferred loading strategy
strategy = self.parent_property._get_strategy(DeferredColumnLoader)
+ # assemble list of all ColumnProperties which will need to be loaded
props = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
+
+ # set the deferred loader on the instance attribute
return strategy.setup_loader(instance, props=props, create_statement=create_statement)
--- /dev/null
+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions, util
+from sqlalchemy.orm import *
+from testlib import *
+from testlib import fixtures
+
+class ABCTest(ORMTest):
+ def define_tables(self, metadata):
+ global a, b, c
+ a = Table('a', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('adata', String(30)),
+ Column('type', String(30)),
+ )
+ b = Table('b', metadata,
+ Column('id', Integer, ForeignKey('a.id'), primary_key=True),
+ Column('bdata', String(30)))
+ c = Table('c', metadata,
+ Column('id', Integer, ForeignKey('b.id'), primary_key=True),
+ Column('cdata', String(30)))
+
+ def make_test(fetchtype):
+ def test_roundtrip(self):
+ class A(fixtures.Base):pass
+ class B(A):pass
+ class C(B):pass
+
+ if fetchtype == 'union':
+ abc = a.outerjoin(b).outerjoin(c)
+ bc = a.join(b).outerjoin(c)
+ else:
+ abc = bc = None
+
+ mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype)
+ mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype)
+ mapper(C, c, inherits=B, polymorphic_identity='c')
+
+ a1 = A(adata='a1')
+ b1 = B(bdata='b1', adata='b1')
+ b2 = B(bdata='b2', adata='b2')
+ b3 = B(bdata='b3', adata='b3')
+ c1 = C(cdata='c1', bdata='c1', adata='c1')
+ c2 = C(cdata='c2', bdata='c2', adata='c2')
+ c3 = C(cdata='c2', bdata='c2', adata='c2')
+
+ sess = create_session()
+ for x in (a1, b1, b2, b3, c1, c2, c3):
+ sess.save(x)
+ sess.flush()
+ sess.clear()
+
+ #for obj in sess.query(A).all():
+ # print obj
+ assert [
+ A(adata='a1'),
+ B(bdata='b1', adata='b1'),
+ B(bdata='b2', adata='b2'),
+ B(bdata='b3', adata='b3'),
+ C(cdata='c1', bdata='c1', adata='c1'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ ] == sess.query(A).all()
+
+ assert [
+ B(bdata='b1', adata='b1'),
+ B(bdata='b2', adata='b2'),
+ B(bdata='b3', adata='b3'),
+ C(cdata='c1', bdata='c1', adata='c1'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ ] == sess.query(B).all()
+
+ assert [
+ C(cdata='c1', bdata='c1', adata='c1'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ C(cdata='c2', bdata='c2', adata='c2'),
+ ] == sess.query(C).all()
+
+ test_roundtrip.__name__ = 'test_%s' % fetchtype
+ return test_roundtrip
+
+ test_union = make_test('union')
+ test_select = make_test('select')
+ test_deferred = make_test('deferred')
+
+
+if __name__ == '__main__':
+ testbase.main()
+
\ No newline at end of file
'orm.inheritance.polymorph',
'orm.inheritance.polymorph2',
'orm.inheritance.poly_linked_list',
+ 'orm.inheritance.abc_polymorphic',
'orm.inheritance.abc_inheritance',
'orm.inheritance.productspec',
'orm.inheritance.magazine',
self.assert_(compare == result)
self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1')
+
class GetTest(ORMTest):
def define_tables(self, metadata):
global foo, bar, blub
setattr(self, k, kwargs[k])
# TODO: add recursion checks to this
- #def __repr__(self):
- # return "%s(%s)" % (
- # (self.__class__.__name__),
- # ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')])
- # )
+ def __repr__(self):
+ return "%s(%s)" % (
+ (self.__class__.__name__),
+ ','.join(["%s=%s" % (key, repr(getattr(self, key))) for key in self.__dict__ if not key.startswith('_')])
+ )
def __ne__(self, other):
return not self.__eq__(other)