From: Federico Caselli Date: Wed, 24 Apr 2024 19:47:01 +0000 (+0200) Subject: Improve typing to the count function. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=980cfc5bdfaa1f379922f21f995fc6df3f65a872;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve typing to the count function. Improve typing to allow `'*'` and 1 in the count function. Fixes: #11316 Change-Id: Iaafdb779b6baa70504154099f0b9554c612a9ffa --- diff --git a/.gitignore b/.gitignore index f2544502f3..2fdd7eb951 100644 --- a/.gitignore +++ b/.gitignore @@ -40,6 +40,7 @@ test/test_schema.db /db_idents.txt .DS_Store .vs +/scratch # cython complied files /lib/**/*.c diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 6d54f415fc..bef7e6e7b7 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -118,10 +118,12 @@ _NOT_ENTITY = TypeVar( "Decimal", ) +_StarOrOne = Literal["*", 1] + _MAYBE_ENTITY = TypeVar( "_MAYBE_ENTITY", roles.ColumnsClauseRole, - Literal["*", 1], + _StarOrOne, Type[Any], Inspectable[_HasClauseElement[Any]], _HasClauseElement[Any], @@ -146,7 +148,7 @@ _ColumnsClauseArgument = Union[ roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, "SQLCoreOperations[_T]", - Literal["*", 1], + _StarOrOne, Type[_T], Inspectable[_HasClauseElement[_T]], _HasClauseElement[_T], diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 088b506c76..3ebf5c0a1e 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -69,6 +69,7 @@ if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _StarOrOne from ._typing import _TypeEngineArgument from .base import _EntityNamespace from .elements import ClauseElement @@ -1721,7 +1722,9 @@ class count(GenericFunction[int]): def __init__( self, - expression: Optional[_ColumnExpressionArgument[Any]] = None, + expression: Union[ + _ColumnExpressionArgument[Any], _StarOrOne, None + ] = None, **kwargs: Any, ): if expression is None: diff --git a/test/typing/plain_files/sql/functions.py b/test/typing/plain_files/sql/functions.py index 726c24b3f1..9f307e5d92 100644 --- a/test/typing/plain_files/sql/functions.py +++ b/test/typing/plain_files/sql/functions.py @@ -3,6 +3,7 @@ from sqlalchemy import column from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import Sequence from sqlalchemy import String @@ -150,3 +151,7 @@ stmt23 = select(func.user()) reveal_type(stmt23) # END GENERATED FUNCTION TYPING TESTS + +stmt_count: Select[int, int, int] = select( + func.count(), func.count("*"), func.count(1) +)