From 71cab3ce9566975093111007481b0508b3e60956 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Fri, 28 Feb 2025 13:01:58 +0100 Subject: [PATCH] Add type annotations to postgresql.ARRAY --- lib/sqlalchemy/dialects/postgresql/array.py | 59 ++++++++++++++------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 2a72609fd0..814586efe2 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -13,6 +13,8 @@ import re from typing import Any as typing_Any from typing import Iterable from typing import Optional +from typing import Sequence +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -26,11 +28,16 @@ from ...sql import expression from ...sql import operators if TYPE_CHECKING: + from ...engine.interfaces import Dialect from ...sql._typing import _TypeEngineArgument + from ...sql.elements import ColumnElement from ...sql.elements import Grouping from ...sql.expression import BindParameter from ...sql.operators import OperatorType from ...sql.selectable import _SelectIterable + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import _ResultProcessorType from ...sql.type_api import TypeEngine from ...util.typing import Self @@ -320,7 +327,9 @@ class ARRAY(sqltypes.ARRAY): """ - def contains(self, other, **kwargs): + def contains( + self, other: typing_Any, **kwargs: typing_Any + ) -> ColumnElement[bool]: """Boolean expression. Test if elements are a superset of the elements of the argument array expression. @@ -329,7 +338,7 @@ class ARRAY(sqltypes.ARRAY): """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other): + def contained_by(self, other: typing_Any) -> ColumnElement[bool]: """Boolean expression. Test if elements are a proper subset of the elements of the argument array expression. """ @@ -337,7 +346,7 @@ class ARRAY(sqltypes.ARRAY): CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def overlap(self, other): + def overlap(self, other: typing_Any) -> ColumnElement[bool]: """Boolean expression. Test if array has elements in common with an argument array expression. """ @@ -346,34 +355,36 @@ class ARRAY(sqltypes.ARRAY): comparator_factory = Comparator @property - def hashable(self): + def hashable(self) -> bool: # type: ignore[override] return self.as_tuple @property - def python_type(self): + def python_type(self) -> Type[typing_Any]: return list - def compare_values(self, x, y): - return x == y + def compare_values(self, x: typing_Any, y: typing_Any) -> bool: + return x == y # type: ignore[no-any-return] @util.memoized_property - def _against_native_enum(self): + def _against_native_enum(self) -> bool: return ( isinstance(self.item_type, sqltypes.Enum) and self.item_type.native_enum ) - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: item_proc = self.item_type.dialect_impl(dialect).literal_processor( dialect ) if item_proc is None: return None - def to_str(elements): + def to_str(elements: Iterable[typing_Any]) -> str: return f"ARRAY[{', '.join(elements)}]" - def process(value): + def process(value: Sequence[typing_Any]) -> str: inner = self._apply_item_processor( value, item_proc, self.dimensions, to_str ) @@ -381,12 +392,16 @@ class ARRAY(sqltypes.ARRAY): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]: item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect ) - def process(value): + def process( + value: Optional[Sequence[typing_Any]], + ) -> Optional[list[typing_Any]]: if value is None: return value else: @@ -396,12 +411,16 @@ class ARRAY(sqltypes.ARRAY): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[Sequence[typing_Any]]: item_proc = self.item_type.dialect_impl(dialect).result_processor( dialect, coltype ) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value else: @@ -416,11 +435,13 @@ class ARRAY(sqltypes.ARRAY): super_rp = process pattern = re.compile(r"^{(.*)}$") - def handle_raw_string(value): - inner = pattern.match(value).group(1) + def handle_raw_string(value: str) -> list[str]: + inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501 return _split_enum_values(inner) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value # isinstance(value, str) is required to handle @@ -435,7 +456,7 @@ class ARRAY(sqltypes.ARRAY): return process -def _split_enum_values(array_string): +def _split_enum_values(array_string: str) -> list[str]: if '"' not in array_string: # no escape char is present so it can just split on the comma return array_string.split(",") if array_string else [] -- 2.47.3