]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in query.update() where 'evaluate' or 'fetch'
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Oct 2010 00:55:42 +0000 (20:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Oct 2010 00:55:42 +0000 (20:55 -0400)
expiration would fail if the column expression key was
a class attribute with a different keyname as the
actual column name.  [ticket:1935]

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index a46b74b939ca9105daa73f6f4f71dc7528b9bb36..fdf0ce75e4523512b818c22388c80ad9ebc7e18e 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -47,6 +47,11 @@ CHANGES
     itself multiple times - you get the same message for
     each attempt at usage. The misnomer "compiles" is being
     traded out for "initialize".
+
+  - Fixed bug in query.update() where 'evaluate' or 'fetch'
+    expiration would fail if the column expression key was
+    a class attribute with a different keyname as the 
+    actual column name.  [ticket:1935]
     
   - Added an assertion during flush which ensures
     that no NULL-holding identity keys were generated
index c72c1fe9d3b79d7054bc82cbd43a868ac61be695..605f391aae0b65f35cdec067f49ac565e4b93733 100644 (file)
@@ -32,7 +32,7 @@ from sqlalchemy.orm import (
 from sqlalchemy.orm.util import (
     AliasedClass, ORMAdapter, _entity_descriptor, _entity_info,
     _is_aliased_class, _is_mapped_class, _orm_columns, _orm_selectable,
-    join as orm_join,with_parent
+    join as orm_join,with_parent, _attr_as_key
     )
 
 
@@ -2204,7 +2204,7 @@ class Query(object):
 
                 value_evaluators = {}
                 for key,value in values.iteritems():
-                    key = expression._column_as_key(key)
+                    key = _attr_as_key(key)
                     value_evaluators[key] = evaluator_compiler.process(
                                         expression._literal_as_binds(value))
             except evaluator.UnevaluatableError:
@@ -2259,7 +2259,7 @@ class Query(object):
                 if identity_key in session.identity_map:
                     session.expire(
                                 session.identity_map[identity_key], 
-                                [expression._column_as_key(k) for k in values]
+                                [_attr_as_key(k) for k in values]
                                 )
 
         for ext in session.extensions:
index d68ff4473f40f420f368dd5c63c9f43b186e6c74..f79a8449fc4a5a179bcc9a78772278193163043e 100644 (file)
@@ -582,6 +582,12 @@ def _orm_selectable(entity):
     mapper, selectable, is_aliased_class = _entity_info(entity)
     return selectable
 
+def _attr_as_key(attr):
+    if hasattr(attr, 'key'):
+        return attr.key
+    else:
+        return expression._column_as_key(attr)
+
 def _is_aliased_class(entity):
     return isinstance(entity, AliasedClass)
 
index 3ba1f0fa09641ec0c0af90e480a5daff9ef6ca58..91c09be63f13e469c097a5b6387d026dba7be0da 100644 (file)
@@ -4358,7 +4358,8 @@ class UpdateDeleteTest(_base.MappedTest):
     def setup_mappers(cls):
         mapper(User, users)
         mapper(Document, documents, properties={
-            'user': relationship(User, lazy='joined', backref=backref('documents', lazy='select'))
+            'user': relationship(User, lazy='joined', 
+                        backref=backref('documents', lazy='select'))
         })
 
     @testing.resolve_artifact_names
@@ -4477,6 +4478,34 @@ class UpdateDeleteTest(_base.MappedTest):
         eq_([john.age, jack.age, jill.age, jane.age], [15,27,19,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([15,27,19,27]))
 
+    @testing.resolve_artifact_names
+    @testing.provide_metadata
+    def test_update_attr_names(self):
+        data = Table('data', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('counter', Integer, nullable=False, default=0)
+        )
+        class Data(_base.ComparableEntity):
+            pass
+        
+        mapper(Data, data, properties={'cnt':data.c.counter})
+        metadata.create_all()
+        d1 = Data()
+        sess = Session()
+        sess.add(d1)
+        sess.commit()
+        eq_(d1.cnt, 0)
+
+        sess.query(Data).update({Data.cnt:Data.cnt + 1})
+        sess.flush()
+        
+        eq_(d1.cnt, 1)
+
+        sess.query(Data).update({Data.cnt:Data.cnt + 1}, 'fetch')
+        sess.flush()
+        
+        eq_(d1.cnt, 2)
+        sess.close()
 
     @testing.resolve_artifact_names
     def test_update_with_bindparams(self):