]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added support for version_id_col in conjunction with inheriting mappers.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Nov 2007 00:40:56 +0000 (00:40 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Nov 2007 00:40:56 +0000 (00:40 +0000)
version_id_col is typically set on the base mapper in an inheritance
relationship where it takes effect for all inheriting mappers.
[ticket:883]
- a little rearrangement of save_obj()

CHANGES
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/basic.py

diff --git a/CHANGES b/CHANGES
index dded89564d3decc2938d056e7213fc4d52d6ebf3..1dc797d77edd6c38729a74d822149bddabdc50e0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -28,6 +28,11 @@ CHANGES
      i.e.: 'somename':synonym('_somename', map_column=True) will map the
      column named 'somename' to the attribute '_somename'. See the example
      in the mapper docs. [ticket:801]
+
+   - added support for version_id_col in conjunction with inheriting mappers.
+     version_id_col is typically set on the base mapper in an inheritance
+     relationship where it takes effect for all inheriting mappers. 
+     [ticket:883]
      
    - fixed endless loop issue when using lazy="dynamic" on both 
      sides of a bi-directional relationship [ticket:872]
index 570251c69cc891c6311e628a296e619a769ce1bb..52523d787677f7e676628688af37a18fbc9b0ef8 100644 (file)
@@ -338,6 +338,9 @@ class Mapper(object):
                 self._identity_class = self.inherits._identity_class
             else:
                 self._identity_class = self.class_
+            
+            if self.version_id_col is None:
+                self.version_id_col = self.inherits.version_id_col
                 
             if self.order_by is False:
                 self.order_by = self.inherits.order_by
@@ -405,12 +408,9 @@ class Mapper(object):
         # and assemble primary key columns
         for t in self.tables + [self.mapped_table]:
             self._all_tables.add(t)
-            try:
-                l = self.pks_by_table[t]
-            except KeyError:
-                l = self.pks_by_table.setdefault(t, util.OrderedSet())
-            for k in t.primary_key:
-                l.add(k)
+            if t not in self.pks_by_table:
+                self.pks_by_table[t] = util.OrderedSet()
+            self.pks_by_table[t].update(t.primary_key)
                 
         if self.primary_key_argument is not None:
             for k in self.primary_key_argument:
@@ -974,8 +974,7 @@ class Mapper(object):
             # UPDATE if so.
             mapper = object_mapper(obj)
             instance_key = mapper.identity_key_from_instance(obj)
-            is_row_switch = not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map
-            if is_row_switch:
+            if not postupdate and not has_identity(obj) and instance_key in uowtransaction.uow.identity_map:
                 existing = uowtransaction.uow.identity_map[instance_key]
                 if not uowtransaction.is_deleted(existing):
                     raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.instance_str(obj), str(instance_key), mapperutil.instance_str(existing)))
@@ -1003,6 +1002,7 @@ class Mapper(object):
                 mapper = object_mapper(obj)
                 if table not in mapper.tables or not mapper._has_pks(table):
                     continue
+                pks = mapper.pks_by_table[table]
                 instance_key = mapper.identity_key_from_instance(obj)
                 if self.__should_log_debug:
                     self.__log_debug("save_obj() table '%s' instance %s identity %s" % (table.name, mapperutil.instance_str(obj), str(instance_key)))
@@ -1011,46 +1011,46 @@ class Mapper(object):
                 params = {}
                 value_params = {}
                 hasdata = False
-                for col in table.columns:
-                    if col is mapper.version_id_col:
-                        if not isinsert:
-                            params[col._label] = mapper.get_attr_by_column(obj, col)
-                            params[col.key] = params[col._label] + 1
-                        else:
+
+                if isinsert:
+                    for col in table.columns:
+                        if col is mapper.version_id_col:
                             params[col.key] = 1
