]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Primary key values can now be changed on a joined-table inheritance
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Feb 2010 22:56:19 +0000 (22:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Feb 2010 22:56:19 +0000 (22:56 +0000)
object, and ON UPDATE CASCADE will be taken into account when
the flush happens.  Set the new "passive_updates" flag to False
on mapper() when using SQLite or MySQL/MyISAM. [ticket:1362]

- flush() now detects when a primary key column was updated by
an ON UPDATE CASCADE operation from another primary key, and
can then locate the row for a subsequent UPDATE on the new PK
value.  This occurs when a relation() is there to establish
the relationship as well as passive_updates=True.  [ticket:1671]

CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/inheritance/test_basic.py
test/orm/test_naturalpks.py
test/orm/test_unitofwork.py

diff --git a/CHANGES b/CHANGES
index cd7eca06e08fb05ad1b3d3c31441e682062b54c6..3540b3c131ee60f7786d46698a940487f93f626f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -91,6 +91,17 @@ CHANGES
        example of how to integrate Beaker with SQLAlchemy.  See
        the notes in the "examples" note below.
   
+  - Primary key values can now be changed on a joined-table inheritance
+    object, and ON UPDATE CASCADE will be taken into account when
+    the flush happens.  Set the new "passive_updates" flag to False
+    on mapper() when using SQLite or MySQL/MyISAM. [ticket:1362]
+    
+  - flush() now detects when a primary key column was updated by
+    an ON UPDATE CASCADE operation from another primary key, and
+    can then locate the row for a subsequent UPDATE on the new PK
+    value.  This occurs when a relation() is there to establish
+    the relationship as well as passive_updates=True.  [ticket:1671]
+    
   - the "save-update" cascade will now cascade the pending *removed*
     values from a scalar or collection attribute into the new session 
     during an add() operation.  This so that the flush() operation
index 596e43353eec30a79aa5ea0805d136d3830f6df2..cc34c22e23138bf85a70bc94aa24842e1ad0d0ee 100644 (file)
@@ -358,6 +358,11 @@ def relation(argument, secondary=None, **kwargs):
       are expected and the database in use doesn't support CASCADE
       (i.e. SQLite, MySQL MyISAM tables).
 
+      Also see the passive_updates flag on ``mapper()``.
+      
+      A future SQLAlchemy release will provide a "detect" feature for
+      this flag.
+
     :param post_update:
       this indicates that the relationship should be handled by a
       second UPDATE statement after an INSERT or before a
@@ -672,6 +677,35 @@ def mapper(class_, local_table=None, *args, **params):
         instances, not their persistence.  Any number of non_primary mappers
         may be created for a particular class.
 
+      passive_updates
+        Indicates UPDATE behavior of foreign keys when a primary key changes 
+        on a joined-table inheritance or other joined table mapping.
+
+        When True, it is assumed that ON UPDATE CASCADE is configured on
+        the foreign key in the database, and that the database will
+        handle propagation of an UPDATE from a source column to
+        dependent rows.  Note that with databases which enforce
+        referential integrity (i.e. PostgreSQL, MySQL with InnoDB tables),
+        ON UPDATE CASCADE is required for this operation.  The
+        relation() will update the value of the attribute on related
+        items which are locally present in the session during a flush.
+
+        When False, it is assumed that the database does not enforce
+        referential integrity and will not be issuing its own CASCADE
+        operation for an update.  The relation() will issue the
+        appropriate UPDATE statements to the database in response to the
+        change of a referenced key, and items locally present in the
+        session during a flush will also be refreshed.
+
+        This flag should probably be set to False if primary key changes
+        are expected and the database in use doesn't support CASCADE
+        (i.e. SQLite, MySQL MyISAM tables).
+        
+        Also see the passive_updates flag on ``relation()``.
+
+        A future SQLAlchemy release will provide a "detect" feature for
+        this flag.
+
       polymorphic_on
         Used with mappers in an inheritance relationship, a ``Column`` which
         will identify the class/mapper combination to be used with a
index 6a9c80ebf5d8a9dd59d196fae3a9d867e053b54c..46dc6301a39cf10833f22ea9d467aa10ce491a32 100644 (file)
@@ -229,7 +229,7 @@ class OneToManyDP(DependencyProcessor):
                     if self._pks_changed(uowcommit, state):
                         for child in history.unchanged:
                             self._synchronize(state, child, None, False, uowcommit)
-
+                            
     def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
         if delete:
             # head object is being deleted, and we manage its list of child objects
@@ -237,7 +237,8 @@ class OneToManyDP(DependencyProcessor):
             if not self.post_update:
                 should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
                 for state in deplist:
-                    history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
+                    history = uowcommit.get_attribute_history(
+                                                state, self.key, passive=self.passive_deletes)
                     if history:
                         for child in history.deleted:
                             if child is not None and self.hasparent(child) is False:
@@ -283,7 +284,9 @@ class OneToManyDP(DependencyProcessor):
         if clearkeys:
             sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
         else:
-            sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
+            sync.populate(source, self.parent, dest, self.mapper, 
+                                    self.prop.synchronize_pairs, uowcommit,
+                                    self.passive_updates)
 
     def _pks_changed(self, uowcommit, state):
         return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
@@ -329,7 +332,10 @@ class DetectKeySwitch(DependencyProcessor):
                     attributes.instance_state(elem.dict[self.key]) in switchers
                 ]:
                 uowcommit.register_object(s)
