From 89b81ec8c45fae34214657cf46bbc9df158a676a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 3 Apr 2025 10:36:28 -0400 Subject: [PATCH] add CRUD column marker Added new Core feature :func:`_sql.from_dml_column` that may be used in expressions inside of :meth:`.UpdateBase.values` for INSERT or UPDATE; this construct will copy whatever SQL expression is used for the given target column in the statement to be used with additional columns. The construct is mostly intended to be a helper with ORM :class:`.hybrid_property` within DML hooks. This is the Core side of the feature being added to the ORM for #12496 Change-Id: Ic568638a8ce3607deea44af988b6451b30cde36c --- doc/build/changelog/unreleased_21/12496.rst | 11 ++ doc/build/core/sqlelement.rst | 3 + lib/sqlalchemy/__init__.py | 1 + lib/sqlalchemy/sql/__init__.py | 1 + lib/sqlalchemy/sql/_elements_constructors.py | 37 +++++ lib/sqlalchemy/sql/_typing.py | 7 + lib/sqlalchemy/sql/compiler.py | 10 ++ lib/sqlalchemy/sql/crud.py | 68 ++++++++ lib/sqlalchemy/sql/elements.py | 24 +++ lib/sqlalchemy/sql/expression.py | 1 + test/sql/test_compare.py | 5 + test/sql/test_insert.py | 145 +++++++++++++++++ test/sql/test_update.py | 157 +++++++++++++++++++ 13 files changed, 470 insertions(+) create mode 100644 doc/build/changelog/unreleased_21/12496.rst diff --git a/doc/build/changelog/unreleased_21/12496.rst b/doc/build/changelog/unreleased_21/12496.rst new file mode 100644 index 0000000000..77d8ffb7d3 --- /dev/null +++ b/doc/build/changelog/unreleased_21/12496.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: feature, sql + :tickets: 12496 + + Added new Core feature :func:`_sql.from_dml_column` that may be used in + expressions inside of :meth:`.UpdateBase.values` for INSERT or UPDATE; this + construct will copy whatever SQL expression is used for the given target + column in the statement to be used with additional columns. The construct + is mostly intended to be a helper with ORM :class:`.hybrid_property` within + DML hooks. + diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index 9481bf5d9f..8d3d65dda5 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -43,6 +43,8 @@ used when building up SQLAlchemy Expression Language constructs. .. autofunction:: false +.. autofunction:: from_dml_column + .. autodata:: func .. autofunction:: lambda_stmt @@ -174,6 +176,7 @@ The classes here are generated using the constructors listed at :special-members: :inherited-members: +.. autoclass:: DMLTargetCopy .. autoclass:: Extract :members: diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index be099c29b3..5e0fb283d5 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -124,6 +124,7 @@ from .sql.expression import Extract as Extract from .sql.expression import extract as extract from .sql.expression import false as false from .sql.expression import False_ as False_ +from .sql.expression import from_dml_column as from_dml_column from .sql.expression import FromClause as FromClause from .sql.expression import FromGrouping as FromGrouping from .sql.expression import func as func diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 4ac8f343d5..a3aa65c2b4 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -47,6 +47,7 @@ from .expression import exists as exists from .expression import extract as extract from .expression import false as false from .expression import False_ as False_ +from .expression import from_dml_column as from_dml_column from .expression import FromClause as FromClause from .expression import func as func from .expression import funcfilter as funcfilter diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index b5f3c74515..abb5b14b4c 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -31,6 +31,7 @@ from .elements import CollationClause from .elements import CollectionAggregate from .elements import ColumnClause from .elements import ColumnElement +from .elements import DMLTargetCopy from .elements import Extract from .elements import False_ from .elements import FunctionFilter @@ -52,6 +53,7 @@ if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _DMLOnlyColumnArgument from ._typing import _TypeEngineArgument from .elements import BinaryExpression from .selectable import FromClause @@ -459,6 +461,41 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: return coercions.expect(roles.ExpressionElementRole, clause).__invert__() +def from_dml_column(column: _DMLOnlyColumnArgument[_T]) -> DMLTargetCopy[_T]: + r"""A placeholder that may be used in compiled INSERT or UPDATE expressions + to refer to the SQL expression or value being applied to another column. + + Given a table such as:: + + t = Table( + "t", + MetaData(), + Column("x", Integer), + Column("y", Integer), + ) + + The :func:`_sql.from_dml_column` construct allows automatic copying + of an expression assigned to a different column to be re-used:: + + >>> stmt = t.insert().values(x=func.foobar(3), y=from_dml_column(t.c.x) + 5) + >>> print(stmt) + INSERT INTO t (x, y) VALUES (foobar(:foobar_1), (foobar(:foobar_1) + :param_1)) + + The :func:`_sql.from_dml_column` construct is intended to be useful primarily + with event-based hooks such as those used by ORM hybrids. + + .. seealso:: + + :ref:`hybrid_bulk_update` + + .. versionadded:: 2.1 + + + """ # noqa: E501 + + return DMLTargetCopy(column) + + def bindparam( key: Optional[str], value: Any = _NoArg.NO_ARG, diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 14769dde17..71f54a63f1 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -274,6 +274,13 @@ the DMLColumnRole to be able to accommodate. """ +_DMLOnlyColumnArgument = Union[ + _HasClauseElement[_T], + roles.DMLColumnRole, + "SQLCoreOperations[_T]", +] + + _DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument) _DMLColumnKeyMapping = Mapping[_DMLKey, Any] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a46fcca2d9..043e9ed238 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3640,6 +3640,16 @@ class SQLCompiler(Compiled): % self.dialect.name ) + def visit_dmltargetcopy(self, element, *, bindmarkers=None, **kw): + if bindmarkers is None: + raise exc.CompileError( + "DML target objects may only be used with " + "compiled INSERT or UPDATE statements" + ) + + bindmarkers[element.column.key] = element + return f"__BINDMARKER_~~{element.column.key}~~" + def visit_bindparam( self, bindparam, diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index e75a3ea1c9..51bede81fd 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -14,6 +14,7 @@ from __future__ import annotations import functools import operator +import re from typing import Any from typing import Callable from typing import cast @@ -52,6 +53,7 @@ if TYPE_CHECKING: from .dml import DMLState from .dml import ValuesBase from .elements import ColumnElement + from .elements import DMLTargetCopy from .elements import KeyedColumnElement from .schema import _SQLExprDefault from .schema import Column @@ -167,6 +169,9 @@ def _get_crud_params( "accumulate_bind_names" not in kw ), "Don't know how to handle insert within insert without a CTE" + bindmarkers: MutableMapping[ColumnElement[Any], DMLTargetCopy[Any]] = {} + kw["bindmarkers"] = bindmarkers + # getters - these are normally just column.key, # but in the case of mysql multi-table update, the rules for # .key must conditionally take tablename into account @@ -397,6 +402,26 @@ def _get_crud_params( cast("Callable[..., str]", _column_as_key), kw, ) + + if bindmarkers: + _replace_bindmarkers( + compiler, + _column_as_key, + bindmarkers, + compile_state, + values, + kw, + ) + for m_v in multi_extended_values: + _replace_bindmarkers( + compiler, + _column_as_key, + bindmarkers, + compile_state, + m_v, + kw, + ) + return _CrudParams(values, multi_extended_values) elif ( not values @@ -417,6 +442,10 @@ def _get_crud_params( ] is_default_metavalue_only = True + if bindmarkers: + _replace_bindmarkers( + compiler, _column_as_key, bindmarkers, compile_state, values, kw + ) return _CrudParams( values, [], @@ -426,6 +455,45 @@ def _get_crud_params( ) +def _replace_bindmarkers( + compiler, _column_as_key, bindmarkers, compile_state, values, kw +): + _expr_by_col_key = { + _column_as_key(col): compiled_str for col, _, compiled_str, _ in values + } + + def replace_marker(m): + try: + return _expr_by_col_key[m.group(1)] + except KeyError as ke: + if dml.isupdate(compile_state): + return compiler.process(bindmarkers[m.group(1)].column, **kw) + else: + raise exc.CompileError( + f"Can't resolve referenced column name in " + f"INSERT statement: {m.group(1)!r}" + ) from ke + + values[:] = [ + ( + col, + col_value, + re.sub( + r"__BINDMARKER_~~(.+?)~~", + replace_marker, + compiled_str, + ), + accumulated_bind_names, + ) + for ( + col, + col_value, + compiled_str, + accumulated_bind_names, + ) in values + ] + + @overload def _create_bind_param( compiler: SQLCompiler, diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 84f813be5f..57daa5a5db 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -87,6 +87,7 @@ if typing.TYPE_CHECKING: from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _DMLOnlyColumnArgument from ._typing import _HasDialect from ._typing import _InfoType from ._typing import _PropagateAttrsType @@ -1950,6 +1951,29 @@ class WrapsColumnExpression(ColumnElement[_T]): return super()._proxy_key +class DMLTargetCopy(roles.InElementRole, KeyedColumnElement[_T]): + """Refer to another column's VALUES or SET expression in an INSERT or + UPDATE statement. + + See the public-facing :func:`_sql.from_dml_column` constructor for + background. + + .. versionadded:: 2.1 + + + """ + + def __init__(self, column: _DMLOnlyColumnArgument[_T]): + self.column = coercions.expect(roles.ColumnArgumentRole, column) + self.type = self.column.type + + __visit_name__ = "dmltargetcopy" + + _traverse_internals: _TraverseInternalsType = [ + ("column", InternalTraversal.dp_clauseelement), + ] + + class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): r"""Represent a "bound expression". diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index dc7dee13b1..5abb4e3ec5 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -28,6 +28,7 @@ from ._elements_constructors import desc as desc from ._elements_constructors import distinct as distinct from ._elements_constructors import extract as extract from ._elements_constructors import false as false +from ._elements_constructors import from_dml_column as from_dml_column from ._elements_constructors import funcfilter as funcfilter from ._elements_constructors import label as label from ._elements_constructors import not_ as not_ diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 9c9bde1dac..68741fa2c1 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -53,6 +53,7 @@ from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import CollationClause +from sqlalchemy.sql.elements import DMLTargetCopy from sqlalchemy.sql.elements import DQLDMLClauseElement from sqlalchemy.sql.elements import ElementList from sqlalchemy.sql.elements import Immutable @@ -367,6 +368,10 @@ class CoreFixtures: bindparam("x", type_=String), bindparam(None), ), + lambda: ( + DMLTargetCopy(table_a.c.a), + DMLTargetCopy(table_a.c.b), + ), lambda: (_OffsetLimitParam("x"), _OffsetLimitParam("y")), lambda: (func.foo(), func.foo(5), func.bar()), lambda: ( diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index a5cfad5b69..c4f15657a6 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -3,9 +3,12 @@ from __future__ import annotations from typing import Tuple from sqlalchemy import bindparam +from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import DateTime from sqlalchemy import exc +from sqlalchemy import from_dml_column from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import Integer @@ -66,6 +69,15 @@ class _InsertTestBase: Column("z", Integer, default=lambda: 10), ) + Table( + "mytable_w_sql_default", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30)), + Column("description", String(30)), + Column("created_at", DateTime, default=func.now()), + ) + class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default_enhanced" @@ -1182,6 +1194,139 @@ class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): ) +class FromDMLInsertTest( + _InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL +): + __dialect__ = "default_enhanced" + + def test_from_bound_col_value(self): + mytable = self.tables.mytable + + # from_dml_column() refers to another column in SET, then the + # same parameter is rendered + stmt = mytable.insert().values( + name="some name", description=from_dml_column(mytable.c.name) + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable (name, description) VALUES (:name, :name)", + checkparams={"name": "some name"}, + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable (name, description) VALUES (?, ?)", + checkpositional=("some name", "some name"), + dialect="sqlite", + ) + + def test_from_static_col_value(self): + mytable = self.tables.mytable + + # from_dml_column() refers to a column not in SET, then it + # raises for INSERT + stmt = mytable.insert().values( + description=from_dml_column(mytable.c.name) + ) + + with expect_raises_message( + exc.CompileError, + "Can't resolve referenced column name in INSERT statement: 'name'", + ): + stmt.compile() + + def test_from_sql_default(self): + """test combinations with a column that has a SQL default""" + + mytable = self.tables.mytable_w_sql_default + stmt = mytable.insert().values( + description=from_dml_column(mytable.c.created_at) + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable_w_sql_default (description, created_at) " + "VALUES (now(), now())", + ) + + stmt = mytable.insert().values( + description=cast(from_dml_column(mytable.c.created_at), String) + + " o clock" + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable_w_sql_default (description, created_at) " + "VALUES ((CAST(now() AS VARCHAR) || :param_1), now())", + ) + + stmt = mytable.insert().values( + name="some name", + description=cast(from_dml_column(mytable.c.created_at), String) + + " " + + from_dml_column(mytable.c.name), + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable_w_sql_default " + "(name, description, created_at) VALUES " + "(:name, (CAST(now() AS VARCHAR) || :param_1 || :name), now())", + checkparams={"name": "some name", "param_1": " "}, + ) + self.assert_compile( + stmt, + "INSERT INTO mytable_w_sql_default " + "(name, description, created_at) VALUES " + "(?, (CAST(CURRENT_TIMESTAMP AS VARCHAR) || ? || ?), " + "CURRENT_TIMESTAMP)", + checkpositional=("some name", " ", "some name"), + dialect="sqlite", + ) + + def test_from_sql_expr(self): + mytable = self.tables.mytable + stmt = mytable.insert().values( + name=mytable.c.name + "lala", + description=from_dml_column(mytable.c.name), + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable (name, description) VALUES " + "((mytable.name || :name_1), (mytable.name || :name_1))", + checkparams={"name_1": "lala"}, + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable (name, description) VALUES " + "((mytable.name || ?), (mytable.name || ?))", + checkpositional=("lala", "lala"), + dialect="sqlite", + ) + + def test_from_sql_expr_multiple_dmlcol(self): + mytable = self.tables.mytable + stmt = mytable.insert().values( + myid=5, + name=mytable.c.name + "lala", + description=from_dml_column(mytable.c.name) + + " " + + cast(from_dml_column(mytable.c.myid), String), + ) + + self.assert_compile( + stmt, + "INSERT INTO mytable (myid, name, description) VALUES " + "(:myid, (mytable.name || :name_1), " + "((mytable.name || :name_1) || :param_1 || " + "CAST(:myid AS VARCHAR)))", + checkparams={"myid": 5, "name_1": "lala", "param_1": " "}, + ) + + class InsertImplicitReturningTest( _InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL ): diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 9a533040e9..5a6133e41c 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -2,11 +2,13 @@ import itertools import random from sqlalchemy import bindparam +from sqlalchemy import cast from sqlalchemy import column from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exists from sqlalchemy import ForeignKey +from sqlalchemy import from_dml_column from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal @@ -1018,6 +1020,161 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): paramstyle.fail() +class FromDMLColumnTest( + _UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL +): + """test the from_dml_column() feature added as part of #12496""" + + __dialect__ = "default_enhanced" + + def test_from_bound_col_value(self): + mytable = self.tables.mytable + + # from_dml_column() refers to another column in SET, then the + # same parameter is rendered + stmt = mytable.update().values( + name="some name", description=from_dml_column(mytable.c.name) + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET name=:name, description=:name", + checkparams={"name": "some name"}, + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET name=?, description=?", + checkpositional=("some name", "some name"), + dialect="sqlite", + ) + + def test_from_static_col_value(self): + mytable = self.tables.mytable + + # from_dml_column() refers to a column not in SET, then the + # column is rendered + stmt = mytable.update().values( + description=from_dml_column(mytable.c.name) + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET description=mytable.name", + checkparams={}, + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET description=mytable.name", + checkpositional=(), + dialect="sqlite", + ) + + def test_from_sql_onupdate(self): + """test combinations with a column that has a SQL onupdate""" + + mytable = self.tables.mytable_with_onupdate + stmt = mytable.update().values( + description=from_dml_column(mytable.c.updated_at) + ) + + self.assert_compile( + stmt, + "UPDATE mytable_with_onupdate SET description=now(), " + "updated_at=now()", + ) + + stmt = mytable.update().values( + description=cast(from_dml_column(mytable.c.updated_at), String) + + " o clock" + ) + + self.assert_compile( + stmt, + "UPDATE mytable_with_onupdate SET " + "description=(CAST(now() AS VARCHAR) || :param_1), " + "updated_at=now()", + ) + + stmt = mytable.update().values( + description=cast(from_dml_column(mytable.c.updated_at), String) + + " " + + from_dml_column(mytable.c.name) + ) + + self.assert_compile( + stmt, + "UPDATE mytable_with_onupdate SET " + "description=(CAST(now() AS VARCHAR) || :param_1 || " + "mytable_with_onupdate.name), updated_at=now()", + ) + + stmt = mytable.update().values( + name="some name", + description=cast(from_dml_column(mytable.c.updated_at), String) + + " " + + from_dml_column(mytable.c.name), + ) + + self.assert_compile( + stmt, + "UPDATE mytable_with_onupdate SET " + "name=:name, " + "description=(CAST(now() AS VARCHAR) || :param_1 || " + ":name), updated_at=now()", + checkparams={"name": "some name", "param_1": " "}, + ) + self.assert_compile( + stmt, + "UPDATE mytable_with_onupdate SET " + "name=?, " + "description=(CAST(CURRENT_TIMESTAMP AS VARCHAR) || ? || " + "?), updated_at=CURRENT_TIMESTAMP", + checkpositional=("some name", " ", "some name"), + dialect="sqlite", + ) + + def test_from_sql_expr(self): + mytable = self.tables.mytable + stmt = mytable.update().values( + name=mytable.c.name + "lala", + description=from_dml_column(mytable.c.name), + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET name=(mytable.name || :name_1), " + "description=(mytable.name || :name_1)", + checkparams={"name_1": "lala"}, + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET name=(mytable.name || ?), " + "description=(mytable.name || ?)", + checkpositional=("lala", "lala"), + dialect="sqlite", + ) + + def test_from_sql_expr_multiple_dmlcol(self): + mytable = self.tables.mytable + stmt = mytable.update().values( + name=mytable.c.name + "lala", + description=from_dml_column(mytable.c.name) + + " " + + cast(from_dml_column(mytable.c.myid), String), + ) + + self.assert_compile( + stmt, + "UPDATE mytable SET name=(mytable.name || :name_1), " + "description=((mytable.name || :name_1) || :param_1 || " + "CAST(mytable.myid AS VARCHAR))", + checkparams={"name_1": "lala", "param_1": " "}, + ) + + class UpdateFromCompileTest( _UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL ): -- 2.47.3