From: Denis Laxalde Date: Fri, 14 Mar 2025 21:01:50 +0000 (-0400) Subject: Add type annotations to `postgresql.array` X-Git-Tag: rel_2_0_40~18^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6883304451acfa020c47e37d942c2c8023f2dbbc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type annotations to `postgresql.array` Improved static typing for `postgresql.array()` by making the type parameter (the type of array's elements) inferred from the `clauses` and `type_` arguments while also ensuring they are consistent. Also completed type annotations of `postgresql.ARRAY` following commit 0bf7e02afbec557eb3a5607db407f27deb7aac77 and added type annotations for functions `postgresql.Any()` and `postgresql.All()`. Finally, fixed shadowing `typing.Any` by the `Any()` function through aliasing as `typing_Any`. Related to #6810 Closes: #12384 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12384 Pull-request-sha: 78eea29f1de850afda036502974521969629de7e Change-Id: I5d35d15ec8ba4d58eeb9bf00abb710e2e585731f (cherry picked from commit 75c8e112c9362f89787d8fc25a6a200700052450) --- diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 7708769cb5..8cbe0c48cf 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -4,15 +4,18 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors from __future__ import annotations import re -from typing import Any +from typing import Any as typing_Any +from typing import Iterable from typing import Optional +from typing import Sequence +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .operators import CONTAINED_BY from .operators import CONTAINS @@ -21,28 +24,50 @@ from ... import types as sqltypes from ... import util from ...sql import expression from ...sql import operators -from ...sql._typing import _TypeEngineArgument - -_T = TypeVar("_T", bound=Any) - - -def Any(other, arrexpr, operator=operators.eq): +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql._typing import _ColumnExpressionArgument + from ...sql._typing import _TypeEngineArgument + from ...sql.elements import ColumnElement + from ...sql.elements import Grouping + from ...sql.expression import BindParameter + from ...sql.operators import OperatorType + from ...sql.selectable import _SelectIterable + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import _ResultProcessorType + from ...sql.type_api import TypeEngine + from ...util.typing import Self + + +_T = TypeVar("_T", bound=typing_Any) + + +def Any( + other: typing_Any, + arrexpr: _ColumnExpressionArgument[_T], + operator: OperatorType = operators.eq, +) -> ColumnElement[bool]: """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. See that method for details. """ - return arrexpr.any(other, operator) + return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 -def All(other, arrexpr, operator=operators.eq): +def All( + other: typing_Any, + arrexpr: _ColumnExpressionArgument[_T], + operator: OperatorType = operators.eq, +) -> ColumnElement[bool]: """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. See that method for details. """ - return arrexpr.all(other, operator) + return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 class array(expression.ExpressionClauseList[_T]): @@ -107,16 +132,19 @@ class array(expression.ExpressionClauseList[_T]): stringify_dialect = "postgresql" inherit_cache = True - def __init__(self, clauses, **kw): - type_arg = kw.pop("type_", None) + def __init__( + self, + clauses: Iterable[_T], + *, + type_: Optional[_TypeEngineArgument[_T]] = None, + **kw: typing_Any, + ): super().__init__(operators.comma_op, *clauses, **kw) - self._type_tuple = [arg.type for arg in self.clauses] - main_type = ( - type_arg - if type_arg is not None - else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE + type_ + if type_ is not None + else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE ) if isinstance(main_type, ARRAY): @@ -127,15 +155,21 @@ class array(expression.ExpressionClauseList[_T]): if main_type.dimensions is not None else 2 ), - ) + ) # type: ignore[assignment] else: - self.type = ARRAY(main_type) + self.type = ARRAY(main_type) # type: ignore[assignment] @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) - def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + def _bind_param( + self, + operator: OperatorType, + obj: typing_Any, + type_: Optional[TypeEngine[_T]] = None, + _assume_scalar: bool = False, + ) -> BindParameter[_T]: if _assume_scalar or operator is operators.getitem: return expression.BindParameter( None, @@ -154,9 +188,11 @@ class array(expression.ExpressionClauseList[_T]): ) for o in obj ] - ) + ) # type: ignore[return-value] - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: @@ -237,7 +273,7 @@ class ARRAY(sqltypes.ARRAY): def __init__( self, - item_type: _TypeEngineArgument[Any], + item_type: _TypeEngineArgument[typing_Any], as_tuple: bool = False, dimensions: Optional[int] = None, zero_indexes: bool = False, @@ -296,7 +332,9 @@ class ARRAY(sqltypes.ARRAY): """ - def contains(self, other, **kwargs): + def contains( + self, other: typing_Any, **kwargs: typing_Any + ) -> ColumnElement[bool]: """Boolean expression. Test if elements are a superset of the elements of the argument array expression. @@ -305,7 +343,7 @@ class ARRAY(sqltypes.ARRAY): """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other): + def contained_by(self, other: typing_Any) -> ColumnElement[bool]: """Boolean expression. Test if elements are a proper subset of the elements of the argument array expression. """ @@ -313,7 +351,7 @@ class ARRAY(sqltypes.ARRAY): CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def overlap(self, other): + def overlap(self, other: typing_Any) -> ColumnElement[bool]: """Boolean expression. Test if array has elements in common with an argument array expression. """ @@ -321,35 +359,26 @@ class ARRAY(sqltypes.ARRAY): comparator_factory = Comparator - @property - def hashable(self): - return self.as_tuple - - @property - def python_type(self): - return list - - def compare_values(self, x, y): - return x == y - @util.memoized_property - def _against_native_enum(self): + def _against_native_enum(self) -> bool: return ( isinstance(self.item_type, sqltypes.Enum) and self.item_type.native_enum ) - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[_T]]: item_proc = self.item_type.dialect_impl(dialect).literal_processor( dialect ) if item_proc is None: return None - def to_str(elements): + def to_str(elements: Iterable[typing_Any]) -> str: return f"ARRAY[{', '.join(elements)}]" - def process(value): + def process(value: Sequence[typing_Any]) -> str: inner = self._apply_item_processor( value, item_proc, self.dimensions, to_str ) @@ -357,12 +386,16 @@ class ARRAY(sqltypes.ARRAY): return process - def bind_processor(self, dialect): + def bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]: item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect ) - def process(value): + def process( + value: Optional[Sequence[typing_Any]], + ) -> Optional[list[typing_Any]]: if value is None: return value else: @@ -372,12 +405,16 @@ class ARRAY(sqltypes.ARRAY): return process - def result_processor(self, dialect, coltype): + def result_processor( + self, dialect: Dialect, coltype: object + ) -> _ResultProcessorType[Sequence[typing_Any]]: item_proc = self.item_type.dialect_impl(dialect).result_processor( dialect, coltype ) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value else: @@ -392,11 +429,13 @@ class ARRAY(sqltypes.ARRAY): super_rp = process pattern = re.compile(r"^{(.*)}$") - def handle_raw_string(value): - inner = pattern.match(value).group(1) + def handle_raw_string(value: str) -> list[str]: + inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501 return _split_enum_values(inner) - def process(value): + def process( + value: Sequence[typing_Any], + ) -> Optional[Sequence[typing_Any]]: if value is None: return value # isinstance(value, str) is required to handle @@ -411,7 +450,7 @@ class ARRAY(sqltypes.ARRAY): return process -def _split_enum_values(array_string): +def _split_enum_values(array_string: str) -> list[str]: if '"' not in array_string: # no escape char is present so it can just split on the comma return array_string.split(",") if array_string else [] diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 663be8b7a2..06f8db5b2a 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -337,7 +337,7 @@ class JSONB(JSON): .. versionadded:: 2.0 """ if not isinstance(array, _pg_array): - array = _pg_array(array) # type: ignore[no-untyped-call] + array = _pg_array(array) right_side = cast(array, ARRAY(sqltypes.TEXT)) return self.operate(DELETE_PATH, right_side, result_type=JSONB) diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 5e56efba98..45ec981bba 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -99,3 +99,21 @@ range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) # EXPECTED_TYPE: Select[Tuple[Range[int], Sequence[Range[int]]]] reveal_type(range_col_stmt) + +array_from_ints = array(range(2)) + +# EXPECTED_TYPE: array[int] +reveal_type(array_from_ints) + +array_of_strings = array([], type_=Text) + +# EXPECTED_TYPE: array[str] +reveal_type(array_of_strings) + +array_of_ints = array([0], type_=Integer) + +# EXPECTED_TYPE: array[int] +reveal_type(array_of_ints) + +# EXPECTED_MYPY: Cannot infer type argument 1 of "array" +array([0], type_=Text)