]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Session.refresh() now does an equivalent expire()
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Mar 2010 21:56:02 +0000 (17:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Mar 2010 21:56:02 +0000 (17:56 -0400)
on the given instance first, so that the "refresh-expire"
cascade is propagated.   Previously, refresh() was
not affected in any way by the presence of "refresh-expire"
cascade.   This is a change in behavior versus that
of 0.6beta2, where the "lockmode" flag passed to refresh()
would cause a version check to occur.  Since the instance
is first expired, refresh() always upgrades the object
to the most recent version.

- The 'refresh-expire' cascade, when reaching a pending object,
will expunge the object if the cascade also includes
"delete-orphan", or will simply detach it otherwise.
[ticket:1754]

CHANGES
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/test/requires.py
test/orm/test_expire.py
test/orm/test_versioning.py

diff --git a/CHANGES b/CHANGES
index a7c0ef4d485688240877331fcc4383fa76c568ba..dd6211fc1d994b8a74483663958175c065853d21 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,21 @@ CHANGES
     eagerloading on the reverse many-to-one side, since 
     that loading is by definition unnecessary.  [ticket:1495]
 
+  - Session.refresh() now does an equivalent expire()
+    on the given instance first, so that the "refresh-expire"
+    cascade is propagated.   Previously, refresh() was
+    not affected in any way by the presence of "refresh-expire"
+    cascade.   This is a change in behavior versus that
+    of 0.6beta2, where the "lockmode" flag passed to refresh()
+    would cause a version check to occur.  Since the instance
+    is first expired, refresh() always upgrades the object
+    to the most recent version.
+    
+  - The 'refresh-expire' cascade, when reaching a pending object,
+    will expunge the object if the cascade also includes
+    "delete-orphan", or will simply detach it otherwise.
+    [ticket:1754]
+    
 0.6beta3
 ========
 
index a8295e2cda13bbaa1df6376721c8aa863fb01ba6..2a5e92c1a8f6a76eae9ebebb942ad2ffc733ea2b 100644 (file)
@@ -742,6 +742,8 @@ class RelationshipProperty(StrategizedProperty):
         else:
             instances = state.value_as_iterable(self.key, passive=passive)
         
+        skip_pending = type_ == 'refresh-expire' and 'delete-orphan' not in self.cascade
+        
         if instances:
             for c in instances:
                 if c is not None and \
@@ -757,12 +759,17 @@ class RelationshipProperty(StrategizedProperty):
                                                 str(self.parent.class_), 
                                                 str(c.__class__)
                                             ))
+                    instance_state = attributes.instance_state(c)
+                    
+                    if skip_pending and not instance_state.key:
+                        continue
+                        
                     visited_instances.add(c)
 
                     # cascade using the mapper local to this 
                     # object, so that its individual properties are located
-                    instance_mapper = object_mapper(c)
-                    yield (c, instance_mapper, attributes.instance_state(c))
+                    instance_mapper = instance_state.manager.mapper
+                    yield (c, instance_mapper, instance_state)
 
     def _add_reverse_property(self, key):
         other = self.mapper._get_property(key)
index 0a3fbe79e26175dc81481ed978cbbd30494db017..0810175bf868d2b82d0d450d0484a2f039f65176 100644 (file)
@@ -883,7 +883,7 @@ class Session(object):
             state.commit_all(dict_, self.identity_map)
 
     def refresh(self, instance, attribute_names=None, lockmode=None):
-        """Refresh the attributes on the given instance.
+        """Expire and refresh the attributes on the given instance.
 
         A query will be issued to the database and all attributes will be
         refreshed with their current database value.
@@ -907,7 +907,9 @@ class Session(object):
             state = attributes.instance_state(instance)
         except exc.NO_STATE:
             raise exc.UnmappedInstanceError(instance)
