]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add overload for ColumnCollection.get(col, default)
authorMark Elliot <123787712+mark-thm@users.noreply.github.com>
Mon, 29 Apr 2024 21:50:10 +0000 (17:50 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 4 May 2024 09:33:36 +0000 (11:33 +0200)
### Description
Fixes #11328 by adding an overload to ColumnCollection when a non-None default is provided.

### Checklist
This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #11329
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11329
Pull-request-sha: 32db849e0df1db357df79df3a0dc2263a755d04e

Change-Id: I8bef91c423fb7048ec8d4a7c99f70f0b1588c37a
(cherry picked from commit ab6df37dad5cccbd0328e83ed55c7cfed91344cb)

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/base.py
test/sql/test_utils.py
test/typing/plain_files/sql/misc.py [new file with mode: 0644]
test/typing/plain_files/sql/selectables.py [deleted file]

index adee44a77e182e2dc9bfadb4d87b07e5e006fc5c..5c49222be1536900b4c12019489b4af496cf2f5c 100644 (file)
@@ -689,7 +689,7 @@ class MappedColumn(
             supercls_mapper = class_mapper(decl_scan.inherits, False)
 
             colname = column.name if column.name is not None else key
-            column = self.column = supercls_mapper.local_table.c.get(  # type: ignore # noqa: E501
+            column = self.column = supercls_mapper.local_table.c.get(  # type: ignore[assignment] # noqa: E501
                 colname, column
             )
 
index 1a65b653ea2dc36a6acf5d874ed02c12a1803d8f..8ad17e2c1a4f0f9e4b481cccf4f151c2bf1c112d 100644 (file)
@@ -72,7 +72,6 @@ if TYPE_CHECKING:
     from .elements import ClauseList
     from .elements import ColumnClause  # noqa
     from .elements import ColumnElement
-    from .elements import KeyedColumnElement
     from .elements import NamedColumn
     from .elements import SQLCoreOperations
     from .elements import TextClause
@@ -1354,7 +1353,7 @@ class _SentinelColumnCharacterization(NamedTuple):
 _COLKEY = TypeVar("_COLKEY", Union[None, str], str)
 
 _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True)
-_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]")
+_COL = TypeVar("_COL", bound="ColumnElement[Any]")
 
 
 class _ColumnMetrics(Generic[_COL_co]):
@@ -1642,9 +1641,15 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]):
     def __eq__(self, other: Any) -> bool:
         return self.compare(other)
 
+    @overload
+    def get(self, key: str, default: None = None) -> Optional[_COL_co]: ...
+
+    @overload
+    def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ...
+
     def get(
-        self, key: str, default: Optional[_COL_co] = None
-    ) -> Optional[_COL_co]:
+        self, key: str, default: Optional[_COL] = None
+    ) -> Optional[Union[_COL_co, _COL]]:
         """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
         based on a string key name from this
         :class:`_expression.ColumnCollection`."""
@@ -1925,16 +1930,15 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
 
     """
 
-    def add(
-        self, column: ColumnElement[Any], key: Optional[str] = None
+    def add(  # type: ignore[override]
+        self, column: _NAMEDCOL, key: Optional[str] = None
     ) -> None:
-        named_column = cast(_NAMEDCOL, column)
-        if key is not None and named_column.key != key:
+        if key is not None and column.key != key:
             raise exc.ArgumentError(
                 "DedupeColumnCollection requires columns be under "
                 "the same key as their .key"
             )
-        key = named_column.key
+        key = column.key
 
         if key is None:
             raise exc.ArgumentError(
@@ -1944,17 +1948,17 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]):
         if key in self._index:
             existing = self._index[key][1]
 
-            if existing is named_column:
+            if existing is column:
                 return
 
-            self.replace(named_column)
+            self.replace(column)
 
             # pop out memoized proxy_set as this
             # operation may very well be occurring
             # in a _make_proxy operation
-            util.memoized_property.reset(named_column, "proxy_set")
+            util.memoized_property.reset(column, "proxy_set")
         else:
-            self._append_new_column(key, named_column)
+            self._append_new_column(key, column)
 
     def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None:
         l = len(self._collection)
index 74cf1eb4f2e12b3bbc47779bc56dc1722cba9724..b741d5d8c0b566a73a428a8903fac9622c673d80 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.sql import coercions
 from sqlalchemy.sql import column
 from sqlalchemy.sql import ColumnElement
 from sqlalchemy.sql import roles
+from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -174,3 +175,12 @@ class MiscTest(fixtures.TestBase):
 
         for a, b in zip_longest(unwrapped, expected):
             assert a is not None and a.compare(b)
+
+    def test_column_collection_get(self):
+        col_id = column("id", Integer)
+        col_alt = column("alt", Integer)
+        table1 = table("mytable", col_id)
+
+        is_(table1.columns.get("id"), col_id)
+        is_(table1.columns.get("alt"), None)
+        is_(table1.columns.get("alt", col_alt), col_alt)
diff --git a/test/typing/plain_files/sql/misc.py b/test/typing/plain_files/sql/misc.py
new file mode 100644 (file)
index 0000000..d598af0
--- /dev/null
@@ -0,0 +1,37 @@
+from typing import Any
+
+from sqlalchemy import column
+from sqlalchemy import ColumnElement
+from sqlalchemy import Integer
+from sqlalchemy import literal
+from sqlalchemy import table
+
+
+def test_col_accessors() -> None:
+    t = table("t", column("a"), column("b"), column("c"))
+
+    t.c.a
+    t.c["a"]
+
+    t.c[2]
+    t.c[0, 1]
+    t.c[0, 1, "b", "c"]
+    t.c[(0, 1, "b", "c")]
+
+    t.c[:-1]
+    t.c[0:2]
+
+
+def test_col_get() -> None:
+    col_id = column("id", Integer)
+    col_alt = column("alt", Integer)
+    tbl = table("mytable", col_id)
+
+    # EXPECTED_TYPE: Union[ColumnClause[Any], None]
+    reveal_type(tbl.c.get("id"))
+    # EXPECTED_TYPE: Union[ColumnClause[Any], None]
+    reveal_type(tbl.c.get("id", None))
+    # EXPECTED_TYPE: Union[ColumnClause[Any], ColumnClause[int]]
+    reveal_type(tbl.c.get("alt", col_alt))
+    col: ColumnElement[Any] = tbl.c.get("foo", literal("bar"))
+    print(col)
diff --git a/test/typing/plain_files/sql/selectables.py b/test/typing/plain_files/sql/selectables.py
deleted file mode 100644 (file)
index 7d31124..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-from sqlalchemy import column
-from sqlalchemy import table
-
-
-def test_col_accessors() -> None:
-    t = table("t", column("a"), column("b"), column("c"))
-
-    t.c.a
-    t.c["a"]
-
-    t.c[2]
-    t.c[0, 1]
-    t.c[0, 1, "b", "c"]
-    t.c[(0, 1, "b", "c")]
-
-    t.c[:-1]
-    t.c[0:2]