]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve typing to the count function.
authorFederico Caselli <cfederico87@gmail.com>
Wed, 24 Apr 2024 19:47:01 +0000 (21:47 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 25 Apr 2024 19:43:05 +0000 (19:43 +0000)
Improve typing to allow `'*'` and 1 in the count function.

Fixes: #11316
Change-Id: Iaafdb779b6baa70504154099f0b9554c612a9ffa
(cherry picked from commit 55fb04f10c0aeee7ace984dbe66642a1286594de)

.gitignore
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions.py

index 13b40c819ad969a81278766b882308fa8135d41a..d2ee9a2f4add17e9a7d735a4a29fe47ae9f20e05 100644 (file)
@@ -40,3 +40,4 @@ test/test_schema.db
 /db_idents.txt
 .DS_Store
 .vs
+/scratch
index c861bae6e0ffa4563431f06c44819605286f1705..0d8f464467ee052b6e5d1c9f802d10a11aff91a6 100644 (file)
@@ -117,10 +117,12 @@ _NOT_ENTITY = TypeVar(
     "Decimal",
 )
 
+_StarOrOne = Literal["*", 1]
+
 _MAYBE_ENTITY = TypeVar(
     "_MAYBE_ENTITY",
     roles.ColumnsClauseRole,
-    Literal["*", 1],
+    _StarOrOne,
     Type[Any],
     Inspectable[_HasClauseElement[Any]],
     _HasClauseElement[Any],
@@ -145,7 +147,7 @@ _ColumnsClauseArgument = Union[
     roles.TypedColumnsClauseRole[_T],
     roles.ColumnsClauseRole,
     "SQLCoreOperations[_T]",
-    Literal["*", 1],
+    _StarOrOne,
     Type[_T],
     Inspectable[_HasClauseElement[_T]],
     _HasClauseElement[_T],
index afb2b1d9b992c4298422a198fb296d3294db3882..8ef7f75bc2142e9fcf21591ff96bc794bfd20fa0 100644 (file)
@@ -69,6 +69,7 @@ if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrLiteralArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
+    from ._typing import _StarOrOne
     from ._typing import _TypeEngineArgument
     from .base import _EntityNamespace
     from .elements import ClauseElement
@@ -1721,7 +1722,9 @@ class count(GenericFunction[int]):
 
     def __init__(
         self,
-        expression: Optional[_ColumnExpressionArgument[Any]] = None,
+        expression: Union[
+            _ColumnExpressionArgument[Any], _StarOrOne, None
+        ] = None,
         **kwargs: Any,
     ):
         if expression is None:
index 6a345fcf6ec6463e753e7ab0e347db0661422876..f657a48571aa9c7d22c940eb3cce18ce8dcb378c 100644 (file)
@@ -1,8 +1,11 @@
 """this file is generated by tools/generate_sql_functions.py"""
 
+from typing import Tuple
+
 from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
@@ -150,3 +153,7 @@ stmt23 = select(func.user())
 reveal_type(stmt23)
 
 # END GENERATED FUNCTION TYPING TESTS
+
+stmt_count: Select[Tuple[int, int, int]] = select(
+    func.count(), func.count("*"), func.count(1)
+)