]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix to "row switch" behavior, i.e. when an INSERT/DELETE is combined into a
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Oct 2007 18:04:00 +0000 (18:04 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 Oct 2007 18:04:00 +0000 (18:04 +0000)
  single UPDATE; many-to-many relations on the parent object update properly.
  [ticket:841]
- it's an error to session.save() an object which is already persistent
  [ticket:840]
- changed a bunch of repr(obj) calls in session.py exceptions to use mapperutil.instance_str()

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/collection.py
test/orm/eager_relations.py
test/orm/inheritance/productspec.py
test/orm/session.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 576aecf5d6e717885f70abf386bd5c00b70f8fdc..e44dd18aa730373f62c53957bfcce5de66b4f507 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -37,6 +37,13 @@ CHANGES
 
 - fixed clear_mappers() behavior to better clean up after itself
 
+- fix to "row switch" behavior, i.e. when an INSERT/DELETE is combined into a
+  single UPDATE; many-to-many relations on the parent object update properly. 
+  [ticket:841]
+
+- it's an error to session.save() an object which is already persistent
+  [ticket:840]
+  
 - behavior of query.options() is now fully based on paths, i.e. an option
   such as eagerload_all('x.y.z.y.x') will apply eagerloading to only
   those paths, i.e. and not 'x.y.x'; eagerload('children.children') applies
index 2f3c365156211ab674d0239912b1a3423d1fed45..94ef04300d4c3ab1b487250315f93ef1d949e276 100644 (file)
@@ -968,7 +968,7 @@ class Mapper(object):
                     raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (mapperutil.instance_str(obj), str(instance_key), mapperutil.instance_str(existing)))
                 if self.__should_log_debug:
                     self.__log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, mapperutil.instance_str(obj), mapperutil.instance_str(existing)))
-                uowtransaction.unregister_object(existing)
+                uowtransaction.set_row_switch(existing)
             if has_identity(obj):
                 if obj._instance_key != instance_key:
                     raise exceptions.FlushError("Can't change the identity of instance %s in session (existing identity: %s; new identity: %s)" % (mapperutil.instance_str(obj), obj._instance_key, instance_key))
index 04aaa120954e28033ecf9da2df9d30adaca886a3..42ad30b6a2c9a278b2537dbaa4cdbd23894644a9 100644 (file)
@@ -631,6 +631,8 @@ class Session(object):
                         
         if self.bind is not None:
             return self.bind
+        elif mapper is None:
+            raise exceptions.InvalidRequestError("Could not locate any mapper associated with SQL expression")
         else:
             if isinstance(mapper, type):
                 mapper = _class_mapper(mapper)
@@ -721,7 +723,7 @@ class Session(object):
 
         self._validate_persistent(obj)
         if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
-            raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj))
+            raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
 
     def expire(self, obj):
         """Mark the given object as expired.
@@ -753,7 +755,7 @@ class Session(object):
 
         def exp():
             if self.query(obj.__class__)._get(obj._instance_key, reload=True) is None:
-                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % repr(obj))
+                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
 
         attribute_manager.trigger_history(obj, exp)
 
@@ -954,10 +956,7 @@ class Session(object):
     
     def _save_impl(self, obj, **kwargs):
         if hasattr(obj, '_instance_key'):
-            if obj._instance_key not in self.identity_map:
-                raise exceptions.InvalidRequestError("Instance '%s' is a detached instance "
-                                                     "or is already persistent in a "
-                                                     "different Session" % repr(obj))
+            raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(obj))
         else:
             # TODO: consolidate the steps here
             attribute_manager.manage(obj)
@@ -969,7 +968,7 @@ class Session(object):
         if self._is_attached(obj) and obj not in self.deleted:
             return
         if not hasattr(obj, '_instance_key'):
-            raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % repr(obj))
+            raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
         self._attach(obj)
 
     def _register_persistent(self, obj):
@@ -983,7 +982,7 @@ class Session(object):
             if old_id is not None and old_id in _sessions:
                 raise exceptions.InvalidRequestError("Object '%s' is already attached "
                                                      "to session '%s' (this is '%s')" %
-                                                     (repr(obj), old_id, id(self)))
+                                                     (mapperutil.instance_str(obj), old_id, id(self)))
 
                 # auto-removal from the old session is disabled.  but if we decide to
                 # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
@@ -1001,7 +1000,7 @@ class Session(object):
 
     def _unattach(self, obj):
         if not self._is_attached(obj):
-            raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % repr(obj))
+            raise exceptions.InvalidRequestError("Instance '%s' not attached to this Session" % mapperutil.instance_str(obj))
         del obj._sa_session_id
 
     def _validate_persistent(self, obj):
index 7a443b331427ba8516bc8c245b5f5e72a4e0b0cc..43f0d46d915fa3ca9615914995273ef3c0316c41 100644 (file)
@@ -304,6 +304,17 @@ class UOWTransaction(object):
 
         task.append(obj, listonly, isdelete=isdelete, **kwargs)
 
+    def set_row_switch(self, obj):
+        """mark a deleted object as a 'row switch'.
+        
+        this indicates that an INSERT statement elsewhere corresponds to this DELETE;
+        the INSERT is converted to an UPDATE and the DELETE does not occur.
+        """
+        mapper = object_mapper(obj)
+        task = self.get_task_by_mapper(mapper)
+        taskelement = task._objects[obj]
+        taskelement.isdelete = "rowswitch"
+        
     def unregister_object(self, obj):
         """remove an object from its parent UOWTask.
         
