]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix `RowMapping`'s `Mapping` type to reflect that it supports `Column`s or strings
authorAndy Freeland <andy@andyfreeland.net>
Thu, 20 Apr 2023 17:41:39 +0000 (13:41 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 20 Apr 2023 18:55:21 +0000 (20:55 +0200)
### Description

I ran into this originally in sqlalchemy2-stubs: https://github.com/sqlalchemy/sqlalchemy2-stubs/pull/251, where `RowMapping` only supported string keys according to the type hints. I ran into a similar issue here upgrading our application where because `RowMapping` subclassed `Mapping[str, Any]`, `Row._mapping.get()` would fail to typecheck when used with `Column` objects.

This patch adds a test to verify that `Row._mapping.get()` continues to work with both strings and `Column`s, though it doesn't look like mypy checks types in the tests.

Fixes #9644.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9643
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9643
Pull-request-sha: 6c33fe534cf457d6b5c73f4830a64880830f0f56

Change-Id: I1009c6defff109d73f13a9e8c51641009e6a79e2

doc/build/changelog/unreleased_20/9644.rst [new file with mode: 0644]
lib/sqlalchemy/engine/row.py
test/ext/mypy/plain_files/typed_results.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_20/9644.rst b/doc/build/changelog/unreleased_20/9644.rst
new file mode 100644 (file)
index 0000000..f40c779
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9644
+
+    Improved typing of :class:`_engine.RowMapping` to indicate that it
+    support also :class:`_schema.Column` as index objects, not only
+    string names.
+    Pull request curtesy or Andy Freeland.
index e2690ac2d5a3b6c7159e88bbdd8578626d8fbd83..e15ea7b176763e9827c012e85c28720a3575c096 100644 (file)
@@ -271,9 +271,11 @@ class ROMappingView(ABC):
     __slots__ = ()
 
     _items: Sequence[Any]
-    _mapping: Mapping[str, Any]
+    _mapping: Mapping["_KeyType", Any]
 
-    def __init__(self, mapping: Mapping[str, Any], items: Sequence[Any]):
+    def __init__(
+        self, mapping: Mapping["_KeyType", Any], items: Sequence[Any]
+    ):
         self._mapping = mapping
         self._items = items
 
@@ -297,16 +299,16 @@ class ROMappingView(ABC):
 
 
 class ROMappingKeysValuesView(
-    ROMappingView, typing.KeysView[str], typing.ValuesView[Any]
+    ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any]
 ):
     __slots__ = ("_items",)
 
 
-class ROMappingItemsView(ROMappingView, typing.ItemsView[str, Any]):
+class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]):
     __slots__ = ("_items",)
 
 
-class RowMapping(BaseRow, typing.Mapping[str, Any]):
+class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]):
     """A ``Mapping`` that maps column names and objects to :class:`.Row`
     values.
 
index 2e42bb655b024d1534dcd858a21685b209941f91..12bfcddf0cef7a1a4a85f189e514c6d035c84720 100644 (file)
@@ -8,7 +8,10 @@ from sqlalchemy import column
 from sqlalchemy import create_engine
 from sqlalchemy import insert
 from sqlalchemy import Integer
+from sqlalchemy import MetaData
 from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
 from sqlalchemy import table
 from sqlalchemy.ext.asyncio import AsyncConnection
 from sqlalchemy.ext.asyncio import AsyncSession
@@ -31,6 +34,14 @@ class User(Base):
     name: Mapped[str]
 
 
+t_user = Table(
+    "user",
+    MetaData(),
+    Column("id", Integer, primary_key=True),
+    Column("name", String),
+)
+
+
 e = create_engine("sqlite://")
 ae = create_async_engine("sqlite+aiosqlite://")
 
@@ -100,6 +111,11 @@ def t_result_ctxmanager() -> None:
         reveal_type(r4)
 
 
+def t_core_mappings() -> None:
+    r = connection.execute(select(t_user)).mappings().one()
+    r.get(t_user.c.id)
+
+
 def t_entity_varieties() -> None:
 
     a1 = aliased(User)
index 0537dc22819961a52d10d0889686e4433c053c10..e382a7fb66104bb801de13149af7cb4db8571fa8 100644 (file)
@@ -1595,6 +1595,15 @@ class CursorResultTest(fixtures.TablesTest):
         r = connection.exec_driver_sql("select user_name from users").first()
         eq_(len(r), 1)
 
+    def test_row_mapping_get(self, connection):
+        users = self.tables.users
+
+        connection.execute(users.insert(), dict(user_id=1, user_name="foo"))
+        result = connection.execute(users.select())
+        row = result.first()
+        eq_(row._mapping.get("user_id"), 1)
+        eq_(row._mapping.get(users.c.user_id), 1)
+
     def test_sorting_in_python(self, connection):
         users = self.tables.users