From: Mark Elliot <123787712+mark-thm@users.noreply.github.com> Date: Sun, 28 Apr 2024 00:32:32 +0000 (-0400) Subject: Fixes: #11328: Add overload for ColumnCollection.get(col, default) X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=039f996decf59343ed6ed227379bee6897e1cc3f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixes: #11328: Add overload for ColumnCollection.get(col, default) --- diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 923e849589..d84ee50b8c 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1641,6 +1641,9 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): def __eq__(self, other: Any) -> bool: return self.compare(other) + @overload + def get(self, key: str, default: _COL_co) -> _COL_co: ... + def get( self, key: str, default: Optional[_COL_co] = None ) -> Optional[_COL_co]: diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 74cf1eb4f2..b45a972af9 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -14,6 +14,7 @@ from sqlalchemy.sql import coercions from sqlalchemy.sql import column from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import roles +from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -174,3 +175,15 @@ class MiscTest(fixtures.TestBase): for a, b in zip_longest(unwrapped, expected): assert a is not None and a.compare(b) + + def test_column_collection_get(self): + col_id = column("id", Integer) + col_alt = column("alt", Integer) + table1 = table( + "mytable", + col_id, + ) + + assert table1.columns.get("id") == col_id + assert table1.columns.get("alt") is None + assert table1.columns.get("alt", col_alt) == col_alt