]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
[typing] Fix type error when passing Mapped columns to values()
authorYossi <54272821+Apakottur@users.noreply.github.com>
Mon, 1 Dec 2025 17:06:12 +0000 (12:06 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2025 03:46:33 +0000 (22:46 -0500)
This adjusts the _DMLOnlyColumnArgument type to be a more
focused _OnlyColumnArgument type where we also add a more tightly
focused coercion, while still allowing ORM attributes to be used
as arguments.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #13012
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13012
Pull-request-sha: 5ebb402c686abf1090e5b83e3489dfca4908efdf

Change-Id: I8bbccaf556ec5ecb2f5cfdd2030bcfa4eb5ce125
(cherry picked from commit 40c2400af7d44a528358ea1d73c275a85bb75616)

lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/typing/plain_files/sql/dml.py

index dfb5ad02aaf0901f714e25df1f377c176bd4e576..6dfed7d8abace7bea8260e220a989f93034a9d3c 100644 (file)
@@ -35,6 +35,7 @@ from .selectable import Values
 if TYPE_CHECKING:
     from ._typing import _FromClauseArgument
     from ._typing import _OnClauseArgument
+    from ._typing import _OnlyColumnArgument
     from ._typing import _SelectStatementForCompoundArgument
     from ._typing import _T0
     from ._typing import _T1
@@ -674,14 +675,13 @@ def union_all(
 
 
 def values(
-    *columns: ColumnClause[Any],
+    *columns: _OnlyColumnArgument[Any],
     name: Optional[str] = None,
     literal_binds: bool = False,
 ) -> Values:
     r"""Construct a :class:`_expression.Values` construct representing the
     SQL ``VALUES`` clause.
 
-
     The column expressions and the actual data for :class:`_expression.Values`
     are given in two separate steps.  The constructor receives the column
     expressions typically as :func:`_expression.column` constructs, and the
index 8e3c66e553f77b58d34ad612b466435cc203aacd..18639df2f5c1b0300fc282af89ee6fcf0b1fc8ab 100644 (file)
@@ -186,6 +186,19 @@ _T8 = TypeVar("_T8", bound=Any)
 _T9 = TypeVar("_T9", bound=Any)
 
 
+_OnlyColumnArgument = Union[
+    "ColumnElement[_T]",
+    _HasClauseElement[_T],
+    roles.DMLColumnRole,
+]
+"""A narrow type that is looking for a ColumnClause (e.g. table column with a
+name) or an ORM element that produces this.
+
+This is used for constructs that need a named column to represent a
+position in a selectable, like TextClause().columns() or values(...).
+
+"""
+
 _ColumnExpressionArgument = Union[
     "ColumnElement[_T]",
     _HasClauseElement[_T],
index ac0393a6056263a58e64eb495b8f41dab448156e..8ffaa7a2a6d02f0a32f63bef469758b446785d8d 100644 (file)
@@ -52,6 +52,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _DDLColumnArgument
     from ._typing import _DMLTableArgument
     from ._typing import _FromClauseArgument
+    from ._typing import _OnlyColumnArgument
     from .dml import _DMLTableElement
     from .elements import BindParameter
     from .elements import ClauseElement
@@ -212,7 +213,7 @@ def expect(
 @overload
 def expect(
     role: Type[roles.LabeledColumnExprRole[Any]],
-    element: _ColumnExpressionArgument[_T],
+    element: Union[_ColumnExpressionArgument[_T], _OnlyColumnArgument[_T]],
     **kw: Any,
 ) -> NamedColumn[_T]: ...
 
index a52c9b30a96e2d3b21a6c08ee4be8ac3fd4c4682..9ba190fcabe971d7a219f67c2760d614197a0d9e 100644 (file)
@@ -87,6 +87,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionOrStrLabelArgument
     from ._typing import _HasDialect
     from ._typing import _InfoType
+    from ._typing import _OnlyColumnArgument
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
     from .base import _EntityNamespace
@@ -2459,7 +2460,7 @@ class TextClause(
     @util.preload_module("sqlalchemy.sql.selectable")
     def columns(
         self,
-        *cols: _ColumnExpressionArgument[Any],
+        *cols: _OnlyColumnArgument[Any],
         **types: _TypeEngineArgument[Any],
     ) -> TextualSelect:
         r"""Turn this :class:`_expression.TextClause` object into a
index d3c7b8eca88a5d1a0740460083429da420279473..4540cb1b3d20e88b466d8a809c4b1ab8d48fd676 100644 (file)
@@ -116,6 +116,7 @@ if TYPE_CHECKING:
     from ._typing import _MAYBE_ENTITY
     from ._typing import _NOT_ENTITY
     from ._typing import _OnClauseArgument
+    from ._typing import _OnlyColumnArgument
     from ._typing import _SelectStatementForCompoundArgument
     from ._typing import _T0
     from ._typing import _T1
@@ -3337,6 +3338,7 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause):
     __visit_name__ = "values"
 
     _data: Tuple[Sequence[Tuple[Any, ...]], ...] = ()
+    _column_args: Tuple[NamedColumn[Any], ...]
 
     _unnamed: bool
     _traverse_internals: _TraverseInternalsType = [
@@ -3350,12 +3352,15 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause):
 
     def __init__(
         self,
-        *columns: ColumnClause[Any],
+        *columns: _OnlyColumnArgument[Any],
         name: Optional[str] = None,
         literal_binds: bool = False,
     ):
         super().__init__()
-        self._column_args = columns
+        self._column_args = tuple(
+            coercions.expect(roles.LabeledColumnExprRole, col)
+            for col in columns
+        )
 
         if name is None:
             self._unnamed = True
@@ -3499,7 +3504,7 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
 
     def __init__(
         self,
-        columns: Sequence[ColumnClause[Any]],
+        columns: Sequence[NamedColumn[Any]],
         data: Tuple[Sequence[Tuple[Any, ...]], ...],
         literal_binds: bool,
     ):
index 5381a1f07f1fc47ddeee650ab390abb46f5130bc..06363e2c7231d651d3e2b7eafb3e268dba15f0eb 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import update
+from sqlalchemy import values
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -60,3 +61,20 @@ stmt5 = update(User).values({User.id: 123, User.data: "value"})
 stmt6 = user_table.update().values(
     {user_table.c.d: 123, user_table.c.data: "value"}
 )
+
+
+update_values = values(
+    User.id,
+    User.name,
+    name="update_values",
+).data([(1, "Alice"), (2, "Bob")])
+
+query = (
+    update(User)
+    .values(
+        {
+            User.name: update_values.c.name,
+        }
+    )
+    .where(User.id == update_values.c.id)
+)