]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations for postgresql.array()
authorDenis Laxalde <denis@laxalde.org>
Fri, 28 Feb 2025 12:01:54 +0000 (13:01 +0100)
committerDenis Laxalde <denis@laxalde.org>
Thu, 13 Mar 2025 19:53:47 +0000 (20:53 +0100)
The type argument of array is inferred from the 'clauses' argument of
the constructor; hence array([1, 2]) has type array[int]. We explicitly
define the 'type_' parameter of __init__() as it helps type inference
when using this argument, e.g. array([], type_=CHAR) is inferred as
array[str]. Consistency between 'clauses' and 'type_' is also ensured.

lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/json.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index 3861995ce01e7a07c4e85baa090c26827b80424b..2a72609fd01a82fe22998e8837d6630f3318da02 100644 (file)
@@ -11,8 +11,11 @@ from __future__ import annotations
 
 import re
 from typing import Any as typing_Any
+from typing import Iterable
 from typing import Optional
+from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from .operators import CONTAINED_BY
 from .operators import CONTAINS
@@ -21,7 +24,15 @@ from ... import types as sqltypes
 from ... import util
 from ...sql import expression
 from ...sql import operators
-from ...sql._typing import _TypeEngineArgument
+
+if TYPE_CHECKING:
+    from ...sql._typing import _TypeEngineArgument
+    from ...sql.elements import Grouping
+    from ...sql.expression import BindParameter
+    from ...sql.operators import OperatorType
+    from ...sql.selectable import _SelectIterable
+    from ...sql.type_api import TypeEngine
+    from ...util.typing import Self
 
 
 _T = TypeVar("_T", bound=typing_Any)
@@ -107,15 +118,20 @@ class array(expression.ExpressionClauseList[_T]):
     stringify_dialect = "postgresql"
     inherit_cache = True
 
-    def __init__(self, clauses, **kw):
-        type_arg = kw.pop("type_", None)
+    def __init__(
+        self,
+        clauses: Iterable[_T],
+        *,
+        type_: Optional[_TypeEngineArgument[_T]] = None,
+        **kw: typing_Any,
+    ):
         super().__init__(operators.comma_op, *clauses, **kw)
 
         self._type_tuple = [arg.type for arg in self.clauses]
 
         main_type = (
-            type_arg
-            if type_arg is not None
+            type_
+            if type_ is not None
             else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
         )
 
@@ -127,15 +143,21 @@ class array(expression.ExpressionClauseList[_T]):
                     if main_type.dimensions is not None
                     else 2
                 ),
-            )
+            )  # type: ignore[assignment]
         else:
-            self.type = ARRAY(main_type)
+            self.type = ARRAY(main_type)  # type: ignore[assignment]
 
     @property
-    def _select_iterable(self):
+    def _select_iterable(self) -> _SelectIterable:
         return (self,)
 
-    def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
+    def _bind_param(
+        self,
+        operator: OperatorType,
+        obj: typing_Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        _assume_scalar: bool = False,
+    ) -> BindParameter[_T]:
         if _assume_scalar or operator is operators.getitem:
             return expression.BindParameter(
                 None,
@@ -154,9 +176,11 @@ class array(expression.ExpressionClauseList[_T]):
                     )
                     for o in obj
                 ]
-            )
+            )  # type: ignore[return-value]
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if against in (operators.any_op, operators.all_op, operators.getitem):
             return expression.Grouping(self)
         else:
index 663be8b7a2b8830f45da58d2202b7f01953dd377..06f8db5b2af8673ec358bc0d17efdecf59df44b1 100644 (file)
@@ -337,7 +337,7 @@ class JSONB(JSON):
             .. versionadded:: 2.0
             """
             if not isinstance(array, _pg_array):
-                array = _pg_array(array)  # type: ignore[no-untyped-call]
+                array = _pg_array(array)
             right_side = cast(array, ARRAY(sqltypes.TEXT))
             return self.operate(DELETE_PATH, right_side, result_type=JSONB)
 
index e65cef65ab9fcc4d45a749c8b53d26072affee9a..9981e4a4fc1a00ba0071c187036d72fdb69e6636 100644 (file)
@@ -99,3 +99,21 @@ range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
 
 # EXPECTED_TYPE: Select[Range[int], Sequence[Range[int]]]
 reveal_type(range_col_stmt)
+
+array_from_ints = array(range(2))
+
+# EXPECTED_TYPE: array[int]
+reveal_type(array_from_ints)
+
+array_of_strings = array([], type_=Text)
+
+# EXPECTED_TYPE: array[str]
+reveal_type(array_of_strings)
+
+array_of_ints = array([0], type_=Integer)
+
+# EXPECTED_TYPE: array[int]
+reveal_type(array_of_ints)
+
+# EXPECTED_MYPY: Cannot infer type argument 1 of "array"
+array([0], type_=Text)