From c106bc43046a2bbfd5894cba9f2789bf4c197b01 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 2 Jan 2024 13:03:40 -0500 Subject: [PATCH] allow literals for function arguments * Fixed the argument types passed to functions so that literal expressions like strings and ints are again interpreted correctly (:ticket:`10818`) this includes a reformatting of the changelog message from #10801 to read as a general "fixed regressions" list. Fixes: #10818 Change-Id: I65ad86e096241863e833608d45f0bdb6069f5896 (cherry picked from commit cc26af00e7483289cb2c2fb7c03e2d0c8fb63362) --- doc/build/changelog/unreleased_20/10801.rst | 15 ++- lib/sqlalchemy/sql/functions.py | 115 +++++++++++++----- .../typing/plain_files/sql/functions_again.py | 13 ++ tools/generate_sql_functions.py | 25 +++- 4 files changed, 129 insertions(+), 39 deletions(-) diff --git a/doc/build/changelog/unreleased_20/10801.rst b/doc/build/changelog/unreleased_20/10801.rst index a35a5485d5..a485e1babb 100644 --- a/doc/build/changelog/unreleased_20/10801.rst +++ b/doc/build/changelog/unreleased_20/10801.rst @@ -1,7 +1,14 @@ .. change:: :tags: bug, typing - :tickets: 10801 + :tickets: 10801, 10818 + + Fixed regressions caused by typing added to the ``sqlalchemy.sql.functions`` + module in version 2.0.24, as part of :ticket:`6810`: + + * Further enhancements to pep-484 typing to allow SQL functions from + :attr:`_sql.func` derived elements to work more effectively with ORM-mapped + attributes (:ticket:`10801`) + + * Fixed the argument types passed to functions so that literal expressions + like strings and ints are again interpreted correctly (:ticket:`10818`) - Further enhancements to pep-484 typing to allow SQL functions from - :attr:`_sql.func` derived elements to work more effectively with ORM-mapped - attributes. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index dfa6f9df5c..5cb5812d69 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -999,14 +999,16 @@ class _FunctionGenerator: def char_length(self) -> Type[char_length]: ... - # appease mypy which seems to not want to accept _T from - # _ColumnExpressionArgument, as it includes non-generic types + # set ColumnElement[_T] as a separate overload, to appease mypy + # which seems to not want to accept _T from _ColumnExpressionArgument. + # this is even if all non-generic types are removed from it, so + # reasons remain unclear for why this does not work @overload def coalesce( self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... @@ -1015,15 +1017,24 @@ class _FunctionGenerator: def coalesce( self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... + @overload def coalesce( self, - col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> coalesce[_T]: + ... + + def coalesce( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> coalesce[_T]: ... @@ -1080,14 +1091,16 @@ class _FunctionGenerator: def localtimestamp(self) -> Type[localtimestamp]: ... - # appease mypy which seems to not want to accept _T from - # _ColumnExpressionArgument, as it includes non-generic types + # set ColumnElement[_T] as a separate overload, to appease mypy + # which seems to not want to accept _T from _ColumnExpressionArgument. + # this is even if all non-generic types are removed from it, so + # reasons remain unclear for why this does not work @overload def max( # noqa: A001 self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> max[_T]: ... @@ -1096,27 +1109,38 @@ class _FunctionGenerator: def max( # noqa: A001 self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> max[_T]: ... + @overload def max( # noqa: A001 self, - col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> max[_T]: + ... + + def max( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> max[_T]: ... - # appease mypy which seems to not want to accept _T from - # _ColumnExpressionArgument, as it includes non-generic types + # set ColumnElement[_T] as a separate overload, to appease mypy + # which seems to not want to accept _T from _ColumnExpressionArgument. + # this is even if all non-generic types are removed from it, so + # reasons remain unclear for why this does not work @overload def min( # noqa: A001 self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> min[_T]: ... @@ -1125,15 +1149,24 @@ class _FunctionGenerator: def min( # noqa: A001 self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> min[_T]: ... + @overload def min( # noqa: A001 self, - col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> min[_T]: + ... + + def min( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> min[_T]: ... @@ -1182,14 +1215,16 @@ class _FunctionGenerator: def session_user(self) -> Type[session_user]: ... - # appease mypy which seems to not want to accept _T from - # _ColumnExpressionArgument, as it includes non-generic types + # set ColumnElement[_T] as a separate overload, to appease mypy + # which seems to not want to accept _T from _ColumnExpressionArgument. + # this is even if all non-generic types are removed from it, so + # reasons remain unclear for why this does not work @overload def sum( # noqa: A001 self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> sum[_T]: ... @@ -1198,15 +1233,24 @@ class _FunctionGenerator: def sum( # noqa: A001 self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> sum[_T]: ... + @overload def sum( # noqa: A001 self, - col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ) -> sum[_T]: + ... + + def sum( # noqa: A001 + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> sum[_T]: ... @@ -1576,14 +1620,16 @@ class ReturnTypeFromArgs(GenericFunction[_T]): inherit_cache = True - # appease mypy which seems to not want to accept _T from - # _ColumnExpressionArgument, as it includes non-generic types + # set ColumnElement[_T] as a separate overload, to appease mypy which seems + # to not want to accept _T from _ColumnExpressionArgument. this is even if + # all non-generic types are removed from it, so reasons remain unclear for + # why this does not work @overload def __init__( self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ): ... @@ -1592,12 +1638,23 @@ class ReturnTypeFromArgs(GenericFunction[_T]): def __init__( self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ): ... - def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any): + @overload + def __init__( + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, + ): + ... + + def __init__( + self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any + ): fn_args: Sequence[ColumnElement[Any]] = [ coercions.expect( roles.ExpressionElementRole, diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index 87ade92246..da656f2d1d 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -15,6 +15,7 @@ class Foo(Base): id: Mapped[int] = mapped_column(primary_key=True) a: Mapped[int] b: Mapped[int] + c: Mapped[str] func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()) @@ -41,3 +42,15 @@ stmt1 = select( ).group_by(Foo.a) # EXPECTED_TYPE: Select[Tuple[int, int]] reveal_type(stmt1) + +# test #10818 +# EXPECTED_TYPE: coalesce[str] +reveal_type(func.coalesce(Foo.c, "a", "b")) + + +stmt2 = select( + Foo.a, + func.coalesce(Foo.c, "a", "b"), +).group_by(Foo.a) +# EXPECTED_TYPE: Select[Tuple[int, str]] +reveal_type(stmt2) diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 348b334484..51422dc7e6 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -62,14 +62,16 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: textwrap.indent( f""" -# appease mypy which seems to not want to accept _T from -# _ColumnExpressionArgument, as it includes non-generic types +# set ColumnElement[_T] as a separate overload, to appease mypy +# which seems to not want to accept _T from _ColumnExpressionArgument. +# this is even if all non-generic types are removed from it, so +# reasons remain unclear for why this does not work @overload def {key}( {' # noqa: A001' if is_reserved_word else ''} self, col: ColumnElement[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: ... @@ -78,15 +80,26 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} def {key}( {' # noqa: A001' if is_reserved_word else ''} self, col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: ... + +@overload def {key}( {' # noqa: A001' if is_reserved_word else ''} self, - col: _ColumnExpressionArgument[_T], - *args: _ColumnExpressionArgument[Any], + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], + **kwargs: Any, +) -> {fn_class.__name__}[_T]: + ... + + +def {key}( {' # noqa: A001' if is_reserved_word else ''} + self, + col: _ColumnExpressionOrLiteralArgument[_T], + *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, ) -> {fn_class.__name__}[_T]: ... -- 2.47.2