]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
[typing] Fix type error when passing Mapped columns to values()
authorYossi <54272821+Apakottur@users.noreply.github.com>
Mon, 1 Dec 2025 17:06:12 +0000 (12:06 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Dec 2025 03:37:57 +0000 (22:37 -0500)
This adjusts the _DMLOnlyColumnArgument type to be a more
focused _OnlyColumnArgument type where we also add a more tightly
focused coercion, while still allowing ORM attributes to be used
as arguments.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #13012
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13012
Pull-request-sha: 5ebb402c686abf1090e5b83e3489dfca4908efdf

Change-Id: I8bbccaf556ec5ecb2f5cfdd2030bcfa4eb5ce125

lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/_selectable_constructors.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
test/typing/plain_files/sql/dml.py

index 6fddb590b54807664da7de489c975753912ff5ef..354db4e90335b98487ff5bc04581732e5d1e5623 100644 (file)
@@ -59,7 +59,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrLiteralArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
-    from ._typing import _DMLOnlyColumnArgument
+    from ._typing import _OnlyColumnArgument
     from ._typing import _TypeEngineArgument
     from .elements import _FrameIntTuple
     from .elements import BinaryExpression
@@ -490,7 +490,7 @@ def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]:
     return coercions.expect(roles.ExpressionElementRole, clause).__invert__()
 
 
-def from_dml_column(column: _DMLOnlyColumnArgument[_T]) -> DMLTargetCopy[_T]:
+def from_dml_column(column: _OnlyColumnArgument[_T]) -> DMLTargetCopy[_T]:
     r"""A placeholder that may be used in compiled INSERT or UPDATE expressions
     to refer to the SQL expression or value being applied to another column.
 
index 129806204bb65690309a48b8a62d1853db4eb86b..3aa2de1bd45dfdfdb624b26db6edd000d790ce74 100644 (file)
@@ -36,6 +36,7 @@ from ..util.typing import Unpack
 if TYPE_CHECKING:
     from ._typing import _FromClauseArgument
     from ._typing import _OnClauseArgument
+    from ._typing import _OnlyColumnArgument
     from ._typing import _SelectStatementForCompoundArgument
     from ._typing import _T0
     from ._typing import _T1
@@ -684,14 +685,13 @@ def union_all(
 
 
 def values(
-    *columns: ColumnClause[Any],
+    *columns: _OnlyColumnArgument[Any],
     name: Optional[str] = None,
     literal_binds: bool = False,
 ) -> Values:
     r"""Construct a :class:`_expression.Values` construct representing the
     SQL ``VALUES`` clause.
 
-
     The column expressions and the actual data for :class:`_expression.Values`
     are given in two separate steps.  The constructor receives the column
     expressions typically as :func:`_expression.column` constructs, and the
index b4af798dbd92e06722a6a59db6fa5b8a2cf20fac..71ad7e13eefebee989ca021244b3c52c798b1f20 100644 (file)
@@ -187,6 +187,19 @@ _T8 = TypeVar("_T8", bound=Any)
 _T9 = TypeVar("_T9", bound=Any)
 
 
