]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support parameters in all ORM insert modes
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Apr 2023 15:56:56 +0000 (11:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Apr 2023 19:45:04 +0000 (15:45 -0400)
Fixed 2.0 regression where use of :func:`_sql.bindparam()` inside of
:meth:`_dml.Insert.values` would fail to be interpreted correctly when
executing the :class:`_dml.Insert` statement using the ORM
:class:`_orm.Session`, due to the new ORM-enabled insert feature not
implementing this use case.

In addition, the bulk INSERT and UPDATE features now add these
capabilities:

* The requirement that extra parameters aren't passed when using ORM
INSERT using the "orm" dml_strategy setting is lifted.
* The requirement that additional WHERE criteria is not passed when using
ORM UPDATE using the "bulk" dml_strategy setting is lifted.  Note that
in this case, the check for expected row count is turned off.

Fixes: #9583
Change-Id: I539c18893b697caeab5a5f0195a27d4f0487e728

doc/build/changelog/unreleased_20/9583.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/elements.py
test/orm/dml/test_bulk_statements.py
test/orm/dml/test_update_delete_where.py
test/sql/test_utils.py

diff --git a/doc/build/changelog/unreleased_20/9583.rst b/doc/build/changelog/unreleased_20/9583.rst
new file mode 100644 (file)
index 0000000..81555f4
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9583, 9595
+
+    Fixed 2.0 regression where use of :func:`_sql.bindparam()` inside of
+    :meth:`_dml.Insert.values` would fail to be interpreted correctly when
+    executing the :class:`_dml.Insert` statement using the ORM
+    :class:`_orm.Session`, due to the new ORM-enabled insert feature not
+    implementing this use case.
+
+    In addition, the bulk INSERT and UPDATE features now add these
+    capabilities:
+
+    * The requirement that extra parameters aren't passed when using ORM
+      INSERT using the "orm" dml_strategy setting is lifted.
+    * The requirement that additional WHERE criteria is not passed when using
+      ORM UPDATE using the "bulk" dml_strategy setting is lifted.  Note that
+      in this case, the check for expected row count is turned off.
index 8388d398048b931b02cde95472e522ebf9474ae9..cb416d69e1e9df955aaa0c2306fe0fed9a9b763e 100644 (file)
@@ -150,6 +150,23 @@ def _bulk_insert(
 
     for table, super_mapper in mappers_to_run:
 
+        # find bindparams in the statement. For bulk, we don't really know if
+        # a key in the params applies to a different table since we are
+        # potentially inserting for multiple tables here; looking at the
+        # bindparam() is a lot more direct.   in most cases this will
+        # use _generate_cache_key() which is memoized, although in practice
+        # the ultimate statement that's executed is probably not the same
+        # object so that memoization might not matter much.
+        extra_bp_names = (
+            [
+                b.key
+                for b in use_orm_insert_stmt._get_embedded_bindparams()
+                if b.key in mappings[0]
+            ]
+            if use_orm_insert_stmt is not None
+            else ()
+        )
+
         records = (
             (
                 None,
@@ -176,6 +193,7 @@ def _bulk_insert(
                 bulk=True,
                 return_defaults=bookkeeping,
                 render_nulls=render_nulls,
+                include_bulk_keys=extra_bp_names,
             )
         )
 
@@ -218,6 +236,7 @@ def _bulk_update(
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Literal[None] = ...,
+    enable_check_rowcount: bool = True,
 ) -> None:
     ...
 
@@ -230,6 +249,7 @@ def _bulk_update(
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = ...,
+    enable_check_rowcount: bool = True,
 ) -> _result.Result[Any]:
     ...
 
@@ -241,6 +261,7 @@ def _bulk_update(
     isstates: bool,
     update_changed_only: bool,
     use_orm_update_stmt: Optional[dml.Update] = None,
+    enable_check_rowcount: bool = True,
 ) -> Optional[_result.Result[Any]]:
     base_mapper = mapper.base_mapper
 
@@ -272,6 +293,18 @@ def _bulk_update(
 
     connection = session_transaction.connection(base_mapper)
 
+    # find bindparams in the statement. see _bulk_insert for similar
+    # notes for the insert case
+    extra_bp_names = (
+        [
+            b.key
+            for b in use_orm_update_stmt._get_embedded_bindparams()
+            if b.key in mappings[0]
+        ]
+        if use_orm_update_stmt is not None
+        else ()
+    )
+
     for table, super_mapper in base_mapper._sorted_tables.items():
         if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
             continue
@@ -295,6 +328,7 @@ def _bulk_update(
             ),
             bulk=True,
             use_orm_update_stmt=use_orm_update_stmt,
+            include_bulk_keys=extra_bp_names,
         )
         persistence._emit_update_statements(
             base_mapper,
@@ -304,6 +338,7 @@ def _bulk_update(
             records,
             bookkeeping=False,
             use_orm_update_stmt=use_orm_update_stmt,
+            enable_check_rowcount=enable_check_rowcount,
         )
 
     if use_orm_update_stmt is not None:
@@ -588,6 +623,7 @@ class BulkUDCompileState(ORMDMLState):
         is_multitable: bool = False,
         is_update_from: bool = False,
         is_delete_using: bool = False,
+        is_executemany: bool = False,
     ) -> bool:
         raise NotImplementedError()
 
@@ -639,11 +675,6 @@ class BulkUDCompileState(ORMDMLState):
         else:
             if update_options._dml_strategy == "auto":
                 update_options += {"_dml_strategy": "bulk"}
-            elif update_options._dml_strategy == "orm":
-                raise sa_exc.InvalidRequestError(
-                    'Can\'t use "orm" ORM insert strategy with a '
-                    "separate parameter list"
-                )
 
         sync = update_options._synchronize_session
         if sync is not None:
@@ -1062,6 +1093,7 @@ class BulkUDCompileState(ORMDMLState):
                 mapper,
                 is_update_from=update_options._is_update_from,
                 is_delete_using=update_options._is_delete_using,
+                is_executemany=orm_context.is_executemany,
             )
 
             if can_use_returning is not None:
@@ -1071,6 +1103,12 @@ class BulkUDCompileState(ORMDMLState):
                         "backends where some support RETURNING and others "
                         "don't"
                     )
+            elif orm_context.is_executemany and not per_bind_result:
+                raise sa_exc.InvalidRequestError(
+                    "For synchronize_session='fetch', can't use multiple "
+                    "parameter sets in ORM mode, which this backend does not "
+                    "support with RETURNING"
+                )
             else:
                 can_use_returning = per_bind_result
 
@@ -1146,11 +1184,6 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
         else:
             if insert_options._dml_strategy == "auto":
                 insert_options += {"_dml_strategy": "bulk"}
-            elif insert_options._dml_strategy == "orm":
-                raise sa_exc.InvalidRequestError(
-                    'Can\'t use "orm" ORM insert strategy with a '
-                    "separate parameter list"
-                )
 
         if insert_options._dml_strategy != "raw":
             # for ORM object loading, like ORMContext, we have to disable
@@ -1512,12 +1545,20 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         result: _result.Result[Any]
 
         if update_options._dml_strategy == "bulk":
-            if statement._where_criteria:
+            enable_check_rowcount = not statement._where_criteria
+
+            assert update_options._synchronize_session != "fetch"
+
+            if (
+                statement._where_criteria
+                and update_options._synchronize_session == "evaluate"
+            ):
                 raise sa_exc.InvalidRequestError(
-                    "WHERE clause with bulk ORM UPDATE not "
-                    "supported right now.   Statement may be invoked at the "
-                    "Core level using "
-                    "session.connection().execute(stmt, parameters)"
+                    "bulk synchronize of persistent objects not supported "
+                    "when using bulk update with additional WHERE "
+                    "criteria right now.  add synchronize_session=None "
+                    "execution option to bypass synchronize of persistent "
+                    "objects."
                 )
             mapper = update_options._subject_mapper
             assert mapper is not None
@@ -1532,6 +1573,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
                 isstates=False,
                 update_changed_only=False,
                 use_orm_update_stmt=statement,
+                enable_check_rowcount=enable_check_rowcount,
             )
             return cls.orm_setup_cursor_result(
                 session,
@@ -1560,6 +1602,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         is_multitable: bool = False,
         is_update_from: bool = False,
         is_delete_using: bool = False,
+        is_executemany: bool = False,
     ) -> bool:
 
         # normal answer for "should we use RETURNING" at all.
