--- /dev/null
+.. change::
+ :tags: bug, orm
+ :tickets: 9583, 9595
+
+ Fixed 2.0 regression where use of :func:`_sql.bindparam()` inside of
+ :meth:`_dml.Insert.values` would fail to be interpreted correctly when
+ executing the :class:`_dml.Insert` statement using the ORM
+ :class:`_orm.Session`, due to the new ORM-enabled insert feature not
+ implementing this use case.
+
+ In addition, the bulk INSERT and UPDATE features now add these
+ capabilities:
+
+ * The requirement that extra parameters aren't passed when using ORM
+ INSERT using the "orm" dml_strategy setting is lifted.
+ * The requirement that additional WHERE criteria is not passed when using
+ ORM UPDATE using the "bulk" dml_strategy setting is lifted. Note that
+ in this case, the check for expected row count is turned off.
for table, super_mapper in mappers_to_run:
+ # find bindparams in the statement. For bulk, we don't really know if
+ # a key in the params applies to a different table since we are
+ # potentially inserting for multiple tables here; looking at the
+ # bindparam() is a lot more direct. in most cases this will
+ # use _generate_cache_key() which is memoized, although in practice
+ # the ultimate statement that's executed is probably not the same
+ # object so that memoization might not matter much.
+ extra_bp_names = (
+ [
+ b.key
+ for b in use_orm_insert_stmt._get_embedded_bindparams()
+ if b.key in mappings[0]
+ ]
+ if use_orm_insert_stmt is not None
+ else ()
+ )
+
records = (
(
None,
bulk=True,
return_defaults=bookkeeping,
render_nulls=render_nulls,
+ include_bulk_keys=extra_bp_names,
)
)
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Literal[None] = ...,
+ enable_check_rowcount: bool = True,
) -> None:
...
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Optional[dml.Update] = ...,
+ enable_check_rowcount: bool = True,
) -> _result.Result[Any]:
...
isstates: bool,
update_changed_only: bool,
use_orm_update_stmt: Optional[dml.Update] = None,
+ enable_check_rowcount: bool = True,
) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
connection = session_transaction.connection(base_mapper)
+ # find bindparams in the statement. see _bulk_insert for similar
+ # notes for the insert case
+ extra_bp_names = (
+ [
+ b.key
+ for b in use_orm_update_stmt._get_embedded_bindparams()
+ if b.key in mappings[0]
+ ]
+ if use_orm_update_stmt is not None
+ else ()
+ )
+
for table, super_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
),
bulk=True,
use_orm_update_stmt=use_orm_update_stmt,
+ include_bulk_keys=extra_bp_names,
)
persistence._emit_update_statements(
base_mapper,
records,
bookkeeping=False,
use_orm_update_stmt=use_orm_update_stmt,
+ enable_check_rowcount=enable_check_rowcount,
)
if use_orm_update_stmt is not None:
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
raise NotImplementedError()
else:
if update_options._dml_strategy == "auto":
update_options += {"_dml_strategy": "bulk"}
- elif update_options._dml_strategy == "orm":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "orm" ORM insert strategy with a '
- "separate parameter list"
- )
sync = update_options._synchronize_session
if sync is not None:
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
+ is_executemany=orm_context.is_executemany,
)
if can_use_returning is not None:
"backends where some support RETURNING and others "
"don't"
)
+ elif orm_context.is_executemany and not per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't use multiple "
+ "parameter sets in ORM mode, which this backend does not "
+ "support with RETURNING"
+ )
else:
can_use_returning = per_bind_result
else:
if insert_options._dml_strategy == "auto":
insert_options += {"_dml_strategy": "bulk"}
- elif insert_options._dml_strategy == "orm":
- raise sa_exc.InvalidRequestError(
- 'Can\'t use "orm" ORM insert strategy with a '
- "separate parameter list"
- )
if insert_options._dml_strategy != "raw":
# for ORM object loading, like ORMContext, we have to disable
result: _result.Result[Any]
if update_options._dml_strategy == "bulk":
- if statement._where_criteria:
+ enable_check_rowcount = not statement._where_criteria
+
+ assert update_options._synchronize_session != "fetch"
+
+ if (
+ statement._where_criteria
+ and update_options._synchronize_session == "evaluate"
+ ):
raise sa_exc.InvalidRequestError(
- "WHERE clause with bulk ORM UPDATE not "
- "supported right now. Statement may be invoked at the "
- "Core level using "
- "session.connection().execute(stmt, parameters)"
+ "bulk synchronize of persistent objects not supported "
+ "when using bulk update with additional WHERE "
+ "criteria right now. add synchronize_session=None "
+ "execution option to bypass synchronize of persistent "
+ "objects."
)
mapper = update_options._subject_mapper
assert mapper is not None
isstates=False,
update_changed_only=False,
use_orm_update_stmt=statement,
+ enable_check_rowcount=enable_check_rowcount,
)
return cls.orm_setup_cursor_result(
session,
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
# normal answer for "should we use RETURNING" at all.
if not normal_answer:
return False
+ if is_executemany:
+ return dialect.update_executemany_returning
+
# these workarounds are currently hypothetical for UPDATE,
# unlike DELETE where they impact MariaDB
if is_update_from:
is_multitable: bool = False,
is_update_from: bool = False,
is_delete_using: bool = False,
+ is_executemany: bool = False,
) -> bool:
# normal answer for "should we use RETURNING" at all.
def _collect_insert_commands(
table,
states_to_insert,
+ *,
bulk=False,
return_defaults=False,
render_nulls=False,
+ include_bulk_keys=(),
):
"""Identify sets of values to use in INSERT statements for a
list of states.
None
)
- if bulk and mapper._set_polymorphic_identity:
- params.setdefault(
- mapper._polymorphic_attr_key, mapper.polymorphic_identity
- )
+ if bulk:
+ if mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
+ if include_bulk_keys:
+ params.update((k, state_dict[k]) for k in include_bulk_keys)
yield (
state,
uowtransaction,
table,
states_to_update,
+ *,
bulk=False,
use_orm_update_stmt=None,
+ include_bulk_keys=(),
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
"key value on column %s" % (table, col)
)
+ if include_bulk_keys:
+ params.update((k, state_dict[k]) for k in include_bulk_keys)
+
if params or value_params:
params.update(pk_params)
yield (
mapper,
table,
update,
+ *,
bookkeeping=True,
use_orm_update_stmt=None,
+ enable_check_rowcount=True,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
c.returned_defaults,
)
rows += c.rowcount
- check_rowcount = assert_singlerow
+ check_rowcount = enable_check_rowcount and assert_singlerow
else:
if not allow_executemany:
- check_rowcount = assert_singlerow
+ check_rowcount = enable_check_rowcount and assert_singlerow
for (
state,
state_dict,
else:
multiparams = [rec[2] for rec in records]
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
+ check_rowcount = enable_check_rowcount and (
+ assert_multirow
+ or (assert_singlerow and len(multiparams) == 1)
)
c = connection.execute(
mapper,
table,
insert,
+ *,
bookkeeping=True,
use_orm_insert_stmt=None,
execution_options=None,
connection, distilled_params, execution_options
).scalar()
+ def _get_embedded_bindparams(self) -> Sequence[BindParameter[Any]]:
+ """Return the list of :class:`.BindParameter` objects embedded in the
+ object.
+
+ This accomplishes the same purpose as ``visitors.traverse()`` or
+ similar would provide, however by making use of the cache key
+ it takes advantage of memoization of the key to result in fewer
+ net method calls, assuming the statement is also going to be
+ executed.
+
+ """
+
+ key = self._generate_cache_key()
+ if key is None:
+ bindparams: List[BindParameter[Any]] = []
+
+ traverse(self, {}, {"bindparam": bindparams.append})
+ return bindparams
+
+ else:
+ return key.bindparams
+
def unique_params(
self,
__optionaldict: Optional[Dict[str, Any]] = None,
from typing import Set
import uuid
+from sqlalchemy import bindparam
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Identity
from sqlalchemy import insert
from sqlalchemy import inspect
+from sqlalchemy import Integer
from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import select
eq_(result.all(), [User(id=1, name="John", age=30)])
+ @testing.variation(
+ "use_returning", [(True, testing.requires.insert_returning), False]
+ )
+ @testing.variation("use_multiparams", [True, False])
+ @testing.variation("bindparam_in_expression", [True, False])
+ @testing.combinations(
+ "auto", "raw", "bulk", "orm", argnames="dml_strategy"
+ )
+ def test_alt_bindparam_names(
+ self,
+ use_returning,
+ decl_base,
+ use_multiparams,
+ dml_strategy,
+ bindparam_in_expression,
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ if bindparam_in_expression:
+ stmt = insert(A).values(y=literal(3) * (bindparam("q") + 15))
+ else:
+ stmt = insert(A).values(y=bindparam("q"))
+
+ if dml_strategy != "auto":
+ # it really should work with any strategy
+ stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+ if use_returning:
+ stmt = stmt.returning(A.x, A.y)
+
+ if use_multiparams:
+ if bindparam_in_expression:
+ expected_qs = [60, 69, 81]
+ else:
+ expected_qs = [5, 8, 12]
+
+ result = s.execute(
+ stmt,
+ [
+ {"q": 5, "x": 10},
+ {"q": 8, "x": 11},
+ {"q": 12, "x": 12},
+ ],
+ )
+ else:
+ if bindparam_in_expression:
+ expected_qs = [60]
+ else:
+ expected_qs = [5]
+
+ result = s.execute(stmt, {"q": 5, "x": 10})
+ if use_returning:
+ if use_multiparams:
+ eq_(
+ result.all(),
+ [
+ (10, expected_qs[0]),
+ (11, expected_qs[1]),
+ (12, expected_qs[2]),
+ ],
+ )
+ else:
+ eq_(result.first(), (10, expected_qs[0]))
+
+
+class UpdateStmtTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.variation(
+ "returning_executemany",
+ [
+ ("returning", testing.requires.update_returning),
+ "executemany",
+ "plain",
+ ],
+ )
+ @testing.variation("bindparam_in_expression", [True, False])
+ # TODO: setting "bulk" here is all over the place as well, UPDATE is not
+ # too settled
+ @testing.combinations("auto", "orm", argnames="dml_strategy")
+ @testing.combinations(
+ "evaluate", "fetch", None, argnames="synchronize_strategy"
+ )
+ def test_alt_bindparam_names(
+ self,
+ decl_base,
+ returning_executemany,
+ dml_strategy,
+ bindparam_in_expression,
+ synchronize_strategy,
+ ):
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ s.add_all(
+ [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+ )
+ s.commit()
+
+ if bindparam_in_expression:
+ stmt = (
+ update(A)
+ .values(y=literal(3) * (bindparam("q") + 15))
+ .where(A.id == bindparam("b_id"))
+ )
+ else:
+ stmt = (
+ update(A)
+ .values(y=bindparam("q"))
+ .where(A.id == bindparam("b_id"))
+ )
+
+ if dml_strategy != "auto":
+ # it really should work with any strategy
+ stmt = stmt.execution_options(dml_strategy=dml_strategy)
+
+ if returning_executemany.returning:
+ stmt = stmt.returning(A.x, A.y)
+
+ if synchronize_strategy in (None, "evaluate", "fetch"):
+ stmt = stmt.execution_options(
+ synchronize_session=synchronize_strategy
+ )
+
+ if returning_executemany.executemany:
+ if bindparam_in_expression:
+ expected_qs = [60, 69, 81]
+ else:
+ expected_qs = [5, 8, 12]
+
+ if dml_strategy != "orm":
+ params = [
+ {"id": 1, "b_id": 1, "q": 5, "x": 10},
+ {"id": 2, "b_id": 2, "q": 8, "x": 11},
+ {"id": 3, "b_id": 3, "q": 12, "x": 12},
+ ]
+ else:
+ params = [
+ {"b_id": 1, "q": 5, "x": 10},
+ {"b_id": 2, "q": 8, "x": 11},
+ {"b_id": 3, "q": 12, "x": 12},
+ ]
+
+ _expect_raises = None
+
+ if synchronize_strategy == "fetch":
+ if dml_strategy != "orm":
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ r"The 'fetch' synchronization strategy is not "
+ r"available for 'bulk' ORM updates "
+ r"\(i.e. multiple parameter sets\)",
+ )
+ elif not testing.db.dialect.update_executemany_returning:
+ # no backend supports this except Oracle
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ r"For synchronize_session='fetch', can't use multiple "
+ r"parameter sets in ORM mode, which this backend does "
+ r"not support with RETURNING",
+ )
+
+ elif synchronize_strategy == "evaluate" and dml_strategy != "orm":
+ _expect_raises = expect_raises_message(
+ exc.InvalidRequestError,
+ "bulk synchronize of persistent objects not supported",
+ )
+
+ if _expect_raises:
+ with _expect_raises:
+ result = s.execute(stmt, params)
+ return
+
+ result = s.execute(stmt, params)
+ else:
+ if bindparam_in_expression:
+ expected_qs = [60]
+ else:
+ expected_qs = [5]
+
+ result = s.execute(stmt, {"b_id": 1, "q": 5, "x": 10})
+
+ if returning_executemany.returning:
+ eq_(result.first(), (10, expected_qs[0]))
+
+ elif returning_executemany.executemany:
+ eq_(
+ s.execute(select(A.x, A.y).order_by(A.id)).all(),
+ [
+ (10, expected_qs[0]),
+ (11, expected_qs[1]),
+ (12, expected_qs[2]),
+ ],
+ )
+
+ def test_bulk_update_w_where_one(self, decl_base):
+ """test use case in #9595"""
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+
+ x: Mapped[int]
+ y: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ s = fixture_session()
+
+ s.add_all(
+ [A(id=1, x=1, y=1), A(id=2, x=2, y=2), A(id=3, x=3, y=3)],
+ )
+ s.commit()
+
+ stmt = (
+ update(A)
+ .where(A.x > 1)
+ .execution_options(synchronize_session=None)
+ )
+
+ s.execute(
+ stmt,
+ [
+ {"id": 1, "x": 3, "y": 8},
+ {"id": 2, "x": 5, "y": 9},
+ {"id": 3, "x": 12, "y": 15},
+ ],
+ )
+
+ eq_(
+ s.execute(select(A.id, A.x, A.y).order_by(A.id)).all(),
+ [(1, 1, 1), (2, 5, 9), (3, 12, 15)],
+ )
+
+ def test_bulk_update_w_where_two(self, decl_base):
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, autoincrement=False
+ )
+ name: Mapped[str]
+ age: Mapped[int]
+
+ decl_base.metadata.create_all(testing.db)
+
+ sess = fixture_session()
+ sess.execute(
+ insert(User),
+ [
+ dict(id=1, name="john", age=25),
+ dict(id=2, name="jack", age=47),
+ dict(id=3, name="jill", age=29),
+ dict(id=4, name="jane", age=37),
+ ],
+ )
+
+ sess.execute(
+ update(User)
+ .where(User.age > bindparam("gtage"))
+ .values(age=bindparam("dest_age"))
+ .execution_options(synchronize_session=None),
+ [
+ {"id": 1, "gtage": 28, "dest_age": 40},
+ {"id": 2, "gtage": 20, "dest_age": 45},
+ ],
+ )
+
+ eq_(
+ sess.execute(
+ select(User.id, User.name, User.age).order_by(User.id)
+ ).all(),
+ [
+ (1, "john", 25),
+ (2, "jack", 45),
+ (3, "jill", 29),
+ (4, "jane", 37),
+ ],
+ )
+
class BulkDMLReturningInhTest:
use_sentinel = False
eq_(coll(ids), coll(actual_ids))
- @testing.variation("insert_strategy", ["orm", "bulk", "bulk_ordered"])
+ @testing.variation(
+ "insert_strategy",
+ ["orm", "bulk", "bulk_ordered", "bulk_w_embedded_bindparam"],
+ )
@testing.requires.provisioned_upsert
def test_base_class_upsert(self, insert_strategy):
"""upsert is really tricky. if you dont have any data updated,
sort_by_parameter_order=insert_strategy.bulk_ordered
):
result = s.scalars(stmt, upsert_data)
+ elif insert_strategy.bulk_w_embedded_bindparam:
+ # test related to #9583, specific user case in
+ # https://github.com/sqlalchemy/sqlalchemy/discussions/9581#discussioncomment-5504077 # noqa: E501
+ stmt = stmt.values(
+ y=select(bindparam("qq1", type_=Integer)).scalar_subquery()
+ )
+ for d in upsert_data:
+ d["qq1"] = d.pop("y")
+ result = s.scalars(stmt, upsert_data)
else:
insert_strategy.fail()
-from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
- def test_update_multirow_not_supported(self):
- User = self.classes.User
-
- sess = fixture_session()
-
- with expect_raises_message(
- exc.InvalidRequestError,
- "WHERE clause with bulk ORM UPDATE not supported " "right now.",
- ):
- sess.execute(
- update(User).where(User.id == bindparam("id")),
- [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
- )
-
def test_delete_bulk_not_supported(self):
User = self.classes.User
from itertools import zip_longest
+from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
+from sqlalchemy import TypeDecorator
from sqlalchemy.sql import base as sql_base
from sqlalchemy.sql import coercions
from sqlalchemy.sql import column
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_none
class MiscTest(fixtures.TestBase):
eq_(set(sql_util.find_tables(subset_select)), {common})
+ @testing.variation("has_cache_key", [True, False])
+ def test_get_embedded_bindparams(self, has_cache_key):
+ bp = bindparam("x")
+
+ if not has_cache_key:
+
+ class NotCacheable(TypeDecorator):
+ impl = String
+ cache_ok = False
+
+ stmt = select(column("q", NotCacheable())).where(column("y") == bp)
+
+ else:
+ stmt = select(column("q")).where(column("y") == bp)
+
+ eq_(stmt._get_embedded_bindparams(), [bp])
+
+ if not has_cache_key:
+ is_(stmt._generate_cache_key(), None)
+ else:
+ is_not_none(stmt._generate_cache_key())
+
def test_find_tables_aliases(self):
metadata = MetaData()
common = Table(