-        self._validate_persistent(state)
+
+        self._expire_state(state, attribute_names)
+
         if self.query(_object_mapper(instance))._get(
                 state.key, refresh_state=state,
                 lockmode=lockmode,
@@ -939,18 +941,31 @@ class Session(object):
             state = attributes.instance_state(instance)
         except exc.NO_STATE:
             raise exc.UnmappedInstanceError(instance)
+        self._expire_state(state, attribute_names)
+        
+    def _expire_state(self, state, attribute_names):
         self._validate_persistent(state)
         if attribute_names:
             _expire_state(state, state.dict, 
-                                attribute_names=attribute_names, instance_dict=self.identity_map)
+                                attribute_names=attribute_names, 
+                                instance_dict=self.identity_map)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
             cascaded = list(_cascade_state_iterator('refresh-expire', state))
-            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+            self._conditional_expire(state)
             for (state, m, o) in cascaded:
-                _expire_state(state, state.dict, None, instance_dict=self.identity_map)
-
+                self._conditional_expire(state)
+        
+    def _conditional_expire(self, state):
+        """Expire a state if persistent, else expunge if pending"""
+        
+        if state.key:
+            _expire_state(state, state.dict, None, instance_dict=self.identity_map)
+        elif state in self._new:
+            self._new.pop(state)
+            state.detach()
+        
     def prune(self):
         """Remove unreferenced instances cached in the identity map.
 
index 73b2120959dd7e91954714ec6d134e141c300737..bf911c2c2201425b98d38d0057fa9d4adae00e8f 100644 (file)
@@ -149,6 +149,18 @@ def sequences(fn):
         no_support('sybase', 'no SEQUENCE support'),
         )
 
+def update_nowait(fn):
+    """Target database must support SELECT...FOR UPDATE NOWAIT"""
+    return _chain_decorators_on(
+        fn,
+        no_support('access', 'no FOR UPDATE NOWAIT support'),
+        no_support('firebird', 'no FOR UPDATE NOWAIT support'),
+        no_support('mssql', 'no FOR UPDATE NOWAIT support'),
+        no_support('mysql', 'no FOR UPDATE NOWAIT support'),
+        no_support('sqlite', 'no FOR UPDATE NOWAIT support'),
+        no_support('sybase', 'no FOR UPDATE NOWAIT support'),
+    )
+    
 def subqueries(fn):
     """Target database must support subqueries."""
     return _chain_decorators_on(
index 2fe4bb15a31cac17a6068d6f118b435d7855cf8f..0b3e09a830d8272fc66c96d3fe11c2e649b8c73f 100644 (file)
@@ -9,7 +9,7 @@ from sqlalchemy.test.schema import Table
 from sqlalchemy.test.schema import Column
 from sqlalchemy.orm import mapper, relationship, create_session, \
                         attributes, deferred, exc as orm_exc, defer, undefer,\
-                        strategies, state, lazyload
+                        strategies, state, lazyload, backref
 from test.orm import _base, _fixtures
 
 
@@ -295,10 +295,62 @@ class ExpireTest(_fixtures.FixtureTest):
 
         u.addresses[0].email_address = 'someotheraddress'
         s.expire(u)
-        u.name
-        print attributes.instance_state(u).dict
         assert u.addresses[0].email_address == 'ed@wood.com'
 
+    @testing.resolve_artifact_names
+    def test_refresh_cascade(self):
+        mapper(User, users, properties={
+            'addresses':relationship(Address, cascade="all, refresh-expire")
+        })
+        mapper(Address, addresses)
+        s = create_session()
+        u = s.query(User).get(8)
+        assert u.addresses[0].email_address == 'ed@wood.com'
+
+        u.addresses[0].email_address = 'someotheraddress'
+        s.refresh(u)
+        assert u.addresses[0].email_address == 'ed@wood.com'
+
+    def test_expire_cascade_pending_orphan(self):
+        cascade = 'save-update, refresh-expire, delete, delete-orphan'
+        self._test_cascade_to_pending(cascade, True)
+
+    def test_refresh_cascade_pending_orphan(self):
+        cascade = 'save-update, refresh-expire, delete, delete-orphan'
+        self._test_cascade_to_pending(cascade, False)
+
+    def test_expire_cascade_pending(self):
+        cascade = 'save-update, refresh-expire'
+        self._test_cascade_to_pending(cascade, True)
+
+    def test_refresh_cascade_pending(self):
+        cascade = 'save-update, refresh-expire'
+        self._test_cascade_to_pending(cascade, False)
+        
+    @testing.resolve_artifact_names
+    def _test_cascade_to_pending(self, cascade, expire_or_refresh):
+        mapper(User, users, properties={
+            'addresses':relationship(Address, cascade=cascade)
+        })
+        mapper(Address, addresses)
+        s = create_session()
+
+        u = s.query(User).get(8)
+        a = Address(email_address='foobar')
+        
+        u.addresses.append(a)
+        if expire_or_refresh:
+            s.expire(u)
+        else:
+            s.refresh(u)
+        if "delete-orphan" in cascade:
+            assert a not in s
+        else:
+            assert a in s
+        
+        assert a not in u.addresses
+        s.flush()
+
     @testing.resolve_artifact_names
     def test_expired_lazy(self):
         mapper(User, users, properties={
index f146e57b8a36bb840be95fceb76f943983d27215..07e545bd16dacf1c56ac12e070afb0c6c8d76731 100644 (file)
@@ -1,6 +1,6 @@
 import sqlalchemy as sa
 from sqlalchemy.test import engines, testing
-from sqlalchemy import Integer, String, ForeignKey, literal_column, orm
+from sqlalchemy import Integer, String, ForeignKey, literal_column, orm, exc
 from sqlalchemy.test.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, create_session, column_property, sessionmaker
 from sqlalchemy.test.testing import eq_, ne_, assert_raises, assert_raises_message
@@ -19,6 +19,7 @@ def make_uuid():
     return _uuids.pop(0)
 
 class VersioningTest(_base.MappedTest):
+    
     @classmethod
     def define_tables(cls, metadata):
         Table('version_table', metadata,
@@ -130,14 +131,8 @@ class VersioningTest(_base.MappedTest):
                 s1.query(Foo).with_lockmode('read').get, f1s1.id
             )
 
-        # load, version is wrong
-        assert_raises(
-                sa.orm.exc.ConcurrentModificationError, 
-                s1.refresh, f1s1, lockmode='read'
-            )
-
-        # reload it
-        s1.query(Foo).populate_existing().get(f1s1.id)
+        # reload it - this expires the old version first
+        s1.refresh(f1s1, lockmode='read')
         
         # now assert version OK
         s1.query(Foo).with_lockmode('read').get(f1s1.id)
@@ -145,9 +140,36 @@ class VersioningTest(_base.MappedTest):
         # assert brand new load is OK too
         s1.close()
         s1.query(Foo).with_lockmode('read').get(f1s1.id)
+
+
+    @testing.emits_warning(r'.*does not support updated rowcount')
+    @engines.close_open_connections
+    @testing.requires.update_nowait
+    @testing.resolve_artifact_names
+    def test_versioncheck_for_update(self):
+        """query.with_lockmode performs a 'version check' on an already loaded instance"""
+
+        s1 = create_session(autocommit=False)
+
+        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+        f1s1 = Foo(value='f1 value')
+        s1.add(f1s1)
+        s1.commit()
+
+        s2 = create_session(autocommit=False)
+        f1s2 = s2.query(Foo).get(f1s1.id)
+        s2.refresh(f1s2, lockmode='update')
+        f1s2.value='f1 new value'
         
+        assert_raises(
+            exc.DBAPIError,
+            s1.refresh, f1s1, lockmode='update_nowait'
+        )
+        s1.rollback()
         
-        
+        s2.commit()
+        s1.refresh(f1s1, lockmode='update_nowait')
+        assert f1s1.version_id == f1s2.version_id
 
     @testing.emits_warning(r'.*does not support updated rowcount')
     @engines.close_open_connections