From 584094a4384e305093dfffc56648626b9659cdf7 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Wed, 24 Apr 2024 21:47:01 +0200 Subject: [PATCH] Improve typing to the count function. Improve typing to allow `'*'` and 1 in the count function. Fixes: #11316 Change-Id: Iaafdb779b6baa70504154099f0b9554c612a9ffa (cherry picked from commit 55fb04f10c0aeee7ace984dbe66642a1286594de) --- .gitignore | 1 + lib/sqlalchemy/sql/_typing.py | 6 ++++-- lib/sqlalchemy/sql/functions.py | 5 ++++- test/typing/plain_files/sql/functions.py | 7 +++++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 13b40c819a..d2ee9a2f4a 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ test/test_schema.db /db_idents.txt .DS_Store .vs +/scratch diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index c861bae6e0..0d8f464467 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -117,10 +117,12 @@ _NOT_ENTITY = TypeVar( "Decimal", ) +_StarOrOne = Literal["*", 1] + _MAYBE_ENTITY = TypeVar( "_MAYBE_ENTITY", roles.ColumnsClauseRole, - Literal["*", 1], + _StarOrOne, Type[Any], Inspectable[_HasClauseElement[Any]], _HasClauseElement[Any], @@ -145,7 +147,7 @@ _ColumnsClauseArgument = Union[ roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, "SQLCoreOperations[_T]", - Literal["*", 1], + _StarOrOne, Type[_T], Inspectable[_HasClauseElement[_T]], _HasClauseElement[_T], diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index afb2b1d9b9..8ef7f75bc2 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -69,6 +69,7 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _StarOrOne from ._typing import _TypeEngineArgument from .base import _EntityNamespace from .elements import ClauseElement @@ -1721,7 +1722,9 @@ class count(GenericFunction[int]): def __init__( self, - expression: Optional[_ColumnExpressionArgument[Any]] = None, + expression: Union[ + _ColumnExpressionArgument[Any], _StarOrOne, None + ] = None, **kwargs: Any, ): if expression is None: diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index 6a345fcf6e..f657a48571 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -1,8 +1,11 @@ """this file is generated by tools/generate_sql_functions.py""" +from typing import Tuple + from sqlalchemy import column from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String @@ -150,3 +153,7 @@ stmt23 = select(func.user()) reveal_type(stmt23) # END GENERATED FUNCTION TYPING TESTS + +stmt_count: Select[Tuple[int, int, int]] = select( + func.count(), func.count("*"), func.count(1) +) -- 2.47.2