]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix with_loader_criteria and LoaderCriteriaOption
authorTomasz Nowacki <t.nowacki87@gmail.com>
Mon, 4 Mar 2024 14:10:33 +0000 (15:10 +0100)
committerTomasz Nowacki <t.nowacki87@gmail.com>
Mon, 4 Mar 2024 14:10:33 +0000 (15:10 +0100)
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/_typing.py
test/typing/plain_files/orm/orm_querying.py
test/typing/plain_files/sql/common_sql_element.py

index 3a7f826e1d19c4a5e1ecdbe401fd7bd73f95a643..b3c67a76a4c2a83484007ff090860259a9ffb7ad 100644 (file)
@@ -725,7 +725,10 @@ def composite(
 
 def with_loader_criteria(
     entity_or_base: _EntityType[Any],
-    where_criteria: _ColumnExpressionArgument[bool],
+    where_criteria: Union[
+        _ColumnExpressionArgument[bool],
+        Callable[[Any], _ColumnExpressionArgument[bool]],
+    ],
     loader_only: bool = False,
     include_aliases: bool = False,
     propagate_to_loaders: bool = True,
index 4309cb119e27afa1d3fbe2cbb21a96b427d023b4..55c7085f0996eb16db6bc2aaba294c7970e2d682 100644 (file)
@@ -1382,7 +1382,10 @@ class LoaderCriteriaOption(CriteriaOption):
     def __init__(
         self,
         entity_or_base: _EntityType[Any],
-        where_criteria: _ColumnExpressionArgument[bool],
+        where_criteria: Union[
+            _ColumnExpressionArgument[bool],
+            Callable[[Any], _ColumnExpressionArgument[bool]],
+        ],
         loader_only: bool = False,
         include_aliases: bool = False,
         propagate_to_loaders: bool = True,
index dc05eec25db264edc9bbebc463b63c2e57ea82b3..689ed19a9f8ce55ace6516b666c139a563c19c94 100644 (file)
@@ -176,7 +176,6 @@ _ColumnExpressionArgument = Union[
     "SQLCoreOperations[_T]",
     roles.ExpressionElementRole[_T],
     Callable[[], "ColumnElement[_T]"],
-    Callable[[Any], "ColumnElement[_T]"],
     "LambdaElement",
 ]
 "See docs in public alias ColumnExpressionArgument."
index fa59baad43a9185d3db3aea550834f0ac3801d79..da78bc72ae1000b85667300eb45f3e8ca9032d88 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from sqlalchemy import ColumnElement
 from sqlalchemy import ForeignKey
 from sqlalchemy import orm
 from sqlalchemy import select
@@ -124,3 +125,13 @@ def load_options_error() -> None:
         # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
         orm.undefer(B.a).undefer("bar"),
     )
+
+
+# test 10959
+def test_10959_with_loader_criteria() -> None:
+    def where_criteria(cls_: type[A]) -> ColumnElement[bool]:
+        return cls_.data == "some data"
+
+    orm.with_loader_criteria(A, lambda cls: cls.public == "some data")
+
+    orm.with_loader_criteria(A, where_criteria)
index d817443af1e11a3479ecbafe808ea1904c39a380..730d99bc1512470178d29003d90842e9f40a4d08 100644 (file)
@@ -12,8 +12,6 @@ from __future__ import annotations
 from sqlalchemy import asc
 from sqlalchemy import Column
 from sqlalchemy import column
-from sqlalchemy import ColumnElement
-from sqlalchemy import ColumnExpressionArgument
 from sqlalchemy import desc
 from sqlalchemy import Integer
 from sqlalchemy import literal
@@ -174,14 +172,3 @@ mydict = {
     literal("5"): "q",
     column("q"): "q",
 }
-
-
-# test 10959
-def where_criteria(cls_: type[User]) -> ColumnElement[bool]:
-    return cls_.email == "test"
-
-
-column_expression: ColumnExpressionArgument[bool] = where_criteria
-column_expression_lambda: ColumnExpressionArgument[bool] = (
-    lambda cls_: cls_.email == "test"
-)