]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve typing to the count function.
authorFederico Caselli <cfederico87@gmail.com>
Wed, 24 Apr 2024 19:47:01 +0000 (21:47 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 25 Apr 2024 19:42:58 +0000 (19:42 +0000)
Improve typing to allow `'*'` and 1 in the count function.

Fixes: #11316
Change-Id: Iaafdb779b6baa70504154099f0b9554c612a9ffa

.gitignore
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/functions.py
test/typing/plain_files/sql/functions.py

index f2544502f3bc1bd0e3e9a781b48b90244f36f6c2..2fdd7eb95190500f835f2df0d23f383ff1ccec74 100644 (file)
@@ -40,6 +40,7 @@ test/test_schema.db
 /db_idents.txt
 .DS_Store
 .vs
+/scratch
 
 # cython complied files
 /lib/**/*.c
index 6d54f415fc8d5258bafa210da38b3fc70ac07bf1..bef7e6e7b72a83c60d3882a1b0f9c6c25c196b57 100644 (file)
@@ -118,10 +118,12 @@ _NOT_ENTITY = TypeVar(
     "Decimal",
 )
 
+_StarOrOne = Literal["*", 1]
+
 _MAYBE_ENTITY = TypeVar(
     "_MAYBE_ENTITY",
     roles.ColumnsClauseRole,
-    Literal["*", 1],
+    _StarOrOne,
     Type[Any],
     Inspectable[_HasClauseElement[Any]],
     _HasClauseElement[Any],
@@ -146,7 +148,7 @@ _ColumnsClauseArgument = Union[
     roles.TypedColumnsClauseRole[_T],
     roles.ColumnsClauseRole,
     "SQLCoreOperations[_T]",
-    Literal["*", 1],
+    _StarOrOne,
     Type[_T],
     Inspectable[_HasClauseElement[_T]],
     _HasClauseElement[_T],
index 088b506c760f0f40441b92774d464635ffd1461b..3ebf5c0a1ef566d01d74bd13997d83030c02241d 100644 (file)
@@ -69,6 +69,7 @@ if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _ColumnExpressionOrLiteralArgument
     from ._typing import _ColumnExpressionOrStrLabelArgument
+    from ._typing import _StarOrOne
     from ._typing import _TypeEngineArgument
     from .base import _EntityNamespace
     from .elements import ClauseElement
@@ -1721,7 +1722,9 @@ class count(GenericFunction[int]):
 
     def __init__(
         self,
-        expression: Optional[_ColumnExpressionArgument[Any]] = None,
+        expression: Union[
+            _ColumnExpressionArgument[Any], _StarOrOne, None
+        ] = None,
         **kwargs: Any,
     ):
         if expression is None:
index 726c24b3f1d6645653268149d59cbadc6e86d4c3..9f307e5d921b60857425dd6caba5f00e2038266b 100644 (file)
@@ -3,6 +3,7 @@
 from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
@@ -150,3 +151,7 @@ stmt23 = select(func.user())
 reveal_type(stmt23)
 
 # END GENERATED FUNCTION TYPING TESTS
+
+stmt_count: Select[int, int, int] = select(
+    func.count(), func.count("*"), func.count(1)
+)