From: Mike Bayer Date: Sat, 20 Nov 2010 21:25:12 +0000 (-0500) Subject: - the column assigned to polymorphic_on now behaves like any other X-Git-Tag: rel_0_7b1~242 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3906eed72bb1372e70977092d4900459a97d8e74;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - the column assigned to polymorphic_on now behaves like any other mapped attribute, in that it can be assigned to, mapped to multiple columns. It is also populated immediately upon object construction with its class-based value, so the attribute can be read before any flush occurs. [ticket:1895] --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8fc0b41321..32eb8f6433 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -208,6 +208,7 @@ class Mapper(object): self._configure_class_instrumentation() self._configure_listeners() self._configure_properties() + self._configure_polymorphic_setter() self._configure_pks() global _new_mappers _new_mappers = True @@ -310,7 +311,7 @@ class Mapper(object): if self.polymorphic_identity is not None: self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ - + if self.mapped_table is None: raise sa_exc.ArgumentError( "Mapper '%s' does not have a mapped_table specified." @@ -560,33 +561,65 @@ class Mapper(object): init=False, setparent=True) + def _configure_polymorphic_setter(self): + """Configure an attribute on the mapper representing the + 'polymorphic_on' column, if applicable, and not + already generated by _configure_properties (which is typical). + + Also create a setter function which will assign this + attribute to the value of the 'polymorphic_identity' + upon instance construction, also if applicable. This + routine will run when an instance is created. + + """ # do a special check for the "discriminiator" column, as it # may only be present in the 'with_polymorphic' selectable # but we need it for the base mapper - if self.polymorphic_on is not None and \ - self.polymorphic_on not in self._columntoproperty: - col = self.mapped_table.corresponding_column(self.polymorphic_on) - if col is None: - instrument = False - col = self.polymorphic_on - if self.with_polymorphic is None \ - or self.with_polymorphic[1].corresponding_column(col) \ - is None: - util.warn("Could not map polymorphic_on column " - "'%s' to the mapped table - polymorphic " - "loads will not function properly" - % col.description) - else: - instrument = True - if self._should_exclude(col.key, col.key, local=False, column=col): - raise sa_exc.InvalidRequestError( - "Cannot exclude or override the discriminator column %r" % - col.key) + setter = False + + if self.polymorphic_on is not None: + setter = True + + if self.polymorphic_on not in self._columntoproperty: + col = self.mapped_table.corresponding_column(self.polymorphic_on) + if col is None: + setter = False + instrument = False + col = self.polymorphic_on + if self.with_polymorphic is None \ + or self.with_polymorphic[1].corresponding_column(col) \ + is None: + util.warn("Could not map polymorphic_on column " + "'%s' to the mapped table - polymorphic " + "loads will not function properly" + % col.description) + else: + instrument = True + + if self._should_exclude(col.key, col.key, False, col): + raise sa_exc.InvalidRequestError( + "Cannot exclude or override the discriminator column %r" % + col.key) - self._configure_property( - col.key, - properties.ColumnProperty(col, _instrument=instrument), - init=False, setparent=True) + self._configure_property( + col.key, + properties.ColumnProperty(col, _instrument=instrument), + init=False, setparent=True) + polymorphic_key = col.key + else: + polymorphic_key = self._columntoproperty[self.polymorphic_on].key + + if setter: + def _set_polymorphic_identity(state): + dict_ = state.dict + state.get_impl(polymorphic_key).set(state, dict_, + self.polymorphic_identity, None) + + self._set_polymorphic_identity = _set_polymorphic_identity + else: + self._set_polymorphic_identity = None + + def _adapt_inherited_property(self, key, prop, init): if not self.concrete: @@ -1656,13 +1689,6 @@ class Mapper(object): if col is mapper.version_id_col: params[col.key] = \ mapper.version_id_generator(None) - elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col): - value = mapper.polymorphic_identity - if ((col.default is None and - col.server_default is None) or - value is not None): - params[col.key] = value elif col in pks: value = \ mapper._get_state_attr_by_column( @@ -1715,10 +1741,6 @@ class Mapper(object): passive=True) if history.added: hasdata = True - elif mapper.polymorphic_on is not None and \ - mapper.polymorphic_on.shares_lineage(col) and \ - col not in pks: - pass else: prop = mapper._columntoproperty[col] history = attributes.get_state_history( @@ -1890,11 +1912,6 @@ class Mapper(object): postfetch_cols = resultproxy.postfetch_cols() generated_cols = list(resultproxy.prefetch_cols()) - if self.polymorphic_on is not None: - po = table.corresponding_column(self.polymorphic_on) - if po is not None: - generated_cols.append(po) - if self.version_id_col is not None: generated_cols.append(self.version_id_col) @@ -2390,6 +2407,9 @@ def _event_on_init(state, args, kwargs): if instrumenting_mapper: # compile() always compiles all mappers instrumenting_mapper.compile() + + if instrumenting_mapper._set_polymorphic_identity: + instrumenting_mapper._set_polymorphic_identity(state) def _event_on_resurrect(state): # re-populate the primary key elements diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index 3e4d50d96b..896c7618f0 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -201,6 +201,75 @@ class PolymorphicSynonymTest(_base.MappedTest): eq_(sess.query(T2).filter(T2.info=='at2').one(), at2) eq_(at2.info, "THE INFO IS:at2") +class PolymorphicAttributeManagementTest(_base.MappedTest): + """Test polymorphic_on can be assigned, can be mirrored, etc.""" + + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + Table('table_a', metadata, + Column('id', Integer, primary_key=True), + Column('class_name', String(50)) + ) + Table('table_b', metadata, + Column('id', Integer, ForeignKey('table_a.id'), primary_key=True), + Column('class_name', String(50)) + ) + Table('table_c', metadata, + Column('id', Integer, ForeignKey('table_b.id'),primary_key=True) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_classes(cls): + class A(_base.ComparableEntity): + pass + class B(A): + pass + class C(B): + pass + + mapper(A, table_a, + polymorphic_on=table_a.c.class_name, + polymorphic_identity='a') + mapper(B, table_b, inherits=A, + polymorphic_on=table_b.c.class_name, + polymorphic_identity='b') + mapper(C, table_c, inherits=B, + polymorphic_identity='c') + + @testing.resolve_artifact_names + def test_poly_configured_immediate(self): + a = A() + b = B() + c = C() + eq_(a.class_name, 'a') + eq_(b.class_name, 'b') + eq_(c.class_name, 'c') + + @testing.resolve_artifact_names + def test_base_class(self): + sess = Session() + c1 = C() + sess.add(c1) + sess.commit() + + assert isinstance(sess.query(B).first(), C) + + sess.close() + + assert isinstance(sess.query(A).first(), C) + + @testing.resolve_artifact_names + def test_assignment(self): + sess = Session() + b1 = B() + b1.class_name = 'c' + sess.add(b1) + sess.commit() + sess.close() + assert isinstance(sess.query(B).first(), C) class CascadeTest(_base.MappedTest): """that cascades on polymorphic relationships continue