]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed session.refresh() with instance that has custom entity_name
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Dec 2007 18:22:21 +0000 (18:22 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Dec 2007 18:22:21 +0000 (18:22 +0000)
[ticket:914]

CHANGES
lib/sqlalchemy/orm/session.py
test/orm/entity.py

diff --git a/CHANGES b/CHANGES
index 68a73eb8e9da0d4ef66ce1e59709970ef388463d..2e4c88642b3cbe9372f981d7ef5614cf9efeb0b6 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -249,6 +249,9 @@ CHANGES
    - fixed bug which could arise when using session.begin_nested() in conjunction
      with more than one level deep of enclosing session.begin() statements
 
+   - fixed session.refresh() with instance that has custom entity_name 
+     [ticket:914]
+     
 - dialects
 
    - sqlite SLDate type will not erroneously render "microseconds" portion 
index f75d5c36c0e634e7c1b5729c009590bc0ee5be0c..dd9de2cd2c4b0cecdd48d6f2212520d3354e0a06 100644 (file)
@@ -747,7 +747,7 @@ class Session(object):
 
         self._validate_persistent(instance)
 
-        if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+        if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
             raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
 
     def expire(self, instance, attribute_names=None):
index ce267189f24c83f0df1d449a664232926ebae55b..5cea83e8a98651b733ae76965809989ef31a2e62 100644 (file)
@@ -83,12 +83,17 @@ class EntityTest(AssertMixin):
         assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
 
         ctx.current.clear()
-        u1list = ctx.current.query(User, entity_name='user1').select()
-        u2list = ctx.current.query(User, entity_name='user2').select()
+        u1list = ctx.current.query(User, entity_name='user1').all()
+        u2list = ctx.current.query(User, entity_name='user2').all()
         assert len(u1list) == len(u2list) == 1
         assert u1list[0] is not u2list[0]
         assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
 
+        u1 = ctx.current.query(User, entity_name='user1').first()
+        ctx.current.refresh(u1)
+        ctx.current.expire(u1)
+
+
     def testcascade(self):
         """same as testbasic but relies on session cascading"""
         class User(object):pass