@@ -1569,6 +1612,9 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         if not normal_answer:
             return False
 
+        if is_executemany:
+            return dialect.update_executemany_returning
+
         # these workarounds are currently hypothetical for UPDATE,
         # unlike DELETE where they impact MariaDB
         if is_update_from:
@@ -1869,6 +1915,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
         is_multitable: bool = False,
         is_update_from: bool = False,
         is_delete_using: bool = False,
+        is_executemany: bool = False,
     ) -> bool:
 
         # normal answer for "should we use RETURNING" at all.
index 1af55df00b7beeea3d9b49d8f91a114dc9c69047..6fa338ced6a445c3eec6ad8f7130843a62cef3a5 100644 (file)
@@ -326,9 +326,11 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction):
 def _collect_insert_commands(
     table,
     states_to_insert,
+    *,
     bulk=False,
     return_defaults=False,
     render_nulls=False,
+    include_bulk_keys=(),
 ):
     """Identify sets of values to use in INSERT statements for a
     list of states.
@@ -401,10 +403,14 @@ def _collect_insert_commands(
                 None
             )
 
-        if bulk and mapper._set_polymorphic_identity:
-            params.setdefault(
-                mapper._polymorphic_attr_key, mapper.polymorphic_identity
-            )
+        if bulk:
+            if mapper._set_polymorphic_identity:
+                params.setdefault(
+                    mapper._polymorphic_attr_key, mapper.polymorphic_identity
+                )
+
+            if include_bulk_keys:
+                params.update((k, state_dict[k]) for k in include_bulk_keys)
 
         yield (
             state,
@@ -422,8 +428,10 @@ def _collect_update_commands(
     uowtransaction,
     table,
     states_to_update,
+    *,
     bulk=False,
     use_orm_update_stmt=None,
+    include_bulk_keys=(),
 ):
     """Identify sets of values to use in UPDATE statements for a
     list of states.
