From c39dc697c1598c4a6a934dc0b5a60a0eaae6555d Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Tue, 28 Feb 2023 22:44:27 +0100 Subject: [PATCH] Add missing overload to Numeric Added missing init overload to :class:`_sql.Numeric` to allow type checkers to properly resolve the type var given the ``asdecimal`` parameter. this fortunately fixes a glitch in the generate_sql_functions script also Fixes: #9391 Change-Id: I9cecc40c52711489e9dbe663f110c3b81c7285e4 --- doc/build/changelog/unreleased_20/9391.rst | 7 +++++++ lib/sqlalchemy/sql/functions.py | 14 +++++++------- lib/sqlalchemy/sql/sqltypes.py | 20 ++++++++++++++++++++ test/ext/mypy/plain_files/functions.py | 4 ++-- test/ext/mypy/plain_files/sqltypes.py | 12 ++++++++++++ tools/generate_sql_functions.py | 14 +------------- 6 files changed, 49 insertions(+), 22 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9391.rst create mode 100644 test/ext/mypy/plain_files/sqltypes.py diff --git a/doc/build/changelog/unreleased_20/9391.rst b/doc/build/changelog/unreleased_20/9391.rst new file mode 100644 index 0000000000..99336a71c7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9391.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, typing + :tickets: 9391 + + Added missing init overload to :class:`_sql.Numeric` to allow + type checkers to properly resolve the type var given the + ``asdecimal`` parameter. diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 6054be98a7..5f2e67288c 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -13,6 +13,7 @@ from __future__ import annotations import datetime +import decimal from typing import Any from typing import cast from typing import Dict @@ -54,7 +55,6 @@ from .elements import WithinGroup from .selectable import FromClause from .selectable import Select from .selectable import TableValuedAlias -from .sqltypes import _N from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal @@ -950,7 +950,7 @@ class _FunctionGenerator: ... @property - def cume_dist(self) -> Type[cume_dist[Any]]: + def cume_dist(self) -> Type[cume_dist]: ... @property @@ -1014,7 +1014,7 @@ class _FunctionGenerator: ... @property - def percent_rank(self) -> Type[percent_rank[Any]]: + def percent_rank(self) -> Type[percent_rank]: ... @property @@ -1703,7 +1703,7 @@ class dense_rank(GenericFunction[int]): inherit_cache = True -class percent_rank(GenericFunction[_N]): +class percent_rank(GenericFunction[decimal.Decimal]): """Implement the ``percent_rank`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1715,11 +1715,11 @@ class percent_rank(GenericFunction[_N]): """ - type: sqltypes.Numeric[_N] = sqltypes.Numeric() + type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric() inherit_cache = True -class cume_dist(GenericFunction[_N]): +class cume_dist(GenericFunction[decimal.Decimal]): """Implement the ``cume_dist`` hypothetical-set aggregate function. This function must be used with the :meth:`.FunctionElement.within_group` @@ -1731,7 +1731,7 @@ class cume_dist(GenericFunction[_N]): """ - type: sqltypes.Numeric[_N] = sqltypes.Numeric() + type: sqltypes.Numeric[decimal.Decimal] = sqltypes.Numeric() inherit_cache = True diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 3c6cb0cb55..4583948704 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -470,6 +470,26 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): _default_decimal_return_scale = 10 + @overload + def __init__( + self: Numeric[decimal.Decimal], + precision: Optional[int] = ..., + scale: Optional[int] = ..., + decimal_return_scale: Optional[int] = ..., + asdecimal: Literal[True] = ..., + ): + ... + + @overload + def __init__( + self: Numeric[float], + precision: Optional[int] = ..., + scale: Optional[int] = ..., + decimal_return_scale: Optional[int] = ..., + asdecimal: Literal[False] = ..., + ): + ... + def __init__( self, precision: Optional[int] = None, diff --git a/test/ext/mypy/plain_files/functions.py b/test/ext/mypy/plain_files/functions.py index ecd404010e..09c2acf057 100644 --- a/test/ext/mypy/plain_files/functions.py +++ b/test/ext/mypy/plain_files/functions.py @@ -29,7 +29,7 @@ reveal_type(stmt3) stmt4 = select(func.cume_dist()) -# EXPECTED_RE_TYPE: .*Select\[Any\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt4) @@ -89,7 +89,7 @@ reveal_type(stmt13) stmt14 = select(func.percent_rank()) -# EXPECTED_RE_TYPE: .*Select\[Any\] +# EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt14) diff --git a/test/ext/mypy/plain_files/sqltypes.py b/test/ext/mypy/plain_files/sqltypes.py new file mode 100644 index 0000000000..230cb957d4 --- /dev/null +++ b/test/ext/mypy/plain_files/sqltypes.py @@ -0,0 +1,12 @@ +from sqlalchemy import Float +from sqlalchemy import Numeric + +# EXPECTED_TYPE: Float[float] +reveal_type(Float()) +# EXPECTED_TYPE: Float[Decimal] +reveal_type(Float(asdecimal=True)) + +# EXPECTED_TYPE: Numeric[Decimal] +reveal_type(Numeric()) +# EXPECTED_TYPE: Numeric[float] +reveal_type(Numeric(asdecimal=False)) diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index d207c62bcd..794b844879 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -5,12 +5,10 @@ from __future__ import annotations -from decimal import Decimal import inspect import re from tempfile import NamedTemporaryFile import textwrap -from typing import Any from sqlalchemy.sql.functions import _registry from sqlalchemy.types import TypeEngine @@ -99,17 +97,7 @@ def {key}(self) -> Type[{fn_class.__name__}{ fn_class.type, TypeEngine ): python_type = fn_class.type.python_type - - # TODO: numeric types don't seem to be coming out - # at the moment, because Numeric is typed generically - # in that it can return Decimal or float. We would need - # to further break out Numeric / Float into types - # that type out as returning an exact Decimal or float - if python_type is Decimal: - python_type = Any - python_expr = f"{python_type.__name__}" - else: - python_expr = rf"Tuple\[.*{python_type.__name__}\]" + python_expr = rf"Tuple\[.*{python_type.__name__}\]" argspec = inspect.getfullargspec(fn_class) args = ", ".join( 'column("x")' for elem in argspec.args[1:] -- 2.47.2