From: Yossi <54272821+Apakottur@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:06:12 +0000 (-0500) Subject: [typing] Fix type error when passing Mapped columns to values() X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=40c2400af7d44a528358ea1d73c275a85bb75616;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git [typing] Fix type error when passing Mapped columns to values() This adjusts the _DMLOnlyColumnArgument type to be a more focused _OnlyColumnArgument type where we also add a more tightly focused coercion, while still allowing ORM attributes to be used as arguments. Co-authored-by: Mike Bayer Closes: #13012 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13012 Pull-request-sha: 5ebb402c686abf1090e5b83e3489dfca4908efdf Change-Id: I8bbccaf556ec5ecb2f5cfdd2030bcfa4eb5ce125 --- diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 6fddb590b5..354db4e903 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -59,7 +59,7 @@ if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument - from ._typing import _DMLOnlyColumnArgument + from ._typing import _OnlyColumnArgument from ._typing import _TypeEngineArgument from .elements import _FrameIntTuple from .elements import BinaryExpression @@ -490,7 +490,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: return coercions.expect(roles.ExpressionElementRole, clause).__invert__() -def from_dml_column(column: _DMLOnlyColumnArgument[_T]) -> DMLTargetCopy[_T]: +def from_dml_column(column: _OnlyColumnArgument[_T]) -> DMLTargetCopy[_T]: r"""A placeholder that may be used in compiled INSERT or UPDATE expressions to refer to the SQL expression or value being applied to another column. diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 129806204b..3aa2de1bd4 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -36,6 +36,7 @@ from ..util.typing import Unpack if TYPE_CHECKING: from ._typing import _FromClauseArgument from ._typing import _OnClauseArgument + from ._typing import _OnlyColumnArgument from ._typing import _SelectStatementForCompoundArgument from ._typing import _T0 from ._typing import _T1 @@ -684,14 +685,13 @@ def union_all( def values( - *columns: ColumnClause[Any], + *columns: _OnlyColumnArgument[Any], name: Optional[str] = None, literal_binds: bool = False, ) -> Values: r"""Construct a :class:`_expression.Values` construct representing the SQL ``VALUES`` clause. - The column expressions and the actual data for :class:`_expression.Values` are given in two separate steps. The constructor receives the column expressions typically as :func:`_expression.column` constructs, and the diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index b4af798dbd..71ad7e13ee 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -187,6 +187,19 @@ _T8 = TypeVar("_T8", bound=Any) _T9 = TypeVar("_T9", bound=Any) +_OnlyColumnArgument = Union[ + "ColumnElement[_T]", + _HasClauseElement[_T], + roles.DMLColumnRole, +] +"""A narrow type that is looking for a ColumnClause (e.g. table column with a +name) or an ORM element that produces this. + +This is used for constructs that need a named column to represent a +position in a selectable, like TextClause().columns() or values(...). + +""" + _ColumnExpressionArgument = Union[ "ColumnElement[_T]", _HasClauseElement[_T], @@ -274,12 +287,6 @@ the DMLColumnRole to be able to accommodate. """ -_DMLOnlyColumnArgument = Union[ - _HasClauseElement[_T], - roles.DMLColumnRole, - "SQLCoreOperations[_T]", -] - _DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument) _DMLColumnKeyMapping = Mapping[_DMLKey, Any] diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 7d51056422..f22b08a08a 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -52,6 +52,7 @@ if typing.TYPE_CHECKING: from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument + from ._typing import _OnlyColumnArgument from .base import SyntaxExtension from .dml import _DMLTableElement from .elements import BindParameter @@ -221,7 +222,7 @@ def expect( @overload def expect( role: Type[roles.LabeledColumnExprRole[Any]], - element: _ColumnExpressionArgument[_T], + element: _ColumnExpressionArgument[_T] | _OnlyColumnArgument[_T], **kw: Any, ) -> NamedColumn[_T]: ... diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8f0d7e0a28..674560d7e1 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -91,9 +91,9 @@ if typing.TYPE_CHECKING: from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument - from ._typing import _DMLOnlyColumnArgument from ._typing import _HasDialect from ._typing import _InfoType + from ._typing import _OnlyColumnArgument from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument from .base import _EntityNamespace @@ -2002,7 +2002,7 @@ class DMLTargetCopy(roles.InElementRole, KeyedColumnElement[_T]): """ - def __init__(self, column: _DMLOnlyColumnArgument[_T]): + def __init__(self, column: _OnlyColumnArgument[_T]): self.column = coercions.expect(roles.ColumnArgumentRole, column) self.type = self.column.type @@ -2401,7 +2401,7 @@ class AbstractTextClause( @util.preload_module("sqlalchemy.sql.selectable") def columns( self, - *cols: _ColumnExpressionArgument[Any], + *cols: _OnlyColumnArgument[Any], **types: _TypeEngineArgument[Any], ) -> TextualSelect: r"""Turn this :class:`_expression.AbstractTextClause` object into a diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 5342cd012a..8113944caa 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -123,6 +123,7 @@ if TYPE_CHECKING: from ._typing import _MAYBE_ENTITY from ._typing import _NOT_ENTITY from ._typing import _OnClauseArgument + from ._typing import _OnlyColumnArgument from ._typing import _SelectStatementForCompoundArgument from ._typing import _T0 from ._typing import _T1 @@ -3355,6 +3356,7 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause): __visit_name__ = "values" _data: Tuple[Sequence[Tuple[Any, ...]], ...] = () + _column_args: Tuple[NamedColumn[Any], ...] _unnamed: bool _traverse_internals: _TraverseInternalsType = [ @@ -3368,12 +3370,15 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause): def __init__( self, - *columns: ColumnClause[Any], + *columns: _OnlyColumnArgument[Any], name: Optional[str] = None, literal_binds: bool = False, ): super().__init__() - self._column_args = columns + self._column_args = tuple( + coercions.expect(roles.LabeledColumnExprRole, col) + for col in columns + ) if name is None: self._unnamed = True @@ -3517,7 +3522,7 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]): def __init__( self, - columns: Sequence[ColumnClause[Any]], + columns: Sequence[NamedColumn[Any]], data: Tuple[Sequence[Tuple[Any, ...]], ...], literal_binds: bool, ): diff --git a/test/typing/plain_files/sql/dml.py b/test/typing/plain_files/sql/dml.py index 5381a1f07f..06363e2c72 100644 --- a/test/typing/plain_files/sql/dml.py +++ b/test/typing/plain_files/sql/dml.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import update +from sqlalchemy import values from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -60,3 +61,20 @@ stmt5 = update(User).values({User.id: 123, User.data: "value"}) stmt6 = user_table.update().values( {user_table.c.d: 123, user_table.c.data: "value"} ) + + +update_values = values( + User.id, + User.name, + name="update_values", +).data([(1, "Alice"), (2, "Bob")]) + +query = ( + update(User) + .values( + { + User.name: update_values.c.name, + } + ) + .where(User.id == update_values.c.id) +)