self._configure_class_instrumentation()
self._configure_listeners()
self._configure_properties()
+ self._configure_polymorphic_setter()
self._configure_pks()
global _new_mappers
_new_mappers = True
if self.polymorphic_identity is not None:
self.polymorphic_map[self.polymorphic_identity] = self
self._identity_class = self.class_
-
+
if self.mapped_table is None:
raise sa_exc.ArgumentError(
"Mapper '%s' does not have a mapped_table specified."
init=False,
setparent=True)
+ def _configure_polymorphic_setter(self):
+ """Configure an attribute on the mapper representing the
+ 'polymorphic_on' column, if applicable, and not
+ already generated by _configure_properties (which is typical).
+
+ Also create a setter function which will assign this
+ attribute to the value of the 'polymorphic_identity'
+ upon instance construction, also if applicable. This
+ routine will run when an instance is created.
+
+ """
# do a special check for the "discriminiator" column, as it
# may only be present in the 'with_polymorphic' selectable
# but we need it for the base mapper
- if self.polymorphic_on is not None and \
- self.polymorphic_on not in self._columntoproperty:
- col = self.mapped_table.corresponding_column(self.polymorphic_on)
- if col is None:
- instrument = False
- col = self.polymorphic_on
- if self.with_polymorphic is None \
- or self.with_polymorphic[1].corresponding_column(col) \
- is None:
- util.warn("Could not map polymorphic_on column "
- "'%s' to the mapped table - polymorphic "
- "loads will not function properly"
- % col.description)
- else:
- instrument = True
- if self._should_exclude(col.key, col.key, local=False, column=col):
- raise sa_exc.InvalidRequestError(
- "Cannot exclude or override the discriminator column %r" %
- col.key)
+ setter = False
+
+ if self.polymorphic_on is not None:
+ setter = True
+
+ if self.polymorphic_on not in self._columntoproperty:
+ col = self.mapped_table.corresponding_column(self.polymorphic_on)
+ if col is None:
+ setter = False
+ instrument = False
+ col = self.polymorphic_on
+ if self.with_polymorphic is None \
+ or self.with_polymorphic[1].corresponding_column(col) \
+ is None:
+ util.warn("Could not map polymorphic_on column "
+ "'%s' to the mapped table - polymorphic "
+ "loads will not function properly"
+ % col.description)
+ else:
+ instrument = True
+
+ if self._should_exclude(col.key, col.key, False, col):
+ raise sa_exc.InvalidRequestError(
+ "Cannot exclude or override the discriminator column %r" %
+ col.key)
- self._configure_property(
- col.key,
- properties.ColumnProperty(col, _instrument=instrument),
- init=False, setparent=True)
+ self._configure_property(
+ col.key,
+ properties.ColumnProperty(col, _instrument=instrument),
+ init=False, setparent=True)
+ polymorphic_key = col.key
+ else:
+ polymorphic_key = self._columntoproperty[self.polymorphic_on].key
+
+ if setter:
+ def _set_polymorphic_identity(state):
+ dict_ = state.dict
+ state.get_impl(polymorphic_key).set(state, dict_,
+ self.polymorphic_identity, None)
+
+ self._set_polymorphic_identity = _set_polymorphic_identity
+ else:
+ self._set_polymorphic_identity = None
+
+
def _adapt_inherited_property(self, key, prop, init):
if not self.concrete:
if col is mapper.version_id_col:
params[col.key] = \
mapper.version_id_generator(None)
- elif mapper.polymorphic_on is not None and \
- mapper.polymorphic_on.shares_lineage(col):
- value = mapper.polymorphic_identity
- if ((col.default is None and
- col.server_default is None) or
- value is not None):
- params[col.key] = value
elif col in pks:
value = \
mapper._get_state_attr_by_column(
passive=True)
if history.added:
hasdata = True
- elif mapper.polymorphic_on is not None and \
- mapper.polymorphic_on.shares_lineage(col) and \
- col not in pks:
- pass
else:
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
postfetch_cols = resultproxy.postfetch_cols()
generated_cols = list(resultproxy.prefetch_cols())
- if self.polymorphic_on is not None:
- po = table.corresponding_column(self.polymorphic_on)
- if po is not None:
- generated_cols.append(po)
-
if self.version_id_col is not None:
generated_cols.append(self.version_id_col)
if instrumenting_mapper:
# compile() always compiles all mappers
instrumenting_mapper.compile()
+
+ if instrumenting_mapper._set_polymorphic_identity:
+ instrumenting_mapper._set_polymorphic_identity(state)
def _event_on_resurrect(state):
# re-populate the primary key elements
eq_(sess.query(T2).filter(T2.info=='at2').one(), at2)
eq_(at2.info, "THE INFO IS:at2")
+class PolymorphicAttributeManagementTest(_base.MappedTest):
+ """Test polymorphic_on can be assigned, can be mirrored, etc."""
+
+ run_setup_mappers = 'once'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('table_a', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('class_name', String(50))
+ )
+ Table('table_b', metadata,
+ Column('id', Integer, ForeignKey('table_a.id'), primary_key=True),
+ Column('class_name', String(50))
+ )
+ Table('table_c', metadata,
+ Column('id', Integer, ForeignKey('table_b.id'),primary_key=True)
+ )
+
+ @classmethod
+ @testing.resolve_artifact_names
+ def setup_classes(cls):
+ class A(_base.ComparableEntity):
+ pass
+ class B(A):
+ pass
+ class C(B):
+ pass
+
+ mapper(A, table_a,
+ polymorphic_on=table_a.c.class_name,
+ polymorphic_identity='a')
+ mapper(B, table_b, inherits=A,
+ polymorphic_on=table_b.c.class_name,
+ polymorphic_identity='b')
+ mapper(C, table_c, inherits=B,
+ polymorphic_identity='c')
+
+ @testing.resolve_artifact_names
+ def test_poly_configured_immediate(self):
+ a = A()
+ b = B()
+ c = C()
+ eq_(a.class_name, 'a')
+ eq_(b.class_name, 'b')
+ eq_(c.class_name, 'c')
+
+ @testing.resolve_artifact_names
+ def test_base_class(self):
+ sess = Session()
+ c1 = C()
+ sess.add(c1)
+ sess.commit()
+
+ assert isinstance(sess.query(B).first(), C)
+
+ sess.close()
+
+ assert isinstance(sess.query(A).first(), C)
+
+ @testing.resolve_artifact_names
+ def test_assignment(self):
+ sess = Session()
+ b1 = B()
+ b1.class_name = 'c'
+ sess.add(b1)
+ sess.commit()
+ sess.close()
+ assert isinstance(sess.query(B).first(), C)
class CascadeTest(_base.MappedTest):
"""that cascades on polymorphic relationships continue