@@ -581,6 +589,9 @@ def _collect_update_commands(
                         "key value on column %s" % (table, col)
                     )
 
+        if include_bulk_keys:
+            params.update((k, state_dict[k]) for k in include_bulk_keys)
+
         if params or value_params:
             params.update(pk_params)
             yield (
@@ -712,8 +723,10 @@ def _emit_update_statements(
     mapper,
     table,
     update,
+    *,
     bookkeeping=True,
     use_orm_update_stmt=None,
+    enable_check_rowcount=True,
 ):
     """Emit UPDATE statements corresponding to value lists collected
     by _collect_update_commands()."""
@@ -847,10 +860,10 @@ def _emit_update_statements(
                         c.returned_defaults,
                     )
                 rows += c.rowcount
-                check_rowcount = assert_singlerow
+                check_rowcount = enable_check_rowcount and assert_singlerow
         else:
             if not allow_executemany:
-                check_rowcount = assert_singlerow
+                check_rowcount = enable_check_rowcount and assert_singlerow
                 for (
                     state,
                     state_dict,
@@ -883,8 +896,9 @@ def _emit_update_statements(
             else:
                 multiparams = [rec[2] for rec in records]
 
-                check_rowcount = assert_multirow or (
-                    assert_singlerow and len(multiparams) == 1
+                check_rowcount = enable_check_rowcount and (
+                    assert_multirow
+                    or (assert_singlerow and len(multiparams) == 1)
                 )
 
                 c = connection.execute(
@@ -941,6 +955,7 @@ def _emit_insert_statements(
     mapper,
     table,
     insert,
+    *,
     bookkeeping=True,
     use_orm_insert_stmt=None,
     execution_options=None,
index ff47ec79d88f40414709a86d43447ca0debb5ddf..2e32da75408a63fc0cce2d7b42b14336efe65f25 100644 (file)
@@ -502,6 +502,28 @@ class ClauseElement(
             connection, distilled_params, execution_options
         ).scalar()
 
+    def _get_embedded_bindparams(self) -> Sequence[BindParameter[Any]]:
+        """Return the list of :class:`.BindParameter` objects embedded in the
+        object.
+
+        This accomplishes the same purpose as ``visitors.traverse()`` or
+        similar would provide, however by making use of the cache key
+        it takes advantage of memoization of the key to result in fewer
+        net method calls, assuming the statement is also going to be
+        executed.
+
+        """
+
+        key = self._generate_cache_key()
+        if key is None:
+            bindparams: List[BindParameter[Any]] = []
+
+            traverse(self, {}, {"bindparam": bindparams.append})
+            return bindparams
+
+        else:
+            return key.bindparams
+
     def unique_params(
         self,
         __optionaldict: Optional[Dict[str, Any]] = None,
index 84ea7c82c9caab7482245a7e3fcbf290dda006f2..ab03b251d1276a6400c9b61a8d14fae52adac231 100644 (file)
@@ -7,6 +7,7 @@ from typing import Optional
 from typing import Set
 import uuid
 
+from sqlalchemy import bindparam
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
@@ -14,6 +15,7 @@ from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import insert
 from sqlalchemy import inspect
+from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import select
@@ -226,6 +228,310 @@ class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
 
         eq_(result.all(), [User(id=1, name="John", age=30)])
 
+    @testing.variation(
+        "use_returning", [(True, testing.requires.insert_returning), False]
+    )
+    @testing.variation("use_multiparams", [True, False])
+    @testing.variation("bindparam_in_expression", [True, False])
+    @testing.combinations(
+        "auto", "raw", "bulk", "orm", argnames="dml_strategy"
+    )
+    def test_alt_bindparam_names(
+        self,
+        use_returning,
+        decl_base,
+        use_multiparams,
+        dml_strategy,
+        bindparam_in_expression,
+    ):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        s = fixture_session()
+
+        if bindparam_in_expression:
+            stmt = insert(A).values(y=literal(3) * (bindparam("q") + 15))
+        else:
+            stmt = insert(A).values(y=bindparam("q"))
+
+        if dml_strategy != "auto":
+            # it really should work with any strategy
+            stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+        if use_returning:
+            stmt = stmt.returning(A.x, A.y)
+
+        if use_multiparams:
+            if bindparam_in_expression:
+                expected_qs = [60, 69, 81]
+            else:
+                expected_qs = [5, 8, 12]
+
+            result = s.execute(
+                stmt,
+                [
+                    {"q": 5, "x": 10},
+                    {"q": 8, "x": 11},
+                    {"q": 12, "x": 12},
+                ],
+            )
+        else:
+            if bindparam_in_expression:
+                expected_qs = [60]
+            else:
+                expected_qs = [5]
+
+            result = s.execute(stmt, {"q": 5, "x": 10})
+        if use_returning:
+            if use_multiparams:
+                eq_(
+                    result.all(),
+                    [
+                        (10, expected_qs[0]),
+                        (11, expected_qs[1]),
+                        (12, expected_qs[2]),
+                    ],
+                )
+            else:
+                eq_(result.first(), (10, expected_qs[0]))
+
+
+class UpdateStmtTest(fixtures.TestBase):
+    __backend__ = True
+
+    @testing.variation(
+        "returning_executemany",
+        [
+            ("returning", testing.requires.update_returning),
+            "executemany",
+            "plain",
+        ],
+    )
+    @testing.variation("bindparam_in_expression", [True, False])
+    # TODO: setting "bulk" here is all over the place as well, UPDATE is not
+    # too settled
+    @testing.combinations("auto", "orm", argnames="dml_strategy")
+    @testing.combinations(
+        "evaluate", "fetch", None, argnames="synchronize_strategy"
+    )
+    def test_alt_bindparam_names(
+        self,
+        decl_base,
+        returning_executemany,
+        dml_strategy,
+        bindparam_in_expression,
+        synchronize_strategy,
+    ):
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        s = fixture_session()
+
+        s.add_all(
+            [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+        )
+        s.commit()
+
+        if bindparam_in_expression:
+            stmt = (
+                update(A)
+                .values(y=literal(3) * (bindparam("q") + 15))
+                .where(A.id == bindparam("b_id"))
+            )
+        else:
+            stmt = (
+                update(A)
+                .values(y=bindparam("q"))
+                .where(A.id == bindparam("b_id"))
+            )
+
+        if dml_strategy != "auto":
+            # it really should work with any strategy
+            stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+        if returning_executemany.returning:
+            stmt = stmt.returning(A.x, A.y)
+
+        if synchronize_strategy in (None, "evaluate", "fetch"):
+            stmt = stmt.execution_options(
+                synchronize_session=synchronize_strategy
+            )
+
+        if returning_executemany.executemany:
+            if bindparam_in_expression:
+                expected_qs = [60, 69, 81]
+            else:
+                expected_qs = [5, 8, 12]
+
+            if dml_strategy != "orm":
+                params = [
+                    {"id": 1, "b_id": 1, "q": 5, "x": 10},
+                    {"id": 2, "b_id": 2, "q": 8, "x": 11},
+                    {"id": 3, "b_id": 3, "q": 12, "x": 12},
+                ]
+            else:
+                params = [
+                    {"b_id": 1, "q": 5, "x": 10},
+                    {"b_id": 2, "q": 8, "x": 11},
+                    {"b_id": 3, "q": 12, "x": 12},
+                ]
+
+            _expect_raises = None
+
+            if synchronize_strategy == "fetch":
+                if dml_strategy != "orm":
+                    _expect_raises = expect_raises_message(
+                        exc.InvalidRequestError,
+                        r"The 'fetch' synchronization strategy is not "
+                        r"available for 'bulk' ORM updates "
+                        r"\(i.e. multiple parameter sets\)",
+                    )
+                elif not testing.db.dialect.update_executemany_returning:
+                    # no backend supports this except Oracle
+                    _expect_raises = expect_raises_message(
+                        exc.InvalidRequestError,
+                        r"For synchronize_session='fetch', can't use multiple "
+                        r"parameter sets in ORM mode, which this backend does "
+                        r"not support with RETURNING",
+                    )
+
+            elif synchronize_strategy == "evaluate" and dml_strategy != "orm":
+                _expect_raises = expect_raises_message(
+                    exc.InvalidRequestError,
+                    "bulk synchronize of persistent objects not supported",
+                )
+
+            if _expect_raises:
+                with _expect_raises:
+                    result = s.execute(stmt, params)
+                return
+
+            result = s.execute(stmt, params)
+        else:
+            if bindparam_in_expression:
+                expected_qs = [60]
+            else:
+                expected_qs = [5]
+
+            result = s.execute(stmt, {"b_id": 1, "q": 5, "x": 10})
+
+        if returning_executemany.returning:
+            eq_(result.first(), (10, expected_qs[0]))
+
+        elif returning_executemany.executemany:
+            eq_(
+                s.execute(select(A.x, A.y).order_by(A.id)).all(),
+                [
+                    (10, expected_qs[0]),
+                    (11, expected_qs[1]),
+                    (12, expected_qs[2]),
+                ],
+            )
+
+    def test_bulk_update_w_where_one(self, decl_base):
+        """test use case in #9595"""
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+
+            x: Mapped[int]
+            y: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        s = fixture_session()
+
+        s.add_all(
+            [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+        )
+        s.commit()
+
+        stmt = (
+            update(A)
+            .where(A.x > 1)
+            .execution_options(synchronize_session=None)
+        )
+
+        s.execute(
+            stmt,
+            [
+                {"id": 1, "x": 3, "y": 8},
+                {"id": 2, "x": 5, "y": 9},
+                {"id": 3, "x": 12, "y": 15},
+            ],
+        )
+
+        eq_(
+            s.execute(select(A.id, A.x, A.y).order_by(A.id)).all(),
+            [(1, 1, 1), (2, 5, 9), (3, 12, 15)],
+        )
+
+    def test_bulk_update_w_where_two(self, decl_base):
+        class User(decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=False
+            )
+            name: Mapped[str]
+            age: Mapped[int]
+
+        decl_base.metadata.create_all(testing.db)
+
+        sess = fixture_session()
+        sess.execute(
+            insert(User),
+            [
+                dict(id=1, name="john", age=25),
+                dict(id=2, name="jack", age=47),
+                dict(id=3, name="jill", age=29),
+                dict(id=4, name="jane", age=37),
+            ],
+        )
+
+        sess.execute(
+            update(User)
+            .where(User.age > bindparam("gtage"))
+            .values(age=bindparam("dest_age"))
+            .execution_options(synchronize_session=None),
+            [
+                {"id": 1, "gtage": 28, "dest_age": 40},
+                {"id": 2, "gtage": 20, "dest_age": 45},
+            ],
+        )
+
+        eq_(
+            sess.execute(
+                select(User.id, User.name, User.age).order_by(User.id)
+            ).all(),
+            [
+                (1, "john", 25),
+                (2, "jack", 45),
+                (3, "jill", 29),
+                (4, "jane", 37),
+            ],
+        )
+
 
 class BulkDMLReturningInhTest:
     use_sentinel = False
@@ -965,7 +1271,10 @@ class BulkDMLReturningInhTest:
 
         eq_(coll(ids), coll(actual_ids))
 
-    @testing.variation("insert_strategy", ["orm", "bulk", "bulk_ordered"])
+    @testing.variation(
+        "insert_strategy",
+        ["orm", "bulk", "bulk_ordered", "bulk_w_embedded_bindparam"],
+    )
     @testing.requires.provisioned_upsert
     def test_base_class_upsert(self, insert_strategy):
         """upsert is really tricky.   if you dont have any data updated,
@@ -1036,6 +1345,15 @@ class BulkDMLReturningInhTest:
                 sort_by_parameter_order=insert_strategy.bulk_ordered
             ):
                 result = s.scalars(stmt, upsert_data)
+        elif insert_strategy.bulk_w_embedded_bindparam:
+            # test related to #9583, specific user case in
+            # https://github.com/sqlalchemy/sqlalchemy/discussions/9581#discussioncomment-5504077  # noqa: E501
+            stmt = stmt.values(
+                y=select(bindparam("qq1", type_=Integer)).scalar_subquery()
+            )
+            for d in upsert_data:
+                d["qq1"] = d.pop("y")
+            result = s.scalars(stmt, upsert_data)
         else:
             insert_strategy.fail()
 
index 19e557fd91e96f24a4fe0cb24cc470226b5ce5bf..e45d92659b1c9e55fb5f23823748202632f604ca 100644 (file)
@@ -1,4 +1,3 @@
-from sqlalchemy import bindparam
 from sqlalchemy import Boolean
 from sqlalchemy import case
 from sqlalchemy import column
@@ -810,20 +809,6 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
 
-    def test_update_multirow_not_supported(self):
-        User = self.classes.User
-
-        sess = fixture_session()
-
-        with expect_raises_message(
-            exc.InvalidRequestError,
-            "WHERE clause with bulk ORM UPDATE not supported " "right now.",
-        ):
-            sess.execute(
-                update(User).where(User.id == bindparam("id")),
-                [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
-            )
-
     def test_delete_bulk_not_supported(self):
         User = self.classes.User
 
index 61777def54f0cc44ac6bb82d0eeeae1c266c6e9f..615995c731238015326e6cd0b05714cae12e9647 100644 (file)
@@ -1,5 +1,6 @@
 from itertools import zip_longest
 
+from sqlalchemy import bindparam
 from sqlalchemy import Column
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
@@ -7,6 +8,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
 from sqlalchemy.sql import base as sql_base
 from sqlalchemy.sql import coercions
 from sqlalchemy.sql import column
@@ -18,6 +20,8 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_none
 
 
 class MiscTest(fixtures.TestBase):
@@ -41,6 +45,28 @@ class MiscTest(fixtures.TestBase):
 
         eq_(set(sql_util.find_tables(subset_select)), {common})
 
+    @testing.variation("has_cache_key", [True, False])
+    def test_get_embedded_bindparams(self, has_cache_key):
+        bp = bindparam("x")
+
+        if not has_cache_key:
+
+            class NotCacheable(TypeDecorator):
+                impl = String
+                cache_ok = False
+
+            stmt = select(column("q", NotCacheable())).where(column("y") == bp)
+
+        else:
+            stmt = select(column("q")).where(column("y") == bp)
+
+        eq_(stmt._get_embedded_bindparams(), [bp])
+
+        if not has_cache_key:
+            is_(stmt._generate_cache_key(), None)
+        else:
+            is_not_none(stmt._generate_cache_key())
+
     def test_find_tables_aliases(self):
         metadata = MetaData()
         common = Table(