]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve rowmapping key type
authorFederico Caselli <cfederico87@gmail.com>
Tue, 25 Feb 2025 22:06:55 +0000 (23:06 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 26 Feb 2025 19:32:09 +0000 (20:32 +0100)
the accepted keys are also orm attributes, column elements, functions
etc, not only columns

Change-Id: I354de9b9668bc02b8b305a3c1f065744b28f8030

lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/mapper.py
test/typing/plain_files/sql/typed_results.py

index dfe7a617888b82a154951a9629932472cae2739a..d550d8c44165be01b7e5a264da46b97c4b4b4e74 100644 (file)
@@ -51,11 +51,11 @@ from ..util.typing import TypeVarTuple
 from ..util.typing import Unpack
 
 if typing.TYPE_CHECKING:
-    from ..sql.schema import Column
+    from ..sql.elements import SQLCoreOperations
     from ..sql.type_api import _ResultProcessorType
 
-_KeyType = Union[str, "Column[Any]"]
-_KeyIndexType = Union[str, "Column[Any]", int]
+_KeyType = Union[str, "SQLCoreOperations[Any]"]
+_KeyIndexType = Union[_KeyType, int]
 
 # is overridden in cursor using _CursorKeyMapRecType
 _KeyMapRecType = Any
index d879b6dbdafc0057323069f90c008077b41bfe11..3c6821d365683ff99f329fe1ecac3d81c0d23e71 100644 (file)
@@ -3442,7 +3442,7 @@ class Mapper(
 
     def identity_key_from_row(
         self,
-        row: Optional[Union[Row[Unpack[TupleAny]], RowMapping]],
+        row: Union[Row[Unpack[TupleAny]], RowMapping],
         identity_token: Optional[Any] = None,
         adapter: Optional[ORMAdapter] = None,
     ) -> _IdentityKeyType[_O]:
@@ -3461,14 +3461,15 @@ class Mapper(
         if adapter:
             pk_cols = [adapter.columns[c] for c in pk_cols]
 
+        mapping: RowMapping
         if hasattr(row, "_mapping"):
-            mapping = row._mapping  # type: ignore
+            mapping = row._mapping
         else:
-            mapping = cast("Mapping[Any, Any]", row)
+            mapping = row  # type: ignore[assignment]
 
         return (
             self._identity_class,
-            tuple(mapping[column] for column in pk_cols),  # type: ignore
+            tuple(mapping[column] for column in pk_cols),
             identity_token,
         )
 
index 498d2d276a430b4175c4673e14f0740da02ac742..c6c0816cb985601e7b85f6e747df4a9c12236e9a 100644 (file)
@@ -8,6 +8,7 @@ from typing import Type
 from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import create_engine
+from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
@@ -117,9 +118,22 @@ def t_result_ctxmanager() -> None:
         reveal_type(r4)
 
 
-def t_core_mappings() -> None:
+def t_mappings() -> None:
     r = connection.execute(select(t_user)).mappings().one()
-    r.get(t_user.c.id)
+    r["name"]  # string
+    r.get(t_user.c.id)  # column
+
+    r2 = connection.execute(select(User)).mappings().one()
+    r2[User.id]  # orm attribute
+    r2[User.__table__.c.id]  # form clause column
+
+    m2 = User.id * 2
+    s2 = User.__table__.c.id + 2
+    fn = func.abs(User.id)
+    r3 = connection.execute(select(m2, s2, fn)).mappings().one()
+    r3[m2]  # col element
+    r3[s2]  # also col element
+    r3[fn]  # function
 
 
 def t_entity_varieties() -> None: