]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed the multiple-table "UPDATE..FROM" construct, only usable on
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Jan 2014 02:01:35 +0000 (21:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Jan 2014 02:17:42 +0000 (21:17 -0500)
MySQL, to correctly render the SET clause among multiple columns
with the same name across tables.  This also changes the name used for
the bound parameter in the SET clause to "<tablename>_<colname>" for
the non-primary table only; as this parameter is typically specified
using the :class:`.Column` object directly this should not have an
impact on applications.   The fix takes effect for both
:meth:`.Table.update` as well as :meth:`.Query.update` in the ORM.
[ticket:2912]

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/orm/test_update_delete.py
test/sql/test_update.py

index 0efffce6285ffe317277926d0b8aa8f280d3df52..d59f3ec604a5aaefe06dc2d48be8fe501a49bc7d 100644 (file)
         the ``CLUSTERED`` keyword to the constraint construct within DDL.
         Pullreq courtesy Derek Harland.
 
+    .. change::
+        :tags: bug, sql, orm
+        :tickets: 2912
+
+        Fixed the multiple-table "UPDATE..FROM" construct, only usable on
+        MySQL, to correctly render the SET clause among multiple columns
+        with the same name across tables.  This also changes the name used for
+        the bound parameter in the SET clause to "<tablename>_<colname>" for
+        the non-primary table only; as this parameter is typically specified
+        using the :class:`.Column` object directly this should not have an
+        impact on applications.   The fix takes effect for both
+        :meth:`.Table.update` as well as :meth:`.Query.update` in the ORM.
+
     .. change::
         :tags: bug, oracle
         :tickets: 2911
index e507885facd6f2d7569ecdb0fdf5378465f909c1..ed975b8cf3aa31213d9f2c564bff318e8b4c2777 100644 (file)
@@ -895,6 +895,8 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         and generate inserted_primary_key collection.
         """
 
+        key_getter = self.compiled._key_getters_for_crud_column[2]
+
         if self.executemany:
             if len(self.compiled.prefetch):
                 scalar_defaults = {}
@@ -918,7 +920,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                         else:
                             val = self.get_update_default(c)
                         if val is not None:
-                            param[c.key] = val
+                            param[key_getter(c)] = val
                 del self.current_parameters
         else:
             self.current_parameters = compiled_parameters = \
@@ -931,12 +933,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                     val = self.get_update_default(c)
 
                 if val is not None:
-                    compiled_parameters[c.key] = val
+                    compiled_parameters[key_getter(c)] = val
             del self.current_parameters
 
             if self.isinsert:
                 self.inserted_primary_key = [
-                                self.compiled_parameters[0].get(c.key, None)
+                                self.compiled_parameters[0].get(key_getter(c), None)
                                         for c in self.compiled.\
                                                 statement.table.primary_key
                                 ]
index 5c5bfad55136b688e00e45e7793a665a6af6ab21..4448f7c7b09cafcf2b32374464a36b3efd8d82d1 100644 (file)
@@ -28,6 +28,7 @@ from . import schema, sqltypes, operators, functions, \
 from .. import util, exc
 import decimal
 import itertools
+import operator
 
 RESERVED_WORDS = set([
     'all', 'analyse', 'analyze', 'and', 'any', 'array',
@@ -1771,7 +1772,7 @@ class SQLCompiler(Compiled):
         table_text = self.update_tables_clause(update_stmt, update_stmt.table,
                                                extra_froms, **kw)
 
-        colparams = self._get_colparams(update_stmt, extra_froms, **kw)
+        colparams = self._get_colparams(update_stmt, **kw)
 
         if update_stmt._hints:
             dialect_hints = dict([
@@ -1840,7 +1841,40 @@ class SQLCompiler(Compiled):
         bindparam._is_crud = True
         return bindparam._compiler_dispatch(self)
 
-    def _get_colparams(self, stmt, extra_tables=None, **kw):
+    @util.memoized_property
+    def _key_getters_for_crud_column(self):
+        if self.isupdate and self.statement._extra_froms:
+            # when extra tables are present, refer to the columns
+            # in those extra tables as table-qualified, including in
+            # dictionaries and when rendering bind param names.
+            # the "main" table of the statement remains unqualified,
+            # allowing the most compatibility with a non-multi-table
+            # statement.
+            _et = set(self.statement._extra_froms)
+            def _column_as_key(key):
+                str_key = elements._column_as_key(key)
+                if hasattr(key, 'table') and key.table in _et:
+                    return (key.table.name, str_key)
+                else:
+                    return str_key
+            def _getattr_col_key(col):
+                if col.table in _et:
+                    return (col.table.name, col.key)
+                else:
+                    return col.key
+            def _col_bind_name(col):
+                if col.table in _et:
+                    return "%s_%s" % (col.table.name, col.key)
+                else:
+                    return col.key
+
+        else:
+            _column_as_key = elements._column_as_key
+            _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+
+        return _column_as_key, _getattr_col_key, _col_bind_name
+
+    def _get_colparams(self, stmt, **kw):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
@@ -1869,12 +1903,18 @@ class SQLCompiler(Compiled):
         else:
             stmt_parameters = stmt.parameters
 
+        # getters - these are normally just column.key,
+        # but in the case of mysql multi-table update, the rules for
+        # .key must conditionally take tablename into account
+        _column_as_key, _getattr_col_key, _col_bind_name = \
+                                self._key_getters_for_crud_column
+
         # if we have statement parameters - set defaults in the
         # compiled params
         if self.column_keys is None:
             parameters = {}
         else:
-            parameters = dict((elements._column_as_key(key), REQUIRED)
+            parameters = dict((_column_as_key(key), REQUIRED)
                               for key in self.column_keys
                               if not stmt_parameters or
                               key not in stmt_parameters)
@@ -1884,7 +1924,7 @@ class SQLCompiler(Compiled):
 
         if stmt_parameters is not None:
             for k, v in stmt_parameters.items():
-                colkey = elements._column_as_key(k)
+                colkey = _column_as_key(k)
                 if colkey is not None:
                     parameters.setdefault(colkey, v)
                 else:
@@ -1892,7 +1932,9 @@ class SQLCompiler(Compiled):
                     # add it to values() in an "as-is" state,
                     # coercing right side to bound param
                     if elements._is_literal(v):
-                        v = self.process(elements.BindParameter(None, v, type_=k.type), **kw)
+                        v = self.process(
+                                elements.BindParameter(None, v, type_=k.type),
+                                **kw)
                     else:
                         v = self.process(v.self_group(), **kw)
 
@@ -1922,24 +1964,25 @@ class SQLCompiler(Compiled):
         postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
 
         check_columns = {}
+
         # special logic that only occurs for multi-table UPDATE
         # statements
-        if extra_tables and stmt_parameters:
+        if self.isupdate and stmt._extra_froms and stmt_parameters:
             normalized_params = dict(
                 (elements._clause_element_as_expr(c), param)
                 for c, param in stmt_parameters.items()
             )
-            assert self.isupdate
             affected_tables = set()
-            for t in extra_tables:
+            for t in stmt._extra_froms:
                 for c in t.c:
                     if c in normalized_params:
                         affected_tables.add(t)
-                        check_columns[c.key] = c
+                        check_columns[_getattr_col_key(c)] = c
                         value = normalized_params[c]
                         if elements._is_literal(value):
                             value = self._create_crud_bind_param(
-                                c, value, required=value is REQUIRED)
+                                c, value, required=value is REQUIRED,
+                                name=_col_bind_name(c))
                         else:
                             self.postfetch.append(c)
                             value = self.process(value.self_group(), **kw)
@@ -1954,12 +1997,18 @@ class SQLCompiler(Compiled):
                     elif c.onupdate is not None and not c.onupdate.is_sequence:
                         if c.onupdate.is_clause_element:
                             values.append(
-                                (c, self.process(c.onupdate.arg.self_group(), **kw))
+                                (c, self.process(
+                                            c.onupdate.arg.self_group(),
+                                            **kw)
+                                )
                             )
                             self.postfetch.append(c)
                         else:
                             values.append(
-                                (c, self._create_crud_bind_param(c, None))
+                                (c, self._create_crud_bind_param(
+                                        c, None, name=_col_bind_name(c)
+                                    )
+                                )
                             )
                             self.prefetch.append(c)
                     elif c.server_onupdate is not None:
@@ -1968,7 +2017,7 @@ class SQLCompiler(Compiled):
         if self.isinsert and stmt.select_names:
             # for an insert from select, we can only use names that
             # are given, so only select for those names.
-            cols = (stmt.table.c[elements._column_as_key(name)]
+            cols = (stmt.table.c[_column_as_key(name)]
                         for name in stmt.select_names)
         else:
             # iterate through all table columns to maintain
@@ -1976,14 +2025,15 @@ class SQLCompiler(Compiled):
             cols = stmt.table.columns
 
         for c in cols:
-            if c.key in parameters and c.key not in check_columns:
-                value = parameters.pop(c.key)
+            col_key = _getattr_col_key(c)
+            if col_key in parameters and col_key not in check_columns:
+                value = parameters.pop(col_key)
                 if elements._is_literal(value):
                     value = self._create_crud_bind_param(
                                     c, value, required=value is REQUIRED,
-                                    name=c.key
+                                    name=_col_bind_name(c)
                                         if not stmt._has_multi_parameters
-                                        else "%s_0" % c.key
+                                        else "%s_0" % _col_bind_name(c)
                                     )
                 else:
                     if isinstance(value, elements.BindParameter) and \
@@ -2119,12 +2169,12 @@ class SQLCompiler(Compiled):
 
         if parameters and stmt_parameters:
             check = set(parameters).intersection(
-                elements._column_as_key(k) for k in stmt.parameters
+                _column_as_key(k) for k in stmt.parameters
             ).difference(check_columns)
             if check:
                 raise exc.CompileError(
                     "Unconsumed column names: %s" %
-                    (", ".join(check))
+                    (", ".join("%s" % c for c in check))
                 )
 
         if stmt._has_multi_parameters:
@@ -2133,17 +2183,17 @@ class SQLCompiler(Compiled):
 
             values.extend(
                 [
-                        (
-                            c,
-                                self._create_crud_bind_param(
-                                        c, row[c.key],
-                                        name="%s_%d" % (c.key, i + 1)
-                                )
-                                if c.key in row else param
-                        )
-                        for (c, param) in values_0
-                    ]
-                    for i, row in enumerate(stmt.parameters[1:])
+                    (
+                        c,
+                            self._create_crud_bind_param(
+                                    c, row[c.key],
+                                    name="%s_%d" % (c.key, i + 1)
+                            )
+                            if c.key in row else param
+                    )
+                    for (c, param) in values_0
+                ]
+                for i, row in enumerate(stmt.parameters[1:])
             )
 
         return values
index 6915ac8a253082d8a6e955335a1c76b3ad3f02c4..ac94fde2faa2163e4d44662bb33ede1ccc67f326 100644 (file)
@@ -545,12 +545,14 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
     def define_tables(cls, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
+              Column('samename', String(10)),
             )
         Table('documents', metadata,
               Column('id', Integer, primary_key=True),
               Column('user_id', None, ForeignKey('users.id')),
               Column('title', String(32)),
-              Column('flag', Boolean)
+              Column('flag', Boolean),
+              Column('samename', String(10)),
         )
 
     @classmethod
@@ -659,6 +661,34 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
                 ])
         )
 
+    @testing.only_on('mysql', 'Multi table update')
+    def test_update_from_multitable_same_names(self):
+        Document = self.classes.Document
+        User = self.classes.User
+
+        s = Session()
+
+        s.query(Document).\
+            filter(User.id == Document.user_id).\
+            filter(User.id == 2).update({
+                    Document.samename: 'd_samename',
+                    User.samename: 'u_samename'
+                }
+            )
+        eq_(
+            s.query(User.id, Document.samename, User.samename).
+                filter(User.id == Document.user_id).
+                order_by(User.id).all(),
+            [
+                (1, None, None),
+                (1, None, None),
+                (2, 'd_samename', 'u_samename'),
+                (2, 'd_samename', 'u_samename'),
+                (3, None, None),
+                (3, None, None),
+            ]
+        )
+
 class ExpressionUpdateTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
@@ -786,3 +816,5 @@ class InheritTest(fixtures.DeclarativeMappedTest):
             set(s.query(Person.name, Engineer.engineer_name)),
             set([('e1', 'e1', ), ('e22', 'e55')])
         )
+
+
index a8510f3747a08bc4681c5528b21ee2c5256205d1..10306372b2078d5d958d4348034669d9c90f4661 100644 (file)
@@ -192,22 +192,6 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s',
             dialect=mysql.dialect())
 
-    def test_alias(self):
-        table1 = self.tables.mytable
-        talias1 = table1.alias('t1')
-
-        self.assert_compile(update(talias1, talias1.c.myid == 7),
-            'UPDATE mytable AS t1 '
-            'SET name=:name '
-            'WHERE t1.myid = :myid_1',
-            params={table1.c.name: 'fred'})
-
-        self.assert_compile(update(talias1, table1.c.myid == 7),
-            'UPDATE mytable AS t1 '
-            'SET name=:name '
-            'FROM mytable '
-            'WHERE mytable.myid = :myid_1',
-            params={table1.c.name: 'fred'})
 
     def test_update_to_expression(self):
         """test update from an expression.
@@ -268,6 +252,64 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest,
 
     run_create_tables = run_inserts = run_deletes = None
 
+    def test_alias_one(self):
+        table1 = self.tables.mytable
+        talias1 = table1.alias('t1')
+
+        # this case is nonsensical.  the UPDATE is entirely
+        # against the alias, but we name the table-bound column
+        # in values.   The behavior here isn't really defined
+        self.assert_compile(
+            update(talias1, talias1.c.myid == 7).
+                values({table1.c.name: "fred"}),
+            'UPDATE mytable AS t1 '
+            'SET name=:name '
+            'WHERE t1.myid = :myid_1')
+
+    def test_alias_two(self):
+        table1 = self.tables.mytable
+        talias1 = table1.alias('t1')
+
+        # Here, compared to
+        # test_alias_one(), here we actually have UPDATE..FROM,
+        # which is causing the "table1.c.name" param to be handled
+        # as an "extra table", hence we see the full table name rendered.
+        self.assert_compile(
+            update(talias1, table1.c.myid == 7).
+                values({table1.c.name: 'fred'}),
+            'UPDATE mytable AS t1 '
+            'SET name=:mytable_name '
+            'FROM mytable '
+            'WHERE mytable.myid = :myid_1',
+            checkparams={'mytable_name': 'fred', 'myid_1': 7},
+            )
+
+    def test_alias_two_mysql(self):
+        table1 = self.tables.mytable
+        talias1 = table1.alias('t1')
+
+        self.assert_compile(
+            update(talias1, table1.c.myid == 7).
+                values({table1.c.name: 'fred'}),
+            "UPDATE mytable AS t1, mytable SET mytable.name=%s "
+            "WHERE mytable.myid = %s",
+            checkparams={'mytable_name': 'fred', 'myid_1': 7},
+            dialect='mysql')
+
+    def test_update_from_multitable_same_name_mysql(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        self.assert_compile(
+            users.update().
+                values(name='newname').\
+                values({addresses.c.name: "new address"}).\
+                where(users.c.id == addresses.c.user_id),
+            "UPDATE users, addresses SET addresses.name=%s, "
+                "users.name=%s WHERE users.id = addresses.user_id",
+            checkparams={u'addresses_name': 'new address', 'name': 'newname'},
+            dialect='mysql'
+        )
+
     def test_render_table(self):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -455,6 +497,36 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (10, 'chuck')]
         self._assert_users(users, expected)
 
+    @testing.only_on('mysql', 'Multi table update')
+    def test_exec_multitable_same_name(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        values = {
+            addresses.c.name: 'ad_ed2',
+            users.c.name: 'ed2'
+        }
+
+        testing.db.execute(
+            addresses.update().
+                values(values).
+                where(users.c.id == addresses.c.user_id).
+                where(users.c.name == 'ed'))
+
+        expected = [
+            (1, 7, 'x', 'jack@bean.com'),
+            (2, 8, 'ad_ed2', 'ed@wood.com'),
+            (3, 8, 'ad_ed2', 'ed@bettyboop.com'),
+            (4, 8, 'ad_ed2', 'ed@lala.com'),
+            (5, 9, 'x', 'fred@fred.com')]
+        self._assert_addresses(addresses, expected)
+
+        expected = [
+            (7, 'jack'),
+            (8, 'ed2'),
+            (9, 'fred'),
+            (10, 'chuck')]
+        self._assert_users(users, expected)
+
     def _assert_addresses(self, addresses, expected):
         stmt = addresses.select().order_by(addresses.c.id)
         eq_(testing.db.execute(stmt).fetchall(), expected)
@@ -478,7 +550,16 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase,
               Column('id', Integer, primary_key=True,
                      test_needs_autoincrement=True),
               Column('user_id', None, ForeignKey('users.id')),
-              Column('email_address', String(50), nullable=False))
+              Column('email_address', String(50), nullable=False),
+            )
+
+        Table('foobar', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('user_id', None, ForeignKey('users.id')),
+              Column('data', String(30)),
+              Column('some_update', String(30), onupdate='im the other update')
+            )
 
     @classmethod
     def fixtures(cls):