-                    elif col in mapper.pks_by_table[table]:
-                        # column is a primary key ?
-                        if not isinsert:
-                            # doing an UPDATE?  put primary key values as "WHERE" parameters
-                            # matching the bindparam we are creating below, i.e. "<tablename>_<colname>"
-                            params[col._label] = mapper.get_attr_by_column(obj, col)
-                        else:
-                            # doing an INSERT, primary key col ?
-                            # if the primary key values are not populated,
-                            # leave them out of the INSERT altogether, since PostGres doesn't want
-                            # them to be present for SERIAL to take effect.  A SQLEngine that uses
-                            # explicit sequences will put them back in if they are needed
+                        elif col in pks:
                             value = mapper.get_attr_by_column(obj, col)
                             if value is not None:
                                 params[col.key] = value
-                    elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
-                        if isinsert:
+                        elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
                             if self.__should_log_debug:
                                 self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
                             value = mapper.polymorphic_identity
                             if col.default is None or value is not None:
                                 params[col.key] = value
-                    else:
-                        # column is not a primary key ?
-                        if not isinsert:
-                            # doing an UPDATE ? get the history for the attribute, with "passive"
-                            # so as not to trigger any deferred loads.  if there is a new
-                            # value, add it to the bind parameters
-                            if post_update_cols is not None and col not in post_update_cols:
+                        else:
+                            value = mapper.get_attr_by_column(obj, col, False)
+                            if value is NO_ATTRIBUTE:
                                 continue
-                            elif is_row_switch:
-                                params[col.key] = self.get_attr_by_column(obj, col)
-                                hasdata = True
+                            if col.default is None or value is not None:
+                                if isinstance(value, sql.ClauseElement):
+                                    value_params[col] = value
+                                else:
+                                    params[col.key] = value
+                    insert.append((obj, params, mapper, connection, value_params))
+                else:
+                    for col in table.columns:
+                        if col is mapper.version_id_col:
+                            params[col._label] = mapper.get_attr_by_column(obj, col)
+                            params[col.key] = params[col._label] + 1
+                            for prop in mapper._columntoproperty.values():
+                                history = prop.get_history(obj, passive=True)
+                                if history and history.added_items():
+                                    hasdata = True
+                        elif col in pks:
+                            params[col._label] = mapper.get_attr_by_column(obj, col)
+                        elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
+                            pass
+                        else:
+                            if post_update_cols is not None and col not in post_update_cols:
                                 continue
                             prop = mapper._getpropbycolumn(col, False)
                             if prop is None:
@@ -1064,38 +1064,17 @@ class Mapper(object):
                                     else:
                                         params[col.key] = prop.get_col_value(col, a[0])
                                     hasdata = True
-                        else:
-                            # doing an INSERT, non primary key col ?
-                            # add the attribute's value to the
-                            # bind parameters, unless its None and the column has a
-                            # default.  if its None and theres no default, we still might
-                            # not want to put it in the col list but SQLIte doesnt seem to like that
-                            # if theres no columns at all
-                            value = mapper.get_attr_by_column(obj, col, False)
-                            if value is NO_ATTRIBUTE:
-                                continue
-                            if col.default is None or value is not None:
-                                # TODO: clauseelments as bind params should 
-                                # be handled by Insert/Update expression upon execute()
-                                if isinstance(value, sql.ClauseElement):
-                                    value_params[col] = value
-                                else:
-                                    params[col.key] = value
-
-                if not isinsert:
                     if hasdata:
                         # if none of the attributes changed, dont even
                         # add the row to be updated.
                         update.append((obj, params, mapper, connection, value_params))
-                else:
-                    insert.append((obj, params, mapper, connection, value_params))
 
             if update:
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
-                if mapper.version_id_col is not None:
+                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
                 statement = table.update(clause)
                 rows = 0
@@ -1230,7 +1209,7 @@ class Mapper(object):
                     delete.setdefault(connection, []).append(params)
                 for col in mapper.pks_by_table[table]:
                     params[col.key] = mapper.get_attr_by_column(obj, col)
-                if mapper.version_id_col is not None:
+                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
                     params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
                 # testlib.pragma exempt:__hash__
                 deleted_objects.add((id(obj), obj, connection))
