]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the "primary_key" argument to mapper() is propigated to the "polymorphic"
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 May 2007 22:33:52 +0000 (22:33 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 May 2007 22:33:52 +0000 (22:33 +0000)
mapper.  primary key columns in this list get normalized to that of the mapper's
local table.

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance5.py

diff --git a/CHANGES b/CHANGES
index 332611431ae354d8e8fb902f35b1f33650b70e72..ee08b9f7f160403a26a49f881a979f063ba83ef3 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -30,6 +30,9 @@
       to be embedded into a correlated subquery [ticket:577]
     - fix to select_by(<propname>=<object instance>) -style joins in conjunction
       with many-to-many relationships, bug introduced in r2556 
+    - the "primary_key" argument to mapper() is propigated to the "polymorphic"
+      mapper.  primary key columns in this list get normalized to that of the mapper's 
+      local table.
 - mysql
     - support for column-level CHARACTER SET and COLLATE declarations,
       as well as ASCII, UNICODE, NATIONAL and BINARY shorthand.
index af9d53ac9cdce71c1e0516f1cf86e629468e0b7d..7e44d8a42eaf0234a0f867c9688b7cb8d47ba62d 100644 (file)
@@ -199,7 +199,7 @@ class Mapper(object):
         self.class_ = class_
         self.entity_name = entity_name
         self.class_key = ClassKey(class_, entity_name)
-        self.primary_key = primary_key
+        self.primary_key_argument = primary_key
         self.non_primary = non_primary
         self.order_by = order_by
         self.always_refresh = always_refresh
@@ -486,17 +486,19 @@ class Mapper(object):
 
         # determine primary key columns, either passed in, or get them from our set of tables
         self.pks_by_table = {}
-        if self.primary_key is not None:
+        if self.primary_key_argument is not None:
             # determine primary keys using user-given list of primary key columns as a guide
             #
             # TODO: this might not work very well for joined-table and/or polymorphic
             # inheritance mappers since local_table isnt taken into account nor is select_table
             # need to test custom primary key columns used with inheriting mappers
-            for k in self.primary_key:
+            for k in self.primary_key_argument:
                 self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
                 if k.table != self.mapped_table:
                     # associate pk cols from subtables to the "main" table
-                    self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(k)
+                    corr = self.mapped_table.corresponding_column(k, raiseerr=False)
+                    if corr is not None:
+                        self.pks_by_table.setdefault(self.mapped_table, util.OrderedSet()).add(corr)
         else:
             # no user-defined primary key columns - go through all of our represented tables
             # and assemble primary key columns
@@ -515,7 +517,7 @@ class Mapper(object):
         if len(self.pks_by_table[self.mapped_table]) == 0:
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
         self.primary_key = self.pks_by_table[self.mapped_table]
-
+        
     def _compile_properties(self):
         """Inspect the properties dictionary sent to the Mapper's
         constructor as well as the mapped_table, and create
@@ -615,7 +617,7 @@ class Mapper(object):
                         props[key] = self.select_table.corresponding_column(prop)
                     elif (isinstance(prop, list) and sql.is_column(prop[0])):
                         props[key] = [self.select_table.corresponding_column(c) for c in prop]
-            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on))
+            self.__surrogate_mapper = Mapper(self.class_, self.select_table, non_primary=True, properties=props, _polymorphic_map=self.polymorphic_map, polymorphic_on=self.select_table.corresponding_column(self.polymorphic_on), primary_key=self.primary_key_argument)
 
     def _compile_class(self):
         """If this mapper is to be a primary mapper (i.e. the
index 43f1fa3ca848068ea18a7241a9a54e967e90338d..1c42146356493bc72962de6daaa352b765c25685 100644 (file)
@@ -835,6 +835,54 @@ class ManyToManyPolyTest(testbase.ORMTest):
         mapper(Collection, collection_table)
         
         class_mapper(BaseItem)
+
+class CustomPKTest(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global t1, t2
+        t1 = Table('t1', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('type', String(30), nullable=False),
+            Column('data', String(30)))
+        t2 = Table('t2', metadata,
+            Column('t2id', Integer, ForeignKey('t1.id'), primary_key=True),
+            Column('t2data', String(30)))
+    def test_custompk(self):
+        """test that the primary_key attribute is propigated to the polymorphic mapper"""
+        
+        class T1(object):pass
+        class T2(T1):pass
+        
+        # create a polymorphic union with the select against the base table first.
+        # with the join being second, the alias of the union will 
+        # pick up two "primary key" columns.  technically the alias should have a
+        # 2-col pk in any case but the leading select has a NULL for the "t2id" column
+        d = util.OrderedDict()
+        d['t1'] = t1.select(t1.c.type=='t1')
+        d['t2'] = t1.join(t2)
+        pjoin = polymorphic_union(d, None, 'pjoin')
+        
+        #print pjoin.original.primary_key
+        #print pjoin.primary_key
+        assert len(pjoin.primary_key) == 2
+        
+        mapper(T1, t1, polymorphic_on=t1.c.type, polymorphic_identity='t1', select_table=pjoin, primary_key=[pjoin.c.id])
+        mapper(T2, t2, inherits=T1, polymorphic_identity='t2')
+        print [str(c) for c in class_mapper(T1).primary_key]
+        ot1 = T1()
+        ot2 = T2()
+        sess = create_session()
+        sess.save(ot1)
+        sess.save(ot2)
+        sess.flush()
+        sess.clear()
+        
+        # query using get(), using only one value.  this requires the select_table mapper
+        # has the same single-col primary key.
+        assert sess.query(T1).get(ot1.id).id is ot1.id
+        
+        ot1 = sess.query(T1).get(ot1.id)
+        ot1.data = 'hi'
+        sess.flush()
         
 if __name__ == "__main__":    
     testbase.main()