]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes: #11328: Add overload for ColumnCollection.get(col, default)
authorMark Elliot <123787712+mark-thm@users.noreply.github.com>
Sun, 28 Apr 2024 00:32:32 +0000 (20:32 -0400)
committerMark Elliot <123787712+mark-thm@users.noreply.github.com>
Sun, 28 Apr 2024 00:34:56 +0000 (20:34 -0400)
lib/sqlalchemy/sql/base.py
test/sql/test_utils.py

index 923e8495899946405c8d3768d2efffc7c33e62aa..d84ee50b8c540d2c2ac8c1037939fd1f87e30b60 100644 (file)
@@ -1641,6 +1641,9 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
     def __eq__(self, other: Any) -> bool:
         return self.compare(other)
 
+    @overload
+    def get(self, key: str, default: _COL_co) -> _COL_co: ...
+
     def get(
         self, key: str, default: Optional[_COL_co] = None
     ) -> Optional[_COL_co]:
index 74cf1eb4f2e12b3bbc47779bc56dc1722cba9724..b45a972af97f7688c78aa868b227965d082e7b26 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.sql import coercions
 from sqlalchemy.sql import column
 from sqlalchemy.sql import ColumnElement
 from sqlalchemy.sql import roles
+from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -174,3 +175,15 @@ class MiscTest(fixtures.TestBase):
 
         for a, b in zip_longest(unwrapped, expected):
             assert a is not None and a.compare(b)
+
+    def test_column_collection_get(self):
+        col_id = column("id", Integer)
+        col_alt = column("alt", Integer)
+        table1 = table(
+            "mytable",
+            col_id,
+        )
+
+        assert table1.columns.get("id") == col_id
+        assert table1.columns.get("alt") is None
+        assert table1.columns.get("alt", col_alt) == col_alt