]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
allow literals for function arguments
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Jan 2024 18:03:40 +0000 (13:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Jan 2024 00:07:34 +0000 (19:07 -0500)
* Fixed the argument types passed to functions so that literal expressions
like strings and ints are again interpreted correctly (:ticket:`10818`)

this includes a reformatting of the changelog message from #10801
to read as a general "fixed regressions" list.

Fixes: #10818
Change-Id: I65ad86e096241863e833608d45f0bdb6069f5896

doc/build/changelog/unreleased_20/10801.rst
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions_again.py
tools/generate_sql_functions.py

index a35a5485d58868adc95f626a55ab7b66551e736d..a485e1babbab656b2eec9b15eee885dc3372fa93 100644 (file)
@@ -1,7 +1,14 @@
 .. change::
     :tags: bug, typing
-    :tickets: 10801
+    :tickets: 10801, 10818
+
+    Fixed regressions caused by typing added to the ``sqlalchemy.sql.functions``
+    module in version 2.0.24, as part of :ticket:`6810`:
+
+    * Further enhancements to pep-484 typing to allow SQL functions from
+      :attr:`_sql.func` derived elements to work more effectively with ORM-mapped
+      attributes (:ticket:`10801`)
+
+    * Fixed the argument types passed to functions so that literal expressions
+      like strings and ints are again interpreted correctly (:ticket:`10818`)
 
-    Further enhancements to pep-484 typing to allow SQL functions from
-    :attr:`_sql.func` derived elements to work more effectively with ORM-mapped
-    attributes.
index dfa6f9df5caedc282d3cc6249d3a45a0992820ec..5cb5812d692431eab9417d45d43fa0ca9930dc68 100644 (file)
@@ -999,14 +999,16 @@ class _FunctionGenerator:
         def char_length(self) -> Type[char_length]:
             ...
 
-        # appease mypy which seems to not want to accept _T from
-        # _ColumnExpressionArgument, as it includes non-generic types
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
 
         @overload
         def coalesce(
             self,
             col: ColumnElement[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]:
             ...
@@ -1015,15 +1017,24 @@ class _FunctionGenerator:
         def coalesce(
             self,
             col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]:
             ...
 
+        @overload
         def coalesce(
             self,
-            col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> coalesce[_T]:
+            ...
+
+        def coalesce(
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> coalesce[_T]:
             ...
@@ -1080,14 +1091,16 @@ class _FunctionGenerator:
         def localtimestamp(self) -> Type[localtimestamp]:
             ...
 
-        # appease mypy which seems to not want to accept _T from
-        # _ColumnExpressionArgument, as it includes non-generic types
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
 
         @overload
         def max(  # noqa: A001
             self,
             col: ColumnElement[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> max[_T]:
             ...
@@ -1096,27 +1109,38 @@ class _FunctionGenerator:
         def max(  # noqa: A001
             self,
             col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> max[_T]:
             ...
 
+        @overload
         def max(  # noqa: A001
             self,
-            col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> max[_T]:
+            ...
+
+        def max(  # noqa: A001
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> max[_T]:
             ...
 
-        # appease mypy which seems to not want to accept _T from
-        # _ColumnExpressionArgument, as it includes non-generic types
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
 
         @overload
         def min(  # noqa: A001
             self,
             col: ColumnElement[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> min[_T]:
             ...
@@ -1125,15 +1149,24 @@ class _FunctionGenerator:
         def min(  # noqa: A001
             self,
             col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> min[_T]:
             ...
 
+        @overload
         def min(  # noqa: A001
             self,
-            col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> min[_T]:
+            ...
+
+        def min(  # noqa: A001
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> min[_T]:
             ...
@@ -1182,14 +1215,16 @@ class _FunctionGenerator:
         def session_user(self) -> Type[session_user]:
             ...
 
-        # appease mypy which seems to not want to accept _T from
-        # _ColumnExpressionArgument, as it includes non-generic types
+        # set ColumnElement[_T] as a separate overload, to appease mypy
+        # which seems to not want to accept _T from _ColumnExpressionArgument.
+        # this is even if all non-generic types are removed from it, so
+        # reasons remain unclear for why this does not work
 
         @overload
         def sum(  # noqa: A001
             self,
             col: ColumnElement[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> sum[_T]:
             ...
@@ -1198,15 +1233,24 @@ class _FunctionGenerator:
         def sum(  # noqa: A001
             self,
             col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> sum[_T]:
             ...
 
+        @overload
         def sum(  # noqa: A001
             self,
-            col: _ColumnExpressionArgument[_T],
-            *args: _ColumnExpressionArgument[Any],
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
+            **kwargs: Any,
+        ) -> sum[_T]:
+            ...
+
+        def sum(  # noqa: A001
+            self,
+            col: _ColumnExpressionOrLiteralArgument[_T],
+            *args: _ColumnExpressionOrLiteralArgument[Any],
             **kwargs: Any,
         ) -> sum[_T]:
             ...
@@ -1576,14 +1620,16 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
 
     inherit_cache = True
 
-    # appease mypy which seems to not want to accept _T from
-    # _ColumnExpressionArgument, as it includes non-generic types
+    # set ColumnElement[_T] as a separate overload, to appease mypy which seems
+    # to not want to accept _T from _ColumnExpressionArgument.  this is even if
+    # all non-generic types are removed from it, so reasons remain unclear for
+    # why this does not work
 
     @overload
     def __init__(
         self,
         col: ColumnElement[_T],
-        *args: _ColumnExpressionArgument[Any],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
         **kwargs: Any,
     ):
         ...
@@ -1592,12 +1638,23 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
     def __init__(
         self,
         col: _ColumnExpressionArgument[_T],
-        *args: _ColumnExpressionArgument[Any],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
         **kwargs: Any,
     ):
         ...
 
-    def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
+    @overload
+    def __init__(
+        self,
+        col: _ColumnExpressionOrLiteralArgument[_T],
+        *args: _ColumnExpressionOrLiteralArgument[Any],
+        **kwargs: Any,
+    ):
+        ...
+
+    def __init__(
+        self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
+    ):
         fn_args: Sequence[ColumnElement[Any]] = [
             coercions.expect(
                 roles.ExpressionElementRole,
index 87ade922468259c1df632c7490ace4502eba951e..da656f2d1d98395787902cbb560b6406739fa87d 100644 (file)
@@ -15,6 +15,7 @@ class Foo(Base):
     id: Mapped[int] = mapped_column(primary_key=True)
     a: Mapped[int]
     b: Mapped[int]
+    c: Mapped[str]
 
 
 func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())
@@ -41,3 +42,15 @@ stmt1 = select(
 ).group_by(Foo.a)
 # EXPECTED_TYPE: Select[Tuple[int, int]]
 reveal_type(stmt1)
+
+# test #10818
+# EXPECTED_TYPE: coalesce[str]
+reveal_type(func.coalesce(Foo.c, "a", "b"))
+
+
+stmt2 = select(
+    Foo.a,
+    func.coalesce(Foo.c, "a", "b"),
+).group_by(Foo.a)
+# EXPECTED_TYPE: Select[Tuple[int, str]]
+reveal_type(stmt2)
index 348b3344845e18d820efc6cf6f2274e88c3d6974..51422dc7e6b835a514246b5ee78c7eb4d05605f5 100644 (file)
@@ -62,14 +62,16 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
                             textwrap.indent(
                                 f"""
 
-# appease mypy which seems to not want to accept _T from
-# _ColumnExpressionArgument, as it includes non-generic types
+# set ColumnElement[_T] as a separate overload, to appease mypy
+# which seems to not want to accept _T from _ColumnExpressionArgument.
+# this is even if all non-generic types are removed from it, so
+# reasons remain unclear for why this does not work
 
 @overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
     col: ColumnElement[_T],
-    *args: _ColumnExpressionArgument[Any],
+    *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
     ...
@@ -78,15 +80,26 @@ def {key}( {'  # noqa: A001' if is_reserved_word else ''}
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
     col: _ColumnExpressionArgument[_T],
-    *args: _ColumnExpressionArgument[Any],
+    *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
         ...
 
+
+@overload
 def {key}( {'  # noqa: A001' if is_reserved_word else ''}
     self,
-    col: _ColumnExpressionArgument[_T],
-    *args: _ColumnExpressionArgument[Any],
+    col: _ColumnExpressionOrLiteralArgument[_T],
+    *args: _ColumnExpressionOrLiteralArgument[Any],
+    **kwargs: Any,
+) -> {fn_class.__name__}[_T]:
+        ...
+
+
+def {key}( {'  # noqa: A001' if is_reserved_word else ''}
+    self,
+    col: _ColumnExpressionOrLiteralArgument[_T],
+    *args: _ColumnExpressionOrLiteralArgument[Any],
     **kwargs: Any,
 ) -> {fn_class.__name__}[_T]:
     ...