]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merge "Fixes: #10933 typing in ColumnExpressionArgument" into main
authorFederico Caselli <cfederico87@gmail.com>
Mon, 11 Mar 2024 22:10:55 +0000 (22:10 +0000)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 11 Mar 2024 22:13:01 +0000 (23:13 +0100)
(cherry picked from commit 716189460f69a9f44dce3af1d47eab4560def86b)

lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/lambdas.py
test/typing/plain_files/orm/orm_querying.py

index f74de91c1d9771dfe24b57f641c1ceda95a92d3d..b9f618af0d7059ee5182e7801979159c4e37dd47 100644 (file)
@@ -716,7 +716,10 @@ def composite(
 
 def with_loader_criteria(
     entity_or_base: _EntityType[Any],
-    where_criteria: _ColumnExpressionArgument[bool],
+    where_criteria: Union[
+        _ColumnExpressionArgument[bool],
+        Callable[[Any], _ColumnExpressionArgument[bool]],
+    ],
     loader_only: bool = False,
     include_aliases: bool = False,
     propagate_to_loaders: bool = True,
index d2e3f5302aeab5b1082aff050ccc86b29026f920..f8431386e4ee5571543c1fba30042c16db437e7c 100644 (file)
@@ -1378,7 +1378,10 @@ class LoaderCriteriaOption(CriteriaOption):
     def __init__(
         self,
         entity_or_base: _EntityType[Any],
-        where_criteria: _ColumnExpressionArgument[bool],
+        where_criteria: Union[
+            _ColumnExpressionArgument[bool],
+            Callable[[Any], _ColumnExpressionArgument[bool]],
+        ],
         loader_only: bool = False,
         include_aliases: bool = False,
         propagate_to_loaders: bool = True,
index 726fa2411f86b779badec9d0c1073ecb7de54df2..7a6b7b8f776a34e2448c27d6c9dbd2df16563e5f 100644 (file)
@@ -437,7 +437,7 @@ class DeferredLambdaElement(LambdaElement):
 
     def __init__(
         self,
-        fn: _LambdaType,
+        fn: _AnyLambdaType,
         role: Type[roles.SQLRole],
         opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions,
         lambda_args: Tuple[Any, ...] = (),
index fa59baad43a9185d3db3aea550834f0ac3801d79..3251147dd6876e04f4263762af2902512a5b4991 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+from sqlalchemy import ColumnElement
 from sqlalchemy import ForeignKey
 from sqlalchemy import orm
 from sqlalchemy import select
@@ -124,3 +125,12 @@ def load_options_error() -> None:
         # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
         orm.undefer(B.a).undefer("bar"),
     )
+
+
+# test 10959
+def test_10959_with_loader_criteria() -> None:
+    def where_criteria(cls_: type[A]) -> ColumnElement[bool]:
+        return cls_.data == "some data"
+
+    orm.with_loader_criteria(A, lambda cls: cls.data == "some data")
+    orm.with_loader_criteria(A, where_criteria)