+_OnlyColumnArgument = Union[
+    "ColumnElement[_T]",
+    _HasClauseElement[_T],
+    roles.DMLColumnRole,
+]
+"""A narrow type that is looking for a ColumnClause (e.g. table column with a
+name) or an ORM element that produces this.
+
+This is used for constructs that need a named column to represent a
+position in a selectable, like TextClause().columns() or values(...).
+
+"""
+
 _ColumnExpressionArgument = Union[
     "ColumnElement[_T]",
     _HasClauseElement[_T],
@@ -274,12 +287,6 @@ the DMLColumnRole to be able to accommodate.
 
 """
 
-_DMLOnlyColumnArgument = Union[
-    _HasClauseElement[_T],
-    roles.DMLColumnRole,
-    "SQLCoreOperations[_T]",
-]
-
 
 _DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument)
 _DMLColumnKeyMapping = Mapping[_DMLKey, Any]
index 7d510564227e7b5dbf86f64c9c86d3723f78adae..f22b08a08af64c2f60f07def6999b4b043b4cdcf 100644 (file)
@@ -52,6 +52,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _DDLColumnArgument
     from ._typing import _DMLTableArgument
     from ._typing import _FromClauseArgument
+    from ._typing import _OnlyColumnArgument
     from .base import SyntaxExtension
     from .dml import _DMLTableElement
     from .elements import BindParameter
@@ -221,7 +222,7 @@ def expect(
 @overload
 def expect(
     role: Type[roles.LabeledColumnExprRole[Any]],
-    element: _ColumnExpressionArgument[_T],
+    element: _ColumnExpressionArgument[_T] | _OnlyColumnArgument[_T],
     **kw: Any,
 ) -> NamedColumn[_T]: ...
 
index 8f0d7e0a283b4f54a3b269f6a756bf7d738725c6..674560d7e190ff63f1c095781add7b6ce17e7f83 100644 (file)
@@ -91,9 +91,9 @@ if typing.TYPE_CHECKING:
     from ._typing import _ByArgument
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
-    from ._typing import _DMLOnlyColumnArgument
     from ._typing import _HasDialect
     from ._typing import _InfoType
+    from ._typing import _OnlyColumnArgument
     from ._typing import _PropagateAttrsType
     from ._typing import _TypeEngineArgument
     from .base import _EntityNamespace
@@ -2002,7 +2002,7 @@ class DMLTargetCopy(roles.InElementRole, KeyedColumnElement[_T]):
 
     """
 
-    def __init__(self, column: _DMLOnlyColumnArgument[_T]):
+    def __init__(self, column: _OnlyColumnArgument[_T]):
         self.column = coercions.expect(roles.ColumnArgumentRole, column)
         self.type = self.column.type
 
@@ -2401,7 +2401,7 @@ class AbstractTextClause(
     @util.preload_module("sqlalchemy.sql.selectable")
     def columns(
         self,
-        *cols: _ColumnExpressionArgument[Any],
+        *cols: _OnlyColumnArgument[Any],
         **types: _TypeEngineArgument[Any],
     ) -> TextualSelect:
         r"""Turn this :class:`_expression.AbstractTextClause` object into a
index 5342cd012add630bf173436b9bea4876a4e33661..8113944caaeb94da600de5f72615d7b8ed190298 100644 (file)
@@ -123,6 +123,7 @@ if TYPE_CHECKING:
     from ._typing import _MAYBE_ENTITY
     from ._typing import _NOT_ENTITY
     from ._typing import _OnClauseArgument
+    from ._typing import _OnlyColumnArgument
     from ._typing import _SelectStatementForCompoundArgument
     from ._typing import _T0
     from ._typing import _T1
@@ -3355,6 +3356,7 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause):
     __visit_name__ = "values"
 
     _data: Tuple[Sequence[Tuple[Any, ...]], ...] = ()
+    _column_args: Tuple[NamedColumn[Any], ...]
 
     _unnamed: bool
     _traverse_internals: _TraverseInternalsType = [
@@ -3368,12 +3370,15 @@ class Values(roles.InElementRole, HasCTE, Generative, LateralFromClause):
 
     def __init__(
         self,
-        *columns: ColumnClause[Any],
+        *columns: _OnlyColumnArgument[Any],
         name: Optional[str] = None,
         literal_binds: bool = False,
     ):
         super().__init__()
-        self._column_args = columns
+        self._column_args = tuple(
+            coercions.expect(roles.LabeledColumnExprRole, col)
+            for col in columns
+        )
 
         if name is None:
             self._unnamed = True
@@ -3517,7 +3522,7 @@ class ScalarValues(roles.InElementRole, GroupedElement, ColumnElement[Any]):
 
     def __init__(
         self,
-        columns: Sequence[ColumnClause[Any]],
+        columns: Sequence[NamedColumn[Any]],
         data: Tuple[Sequence[Tuple[Any, ...]], ...],
         literal_binds: bool,
     ):
index 5381a1f07f1fc47ddeee650ab390abb46f5130bc..06363e2c7231d651d3e2b7eafb3e268dba15f0eb 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import update
+from sqlalchemy import values
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
@@ -60,3 +61,20 @@ stmt5 = update(User).values({User.id: 123, User.data: "value"})
 stmt6 = user_table.update().values(
     {user_table.c.d: 123, user_table.c.data: "value"}
 )
+
+
+update_values = values(
+    User.id,
+    User.name,
+    name="update_values",
+).data([(1, "Alice"), (2, "Bob")])
+
+query = (
+    update(User)
+    .values(
+        {
+            User.name: update_values.c.name,
+        }
+    )
+    .where(User.id == update_values.c.id)
+)