From 33b7f38586a054e4cdf9bcb81aef54390bf1f5a1 Mon Sep 17 00:00:00 2001 From: Tomasz Nowacki Date: Mon, 4 Mar 2024 15:10:33 +0100 Subject: [PATCH] fix with_loader_criteria and LoaderCriteriaOption --- lib/sqlalchemy/orm/_orm_constructors.py | 5 ++++- lib/sqlalchemy/orm/util.py | 5 ++++- lib/sqlalchemy/sql/_typing.py | 1 - test/typing/plain_files/orm/orm_querying.py | 11 +++++++++++ test/typing/plain_files/sql/common_sql_element.py | 13 ------------- 5 files changed, 19 insertions(+), 16 deletions(-) diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 3a7f826e1d..b3c67a76a4 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -725,7 +725,10 @@ def composite( def with_loader_criteria( entity_or_base: _EntityType[Any], - where_criteria: _ColumnExpressionArgument[bool], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4309cb119e..55c7085f09 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1382,7 +1382,10 @@ class LoaderCriteriaOption(CriteriaOption): def __init__( self, entity_or_base: _EntityType[Any], - where_criteria: _ColumnExpressionArgument[bool], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index dc05eec25d..689ed19a9f 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -176,7 +176,6 @@ _ColumnExpressionArgument = Union[ "SQLCoreOperations[_T]", roles.ExpressionElementRole[_T], Callable[[], "ColumnElement[_T]"], - Callable[[Any], "ColumnElement[_T]"], "LambdaElement", ] "See docs in public alias ColumnExpressionArgument." diff --git a/test/typing/plain_files/orm/orm_querying.py b/test/typing/plain_files/orm/orm_querying.py index fa59baad43..da78bc72ae 100644 --- a/test/typing/plain_files/orm/orm_querying.py +++ b/test/typing/plain_files/orm/orm_querying.py @@ -1,5 +1,6 @@ from __future__ import annotations +from sqlalchemy import ColumnElement from sqlalchemy import ForeignKey from sqlalchemy import orm from sqlalchemy import select @@ -124,3 +125,13 @@ def load_options_error() -> None: # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .* orm.undefer(B.a).undefer("bar"), ) + + +# test 10959 +def test_10959_with_loader_criteria() -> None: + def where_criteria(cls_: type[A]) -> ColumnElement[bool]: + return cls_.data == "some data" + + orm.with_loader_criteria(A, lambda cls: cls.public == "some data") + + orm.with_loader_criteria(A, where_criteria) diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index d817443af1..730d99bc15 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -12,8 +12,6 @@ from __future__ import annotations from sqlalchemy import asc from sqlalchemy import Column from sqlalchemy import column -from sqlalchemy import ColumnElement -from sqlalchemy import ColumnExpressionArgument from sqlalchemy import desc from sqlalchemy import Integer from sqlalchemy import literal @@ -174,14 +172,3 @@ mydict = { literal("5"): "q", column("q"): "q", } - - -# test 10959 -def where_criteria(cls_: type[User]) -> ColumnElement[bool]: - return cls_.email == "test" - - -column_expression: ColumnExpressionArgument[bool] = where_criteria -column_expression_lambda: ColumnExpressionArgument[bool] = ( - lambda cls_: cls_.email == "test" -) -- 2.47.2