From: Mike Bayer Date: Sun, 17 May 2009 22:20:28 +0000 (+0000) Subject: - The "polymorphic discriminator" column may be part of a X-Git-Tag: rel_0_5_4~2 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=eb30cb1febce323d3647527d8a63a8267c943832;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - The "polymorphic discriminator" column may be part of a primary key, and it will be populated with the correct discriminator value. [ticket:1300] --- diff --git a/CHANGES b/CHANGES index 5e0bc71e57..ae4fdcadf6 100644 --- a/CHANGES +++ b/CHANGES @@ -64,6 +64,10 @@ CHANGES mixin version of the AppenderQuery, which allows subclassing the AppenderMixin. + - The "polymorphic discriminator" column may be part of a + primary key, and it will be populated with the correct + discriminator value. [ticket:1300] + - Fixed the evaluator not being able to evaluate IS NULL clauses. - Fixed the "set collection" function on "dynamic" relations to diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 87c4c8100f..b84f0166a4 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1297,10 +1297,6 @@ class Mapper(object): for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col.key] = 1 - elif col in pks: - value = mapper._get_state_attr_by_column(state, col) - if value is not None: - params[col.key] = value elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): if self._should_log_debug: self._log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) @@ -1309,6 +1305,10 @@ class Mapper(object): col.server_default is None) or value is not None): params[col.key] = value + elif col in pks: + value = mapper._get_state_attr_by_column(state, col) + if value is not None: + params[col.key] = value else: value = mapper._get_state_attr_by_column(state, col) if ((col.default is None and diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 859419022d..65c3c2135d 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1587,7 +1587,7 @@ class ColumnElement(ClauseElement, _CompareMixin): def shares_lineage(self, othercolumn): """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``.""" - return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0 + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) def _make_proxy(self, selectable, name=None): """Create a new ``ColumnElement`` representing this diff --git a/test/orm/inheritance/basic.py b/test/orm/inheritance/basic.py index ddb4fa4ba5..d7f19a2cc0 100644 --- a/test/orm/inheritance/basic.py +++ b/test/orm/inheritance/basic.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import * from sqlalchemy.orm import exc as orm_exc from testlib import * from testlib import fixtures +from orm import _base, _fixtures class O2MTest(ORMTest): """deals with inheritance and one-to-many relationships""" @@ -924,6 +925,49 @@ class OptimizedLoadTest(ORMTest): # the optimized load needs to return "None" so regular full-row loading proceeds s1 = sess.query(Base).get(s1.id) assert s1.sub == 's1sub' + +class PKDiscriminatorTest(_base.MappedTest): + def define_tables(self, metadata): + parents = Table('parents', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(60))) + + children = Table('children', metadata, + Column('id', Integer, ForeignKey('parents.id'), primary_key=True), + Column('type', Integer,primary_key=True), + Column('name', String(60))) + + @testing.resolve_artifact_names + def test_pk_as_discriminator(self): + class Parent(object): + def __init__(self, name=None): + self.name = name + + class Child(object): + def __init__(self, name=None): + self.name = name + + class A(Child): + pass + + mapper(Parent, parents, properties={ + 'children': relation(Child, backref='parent'), + }) + mapper(Child, children, polymorphic_on=children.c.type, + polymorphic_identity=1) + + mapper(A, inherits=Child, polymorphic_identity=2) + + s = create_session() + p = Parent('p1') + a = A('a1') + p.children.append(a) + s.add(p) + s.flush() + + assert a.id + assert a.type == 2 + class DeleteOrphanTest(ORMTest): def define_tables(self, metadata):