From: Mike Bayer Date: Sat, 29 Dec 2007 18:22:21 +0000 (+0000) Subject: - fixed session.refresh() with instance that has custom entity_name X-Git-Tag: rel_0_4_2~13 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b528d58d497c100c92c1600349d587f2146dd6c2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fixed session.refresh() with instance that has custom entity_name [ticket:914] --- diff --git a/CHANGES b/CHANGES index 68a73eb8e9..2e4c88642b 100644 --- a/CHANGES +++ b/CHANGES @@ -249,6 +249,9 @@ CHANGES - fixed bug which could arise when using session.begin_nested() in conjunction with more than one level deep of enclosing session.begin() statements + - fixed session.refresh() with instance that has custom entity_name + [ticket:914] + - dialects - sqlite SLDate type will not erroneously render "microseconds" portion diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f75d5c36c0..dd9de2cd2c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -747,7 +747,7 @@ class Session(object): self._validate_persistent(instance) - if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None: + if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) def expire(self, instance, attribute_names=None): diff --git a/test/orm/entity.py b/test/orm/entity.py index ce267189f2..5cea83e8a9 100644 --- a/test/orm/entity.py +++ b/test/orm/entity.py @@ -83,12 +83,17 @@ class EntityTest(AssertMixin): assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')] ctx.current.clear() - u1list = ctx.current.query(User, entity_name='user1').select() - u2list = ctx.current.query(User, entity_name='user2').select() + u1list = ctx.current.query(User, entity_name='user1').all() + u2list = ctx.current.query(User, entity_name='user2').all() assert len(u1list) == len(u2list) == 1 assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + u1 = ctx.current.query(User, entity_name='user1').first() + ctx.current.refresh(u1) + ctx.current.expire(u1) + + def testcascade(self): """same as testbasic but relies on session cascading""" class User(object):pass