@@ -902,7 +913,7 @@ class UOWTaskElement(object):
         self.childtasks = []
         self.__isdelete = False
         self.__preprocessed = {}
-
+        
     def _get_listonly(self):
         return self.__listonly
 
index 9d5ae7ab92281d35ee093e6c8eb7d5e0a4486324..d421952b53b49bfbe0e48a97123f1e2d543b1a82 100644 (file)
@@ -1155,7 +1155,6 @@ class DictHelpersTest(ORMTest):
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', 'newvalue'))
         
-        session.save(p)
         session.flush()
         session.clear()
         
@@ -1215,7 +1214,6 @@ class DictHelpersTest(ORMTest):
         collections.collection_adapter(p.children).append_with_event(
             Child('foo', '1', 'newvalue'))
         
-        session.save(p)
         session.flush()
         session.clear()
         
index 4e18dda20283cecd6c1aa2697bf08b6f5e10ad9d..8f42b5128da354beae8436033d9d17b1543e0f92 100644 (file)
@@ -250,10 +250,10 @@ class EagerTest(QueryTest):
 
         mapper(Item, items)
         mapper(Order, orders, properties={
-            'items':relation(Item, secondary=order_items, lazy=False)
+            'items':relation(Item, secondary=order_items, lazy=False, order_by=items.c.id)
         })
         mapper(User, users, properties={
-            'addresses':relation(mapper(Address, addresses), lazy=False),
+            'addresses':relation(mapper(Address, addresses), lazy=False, order_by=addresses.c.id),
             'orders':relation(Order, lazy=True)
         })
 
@@ -264,7 +264,9 @@ class EagerTest(QueryTest):
             l = q.limit(2).all()
             assert fixtures.user_all_result[:2] == l
         else:        
-            l = q.limit(2).offset(1).all()
+            l = q.limit(2).offset(1).order_by(User.id).all()
+            print fixtures.user_all_result[1:3]
+            print l
             assert fixtures.user_all_result[1:3] == l
     
     def test_distinct(self):
index 5c8c64bed47c408cb5f96c9a1d2bce5e3109bf8d..6da0b3f162fc5c3ea13def3d85a1821fcc0fa689 100644 (file)
@@ -247,7 +247,6 @@ class InheritTest(ORMTest):
         assert orig == new  == '<Assembly a1> specification=None documents=[<RasterDocument doc2>]'
 
         del a1.documents[0]
-        session.save(a1)
         session.flush()
         session.clear()
 
index 553a14c8ed85f3f2d9b771f03639ca69b675938a..6efc93d50ebbe326469db60aec1b69997ed59207 100644 (file)
@@ -396,14 +396,22 @@ class SessionTest(AssertMixin):
         
         
     @engines.close_open_connections
-    def test_update(self):
-        """test that the update() method functions and doesnet blow away changes"""
+    def test_save_update(self):
+        
         s = create_session()
-        class User(object):pass
+        class User(object):
+            pass
         mapper(User, users)
         
-        # save user
-        s.save(User())
+        user = User()
+
+        try:
+            s.update(user)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
+            
+        s.save(user)
         s.flush()
         user = s.query(User).selectone()
         s.expunge(user)
@@ -425,6 +433,12 @@ class SessionTest(AssertMixin):
         s.update(user)
         assert user in s
         assert user not in s.dirty
+        
+        try:
+            s.save(user)
+            assert False
+        except exceptions.InvalidRequestError, e:
+            assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user))
     
     def test_is_modified(self):
         s = create_session()
