From: Federico Caselli Date: Thu, 12 Mar 2026 22:34:27 +0000 (+0100) Subject: ensure function classes are not shadowed X-Git-Tag: rel_2_0_49~16^2 X-Git-Url: http://git.ipfire.org/gitweb/?a=commitdiff_plain;h=9873053ec311b150ba7cfb20913e6ea03af87c1f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure function classes are not shadowed Ensure the _FunctionGenerator method do not shadow the function class of the same name Fixed a typing issue where the typed members of :data:`.func` would return the appropriate class of the same name, however this creates an issue for typecheckers such as Zuban and pyrefly that assume :pep:`749` style typechecking even if the file states that it's a :pep:`563` file; they see the returned name as indicating the method object and not the class object. These typecheckers are actually following along with an upcoming test harness that insists on :pep:`749` style name resolution for this case unconditionally. Since :pep:`749` is the way of the future regardless, differently-named type aliases have been added for these return types. Fixes: #13167 Change-Id: If58a3858001c78ab21b2ed343205dfd9ce868576 (cherry picked from commit 0a185a3bb6347719ffab60012db8fbbc23eb29e4) --- diff --git a/doc/build/changelog/unreleased_20/13167.rst b/doc/build/changelog/unreleased_20/13167.rst new file mode 100644 index 0000000000..e874b406d0 --- /dev/null +++ b/doc/build/changelog/unreleased_20/13167.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, typing + :tickets: 13167 + + Fixed a typing issue where the typed members of :data:`.func` would return + the appropriate class of the same name, however this creates an issue for + typecheckers such as Zuban and pyrefly that assume :pep:`749` style + typechecking even if the file states that it's a :pep:`563` file; they see + the returned name as indicating the method object and not the class object. + These typecheckers are actually following along with an upcoming test + harness that insists on :pep:`749` style name resolution for this case + unconditionally. Since :pep:`749` is the way of the future regardless, + differently-named type aliases have been added for these return types. + diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 663ced0e43..498caf134f 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -59,7 +59,7 @@ from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import InternalTraversal from .. import util - +from ..util.typing import TypeAlias if TYPE_CHECKING: from ._typing import _ByArgument @@ -1003,10 +1003,10 @@ class _FunctionGenerator: # statically generated** by tools/generate_sql_functions.py @property - def aggregate_strings(self) -> Type[aggregate_strings]: ... + def aggregate_strings(self) -> Type[_aggregate_strings_func]: ... @property - def ansifunction(self) -> Type[AnsiFunction[Any]]: ... + def ansifunction(self) -> Type[_AnsiFunction_func[Any]]: ... # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from @@ -1019,7 +1019,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> array_agg[_T]: ... + ) -> _array_agg_func[_T]: ... @overload def array_agg( @@ -1027,7 +1027,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> array_agg[_T]: ... + ) -> _array_agg_func[_T]: ... @overload def array_agg( @@ -1035,20 +1035,20 @@ class _FunctionGenerator: col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> array_agg[_T]: ... + ) -> _array_agg_func[_T]: ... def array_agg( self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> array_agg[_T]: ... + ) -> _array_agg_func[_T]: ... @property - def cast(self) -> Type[Cast[Any]]: ... + def cast(self) -> Type[_Cast_func[Any]]: ... @property - def char_length(self) -> Type[char_length]: ... + def char_length(self) -> Type[_char_length_func]: ... # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from @@ -1061,7 +1061,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: ... + ) -> _coalesce_func[_T]: ... @overload def coalesce( @@ -1069,7 +1069,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[Optional[_T]], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: ... + ) -> _coalesce_func[_T]: ... @overload def coalesce( @@ -1077,53 +1077,53 @@ class _FunctionGenerator: col: Optional[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: ... + ) -> _coalesce_func[_T]: ... def coalesce( self, col: _ColumnExpressionOrLiteralArgument[Optional[_T]], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: ... + ) -> _coalesce_func[_T]: ... @property - def concat(self) -> Type[concat]: ... + def concat(self) -> Type[_concat_func]: ... @property - def count(self) -> Type[count]: ... + def count(self) -> Type[_count_func]: ... @property - def cube(self) -> Type[cube[Any]]: ... + def cube(self) -> Type[_cube_func[Any]]: ... @property - def cume_dist(self) -> Type[cume_dist]: ... + def cume_dist(self) -> Type[_cume_dist_func]: ... @property - def current_date(self) -> Type[current_date]: ... + def current_date(self) -> Type[_current_date_func]: ... @property - def current_time(self) -> Type[current_time]: ... + def current_time(self) -> Type[_current_time_func]: ... @property - def current_timestamp(self) -> Type[current_timestamp]: ... + def current_timestamp(self) -> Type[_current_timestamp_func]: ... @property - def current_user(self) -> Type[current_user]: ... + def current_user(self) -> Type[_current_user_func]: ... @property - def dense_rank(self) -> Type[dense_rank]: ... + def dense_rank(self) -> Type[_dense_rank_func]: ... @property - def extract(self) -> Type[Extract]: ... + def extract(self) -> Type[_Extract_func]: ... @property - def grouping_sets(self) -> Type[grouping_sets[Any]]: ... + def grouping_sets(self) -> Type[_grouping_sets_func[Any]]: ... @property - def localtime(self) -> Type[localtime]: ... + def localtime(self) -> Type[_localtime_func]: ... @property - def localtimestamp(self) -> Type[localtimestamp]: ... + def localtimestamp(self) -> Type[_localtimestamp_func]: ... # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from @@ -1136,7 +1136,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: ... + ) -> _max_func[_T]: ... @overload def max( # noqa: A001 @@ -1144,7 +1144,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: ... + ) -> _max_func[_T]: ... @overload def max( # noqa: A001 @@ -1152,14 +1152,14 @@ class _FunctionGenerator: col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: ... + ) -> _max_func[_T]: ... def max( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: ... + ) -> _max_func[_T]: ... # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from @@ -1172,7 +1172,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: ... + ) -> _min_func[_T]: ... @overload def min( # noqa: A001 @@ -1180,7 +1180,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: ... + ) -> _min_func[_T]: ... @overload def min( # noqa: A001 @@ -1188,47 +1188,47 @@ class _FunctionGenerator: col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: ... + ) -> _min_func[_T]: ... def min( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: ... + ) -> _min_func[_T]: ... @property - def mode(self) -> Type[mode[Any]]: ... + def mode(self) -> Type[_mode_func[Any]]: ... @property - def next_value(self) -> Type[next_value]: ... + def next_value(self) -> Type[_next_value_func]: ... @property - def now(self) -> Type[now]: ... + def now(self) -> Type[_now_func]: ... @property - def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: ... + def orderedsetagg(self) -> Type[_OrderedSetAgg_func[Any]]: ... @property - def percent_rank(self) -> Type[percent_rank]: ... + def percent_rank(self) -> Type[_percent_rank_func]: ... @property - def percentile_cont(self) -> Type[percentile_cont[Any]]: ... + def percentile_cont(self) -> Type[_percentile_cont_func[Any]]: ... @property - def percentile_disc(self) -> Type[percentile_disc[Any]]: ... + def percentile_disc(self) -> Type[_percentile_disc_func[Any]]: ... @property - def random(self) -> Type[random]: ... + def random(self) -> Type[_random_func]: ... @property - def rank(self) -> Type[rank]: ... + def rank(self) -> Type[_rank_func]: ... @property - def rollup(self) -> Type[rollup[Any]]: ... + def rollup(self) -> Type[_rollup_func[Any]]: ... @property - def session_user(self) -> Type[session_user]: ... + def session_user(self) -> Type[_session_user_func]: ... # set ColumnElement[_T] as a separate overload, to appease # mypy which seems to not want to accept _T from @@ -1241,7 +1241,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: ... + ) -> _sum_func[_T]: ... @overload def sum( # noqa: A001 @@ -1249,7 +1249,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: ... + ) -> _sum_func[_T]: ... @overload def sum( # noqa: A001 @@ -1257,20 +1257,20 @@ class _FunctionGenerator: col: _T, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: ... + ) -> _sum_func[_T]: ... def sum( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: ... + ) -> _sum_func[_T]: ... @property - def sysdate(self) -> Type[sysdate]: ... + def sysdate(self) -> Type[_sysdate_func]: ... @property - def user(self) -> Type[user]: ... + def user(self) -> Type[_user_func]: ... # END GENERATED FUNCTION ACCESSORS @@ -2156,3 +2156,44 @@ class aggregate_strings(GenericFunction[str]): self, clause: _ColumnExpressionArgument[Any], separator: str ) -> None: super().__init__(clause, separator) + + +# These aliases are required to avoid shadowing the class with the function +# name. See https://github.com/sqlalchemy/sqlalchemy/issues/13167 +# START GENERATED FUNCTION ALIASES +_aggregate_strings_func: TypeAlias = aggregate_strings +_AnsiFunction_func: TypeAlias = AnsiFunction[_T] +_array_agg_func: TypeAlias = array_agg[_T] +_Cast_func: TypeAlias = Cast[_T] +_char_length_func: TypeAlias = char_length +_coalesce_func: TypeAlias = coalesce[_T] +_concat_func: TypeAlias = concat +_count_func: TypeAlias = count +_cube_func: TypeAlias = cube[_T] +_cume_dist_func: TypeAlias = cume_dist +_current_date_func: TypeAlias = current_date +_current_time_func: TypeAlias = current_time +_current_timestamp_func: TypeAlias = current_timestamp +_current_user_func: TypeAlias = current_user +_dense_rank_func: TypeAlias = dense_rank +_Extract_func: TypeAlias = Extract +_grouping_sets_func: TypeAlias = grouping_sets[_T] +_localtime_func: TypeAlias = localtime +_localtimestamp_func: TypeAlias = localtimestamp +_max_func: TypeAlias = max[_T] +_min_func: TypeAlias = min[_T] +_mode_func: TypeAlias = mode[_T] +_next_value_func: TypeAlias = next_value +_now_func: TypeAlias = now +_OrderedSetAgg_func: TypeAlias = OrderedSetAgg[_T] +_percent_rank_func: TypeAlias = percent_rank +_percentile_cont_func: TypeAlias = percentile_cont[_T] +_percentile_disc_func: TypeAlias = percentile_disc[_T] +_random_func: TypeAlias = random +_rank_func: TypeAlias = rank +_rollup_func: TypeAlias = rollup[_T] +_session_user_func: TypeAlias = session_user +_sum_func: TypeAlias = sum[_T] +_sysdate_func: TypeAlias = sysdate +_user_func: TypeAlias = user +# END GENERATED FUNCTION ALIASES diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index e1cea4193e..afa9297abe 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -1,5 +1,6 @@ """this file is generated by tools/generate_sql_functions.py""" +from typing import assert_type from typing import Tuple from sqlalchemy import column @@ -9,152 +10,271 @@ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String +from sqlalchemy.sql import functions # START GENERATED FUNCTION TYPING TESTS # code within this block is **programmatically, # statically generated** by tools/generate_sql_functions.py -stmt1 = select(func.aggregate_strings(column("x", String), ",")) +# test the aggregate_strings() function. +# this function is somewhat special case. + +stmt1 = select(func.aggregate_strings(column("x", String), ",")) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt1) -stmt2 = select(func.array_agg(column("x", Integer))) +# test the array_agg() function. +# this function is a ReturnTypeFromArgs type. + +fn2 = func.array_agg(column("x", Integer)) +assert_type(fn2, functions.array_agg[int]) +stmt2 = select(func.array_agg(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Sequence\[.*int\]\]\] reveal_type(stmt2) -stmt3 = select(func.char_length(column("x"))) +# test the char_length() function. +# this function is fixed to the SQL INTEGER class, or the Tuple\[.*int\] type. +fn3 = func.char_length(column("x")) +assert_type(fn3, functions.char_length) + +stmt3 = select(func.char_length(column("x"))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt3) -stmt4 = select(func.coalesce(column("x", Integer))) +# test the coalesce() function. +# this function is a ReturnTypeFromArgs type. + +fn4 = func.coalesce(column("x", Integer)) +assert_type(fn4, functions.coalesce[int]) +stmt4 = select(func.coalesce(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt4) -stmt5 = select(func.concat()) +# test the concat() function. +# this function is fixed to the SQL VARCHAR class, or the Tuple\[.*str\] type. + +fn5 = func.concat() +assert_type(fn5, functions.concat) +stmt5 = select(func.concat()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt5) -stmt6 = select(func.count(column("x"))) +# test the count() function. +# this function is fixed to the SQL INTEGER class, or the Tuple\[.*int\] type. +fn6 = func.count(column("x")) +assert_type(fn6, functions.count) + +stmt6 = select(func.count(column("x"))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt6) -stmt7 = select(func.cume_dist()) +# test the cume_dist() function. +# this function is fixed to the SQL NUMERIC class, or the Tuple\[.*Decimal\] type. + +fn7 = func.cume_dist() +assert_type(fn7, functions.cume_dist) +stmt7 = select(func.cume_dist()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt7) -stmt8 = select(func.current_date()) +# test the current_date() function. +# this function is fixed to the SQL DATE class, or the Tuple\[.*date\] type. + +fn8 = func.current_date() +assert_type(fn8, functions.current_date) +stmt8 = select(func.current_date()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*date\]\] reveal_type(stmt8) -stmt9 = select(func.current_time()) +# test the current_time() function. +# this function is fixed to the SQL TIME class, or the Tuple\[.*time\] type. +fn9 = func.current_time() +assert_type(fn9, functions.current_time) + +stmt9 = select(func.current_time()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*time\]\] reveal_type(stmt9) -stmt10 = select(func.current_timestamp()) +# test the current_timestamp() function. +# this function is fixed to the SQL DATETIME class, or the Tuple\[.*datetime\] type. + +fn10 = func.current_timestamp() +assert_type(fn10, functions.current_timestamp) +stmt10 = select(func.current_timestamp()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt10) -stmt11 = select(func.current_user()) +# test the current_user() function. +# this function is fixed to the SQL VARCHAR class, or the Tuple\[.*str\] type. + +fn11 = func.current_user() +assert_type(fn11, functions.current_user) +stmt11 = select(func.current_user()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt11) -stmt12 = select(func.dense_rank()) +# test the dense_rank() function. +# this function is fixed to the SQL INTEGER class, or the Tuple\[.*int\] type. +fn12 = func.dense_rank() +assert_type(fn12, functions.dense_rank) + +stmt12 = select(func.dense_rank()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt12) -stmt13 = select(func.localtime()) +# test the localtime() function. +# this function is fixed to the SQL DATETIME class, or the Tuple\[.*datetime\] type. + +fn13 = func.localtime() +assert_type(fn13, functions.localtime) +stmt13 = select(func.localtime()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt13) -stmt14 = select(func.localtimestamp()) +# test the localtimestamp() function. +# this function is fixed to the SQL DATETIME class, or the Tuple\[.*datetime\] type. + +fn14 = func.localtimestamp() +assert_type(fn14, functions.localtimestamp) +stmt14 = select(func.localtimestamp()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt14) -stmt15 = select(func.max(column("x", Integer))) +# test the max() function. +# this function is a ReturnTypeFromArgs type. +fn15 = func.max(column("x", Integer)) +assert_type(fn15, functions.max[int]) + +stmt15 = select(func.max(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt15) -stmt16 = select(func.min(column("x", Integer))) +# test the min() function. +# this function is a ReturnTypeFromArgs type. + +fn16 = func.min(column("x", Integer)) +assert_type(fn16, functions.min[int]) +stmt16 = select(func.min(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt16) -stmt17 = select(func.next_value(Sequence("x_seq"))) +# test the next_value() function. +# this function is fixed to the SQL INTEGER class, or the Tuple\[.*int\] type. + +fn17 = func.next_value(Sequence("x_seq")) +assert_type(fn17, functions.next_value) +stmt17 = select(func.next_value(Sequence("x_seq"))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt17) -stmt18 = select(func.now()) +# test the now() function. +# this function is fixed to the SQL DATETIME class, or the Tuple\[.*datetime\] type. +fn18 = func.now() +assert_type(fn18, functions.now) + +stmt18 = select(func.now()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt18) -stmt19 = select(func.percent_rank()) +# test the percent_rank() function. +# this function is fixed to the SQL NUMERIC class, or the Tuple\[.*Decimal\] type. + +fn19 = func.percent_rank() +assert_type(fn19, functions.percent_rank) +stmt19 = select(func.percent_rank()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*Decimal\]\] reveal_type(stmt19) -stmt20 = select(func.rank()) +# test the rank() function. +# this function is fixed to the SQL INTEGER class, or the Tuple\[.*int\] type. + +fn20 = func.rank() +assert_type(fn20, functions.rank) +stmt20 = select(func.rank()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt20) -stmt21 = select(func.session_user()) +# test the session_user() function. +# this function is fixed to the SQL VARCHAR class, or the Tuple\[.*str\] type. +fn21 = func.session_user() +assert_type(fn21, functions.session_user) + +stmt21 = select(func.session_user()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt21) -stmt22 = select(func.sum(column("x", Integer))) +# test the sum() function. +# this function is a ReturnTypeFromArgs type. + +fn22 = func.sum(column("x", Integer)) +assert_type(fn22, functions.sum[int]) +stmt22 = select(func.sum(column("x", Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*int\]\] reveal_type(stmt22) -stmt23 = select(func.sysdate()) +# test the sysdate() function. +# this function is fixed to the SQL DATETIME class, or the Tuple\[.*datetime\] type. + +fn23 = func.sysdate() +assert_type(fn23, functions.sysdate) +stmt23 = select(func.sysdate()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*datetime\]\] reveal_type(stmt23) -stmt24 = select(func.user()) +# test the user() function. +# this function is fixed to the SQL VARCHAR class, or the Tuple\[.*str\] type. +fn24 = func.user() +assert_type(fn24, functions.user) + +stmt24 = select(func.user()) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt24) diff --git a/tools/generate_sql_functions.py b/tools/generate_sql_functions.py index 49844947bb..e6ee3e12fb 100644 --- a/tools/generate_sql_functions.py +++ b/tools/generate_sql_functions.py @@ -39,6 +39,7 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: ): indent = "" in_block = False + alias_mapping: dict[str, str] = {} for line in orig_py: m = re.match( @@ -63,7 +64,9 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str: for key, fn_class in _fns_in_deterministic_order(): is_reserved_word = key in builtins + class_name = f"_{fn_class.__name__}_func" if issubclass(fn_class, ReturnTypeFromArgs): + guess_its_generic = True if issubclass(fn_class, ReturnTypeFromOptionalArgs): _TEE = "Optional[_T]" else: @@ -84,7 +87,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, -) -> {fn_class.__name__}[_T]: +) -> {class_name}[_T]: ... @overload @@ -93,7 +96,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} col: _ColumnExpressionArgument[{_TEE}], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, -) -> {fn_class.__name__}[_T]: +) -> {class_name}[_T]: ... @overload @@ -102,7 +105,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} col: {_TEE}, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, -) -> {fn_class.__name__}[_T]: +) -> {class_name}[_T]: ... def {key}( {' # noqa: A001' if is_reserved_word else ''} @@ -110,7 +113,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} col: _ColumnExpressionOrLiteralArgument[{_TEE}], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, -) -> {fn_class.__name__}[_T]: +) -> {class_name}[_T]: ... """, @@ -130,7 +133,7 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''} # to get around the indentation errors # 4. Therefore here I have to concat part of the # string outside of the f-string - _type = fn_class.__name__ + _type = class_name _type += "[Any]" if guess_its_generic else "" _reserved_word = ( " # noqa: A001" if is_reserved_word else "" @@ -148,6 +151,11 @@ def {key}(self) -> Type[{_type}]:{_reserved_word} indent, ) ) + orig_name = fn_class.__name__ + alias_name = class_name + if guess_its_generic: + orig_name += "[_T]" + alias_mapping[orig_name] = alias_name m = re.match( r"^( *)# START GENERATED FUNCTION TYPING TESTS", @@ -189,11 +197,18 @@ def {key}(self) -> Type[{_type}]:{_reserved_word} buf.write( textwrap.indent( rf""" -stmt{count} = select(func.{key}(column('x', Integer))) +# test the {key}() function. +# this function is a ReturnTypeFromArgs type. + +fn{count} = func.{key}(column('x', Integer)) +assert_type(fn{count}, functions.{key}[int]) + +stmt{count} = select(func.{key}(column('x', Integer))) # EXPECTED_RE_TYPE: .*Select\[Tuple\[{coltype}\]\] reveal_type(stmt{count}) + """, indent, ) @@ -203,8 +218,11 @@ reveal_type(stmt{count}) buf.write( textwrap.indent( rf""" -stmt{count} = select(func.{key}(column('x', String), ',')) +# test the aggregate_strings() function. +# this function is somewhat special case. + +stmt{count} = select(func.{key}(column('x', String), ',')) # EXPECTED_RE_TYPE: .*Select\[Tuple\[.*str\]\] reveal_type(stmt{count}) @@ -230,16 +248,34 @@ reveal_type(stmt{count}) buf.write( textwrap.indent( rf""" -stmt{count} = select(func.{key}({args})) +# test the {key}() function. +# this function is fixed to the SQL {fn_class.type} class, or the {python_expr} type. + +fn{count} = func.{key}({args}) +assert_type(fn{count}, functions.{key}) + +stmt{count} = select(func.{key}({args})) # EXPECTED_RE_TYPE: .*Select\[{python_expr}\] reveal_type(stmt{count}) -""", +""", # noqa: E501 indent, ) ) + m = re.match( + r"^( *)# START GENERATED FUNCTION ALIASES", + line, + ) + if m: + in_block = True + buf.write(line) + indent = m.group(1) + + for name, alias in alias_mapping.items(): + buf.write(f"{indent}{alias}: TypeAlias = {name}\n") + if in_block and line.startswith( f"{indent}# END GENERATED FUNCTION" ):