From da5e2c2512c00ae10bea6b1ccb39d0ccf0e7e475 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Fri, 28 Feb 2025 13:01:54 +0100 Subject: [PATCH] Add type annotations for postgresql.array() The type argument of array is inferred from the 'clauses' argument of the constructor; hence array([1, 2]) has type array[int]. We explicitly define the 'type_' parameter of __init__() as it helps type inference when using this argument, e.g. array([], type_=CHAR) is inferred as array[str]. Consistency between 'clauses' and 'type_' is also ensured. --- lib/sqlalchemy/dialects/postgresql/array.py | 46 ++++++++++++++----- lib/sqlalchemy/dialects/postgresql/json.py | 2 +- .../dialects/postgresql/pg_stuff.py | 18 ++++++++ 3 files changed, 54 insertions(+), 12 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 3861995ce0..2a72609fd0 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -11,8 +11,11 @@ from __future__ import annotations import re from typing import Any as typing_Any +from typing import Iterable from typing import Optional +from typing import TYPE_CHECKING from typing import TypeVar +from typing import Union from .operators import CONTAINED_BY from .operators import CONTAINS @@ -21,7 +24,15 @@ from ... import types as sqltypes from ... import util from ...sql import expression from ...sql import operators -from ...sql._typing import _TypeEngineArgument + +if TYPE_CHECKING: + from ...sql._typing import _TypeEngineArgument + from ...sql.elements import Grouping + from ...sql.expression import BindParameter + from ...sql.operators import OperatorType + from ...sql.selectable import _SelectIterable + from ...sql.type_api import TypeEngine + from ...util.typing import Self _T = TypeVar("_T", bound=typing_Any) @@ -107,15 +118,20 @@ class array(expression.ExpressionClauseList[_T]): stringify_dialect = "postgresql" inherit_cache = True - def __init__(self, clauses, **kw): - type_arg = kw.pop("type_", None) + def __init__( + self, + clauses: Iterable[_T], + *, + type_: Optional[_TypeEngineArgument[_T]] = None, + **kw: typing_Any, + ): super().__init__(operators.comma_op, *clauses, **kw) self._type_tuple = [arg.type for arg in self.clauses] main_type = ( - type_arg - if type_arg is not None + type_ + if type_ is not None else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE ) @@ -127,15 +143,21 @@ class array(expression.ExpressionClauseList[_T]): if main_type.dimensions is not None else 2 ), - ) + ) # type: ignore[assignment] else: - self.type = ARRAY(main_type) + self.type = ARRAY(main_type) # type: ignore[assignment] @property - def _select_iterable(self): + def _select_iterable(self) -> _SelectIterable: return (self,) - def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + def _bind_param( + self, + operator: OperatorType, + obj: typing_Any, + type_: Optional[TypeEngine[_T]] = None, + _assume_scalar: bool = False, + ) -> BindParameter[_T]: if _assume_scalar or operator is operators.getitem: return expression.BindParameter( None, @@ -154,9 +176,11 @@ class array(expression.ExpressionClauseList[_T]): ) for o in obj ] - ) + ) # type: ignore[return-value] - def self_group(self, against=None): + def self_group( + self, against: Optional[OperatorType] = None + ) -> Union[Self, Grouping[_T]]: if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 663be8b7a2..06f8db5b2a 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -337,7 +337,7 @@ class JSONB(JSON): .. versionadded:: 2.0 """ if not isinstance(array, _pg_array): - array = _pg_array(array) # type: ignore[no-untyped-call] + array = _pg_array(array) right_side = cast(array, ARRAY(sqltypes.TEXT)) return self.operate(DELETE_PATH, right_side, result_type=JSONB) diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index e65cef65ab..9981e4a4fc 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -99,3 +99,21 @@ range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) # EXPECTED_TYPE: Select[Range[int], Sequence[Range[int]]] reveal_type(range_col_stmt) + +array_from_ints = array(range(2)) + +# EXPECTED_TYPE: array[int] +reveal_type(array_from_ints) + +array_of_strings = array([], type_=Text) + +# EXPECTED_TYPE: array[str] +reveal_type(array_of_strings) + +array_of_ints = array([0], type_=Integer) + +# EXPECTED_TYPE: array[int] +reveal_type(array_of_ints) + +# EXPECTED_MYPY: Cannot infer type argument 1 of "array" +array([0], type_=Text) -- 2.47.3