From edbfbf81f7ffa4f38e086e51d8efbde7230e0f28 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 11 Dec 2007 17:29:08 +0000 Subject: [PATCH] - fix to cascades on polymorphic relations, such that cascades from an object to a polymorphic collection continue cascading along the set of attributes specific to each element in the collection. --- CHANGES | 4 ++ lib/sqlalchemy/orm/mapper.py | 4 +- lib/sqlalchemy/orm/properties.py | 9 ++-- test/orm/inheritance/basic.py | 70 +++++++++++++++++++++++++++++++- 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/CHANGES b/CHANGES index 6154bc3d18..5a9893368c 100644 --- a/CHANGES +++ b/CHANGES @@ -87,6 +87,10 @@ CHANGES - also with dynamic, implemented correct count() behavior as well as other helper methods. + - fix to cascades on polymorphic relations, such that cascades + from an object to a polymorphic collection continue cascading + along the set of attributes specific to each element in the collection. + - query.get() and query.load() do not take existing filter or other criterion into account; these methods *always* look up the given id in the database or return the current instance from the identity map, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 0f5dbaaf54..e9fe41fdc7 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1538,8 +1538,8 @@ def has_mapper(object): return hasattr(object, '_entity_name') -def _state_mapper(state): - return state.class_._class_state.mappers[state.dict.get('_entity_name', None)] +def _state_mapper(state, entity_name=None): + return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)] def object_mapper(object, entity_name=None, raiseerror=True): """Given an object, return the primary Mapper associated with the object instance. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 9394e9aead..4d41556a07 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -13,7 +13,7 @@ to handle flush-time dependency sorting and processing. from sqlalchemy import sql, schema, util, exceptions, logging from sqlalchemy.sql import util as sql_util, visitors, operators, ColumnElement -from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency +from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty @@ -365,8 +365,11 @@ class PropertyLoader(StrategizedProperty): if not isinstance(c, self.mapper.class_): raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) recursive.add(c) - yield (c, mapper) - for (c2, m) in mapper.cascade_iterator(type, c._state, recursive): + + # cascade using the mapper local to this object, so that its individual properties are located + instance_mapper = object_mapper(c, entity_name=mapper.entity_name) + yield (c, instance_mapper) + for (c2, m) in instance_mapper.cascade_iterator(type, c._state, recursive): yield (c2, m) def _get_target_class(self): diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index 05603ac864..2ef76b6d8d 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -9,7 +9,6 @@ class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" def define_tables(self, metadata): global foo, bar, blub - # the 'data' columns are to appease SQLite which cant handle a blank INSERT foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_seq', optional=True), primary_key=True), @@ -65,7 +64,76 @@ class O2MTest(ORMTest): self.assert_(compare == result) self.assert_(l[0].parent_foo.data == 'foo #1' and l[1].parent_foo.data == 'foo #1') +class CascadeTest(ORMTest): + """that cascades on polymorphic relations continue + cascading along the path of the instance's mapper, not + the base mapper.""" + + def define_tables(self, metadata): + global t1, t2, t3, t4 + t1= Table('t1', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30)) + ) + + t2 = Table('t2', metadata, + Column('id', Integer, primary_key=True), + Column('t1id', Integer, ForeignKey('t1.id')), + Column('type', String(30)), + Column('data', String(30)) + ) + t3 = Table('t3', metadata, + Column('id', Integer, ForeignKey('t2.id'), primary_key=True), + Column('moredata', String(30))) + + t4 = Table('t4', metadata, + Column('id', Integer, primary_key=True), + Column('t3id', Integer, ForeignKey('t3.id')), + Column('data', String(30))) + + def test_cascade(self): + class T1(fixtures.Base): + pass + class T2(fixtures.Base): + pass + class T3(T2): + pass + class T4(fixtures.Base): + pass + + mapper(T1, t1, properties={ + 't2s':relation(T2, cascade="all") + }) + mapper(T2, t2, polymorphic_on=t2.c.type, polymorphic_identity='t2') + mapper(T3, t3, inherits=T2, polymorphic_identity='t3', properties={ + 't4s':relation(T4, cascade="all") + }) + mapper(T4, t4) + + sess = create_session() + t1_1 = T1(data='t1') + + t3_1 = T3(data ='t3', moredata='t3') + t2_1 = T2(data='t2') + + t1_1.t2s.append(t2_1) + t1_1.t2s.append(t3_1) + + t4_1 = T4(data='t4') + t3_1.t4s.append(t4_1) + + sess.save(t1_1) + + assert t4_1 in sess.new + sess.flush() + + sess.delete(t1_1) + assert t4_1 in sess.deleted + sess.flush() + + + class GetTest(ORMTest): def define_tables(self, metadata): global foo, bar, blub -- 2.47.3