]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to `postgresql.array`
authorDenis Laxalde <denis@laxalde.org>
Fri, 14 Mar 2025 21:01:50 +0000 (17:01 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 15 Mar 2025 15:21:27 +0000 (16:21 +0100)
Improved static typing for `postgresql.array()` by making the type parameter (the type of array's elements) inferred from the `clauses` and `type_` arguments while also ensuring they are consistent.

Also completed type annotations of `postgresql.ARRAY` following commit 0bf7e02afbec557eb3a5607db407f27deb7aac77 and added type annotations for functions `postgresql.Any()` and `postgresql.All()`.

Finally, fixed shadowing `typing.Any` by the `Any()` function through aliasing as `typing_Any`.

Related to #6810

Closes: #12384
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12384
Pull-request-sha: 78eea29f1de850afda036502974521969629de7e

Change-Id: I5d35d15ec8ba4d58eeb9bf00abb710e2e585731f

lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/json.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index 7708769cb5380dde321e1f8948d2da48d2aeff09..8cbe0c48cf9a617e5252fe1c29c73545d9087f30 100644 (file)
@@ -4,15 +4,18 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 
 from __future__ import annotations
 
 import re
-from typing import Any
+from typing import Any as typing_Any
+from typing import Iterable
 from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
 from typing import TypeVar
+from typing import Union
 
 from .operators import CONTAINED_BY
 from .operators import CONTAINS
@@ -21,28 +24,50 @@ from ... import types as sqltypes
 from ... import util
 from ...sql import expression
 from ...sql import operators
-from ...sql._typing import _TypeEngineArgument
 
-
-_T = TypeVar("_T", bound=Any)
-
-
-def Any(other, arrexpr, operator=operators.eq):
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql._typing import _ColumnExpressionArgument
+    from ...sql._typing import _TypeEngineArgument
+    from ...sql.elements import ColumnElement
+    from ...sql.elements import Grouping
+    from ...sql.expression import BindParameter
+    from ...sql.operators import OperatorType
+    from ...sql.selectable import _SelectIterable
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+    from ...sql.type_api import _ResultProcessorType
+    from ...sql.type_api import TypeEngine
+    from ...util.typing import Self
+
+
+_T = TypeVar("_T", bound=typing_Any)
+
+
+def Any(
+    other: typing_Any,
+    arrexpr: _ColumnExpressionArgument[_T],
+    operator: OperatorType = operators.eq,
+) -> ColumnElement[bool]:
     """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
     See that method for details.
 
     """
 
-    return arrexpr.any(other, operator)
+    return arrexpr.any(other, operator)  # type: ignore[no-any-return, union-attr]  # noqa: E501
 
 
-def All(other, arrexpr, operator=operators.eq):
+def All(
+    other: typing_Any,
+    arrexpr: _ColumnExpressionArgument[_T],
+    operator: OperatorType = operators.eq,
+) -> ColumnElement[bool]:
     """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
     See that method for details.
 
     """
 
-    return arrexpr.all(other, operator)
+    return arrexpr.all(other, operator)  # type: ignore[no-any-return, union-attr]  # noqa: E501
 
 
 class array(expression.ExpressionClauseList[_T]):
@@ -107,16 +132,19 @@ class array(expression.ExpressionClauseList[_T]):
     stringify_dialect = "postgresql"
     inherit_cache = True
 
-    def __init__(self, clauses, **kw):
-        type_arg = kw.pop("type_", None)
+    def __init__(
+        self,
+        clauses: Iterable[_T],
+        *,
+        type_: Optional[_TypeEngineArgument[_T]] = None,
+        **kw: typing_Any,
+    ):
         super().__init__(operators.comma_op, *clauses, **kw)
 
-        self._type_tuple = [arg.type for arg in self.clauses]
-
         main_type = (
-            type_arg
-            if type_arg is not None
-            else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
+            type_
+            if type_ is not None
+            else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
         )
 
         if isinstance(main_type, ARRAY):
@@ -127,15 +155,21 @@ class array(expression.ExpressionClauseList[_T]):
                     if main_type.dimensions is not None
                     else 2
                 ),
-            )
+            )  # type: ignore[assignment]
         else:
-            self.type = ARRAY(main_type)
+            self.type = ARRAY(main_type)  # type: ignore[assignment]
 
     @property
-    def _select_iterable(self):
+    def _select_iterable(self) -> _SelectIterable:
         return (self,)
 
-    def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
+    def _bind_param(
+        self,
+        operator: OperatorType,
+        obj: typing_Any,
+        type_: Optional[TypeEngine[_T]] = None,
+        _assume_scalar: bool = False,
+    ) -> BindParameter[_T]:
         if _assume_scalar or operator is operators.getitem:
             return expression.BindParameter(
                 None,
@@ -154,9 +188,11 @@ class array(expression.ExpressionClauseList[_T]):
                     )
                     for o in obj
                 ]
-            )
+            )  # type: ignore[return-value]
 
-    def self_group(self, against=None):
+    def self_group(
+        self, against: Optional[OperatorType] = None
+    ) -> Union[Self, Grouping[_T]]:
         if against in (operators.any_op, operators.all_op, operators.getitem):
             return expression.Grouping(self)
         else:
@@ -237,7 +273,7 @@ class ARRAY(sqltypes.ARRAY):
 
     def __init__(
         self,
-        item_type: _TypeEngineArgument[Any],
+        item_type: _TypeEngineArgument[typing_Any],
         as_tuple: bool = False,
         dimensions: Optional[int] = None,
         zero_indexes: bool = False,
@@ -296,7 +332,9 @@ class ARRAY(sqltypes.ARRAY):
 
         """
 
-        def contains(self, other, **kwargs):
+        def contains(
+            self, other: typing_Any, **kwargs: typing_Any
+        ) -> ColumnElement[bool]:
             """Boolean expression.  Test if elements are a superset of the
             elements of the argument array expression.
 
@@ -305,7 +343,7 @@ class ARRAY(sqltypes.ARRAY):
             """
             return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
 
-        def contained_by(self, other):
+        def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if elements are a proper subset of the
             elements of the argument array expression.
             """
@@ -313,7 +351,7 @@ class ARRAY(sqltypes.ARRAY):
                 CONTAINED_BY, other, result_type=sqltypes.Boolean
             )
 
-        def overlap(self, other):
+        def overlap(self, other: typing_Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if array has elements in common with
             an argument array expression.
             """
@@ -321,35 +359,26 @@ class ARRAY(sqltypes.ARRAY):
 
     comparator_factory = Comparator
 
-    @property
-    def hashable(self):
-        return self.as_tuple
-
-    @property
-    def python_type(self):
-        return list
-
-    def compare_values(self, x, y):
-        return x == y
-
     @util.memoized_property
-    def _against_native_enum(self):
+    def _against_native_enum(self) -> bool:
         return (
             isinstance(self.item_type, sqltypes.Enum)
             and self.item_type.native_enum
         )
 
-    def literal_processor(self, dialect):
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         item_proc = self.item_type.dialect_impl(dialect).literal_processor(
             dialect
         )
         if item_proc is None:
             return None
 
-        def to_str(elements):
+        def to_str(elements: Iterable[typing_Any]) -> str:
             return f"ARRAY[{', '.join(elements)}]"
 
-        def process(value):
+        def process(value: Sequence[typing_Any]) -> str:
             inner = self._apply_item_processor(
                 value, item_proc, self.dimensions, to_str
             )
@@ -357,12 +386,16 @@ class ARRAY(sqltypes.ARRAY):
 
         return process
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(
             dialect
         )
 
-        def process(value):
+        def process(
+            value: Optional[Sequence[typing_Any]],
+        ) -> Optional[list[typing_Any]]:
             if value is None:
                 return value
             else:
@@ -372,12 +405,16 @@ class ARRAY(sqltypes.ARRAY):
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[Sequence[typing_Any]]:
         item_proc = self.item_type.dialect_impl(dialect).result_processor(
             dialect, coltype
         )
 
-        def process(value):
+        def process(
+            value: Sequence[typing_Any],
+        ) -> Optional[Sequence[typing_Any]]:
             if value is None:
                 return value
             else:
@@ -392,11 +429,13 @@ class ARRAY(sqltypes.ARRAY):
             super_rp = process
             pattern = re.compile(r"^{(.*)}$")
 
-            def handle_raw_string(value):
-                inner = pattern.match(value).group(1)
+            def handle_raw_string(value: str) -> list[str]:
+                inner = pattern.match(value).group(1)  # type: ignore[union-attr]  # noqa: E501
                 return _split_enum_values(inner)
 
-            def process(value):
+            def process(
+                value: Sequence[typing_Any],
+            ) -> Optional[Sequence[typing_Any]]:
                 if value is None:
                     return value
                 # isinstance(value, str) is required to handle
@@ -411,7 +450,7 @@ class ARRAY(sqltypes.ARRAY):
         return process
 
 
-def _split_enum_values(array_string):
+def _split_enum_values(array_string: str) -> list[str]:
     if '"' not in array_string:
         # no escape char is present so it can just split on the comma
         return array_string.split(",") if array_string else []
index 663be8b7a2b8830f45da58d2202b7f01953dd377..06f8db5b2af8673ec358bc0d17efdecf59df44b1 100644 (file)
@@ -337,7 +337,7 @@ class JSONB(JSON):
             .. versionadded:: 2.0
             """
             if not isinstance(array, _pg_array):
-                array = _pg_array(array)  # type: ignore[no-untyped-call]
+                array = _pg_array(array)
             right_side = cast(array, ARRAY(sqltypes.TEXT))
             return self.operate(DELETE_PATH, right_side, result_type=JSONB)
 
index e65cef65ab9fcc4d45a749c8b53d26072affee9a..9981e4a4fc1a00ba0071c187036d72fdb69e6636 100644 (file)
@@ -99,3 +99,21 @@ range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
 
 # EXPECTED_TYPE: Select[Range[int], Sequence[Range[int]]]
 reveal_type(range_col_stmt)
+
+array_from_ints = array(range(2))
+
+# EXPECTED_TYPE: array[int]
+reveal_type(array_from_ints)
+
+array_of_strings = array([], type_=Text)
+
+# EXPECTED_TYPE: array[str]
+reveal_type(array_of_strings)
+
+array_of_ints = array([0], type_=Integer)
+
+# EXPECTED_TYPE: array[int]
+reveal_type(array_of_ints)
+
+# EXPECTED_MYPY: Cannot infer type argument 1 of "array"
+array([0], type_=Text)