@@ -494,6 +575,12 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase,
                 (3, 8, 'ed@bettyboop.com'),
                 (4, 9, 'fred@fred.com')
             ),
+            foobar=(
+                ('id', 'user_id', 'data'),
+                (2, 8, 'd1'),
+                (3, 8, 'd2'),
+                (4, 9, 'd3')
+            )
         )
 
     @testing.only_on('mysql', 'Multi table update')
@@ -524,6 +611,37 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase,
             (9, 'fred', 'value')]
         self._assert_users(users, expected)
 
+    @testing.only_on('mysql', 'Multi table update')
+    def test_defaults_second_table_same_name(self):
+        users, foobar = self.tables.users, self.tables.foobar
+
+        values = {
+            foobar.c.data: foobar.c.data + 'a',
+            users.c.name: 'ed2'
+        }
+
+        ret = testing.db.execute(
+            users.update().
+                values(values).
+                where(users.c.id == foobar.c.user_id).
+                where(users.c.name == 'ed'))
+
+        eq_(
+            set(ret.prefetch_cols()),
+            set([users.c.some_update, foobar.c.some_update])
+        )
+
+        expected = [
+            (2, 8, 'd1a', 'im the other update'),
+            (3, 8, 'd2a', 'im the other update'),
+            (4, 9, 'd3', None)]
+        self._assert_foobar(foobar, expected)
+
+        expected = [
+            (8, 'ed2', 'im the update'),
+            (9, 'fred', 'value')]
+        self._assert_users(users, expected)
+
     @testing.only_on('mysql', 'Multi table update')
     def test_no_defaults_second_table(self):
         users, addresses = self.tables.users, self.tables.addresses
@@ -548,6 +666,10 @@ class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase,
             (9, 'fred', 'value')]
         self._assert_users(users, expected)
 
+    def _assert_foobar(self, foobar, expected):
+        stmt = foobar.select().order_by(foobar.c.id)
+        eq_(testing.db.execute(stmt).fetchall(), expected)
+
     def _assert_addresses(self, addresses, expected):
         stmt = addresses.select().order_by(addresses.c.id)
         eq_(testing.db.execute(stmt).fetchall(), expected)