]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix `RowMapping`'s `Mapping` type to reflect that it supports `Column`s or strings
authorAndy Freeland <andy@andyfreeland.net>
Wed, 12 Apr 2023 14:58:39 +0000 (07:58 -0700)
committerAndy Freeland <andy@andyfreeland.net>
Thu, 20 Apr 2023 16:33:46 +0000 (09:33 -0700)
I ran into this originally in sqlalchemy2-stubs:
https://github.com/sqlalchemy/sqlalchemy2-stubs/pull/251, where
`RowMapping` only supported string keys according to the type hints. I
ran into a similar issue here upgrading our application where because
`RowMapping` subclassed `Mapping[str, Any]`, `Row._mapping.get()` would
fail to typecheck when used with `Column` objects.

This patch adds a test to verify that `Row._mapping.get()` continues to
work with both strings and `Column`s, though it doesn't look like mypy
checks types in the tests.

Fixes #9644.

lib/sqlalchemy/engine/row.py
test/sql/test_resultset.py

index e2690ac2d5a3b6c7159e88bbdd8578626d8fbd83..e15ea7b176763e9827c012e85c28720a3575c096 100644 (file)
@@ -271,9 +271,11 @@ class ROMappingView(ABC):
     __slots__ = ()
 
     _items: Sequence[Any]
-    _mapping: Mapping[str, Any]
+    _mapping: Mapping["_KeyType", Any]
 
-    def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]):
+    def __init__(
+        self, mapping: Mapping["_KeyType", Any], items: Sequence[Any]
+    ):
         self._mapping = mapping
         self._items = items
 
@@ -297,16 +299,16 @@ class ROMappingView(ABC):
 
 
 class ROMappingKeysValuesView(
-    ROMappingView, typing.KeysView[str], typing.ValuesView[Any]
+    ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any]
 ):
     __slots__ = ("_items",)
 
 
-class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]):
+class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]):
     __slots__ = ("_items",)
 
 
-class RowMapping(BaseRow, typing.Mapping[str, Any]):
+class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
     """A ``Mapping`` that maps column names and objects to :class:`.Row`
     values.
 
index 0537dc22819961a52d10d0889686e4433c053c10..e382a7fb66104bb801de13149af7cb4db8571fa8 100644 (file)
@@ -1595,6 +1595,15 @@ class CursorResultTest(fixtures.TablesTest):
         r = connection.exec_driver_sql("select user_name from users").first()
         eq_(len(r), 1)
 
+    def test_row_mapping_get(self, connection):
+        users = self.tables.users
+
+        connection.execute(users.insert(), dict(user_id=1, user_name="foo"))
+        result = connection.execute(users.select())
+        row = result.first()
+        eq_(row._mapping.get("user_id"), 1)
+        eq_(row._mapping.get(users.c.user_id), 1)
+
     def test_sorting_in_python(self, connection):
         users = self.tables.users