-                sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs)
+                sync.populate(
+                            attributes.instance_state(s.dict[self.key]), 
+                            self.mapper, s, self.parent, self.prop.synchronize_pairs, 
+                            uowcommit, self.passive_updates)
 
     def _pks_changed(self, uowcommit, state):
         return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
@@ -412,7 +418,10 @@ class ManyToOneDP(DependencyProcessor):
             sync.clear(state, self.parent, self.prop.synchronize_pairs)
         else:
             self._verify_canload(child)
-            sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs)
+            sync.populate(child, self.mapper, state, 
+                            self.parent, self.prop.synchronize_pairs, uowcommit,
+                            self.passive_updates
+                            )
 
 class ManyToManyDP(DependencyProcessor):
     def register_dependencies(self, uowcommit):
@@ -517,8 +526,10 @@ class ManyToManyDP(DependencyProcessor):
             return
         self._verify_canload(child)
         
-        sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs)
-        sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
+        sync.populate_dict(state, self.parent, associationrow, 
+                                        self.prop.synchronize_pairs)
+        sync.populate_dict(child, self.mapper, associationrow,
+                                        self.prop.secondary_synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
         return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
index bde7fd2e1ac427530d9f655285cb1f7a5c4ca46c..941c303d777f2cda90fe28eea27c7b4b0905e7b1 100644 (file)
@@ -94,6 +94,7 @@ class Mapper(object):
                  column_prefix=None,
                  include_properties=None,
                  exclude_properties=None,
+                 passive_updates=True,
                  eager_defaults=False):
         """Construct a new mapper.
 
@@ -131,6 +132,7 @@ class Mapper(object):
         self.polymorphic_on = polymorphic_on
         self._dependency_processors = []
         self._validators = {}
+        self.passive_updates = passive_updates
         self._clause_adapter = None
         self._requires_row_aliasing = False
         self._inherits_equated_pairs = None
@@ -231,7 +233,9 @@ class Mapper(object):
                     self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
 
                     fks = util.to_set(self.inherit_foreign_keys)
-                    self._inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
+                    self._inherits_equated_pairs = \
+                                sqlutil.criterion_as_pairs(self.mapped_table.onclause,
+                                                            consider_as_foreign_keys=fks)
             else:
                 self.mapped_table = self.local_table
 
@@ -254,6 +258,7 @@ class Mapper(object):
             self.batch = self.inherits.batch
             self.inherits._inheriting_mappers.add(self)
             self.base_mapper = self.inherits.base_mapper
+            self.passive_updates = self.inherits.passive_updates
             self._all_tables = self.inherits._all_tables
 
             if self.polymorphic_identity is not None:
@@ -1385,7 +1390,8 @@ class Mapper(object):
                                 history = attributes.get_state_history(state, prop.key, passive=True)
                                 if history.added:
                                     hasdata = True
-                        elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col) and col not in pks:
+                        elif mapper.polymorphic_on is not None and \
+                                mapper.polymorphic_on.shares_lineage(col) and col not in pks:
                             pass
                         else:
                             if post_update_cols is not None and col not in post_update_cols:
@@ -1402,12 +1408,16 @@ class Mapper(object):
                                     params[col.key] = prop.get_col_value(col, history.added[0])
 
                                 if col in pks:
-                                    # TODO: there is one case we want to use history.added for
-                                    # the PK value - when we know that the PK has already been
-                                    # updated via CASCADE.   This information needs to get here
-                                    # somehow.  see [ticket:1671]
-                                    
-                                    if history.deleted:
+                                    # if passive_updates and sync detected this was a 
+                                    # pk->pk sync, use the new value to locate the row, 
+                                    # since the DB would already have set this
+                                    if ("pk_cascaded", state, col) in \
+                                                    uowtransaction.attributes:
+                                        params[col._label] = \
+                                                prop.get_col_value(col, history.added[0])
+                                        hasdata = True
+                                        
+                                    elif history.deleted:
                                         # PK is changing - use the old value to locate the row
                                         params[col._label] = \
                                                 prop.get_col_value(col, history.deleted[0])
@@ -1433,14 +1443,19 @@ class Mapper(object):
                 for col in mapper._pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col._label, type_=col.type))
 
-                if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
-                    clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type))
+                if mapper.version_id_col is not None and \
+                            table.c.contains_column(mapper.version_id_col):
+                            
+                    clause.clauses.append(mapper.version_id_col ==\
+                            sql.bindparam(mapper.version_id_col._label, type_=col.type))
 
                 statement = table.update(clause)
+                
                 rows = 0
                 for state, params, mapper, connection, value_params in update:
                     c = connection.execute(statement.values(value_params), params)
-                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
+                    mapper._postfetch(uowtransaction, connection, table, 
+                                        state, c, c.last_updated_params(), value_params)
 
                     rows += c.rowcount
 
@@ -1464,19 +1479,14 @@ class Mapper(object):
                     if primary_key is not None:
                         # set primary key attributes
                         for i, col in enumerate(mapper._pks_by_table[table]):
-                            if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i:
+                            if mapper._get_state_attr_by_column(state, col) is None and \
+                                                                len(primary_key) > i:
                                 mapper._set_state_attr_by_column(state, col, primary_key[i])
-                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
-
-                    # synchronize newly inserted ids from one table to the next
-                    # TODO: this performs some unnecessary attribute transfers
-                    # from an attribute to itself, since the attribute is often mapped
-                    # to multiple, equivalent columns.  it also may fire off more
-                    # than needed overall.
-                    for m in mapper.iterate_to_root():
-                        if m._inherits_equated_pairs:
-                            sync.populate(state, m, state, m, m._inherits_equated_pairs)
+                                
+                    mapper._postfetch(uowtransaction, connection, table, 
+                                        state, c, c.last_inserted_params(), value_params)
 
+                        
         if not postupdate:
             for state, mapper, connection, has_identity, instance_key in tups:
 
@@ -1504,7 +1514,8 @@ class Mapper(object):
                     if 'after_update' in mapper.extension:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
-    def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
+    def _postfetch(self, uowtransaction, connection, table, 
+                                state, resultproxy, params, value_params):
         """Expire attributes in need of newly persisted database state."""
 
         postfetch_cols = resultproxy.postfetch_cols()
@@ -1527,6 +1538,18 @@ class Mapper(object):
         if deferred_props:
             _expire_state(state, deferred_props)
 
+        # synchronize newly inserted ids from one table to the next
+        # TODO: this still goes a little too often.  would be nice to
+        # have definitive list of "columns that changed" here
+        cols = set(table.c)
+        for m in self.iterate_to_root():
+            if m._inherits_equated_pairs and \
+                        cols.intersection([l for l, r in m._inherits_equated_pairs]):
+                sync.populate(state, m, state, m, 
+                                                m._inherits_equated_pairs, 
+                                                uowtransaction,
+                                                self.passive_updates)
+
     def _delete_obj(self, states, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
 
index 8826ab3aa33eed1afb7de9d3c48f76cf5959d8d5..8a30a9e623f0dbb1768a341fee5e42ae4fa40333 100644 (file)
@@ -10,7 +10,8 @@ based on join conditions.
 
 from sqlalchemy.orm import exc, util as mapperutil
 
-def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
+def populate(source, source_mapper, dest, dest_mapper, 
+                        synchronize_pairs, uowcommit, passive_updates):
     for l, r in synchronize_pairs:
         try:
             value = source_mapper._get_state_attr_by_column(source, l)
@@ -21,6 +22,15 @@ def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
             dest_mapper._set_state_attr_by_column(dest, r, value)
         except exc.UnmappedColumnError:
             _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
+        
+        # techically the "r.primary_key" check isn't
+        # needed here, but we check for this condition to limit
+        # how often this logic is invoked for memory/performance
+        # reasons, since we only need this info for a primary key
+        # destination.
+        if l.primary_key and r.primary_key and \
+                    r.references(l) and passive_updates:
+            uowcommit.attributes[("pk_cascaded", dest, r)] = True
 
 def clear(dest, dest_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
index 1fffda028be90bc6481efc8836f631c155464700..d2901a49ff14d59a7bf4aeddc3a7b13fd112d4ae 100644 (file)
@@ -99,7 +99,7 @@ class UOWTransaction(object):
         self.attributes = {}
         
         self.processors = set()
-        
+    
     def get_attribute_history(self, state, key, passive=True):
         hashkey = ("history", state, key)
 
index e189159ea970461a5ae24fb0df48448b0fe3f3e8..aed7cf5efaac9aebe3df6cd693b74e2dec6f7431 100644 (file)
@@ -732,6 +732,7 @@ class DistinctPKTest(_base.MappedTest):
 
 class SyncCompileTest(_base.MappedTest):
     """test that syncrules compile properly on custom inherit conds"""
+    
     @classmethod
     def define_tables(cls, metadata):
         global _a_table, _b_table, _c_table
@@ -754,7 +755,8 @@ class SyncCompileTest(_base.MappedTest):
 
     def test_joins(self):
         for j1 in (None, _b_table.c.a_id==_a_table.c.id, _a_table.c.id==_b_table.c.a_id):
-            for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id, _c_table.c.b_a_id==_b_table.c.a_id):
+            for j2 in (None, _b_table.c.a_id==_c_table.c.b_a_id,
+                                    _c_table.c.b_a_id==_b_table.c.a_id):
                 self._do_test(j1, j2)
                 for t in reversed(_a_table.metadata.sorted_tables):
                     t.delete().execute().close()
index 277d1ef24a6f8705dec95ad8f8439758978f4651..768ffeebadb26550d6ef8c6e71b0459d9ea3bd4a 100644 (file)
@@ -587,13 +587,13 @@ class CascadeToFKPKTest(_base.MappedTest):
         class Address(_base.ComparableEntity):
             pass
     
-        
-    @testing.fails_on_everything_except('sqlite') # Ticket #1671
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
     def test_onetomany_passive(self):
         self._test_onetomany(True)
 
-    @testing.fails_on_everything_except('sqlite') # Ticket #1671
+    # PG etc. need passive=True to allow PK->PK cascade
+    @testing.fails_on_everything_except('sqlite')
     def test_onetomany_nonpassive(self):
         self._test_onetomany(False)
         
@@ -657,3 +657,115 @@ class CascadeToFKPKTest(_base.MappedTest):
         eq_(a1.username, 'jack')
         eq_(a2.username, 'jack')
         eq_(sa.select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack', )])
+
+
+class JoinedInheritanceTest(_base.MappedTest):
+    """Test cascades of pk->pk/fk on joined table inh."""
+
+    @classmethod
+    def define_tables(cls, metadata):
+        if testing.against('oracle'):
+            fk_args = dict(deferrable=True, initially='deferred')
+        else:
+            fk_args = dict(onupdate='cascade')
+
+        Table('person', metadata,
+            Column('name', String(50), primary_key=True),
+            Column('type', String(50), nullable=False),
+            test_needs_fk=True)
+        
+        Table('engineer', metadata,
+            Column('name', String(50), ForeignKey('person.name', **fk_args), primary_key=True),
+            Column('primary_language', String(50)),
+            Column('boss_name', String(50), ForeignKey('manager.name', **fk_args)),
+            test_needs_fk=True
+        )
+
+        Table('manager', metadata,
+            Column('name', String(50), ForeignKey('person.name', **fk_args), primary_key=True),
+            Column('paperwork', String(50)),
+            test_needs_fk=True
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Person(_base.ComparableEntity):
+            pass
+        class Engineer(Person):
+            pass
+        class Manager(Person):
+            pass
+
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
+    def test_pk_passive(self):
+        self._test_pk(True)
+
+    # PG etc. need passive=True to allow PK->PK cascade
+    @testing.fails_on_everything_except('sqlite')
+    def test_pk_nonpassive(self):
+        self._test_pk(False)
+        
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    @testing.fails_on('oracle', 'oracle doesnt support ON UPDATE CASCADE')
+    def test_fk_passive(self):
+        self._test_fk(True)
+        
+    # PG etc. need passive=True to allow PK->PK cascade
+    @testing.fails_on_everything_except('sqlite')
+    def test_fk_nonpassive(self):
+        self._test_fk(False)
+
+    @testing.resolve_artifact_names
+    def _test_pk(self, passive_updates):
+        mapper(Person, person, polymorphic_on=person.c.type, 
+                polymorphic_identity='person', passive_updates=passive_updates)
+        mapper(Engineer, engineer, inherits=Person, polymorphic_identity='engineer', properties={
+            'boss':relation(Manager, 
+                        primaryjoin=manager.c.name==engineer.c.boss_name,
+                        passive_updates=passive_updates
+                        )
+        })
+        mapper(Manager, manager, inherits=Person, polymorphic_identity='manager')
+
+        sess = sa.orm.sessionmaker()()
+
+        e1 = Engineer(name='dilbert', primary_language='java')
+        sess.add(e1)
+        sess.commit()
+        e1.name = 'wally'
+        e1.primary_language = 'c++'
+        sess.commit()
+        
+    @testing.resolve_artifact_names
+    def _test_fk(self, passive_updates):
+        mapper(Person, person, polymorphic_on=person.c.type, 
+                polymorphic_identity='person', passive_updates=passive_updates)
+        mapper(Engineer, engineer, inherits=Person, polymorphic_identity='engineer', properties={
+            'boss':relation(Manager, 
+                        primaryjoin=manager.c.name==engineer.c.boss_name,
+                        passive_updates=passive_updates
+                        )
+        })
+        mapper(Manager, manager, inherits=Person, polymorphic_identity='manager')
+        
+        sess = sa.orm.sessionmaker()()
+        
+        m1 = Manager(name='dogbert', paperwork='lots')
+        e1, e2 = \
+                Engineer(name='dilbert', primary_language='java', boss=m1),\
+                Engineer(name='wally', primary_language='c++', boss=m1)
+        sess.add_all([
+            e1, e2, m1
+        ])
+        sess.commit()
+        
+        m1.name = 'pointy haired'
+        e1.primary_language = 'scala'
+        e2.primary_language = 'cobol'
+        sess.commit()
+        
+    
+    
+    
+    
\ No newline at end of file
index 22c5f8918fd21726f60837d78dd41a1623ba7d1f..9fed9e1859f8a225c420d18dac390b9afa5b9d56 100644 (file)
@@ -2365,6 +2365,14 @@ class InheritingRowSwitchTest(_base.MappedTest):
         self.assert_sql_execution(testing.db, sess.flush,
             CompiledSQL("UPDATE parent SET pdata=:pdata WHERE parent.id = :parent_id",
                 {'pdata':'c2', 'parent_id':1}
+            ),
+            
+            # this fires as of [ticket:1362], since we synchronzize
+            # PK/FKs on UPDATES.  c2 is new so the history shows up as
+            # pure added, update occurs.  If a future change limits the
+            # sync operation during _save_obj().update, this is safe to remove again.
+            CompiledSQL("UPDATE child SET pid=:pid WHERE child.id = :child_id",
+                {'pid':1, 'child_id':1}
             )
         )