@@ -1246,7 +1225,7 @@ class Mapper(object):
                 clause = sql.and_()
                 for col in mapper.pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
-                if mapper.version_id_col is not None:
+                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
                 statement = table.delete(clause)
                 c = connection.execute(statement, del_objects)
@@ -1259,6 +1238,7 @@ class Mapper(object):
                     mapper.extension.after_delete(mapper, connection, obj)
 
     def _has_pks(self, table):
+        # TODO: determine this beforehand
         if self.pks_by_table.get(table, None):
             for k in self.pks_by_table[table]:
                 if k not in self._columntoproperty:
index 68d7956cc5461f773f51a20112608c74447f5324..0301d01c99af69cf5b508d09b197389cc10c6674 100644 (file)
@@ -3,7 +3,7 @@ from sqlalchemy import *
 from sqlalchemy import exceptions, util
 from sqlalchemy.orm import *
 from testlib import *
-
+from testlib import fixtures
 
 class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
@@ -332,6 +332,109 @@ class FlushTest(ORMTest):
         sess.flush()
         assert user_roles.count().scalar() == 1
 
+class VersioningTest(ORMTest):
+    def define_tables(self, metadata):
+        global base, subtable, stuff
+        base = Table('base', metadata, 
+            Column('id', Integer, Sequence('version_test_seq', optional=True), primary_key=True ),
+            Column('version_id', Integer, nullable=False),
+            Column('value', String(40)),
+            Column('discriminator', Integer, nullable=False)
+        )
+        subtable = Table('subtable', metadata, 
+            Column('id', None, ForeignKey('base.id'), primary_key=True),
+            Column('subdata', String(50))
+            )
+        stuff = Table('stuff', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('parent', Integer, ForeignKey('base.id'))
+            )
+            
+    @engines.close_open_connections
+    def test_save_update(self):
+        class Base(fixtures.Base):
+            pass
+        class Sub(Base):
+            pass
+        class Stuff(Base):
+            pass
+        mapper(Stuff, stuff)
+        mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={
+            'stuff':relation(Stuff)
+        })
+        mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
+        
+        sess = create_session()
+        
+        b1 = Base(value='b1')
+        s1 = Sub(value='sub1', subdata='some subdata')
+        sess.save(b1)
+        sess.save(s1)
+        
+        sess.flush()
+        
+        sess2 = create_session()
+        s2 = sess2.query(Base).get(s1.id)
+        s2.subdata = 'sess2 subdata'
+        
+        s1.subdata = 'sess1 subdata'
+        
+        sess.flush()
+        
+        try:
+            sess2.query(Base).with_lockmode('read').get(s1.id)
+            assert False
+        except exceptions.ConcurrentModificationError, e:
+            assert True
+
+        try:
+            sess2.flush()
+            assert False
+        except exceptions.ConcurrentModificationError, e:
+            assert True
+        
+        sess2.refresh(s2)
+        assert s2.subdata == 'sess1 subdata'
+        s2.subdata = 'sess2 subdata'
+        sess2.flush()
+    
+    def test_delete(self):
+        class Base(fixtures.Base):
+            pass
+        class Sub(Base):
+            pass
+            
+        mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1)
+        mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
+        
+        sess = create_session()
+        
+        b1 = Base(value='b1')
+        s1 = Sub(value='sub1', subdata='some subdata')
+        s2 = Sub(value='sub2', subdata='some other subdata')
+        sess.save(b1)
+        sess.save(s1)
+        sess.save(s2)
+        
+        sess.flush()
+
+        sess2 = create_session()
+        s3 = sess2.query(Base).get(s1.id)
+        sess2.delete(s3)
+        sess2.flush()
+        
+        s2.subdata = 'some new subdata'
+        sess.flush()
+        
+        try:
+            s1.subdata = 'some new subdata'
+            sess.flush()
+            assert False
+        except exceptions.ConcurrentModificationError, e:
+            assert True
+        
+
+    
 class DistinctPKTest(ORMTest):
     """test the construction of mapper.primary_key when an inheriting relationship
     joins on a column other than primary key column."""