From: Mike Bayer Date: Sun, 28 Dec 2008 21:48:12 +0000 (+0000) Subject: - Can pass mapped attributes and column objects as keys X-Git-Tag: rel_0_5_0~56 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=bd23baf4ac0f9dd520120445594bd00f1b760f4b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Can pass mapped attributes and column objects as keys to query.update({}). [ticket:1262] - Mapped attributes passed to the values() of an expression level insert() or update() will use the keys of the mapped columns, not that of the mapped attribute. --- diff --git a/CHANGES b/CHANGES index 791833d6fa..082e9e7b2b 100644 --- a/CHANGES +++ b/CHANGES @@ -64,6 +64,14 @@ CHANGES when using declarative. - Added ScopedSession.is_active accessor. [ticket:976] + + - Can pass mapped attributes and column objects as keys + to query.update({}). [ticket:1262] + + - Mapped attributes passed to the values() of an + expression level insert() or update() will use the + keys of the mapped columns, not that of the mapped + attribute. - Corrected problem with Query.delete() and Query.update() not working properly with bind diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 876471c132..c61c8d04c8 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1448,6 +1448,7 @@ class Query(object): value_evaluators = {} for key,value in values.items(): + key = expression._column_as_key(key) value_evaluators[key] = evaluator_compiler.process(expression._literal_as_binds(value)) except evaluator.UnevaluatableError: synchronize_session = 'expire' diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 5b62e1db6d..31fc9ae1e6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -653,12 +653,12 @@ class DefaultCompiler(engine.Compiled): if self.column_keys is None: parameters = {} else: - parameters = dict((getattr(key, 'key', key), None) + parameters = dict((sql._column_as_key(key), None) for key in self.column_keys) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): - parameters.setdefault(getattr(k, 'key', k), v) + parameters.setdefault(sql._column_as_key(k), v) # create a list of column assignment clauses as tuples values = [] diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b7d4965dd5..07df207dd9 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -901,6 +901,13 @@ def _labeled(element): else: return element +def _column_as_key(element): + if isinstance(element, basestring): + return element + if hasattr(element, '__clause_element__'): + element = element.__clause_element__() + return element.key + def _literal_as_text(element): if hasattr(element, '__clause_element__'): return element.__clause_element__() @@ -3496,10 +3503,6 @@ class _UpdateBase(ClauseElement): return s def _process_colparams(self, parameters): - - if parameters is None: - return None - if isinstance(parameters, (list, tuple)): pp = {} for i, c in enumerate(self.table.c): diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 64a329f842..35d6272699 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -24,6 +24,21 @@ class MapperTest(_fixtures.FixtureTest): }) self.assertRaises(sa.exc.ArgumentError, sa.orm.compile_mappers) + @testing.resolve_artifact_names + def test_update_attr_keys(self): + """test that update()/insert() use the correct key when given InstrumentedAttributes.""" + + mapper(User, users, properties={ + 'foobar':users.c.name + }) + + users.insert().values({User.foobar:'name1'}).execute() + eq_(sa.select([User.foobar]).where(User.foobar=='name1').execute().fetchall(), [('name1',)]) + + users.update().values({User.foobar:User.foobar + 'foo'}).execute() + eq_(sa.select([User.foobar]).where(User.foobar=='name1foo').execute().fetchall(), [('name1foo',)]) + + @testing.resolve_artifact_names def test_prop_accessor(self): mapper(User, users) diff --git a/test/orm/query.py b/test/orm/query.py index c617d860a9..076c1c9406 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -42,7 +42,6 @@ class QueryTest(FixtureTest): mapper(Keyword, keywords) compile_mappers() - #class_mapper(User).add_property('addresses', relation(Address, primaryjoin=User.id==Address.user_id, order_by=Address.id, backref='user')) class UnicodeSchemaTest(QueryTest): keep_mappers = False @@ -2664,6 +2663,15 @@ class UpdateDeleteTest(_base.MappedTest): eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27]) eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27])) + sess.query(User).filter(User.age > 29).update({User.age: User.age - 10}, synchronize_session='evaluate') + eq_([john.age, jack.age, jill.age, jane.age], [25,27,29,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,29,27])) + + sess.query(User).filter(User.age > 27).update({users.c.age: User.age - 10}, synchronize_session='evaluate') + eq_([john.age, jack.age, jill.age, jane.age], [25,27,19,27]) + eq_(sess.query(User.age).order_by(User.id).all(), zip([25,27,19,27])) + + @testing.resolve_artifact_names def test_update_with_bindparams(self): sess = create_session(bind=testing.db, autocommit=False)