]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Can pass mapped attributes and column objects as keys
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Dec 2008 21:48:12 +0000 (21:48 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Dec 2008 21:48:12 +0000 (21:48 +0000)
to query.update({}).  [ticket:1262]

- Mapped attributes passed to the values() of an
expression level insert() or update() will use the
keys of the mapped columns, not that of the mapped
attribute.

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/orm/mapper.py
test/orm/query.py

diff --git a/CHANGES b/CHANGES
index 791833d6fa52dc349c39ac92a7ca3397a2756d25..082e9e7b2b3aba2bad7563dd67447f895586b0fc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -64,6 +64,14 @@ CHANGES
       when using declarative.
 
     - Added ScopedSession.is_active accessor. [ticket:976]
+    
+    - Can pass mapped attributes and column objects as keys
+      to query.update({}).  [ticket:1262]
+      
+    - Mapped attributes passed to the values() of an 
+      expression level insert() or update() will use the 
+      keys of the mapped columns, not that of the mapped 
+      attribute.
       
     - Corrected problem with Query.delete() and
       Query.update() not working properly with bind 
index 876471c1326e3f29476b197e2b167259ff7dafe3..c61c8d04c829c930c683ac8d67acc1bd74d278ec 100644 (file)
@@ -1448,6 +1448,7 @@ class Query(object):
 
                 value_evaluators = {}
                 for key,value in values.items():
+                    key = expression._column_as_key(key)
                     value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value))
             except evaluator.UnevaluatableError:
                 synchronize_session = 'expire'
index 5b62e1db6d55891ba7ad5f6fefad55f607a52ef5..31fc9ae1e60a5c4773d335979c60881101a592c1 100644 (file)
@@ -653,12 +653,12 @@ class DefaultCompiler(engine.Compiled):
         if self.column_keys is None:
             parameters = {}
         else:
-            parameters = dict((getattr(key, 'key', key), None)
+            parameters = dict((sql._column_as_key(key), None)
                               for key in self.column_keys)
 
         if stmt.parameters is not None:
             for k, v in stmt.parameters.iteritems():
-                parameters.setdefault(getattr(k, 'key', k), v)
+                parameters.setdefault(sql._column_as_key(k), v)
 
         # create a list of column assignment clauses as tuples
         values = []
index b7d4965dd5e7f7694661c19d3f511731a210ac8f..07df207dd971d1b4eb74fd910a115e4e553f1779 100644 (file)
@@ -901,6 +901,13 @@ def _labeled(element):
     else:
         return element
 
+def _column_as_key(element):
+    if isinstance(element, basestring):
+        return element
+    if hasattr(element, '__clause_element__'):
+        element = element.__clause_element__()
+    return element.key
+    
 def _literal_as_text(element):
     if hasattr(element, '__clause_element__'):
         return element.__clause_element__()
@@ -3496,10 +3503,6 @@ class _UpdateBase(ClauseElement):
         return s
 
     def _process_colparams(self, parameters):
-
-        if parameters is None:
-            return None
-
         if isinstance(parameters, (list, tuple)):
             pp = {}
             for i, c in enumerate(self.table.c):
index 64a329f8420259266d60052b50b1dd22e7e7cdcf..35d6272699caf471f6ec16778165e36a578d28bf 100644 (file)
@@ -24,6 +24,21 @@ class MapperTest(_fixtures.FixtureTest):
         })
         self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers)
 
+    @testing.resolve_artifact_names
+    def test_update_attr_keys(self):
+        """test that update()/insert() use the correct key when given InstrumentedAttributes."""
+        
+        mapper(User, users, properties={
+            'foobar':users.c.name
+        })
+
+        users.insert().values({User.foobar:'name1'}).execute()
+        eq_(sa.select([User.foobar]).where(User.foobar=='name1').execute().fetchall(), [('name1',)])
+
+        users.update().values({User.foobar:User.foobar + 'foo'}).execute()
+        eq_(sa.select([User.foobar]).where(User.foobar=='name1foo').execute().fetchall(), [('name1foo',)])
+        
+
     @testing.resolve_artifact_names
     def test_prop_accessor(self):
         mapper(User, users)
index c617d860a9f5871636f66726f0d0cfe0f1ecb111..076c1c9406dd9ece32b51bd58d5acba71848d7f5 100644 (file)
@@ -42,7 +42,6 @@ class QueryTest(FixtureTest):
         mapper(Keyword, keywords)
 
         compile_mappers()
-        #class_mapper(User).add_property('addresses', relation(Address, primaryjoin=User.id==Address.user_id, order_by=Address.id, backref='user'))
 
 class UnicodeSchemaTest(QueryTest):
     keep_mappers = False
@@ -2664,6 +2663,15 @@ class UpdateDeleteTest(_base.MappedTest):
         eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
 
+        sess.query(User).filter(User.age > 29).update({User.age: User.age - 10}, synchronize_session='evaluate')
+        eq_([john.age, jack.age, jill.age, jane.age], [25,27,29,27])
+        eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,29,27]))
+
+        sess.query(User).filter(User.age > 27).update({users.c.age: User.age - 10}, synchronize_session='evaluate')
+        eq_([john.age, jack.age, jill.age, jane.age], [25,27,19,27])
+        eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,19,27]))
+
+
     @testing.resolve_artifact_names
     def test_update_with_bindparams(self):
         sess = create_session(bind=testing.db, autocommit=False)