From: Mark Elliot <123787712+mark-thm@users.noreply.github.com> Date: Mon, 29 Apr 2024 21:50:10 +0000 (-0400) Subject: Add overload for ColumnCollection.get(col, default) X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ab6df37dad5cccbd0328e83ed55c7cfed91344cb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add overload for ColumnCollection.get(col, default) ### Description Fixes #11328 by adding an overload to ColumnCollection when a non-None default is provided. ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [x] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #11329 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11329 Pull-request-sha: 32db849e0df1db357df79df3a0dc2263a755d04e Change-Id: I8bef91c423fb7048ec8d4a7c99f70f0b1588c37a --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index adee44a77e..5c49222be1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -689,7 +689,7 @@ class MappedColumn( supercls_mapper = class_mapper(decl_scan.inherits, False) colname = column.name if column.name is not None else key - column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + column = self.column = supercls_mapper.local_table.c.get( # type: ignore[assignment] # noqa: E501 colname, column ) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 923e849589..96a9337f48 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -71,7 +71,6 @@ if TYPE_CHECKING: from .elements import ClauseList from .elements import ColumnClause # noqa from .elements import ColumnElement - from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -1353,7 +1352,7 @@ class _SentinelColumnCharacterization(NamedTuple): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") +_COL = TypeVar("_COL", bound="ColumnElement[Any]") class _ColumnMetrics(Generic[_COL_co]): @@ -1641,9 +1640,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): def __eq__(self, other: Any) -> bool: return self.compare(other) + @overload + def get(self, key: str, default: None = None) -> Optional[_COL_co]: ... + + @overload + def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ... + def get( - self, key: str, default: Optional[_COL_co] = None - ) -> Optional[_COL_co]: + self, key: str, default: Optional[_COL] = None + ) -> Optional[Union[_COL_co, _COL]]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1924,16 +1929,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - def add( - self, column: ColumnElement[Any], key: Optional[str] = None + def add( # type: ignore[override] + self, column: _NAMEDCOL, key: Optional[str] = None ) -> None: - named_column = cast(_NAMEDCOL, column) - if key is not None and named_column.key != key: + if key is not None and column.key != key: raise exc.ArgumentError( "DedupeColumnCollection requires columns be under " "the same key as their .key" ) - key = named_column.key + key = column.key if key is None: raise exc.ArgumentError( @@ -1943,17 +1947,17 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): if key in self._index: existing = self._index[key][1] - if existing is named_column: + if existing is column: return - self.replace(named_column) + self.replace(column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(named_column, "proxy_set") + util.memoized_property.reset(column, "proxy_set") else: - self._append_new_column(key, named_column) + self._append_new_column(key, column) def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None: l = len(self._collection) diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index 74cf1eb4f2..b741d5d8c0 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -14,6 +14,7 @@ from sqlalchemy.sql import coercions from sqlalchemy.sql import column from sqlalchemy.sql import ColumnElement from sqlalchemy.sql import roles +from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -174,3 +175,12 @@ class MiscTest(fixtures.TestBase): for a, b in zip_longest(unwrapped, expected): assert a is not None and a.compare(b) + + def test_column_collection_get(self): + col_id = column("id", Integer) + col_alt = column("alt", Integer) + table1 = table("mytable", col_id) + + is_(table1.columns.get("id"), col_id) + is_(table1.columns.get("alt"), None) + is_(table1.columns.get("alt", col_alt), col_alt) diff --git a/test/typing/plain_files/sql/misc.py b/test/typing/plain_files/sql/misc.py new file mode 100644 index 0000000000..d598af06ef --- /dev/null +++ b/test/typing/plain_files/sql/misc.py @@ -0,0 +1,37 @@ +from typing import Any + +from sqlalchemy import column +from sqlalchemy import ColumnElement +from sqlalchemy import Integer +from sqlalchemy import literal +from sqlalchemy import table + + +def test_col_accessors() -> None: + t = table("t", column("a"), column("b"), column("c")) + + t.c.a + t.c["a"] + + t.c[2] + t.c[0, 1] + t.c[0, 1, "b", "c"] + t.c[(0, 1, "b", "c")] + + t.c[:-1] + t.c[0:2] + + +def test_col_get() -> None: + col_id = column("id", Integer) + col_alt = column("alt", Integer) + tbl = table("mytable", col_id) + + # EXPECTED_TYPE: Union[ColumnClause[Any], None] + reveal_type(tbl.c.get("id")) + # EXPECTED_TYPE: Union[ColumnClause[Any], None] + reveal_type(tbl.c.get("id", None)) + # EXPECTED_TYPE: Union[ColumnClause[Any], ColumnClause[int]] + reveal_type(tbl.c.get("alt", col_alt)) + col: ColumnElement[Any] = tbl.c.get("foo", literal("bar")) + print(col) diff --git a/test/typing/plain_files/sql/selectables.py b/test/typing/plain_files/sql/selectables.py deleted file mode 100644 index 7d31124587..0000000000 --- a/test/typing/plain_files/sql/selectables.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import column -from sqlalchemy import table - - -def test_col_accessors() -> None: - t = table("t", column("a"), column("b"), column("c")) - - t.c.a - t.c["a"] - - t.c[2] - t.c[0, 1] - t.c[0, 1, "b", "c"] - t.c[(0, 1, "b", "c")] - - t.c[:-1] - t.c[0:2]