From: Yossi <54272821+Apakottur@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:06:12 +0000 (-0500) Subject: [typing] Fix type error when passing Mapped columns to values() X-Git-Tag: rel_2_0_45~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=86860f2dbb0fcb4036b222a4b28a552f2bd1e5cd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git [typing] Fix type error when passing Mapped columns to values() This adjusts the _DMLOnlyColumnArgument type to be a more focused _OnlyColumnArgument type where we also add a more tightly focused coercion, while still allowing ORM attributes to be used as arguments. Co-authored-by: Mike Bayer Closes: #13012 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13012 Pull-request-sha: 5ebb402c686abf1090e5b83e3489dfca4908efdf Change-Id: I8bbccaf556ec5ecb2f5cfdd2030bcfa4eb5ce125 (cherry picked from commit 40c2400af7d44a528358ea1d73c275a85bb75616) --- diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index dfb5ad02aa..6dfed7d8ab 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -35,6 +35,7 @@ from .selectable import Values if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument + from ._typing import _OnlyColumnArgument from ._typing import _SelectStatementForCompoundArgument from ._typing import _T0 from ._typing import _T1 @@ -674,14 +675,13 @@ def union_all( def values( - *columns: ColumnClause[Any], + *columns: _OnlyColumnArgument[Any], name: Optional[str] = None, literal_binds: bool = False, ) -> Values: r"""Construct a :class:`_expression.Values` construct representing the SQL ``VALUES`` clause. - The column expressions and the actual data for :class:`_expression.Values` are given in two separate steps. The constructor receives the column expressions typically as :func:`_expression.column` constructs, and the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 8e3c66e553..18639df2f5 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -186,6 +186,19 @@ _T8 = TypeVar("_T8", bound=Any) _T9 = TypeVar("_T9", bound=Any) +_OnlyColumnArgument = Union[ + "ColumnElement[_T]", + _HasClauseElement[_T], + roles.DMLColumnRole, +] +"""A narrow type that is looking for a ColumnClause (e.g. table column with a +name) or an ORM element that produces this. + +This is used for constructs that need a named column to represent a +position in a selectable, like TextClause().columns() or values(...). + +""" + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement[_T], diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index ac0393a605..8ffaa7a2a6 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -52,6 +52,7 @@ if typing.TYPE_CHECKING: from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument + from ._typing import _OnlyColumnArgument from .dml import _DMLTableElement from .elements import BindParameter from .elements import ClauseElement @@ -212,7 +213,7 @@ def expect( @overload def expect( role: Type[roles.LabeledColumnExprRole[Any]], - element: _ColumnExpressionArgument[_T], + element: Union[_ColumnExpressionArgument[_T], _OnlyColumnArgument[_T]], **kw: Any, ) -> NamedColumn[_T]: ... diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a52c9b30a9..9ba190fcab 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -87,6 +87,7 @@ if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _HasDialect from ._typing import _InfoType + from ._typing import _OnlyColumnArgument from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument from .base import _EntityNamespace @@ -2459,7 +2460,7 @@ class TextClause( @util.preload_module("sqlalchemy.sql.selectable") def columns( self, - *cols: _ColumnExpressionArgument[Any], + *cols: _OnlyColumnArgument[Any], **types: _TypeEngineArgument[Any], ) -> TextualSelect: r"""Turn this :class:`_expression.TextClause` object into a diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d3c7b8eca8..4540cb1b3d 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -116,6 +116,7 @@ if TYPE_CHECKING: from ._typing import _MAYBE_ENTITY from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument + from ._typing import _OnlyColumnArgument from ._typing import _SelectStatementForCompoundArgument from ._typing import _T0 from ._typing import _T1 @@ -3337,6 +3338,7 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause): __visit_name__ = "values" _data: Tuple[Sequence[Tuple[Any, ...]], ...] = () + _column_args: Tuple[NamedColumn[Any], ...] _unnamed: bool _traverse_internals: _TraverseInternalsType = [ @@ -3350,12 +3352,15 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause): def __init__( self, - *columns: ColumnClause[Any], + *columns: _OnlyColumnArgument[Any], name: Optional[str] = None, literal_binds: bool = False, ): super().__init__() - self._column_args = columns + self._column_args = tuple( + coercions.expect(roles.LabeledColumnExprRole, col) + for col in columns + ) if name is None: self._unnamed = True @@ -3499,7 +3504,7 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]): def __init__( self, - columns: Sequence[ColumnClause[Any]], + columns: Sequence[NamedColumn[Any]], data: Tuple[Sequence[Tuple[Any, ...]], ...], literal_binds: bool, ): diff --git a/test/typing/plain_files/sql/dml.py b/test/typing/plain_files/sql/dml.py index 5381a1f07f..06363e2c72 100644 --- a/test/typing/plain_files/sql/dml.py +++ b/test/typing/plain_files/sql/dml.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import update +from sqlalchemy import values from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -60,3 +61,20 @@ stmt5 = update(User).values({User.id: 123, User.data: "value"}) stmt6 = user_table.update().values( {user_table.c.d: 123, user_table.c.data: "value"} ) + + +update_values = values( + User.id, + User.name, + name="update_values", +).data([(1, "Alice"), (2, "Bob")]) + +query = ( + update(User) + .values( + { + User.name: update_values.c.name, + } + ) + .where(User.id == update_values.c.id) +)