]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add missing overload to Numeric
authorFederico Caselli <cfederico87@gmail.com>
Tue, 28 Feb 2023 21:44:27 +0000 (22:44 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Mar 2023 15:40:17 +0000 (10:40 -0500)
Added missing init overload to :class:`_sql.Numeric` to allow
type checkers to properly resolve the type var given the
``asdecimal`` parameter.

this fortunately fixes a glitch in the generate_sql_functions script
also

Fixes: #9391
Change-Id: I9cecc40c52711489e9dbe663f110c3b81c7285e4

doc/build/changelog/unreleased_20/9391.rst [new file with mode: 0644]
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/sqltypes.py
test/ext/mypy/plain_files/functions.py
test/ext/mypy/plain_files/sqltypes.py [new file with mode: 0644]
tools/generate_sql_functions.py

diff --git a/doc/build/changelog/unreleased_20/9391.rst b/doc/build/changelog/unreleased_20/9391.rst
new file mode 100644 (file)
index 0000000..99336a7
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9391
+
+    Added missing init overload to :class:`_sql.Numeric` to allow
+    type checkers to properly resolve the type var given the
+    ``asdecimal`` parameter.
index 6054be98a7675f55fe9aa192d2da9047b8df47fd..5f2e67288cac6996f950883dcc650cb75472d7fd 100644 (file)
@@ -13,6 +13,7 @@
 from __future__ import annotations
 
 import datetime
+import decimal
 from typing import Any
 from typing import cast
 from typing import Dict
@@ -54,7 +55,6 @@ from .elements import WithinGroup
 from .selectable import FromClause
 from .selectable import Select
 from .selectable import TableValuedAlias
-from .sqltypes import _N
 from .sqltypes import TableValueType
 from .type_api import TypeEngine
 from .visitors import InternalTraversal
@@ -950,7 +950,7 @@ class _FunctionGenerator:
             ...
 
         @property
-        def cume_dist(self) -> Type[cume_dist[Any]]:
+        def cume_dist(self) -> Type[cume_dist]:
             ...
 
         @property
@@ -1014,7 +1014,7 @@ class _FunctionGenerator:
             ...
 
         @property
-        def percent_rank(self) -> Type[percent_rank[Any]]:
+        def percent_rank(self) -> Type[percent_rank]:
             ...
 
         @property
@@ -1703,7 +1703,7 @@ class dense_rank(GenericFunction[int]):
     inherit_cache = True
 
 
-class percent_rank(GenericFunction[_N]):
+class percent_rank(GenericFunction[decimal.Decimal]):
     """Implement the ``percent_rank`` hypothetical-set aggregate function.
 
     This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1715,11 +1715,11 @@ class percent_rank(GenericFunction[_N]):
 
     """
 
-    type: sqltypes.Numeric[_N] = sqltypes.Numeric()
+    type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric()
     inherit_cache = True
 
 
-class cume_dist(GenericFunction[_N]):
+class cume_dist(GenericFunction[decimal.Decimal]):
     """Implement the ``cume_dist`` hypothetical-set aggregate function.
 
     This function must be used with the :meth:`.FunctionElement.within_group`
@@ -1731,7 +1731,7 @@ class cume_dist(GenericFunction[_N]):
 
     """
 
-    type: sqltypes.Numeric[_N] = sqltypes.Numeric()
+    type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric()
     inherit_cache = True
 
 
index 3c6cb0cb558740c6308d01f044d7f074f9352fa4..4583948704f33975beb8d5b84a6536a5a4aaeb5f 100644 (file)
@@ -470,6 +470,26 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]):
 
     _default_decimal_return_scale = 10
 
+    @overload
+    def __init__(
+        self: Numeric[decimal.Decimal],
+        precision: Optional[int] = ...,
+        scale: Optional[int] = ...,
+        decimal_return_scale: Optional[int] = ...,
+        asdecimal: Literal[True] = ...,
+    ):
+        ...
+
+    @overload
+    def __init__(
+        self: Numeric[float],
+        precision: Optional[int] = ...,
+        scale: Optional[int] = ...,
+        decimal_return_scale: Optional[int] = ...,
+        asdecimal: Literal[False] = ...,
+    ):
+        ...
+
     def __init__(
         self,
         precision: Optional[int] = None,
index ecd404010e312de24e537bd3c6a2a5f7033886fd..09c2acf057f83062933e095db18d1ae61e0c3c0e 100644 (file)
@@ -29,7 +29,7 @@ reveal_type(stmt3)
 
 stmt4 = select(func.cume_dist())
 
-# EXPECTED_RE_TYPE: .*Select\[Any\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt4)
 
 
@@ -89,7 +89,7 @@ reveal_type(stmt13)
 
 stmt14 = select(func.percent_rank())
 
-# EXPECTED_RE_TYPE: .*Select\[Any\]
+# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\]
 reveal_type(stmt14)
 
 
diff --git a/test/ext/mypy/plain_files/sqltypes.py b/test/ext/mypy/plain_files/sqltypes.py
new file mode 100644 (file)
index 0000000..230cb95
--- /dev/null
@@ -0,0 +1,12 @@
+from sqlalchemy import Float
+from sqlalchemy import Numeric
+
+# EXPECTED_TYPE: Float[float]
+reveal_type(Float())
+# EXPECTED_TYPE: Float[Decimal]
+reveal_type(Float(asdecimal=True))
+
+# EXPECTED_TYPE: Numeric[Decimal]
+reveal_type(Numeric())
+# EXPECTED_TYPE: Numeric[float]
+reveal_type(Numeric(asdecimal=False))
index d207c62bcd98a2f0e01c916639a478b2d1692bc5..794b8448792d58dd28699762da4d30eab93754ef 100644 (file)
@@ -5,12 +5,10 @@
 
 from __future__ import annotations
 
-from decimal import Decimal
 import inspect
 import re
 from tempfile import NamedTemporaryFile
 import textwrap
-from typing import Any
 
 from sqlalchemy.sql.functions import _registry
 from sqlalchemy.types import TypeEngine
@@ -99,17 +97,7 @@ def {key}(self) -> Type[{fn_class.__name__}{
                         fn_class.type, TypeEngine
                     ):
                         python_type = fn_class.type.python_type
-
-                        # TODO: numeric types don't seem to be coming out
-                        # at the moment, because Numeric is typed generically
-                        # in that it can return Decimal or float. We would need
-                        # to further break out Numeric / Float into types
-                        # that type out as returning an exact Decimal or float
-                        if python_type is Decimal:
-                            python_type = Any
-                            python_expr = f"{python_type.__name__}"
-                        else:
-                            python_expr = rf"Tuple\[.*{python_type.__name__}\]"
+                        python_expr = rf"Tuple\[.*{python_type.__name__}\]"
                         argspec = inspect.getfullargspec(fn_class)
                         args = ", ".join(
                             'column("x")' for elem in argspec.args[1:]