index 44d229985eaccf8f5f3992405ec9358752389a2e..74b08e2c8bc12ebe0e0c4fd85f90e34ae2cb2a0d 100644 (file)
@@ -1703,6 +1703,138 @@ class SaveTest3(ORMTest):
         Session.commit()
         assert t2.count().scalar() == 0
 
+class RowSwitchTest(ORMTest):
+    def define_tables(self, metadata):
+        global t1, t2, t3, t1t3
+        
+        global T1, T2, T3
+        
+        Session.remove()
+        
+        # parent
+        t1 = Table('t1', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30), nullable=False))
+
+        # onetomany
+        t2 = Table('t2', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30), nullable=False),
+            Column('t1id', Integer, ForeignKey('t1.id'),nullable=False),
+            )
+
+        # associated
+        t3 = Table('t3', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30), nullable=False),
+            )
+
+        #manytomany
+        t1t3 = Table('t1t3', metadata,
+            Column('t1id', Integer, ForeignKey('t1.id'),nullable=False),
+            Column('t3id', Integer, ForeignKey('t3.id'),nullable=False),
+        )
+        
+        class T1(fixtures.Base):
+            pass
+
+        class T2(fixtures.Base):
+            pass
+
+        class T3(fixtures.Base):
+            pass
+    
+    def tearDown(self):
+        Session.remove()
+        super(RowSwitchTest, self).tearDown()
+        
+    def test_onetomany(self):
+        mapper(T1, t1, properties={
+            't2s':relation(T2, cascade="all, delete-orphan")
+        })
+        mapper(T2, t2)
+        
+        sess = Session(autoflush=False)
+        
+        o1 = T1(data='some t1', id=1)
+        o1.t2s.append(T2(data='some t2', id=1))
+        o1.t2s.append(T2(data='some other t2', id=2))
+        
+        sess.save(o1)
+        sess.flush()
+        
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')]
+        assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some t2', 1), (2, 'some other t2', 1)]
+        
+        o2 = T1(data='some other t1', id=o1.id, t2s=[
+            T2(data='third t2', id=3),
+            T2(data='fourth t2', id=4),
+            ])
+        sess.delete(o1)
+        sess.save(o2)
+        sess.flush()
+
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some other t1')]
+        assert list(sess.execute(t2.select(), mapper=T1)) == [(3, 'third t2', 1), (4, 'fourth t2', 1)]
+
+    def test_manytomany(self):
+        mapper(T1, t1, properties={
+            't3s':relation(T3, secondary=t1t3, cascade="all, delete-orphan")
+        })
+        mapper(T3, t3)
+
+        sess = Session(autoflush=False)
+
+        o1 = T1(data='some t1', id=1)
+        o1.t3s.append(T3(data='some t3', id=1))
+        o1.t3s.append(T3(data='some other t3', id=2))
+
+        sess.save(o1)
+        sess.flush()
+
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')]
+        assert list(sess.execute(t1t3.select(), mapper=T1)) == [(1,1), (1, 2)]
+        assert list(sess.execute(t3.select(), mapper=T1)) == [(1, 'some t3'), (2, 'some other t3')]
+
+        o2 = T1(data='some other t1', id=1, t3s=[
+            T3(data='third t3', id=3),
+            T3(data='fourth t3', id=4),
+            ])
+        sess.delete(o1)
+        sess.save(o2)
+        sess.flush()
+
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some other t1')]
+        assert list(sess.execute(t3.select(), mapper=T1)) == [(3, 'third t3'), (4, 'fourth t3')]
 
+    def test_manytoone(self):
+        
+        mapper(T2, t2, properties={
+            't1':relation(T1)
+        })
+        mapper(T1, t1)
+
+        sess = Session(autoflush=False)
+
+        o1 = T2(data='some t2', id=1)
+        o1.t1 = T1(data='some t1', id=1)
+
+        sess.save(o1)
+        sess.flush()
+
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')]
+        assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some t2', 1)]
+
+        o2 = T2(data='some other t2', id=1, t1=T1(data='some other t1', id=2))
+        sess.delete(o1)
+        sess.delete(o1.t1)
+        sess.save(o2)
+        sess.flush()
+
+        assert list(sess.execute(t1.select(), mapper=T1)) == [(2, 'some other t1')]
+        assert list(sess.execute(t2.select(), mapper=T1)) == [(1, 'some other t2', 2)]
+        
+        
+        
 if __name__ == "__main__":
     testbase.main()