]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The "polymorphic discriminator" column may be part of a
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 22:20:28 +0000 (22:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 22:20:28 +0000 (22:20 +0000)
primary key, and it will be populated with the correct
discriminator value.  [ticket:1300]

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql/expression.py
test/orm/inheritance/basic.py

diff --git a/CHANGES b/CHANGES
index 5e0bc71e57b4a67303662257f71e1173b74ae390..ae4fdcadf66d6173c2a373994d61ccfa1f741e4b 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -64,6 +64,10 @@ CHANGES
       mixin version of the AppenderQuery, which allows subclassing
       the AppenderMixin.
 
+    - The "polymorphic discriminator" column may be part of a 
+      primary key, and it will be populated with the correct 
+      discriminator value.  [ticket:1300]
+      
     - Fixed the evaluator not being able to evaluate IS NULL clauses.
 
     - Fixed the "set collection" function on "dynamic" relations to
index 87c4c8100fa6176ba591fd2f6a3f9103e8d3e299..b84f0166a422e40047cd5ce04fad0ad034eff2d2 100644 (file)
@@ -1297,10 +1297,6 @@ class Mapper(object):
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col.key] = 1
-                        elif col in pks:
-                            value = mapper._get_state_attr_by_column(state, col)
-                            if value is not None:
-                                params[col.key] = value
                         elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col):
                             if self._should_log_debug:
                                 self._log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
@@ -1309,6 +1305,10 @@ class Mapper(object):
                                  col.server_default is None) or
                                 value is not None):
                                 params[col.key] = value
+                        elif col in pks:
+                            value = mapper._get_state_attr_by_column(state, col)
+                            if value is not None:
+                                params[col.key] = value
                         else:
                             value = mapper._get_state_attr_by_column(state, col)
                             if ((col.default is None and
index 859419022de2616a5b222ea5e68175e8218ad442..65c3c2135d74afdb44d99733e2a7835954c787fb 100644 (file)
@@ -1587,7 +1587,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
     def shares_lineage(self, othercolumn):
         """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
 
-        return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0
+        return bool(self.proxy_set.intersection(othercolumn.proxy_set))
 
     def _make_proxy(self, selectable, name=None):
         """Create a new ``ColumnElement`` representing this
index ddb4fa4ba5f74f361c98983a06f9e712ccdb29c5..d7f19a2cc0c55caece6a2a93e9dfae60a96de836 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from testlib import fixtures
+from orm import _base, _fixtures
 
 class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
@@ -924,6 +925,49 @@ class OptimizedLoadTest(ORMTest):
         # the optimized load needs to return "None" so regular full-row loading proceeds
         s1 = sess.query(Base).get(s1.id)
         assert s1.sub == 's1sub'
+
+class PKDiscriminatorTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        parents = Table('parents', metadata,
+                           Column('id', Integer, primary_key=True),
+                           Column('name', String(60)))
+                           
+        children = Table('children', metadata,
+                        Column('id', Integer, ForeignKey('parents.id'), primary_key=True),
+                        Column('type', Integer,primary_key=True),
+                        Column('name', String(60)))
+
+    @testing.resolve_artifact_names
+    def test_pk_as_discriminator(self):
+        class Parent(object):
+                def __init__(self, name=None):
+                    self.name = name
+
+        class Child(object):
+            def __init__(self, name=None):
+                self.name = name
+
+        class A(Child):
+            pass
+            
+        mapper(Parent, parents, properties={
+            'children': relation(Child, backref='parent'),
+        })
+        mapper(Child, children, polymorphic_on=children.c.type,
+            polymorphic_identity=1)
+            
+        mapper(A, inherits=Child, polymorphic_identity=2)
+
+        s = create_session()
+        p = Parent('p1')
+        a = A('a1')
+        p.children.append(a)
+        s.add(p)
+        s.flush()
+
+        assert a.id
+        assert a.type == 2
+        
         
 class DeleteOrphanTest(ORMTest):
     def define_tables(self, metadata):