]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to postgresql.ARRAY
authorDenis Laxalde <denis@laxalde.org>
Fri, 28 Feb 2025 12:01:58 +0000 (13:01 +0100)
committerDenis Laxalde <denis@laxalde.org>
Thu, 13 Mar 2025 20:03:20 +0000 (21:03 +0100)
lib/sqlalchemy/dialects/postgresql/array.py

index 2a72609fd01a82fe22998e8837d6630f3318da02..814586efe2b6ebb73476fbc0e0990502e02bfb02 100644 (file)
@@ -13,6 +13,8 @@ import re
 from typing import Any as typing_Any
 from typing import Iterable
 from typing import Optional
+from typing import Sequence
+from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
@@ -26,11 +28,16 @@ from ...sql import expression
 from ...sql import operators
 
 if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
     from ...sql._typing import _TypeEngineArgument
+    from ...sql.elements import ColumnElement
     from ...sql.elements import Grouping
     from ...sql.expression import BindParameter
     from ...sql.operators import OperatorType
     from ...sql.selectable import _SelectIterable
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+    from ...sql.type_api import _ResultProcessorType
     from ...sql.type_api import TypeEngine
     from ...util.typing import Self
 
@@ -320,7 +327,9 @@ class ARRAY(sqltypes.ARRAY):
 
         """
 
-        def contains(self, other, **kwargs):
+        def contains(
+            self, other: typing_Any, **kwargs: typing_Any
+        ) -> ColumnElement[bool]:
             """Boolean expression.  Test if elements are a superset of the
             elements of the argument array expression.
 
@@ -329,7 +338,7 @@ class ARRAY(sqltypes.ARRAY):
             """
             return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
 
-        def contained_by(self, other):
+        def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if elements are a proper subset of the
             elements of the argument array expression.
             """
@@ -337,7 +346,7 @@ class ARRAY(sqltypes.ARRAY):
                 CONTAINED_BY, other, result_type=sqltypes.Boolean
             )
 
-        def overlap(self, other):
+        def overlap(self, other: typing_Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if array has elements in common with
             an argument array expression.
             """
@@ -346,34 +355,36 @@ class ARRAY(sqltypes.ARRAY):
     comparator_factory = Comparator
 
     @property
-    def hashable(self):
+    def hashable(self) -> bool:  # type: ignore[override]
         return self.as_tuple
 
     @property
-    def python_type(self):
+    def python_type(self) -> Type[typing_Any]:
         return list
 
-    def compare_values(self, x, y):
-        return x == y
+    def compare_values(self, x: typing_Any, y: typing_Any) -> bool:
+        return x == y  # type: ignore[no-any-return]
 
     @util.memoized_property
-    def _against_native_enum(self):
+    def _against_native_enum(self) -> bool:
         return (
             isinstance(self.item_type, sqltypes.Enum)
             and self.item_type.native_enum
         )
 
-    def literal_processor(self, dialect):
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         item_proc = self.item_type.dialect_impl(dialect).literal_processor(
             dialect
         )
         if item_proc is None:
             return None
 
-        def to_str(elements):
+        def to_str(elements: Iterable[typing_Any]) -> str:
             return f"ARRAY[{', '.join(elements)}]"
 
-        def process(value):
+        def process(value: Sequence[typing_Any]) -> str:
             inner = self._apply_item_processor(
                 value, item_proc, self.dimensions, to_str
             )
@@ -381,12 +392,16 @@ class ARRAY(sqltypes.ARRAY):
 
         return process
 
-    def bind_processor(self, dialect):
+    def bind_processor(
+        self, dialect: Dialect
+    ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
         item_proc = self.item_type.dialect_impl(dialect).bind_processor(
             dialect
         )
 
-        def process(value):
+        def process(
+            value: Optional[Sequence[typing_Any]],
+        ) -> Optional[list[typing_Any]]:
             if value is None:
                 return value
             else:
@@ -396,12 +411,16 @@ class ARRAY(sqltypes.ARRAY):
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(
+        self, dialect: Dialect, coltype: object
+    ) -> _ResultProcessorType[Sequence[typing_Any]]:
         item_proc = self.item_type.dialect_impl(dialect).result_processor(
             dialect, coltype
         )
 
-        def process(value):
+        def process(
+            value: Sequence[typing_Any],
+        ) -> Optional[Sequence[typing_Any]]:
             if value is None:
                 return value
             else:
@@ -416,11 +435,13 @@ class ARRAY(sqltypes.ARRAY):
             super_rp = process
             pattern = re.compile(r"^{(.*)}$")
 
-            def handle_raw_string(value):
-                inner = pattern.match(value).group(1)
+            def handle_raw_string(value: str) -> list[str]:
+                inner = pattern.match(value).group(1)  # type: ignore[union-attr]  # noqa: E501
                 return _split_enum_values(inner)
 
-            def process(value):
+            def process(
+                value: Sequence[typing_Any],
+            ) -> Optional[Sequence[typing_Any]]:
                 if value is None:
                     return value
                 # isinstance(value, str) is required to handle
@@ -435,7 +456,7 @@ class ARRAY(sqltypes.ARRAY):
         return process
 
 
-def _split_enum_values(array_string):
+def _split_enum_values(array_string: str) -> list[str]:
     if '"' not in array_string:
         # no escape char is present so it can just split on the comma
         return array_string.split(",") if array_string else []