#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
from __future__ import annotations
import re
-from typing import Any
+from typing import Any as typing_Any
+from typing import Iterable
from typing import Optional
+from typing import Sequence
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from .operators import CONTAINED_BY
from .operators import CONTAINS
from ... import util
from ...sql import expression
from ...sql import operators
-from ...sql._typing import _TypeEngineArgument
-
-_T = TypeVar("_T", bound=Any)
-
-
-def Any(other, arrexpr, operator=operators.eq):
+if TYPE_CHECKING:
+ from ...engine.interfaces import Dialect
+ from ...sql._typing import _ColumnExpressionArgument
+ from ...sql._typing import _TypeEngineArgument
+ from ...sql.elements import ColumnElement
+ from ...sql.elements import Grouping
+ from ...sql.expression import BindParameter
+ from ...sql.operators import OperatorType
+ from ...sql.selectable import _SelectIterable
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _LiteralProcessorType
+ from ...sql.type_api import _ResultProcessorType
+ from ...sql.type_api import TypeEngine
+ from ...util.typing import Self
+
+
+_T = TypeVar("_T", bound=typing_Any)
+
+
+def Any(
+ other: typing_Any,
+ arrexpr: _ColumnExpressionArgument[_T],
+ operator: OperatorType = operators.eq,
+) -> ColumnElement[bool]:
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
See that method for details.
"""
- return arrexpr.any(other, operator)
+ return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
-def All(other, arrexpr, operator=operators.eq):
+def All(
+ other: typing_Any,
+ arrexpr: _ColumnExpressionArgument[_T],
+ operator: OperatorType = operators.eq,
+) -> ColumnElement[bool]:
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
See that method for details.
"""
- return arrexpr.all(other, operator)
+ return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
class array(expression.ExpressionClauseList[_T]):
stringify_dialect = "postgresql"
inherit_cache = True
- def __init__(self, clauses, **kw):
- type_arg = kw.pop("type_", None)
+ def __init__(
+ self,
+ clauses: Iterable[_T],
+ *,
+ type_: Optional[_TypeEngineArgument[_T]] = None,
+ **kw: typing_Any,
+ ):
super().__init__(operators.comma_op, *clauses, **kw)
- self._type_tuple = [arg.type for arg in self.clauses]
-
main_type = (
- type_arg
- if type_arg is not None
- else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
+ type_
+ if type_ is not None
+ else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
)
if isinstance(main_type, ARRAY):
if main_type.dimensions is not None
else 2
),
- )
+ ) # type: ignore[assignment]
else:
- self.type = ARRAY(main_type)
+ self.type = ARRAY(main_type) # type: ignore[assignment]
@property
- def _select_iterable(self):
+ def _select_iterable(self) -> _SelectIterable:
return (self,)
- def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
+ def _bind_param(
+ self,
+ operator: OperatorType,
+ obj: typing_Any,
+ type_: Optional[TypeEngine[_T]] = None,
+ _assume_scalar: bool = False,
+ ) -> BindParameter[_T]:
if _assume_scalar or operator is operators.getitem:
return expression.BindParameter(
None,
)
for o in obj
]
- )
+ ) # type: ignore[return-value]
- def self_group(self, against=None):
+ def self_group(
+ self, against: Optional[OperatorType] = None
+ ) -> Union[Self, Grouping[_T]]:
if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
def __init__(
self,
- item_type: _TypeEngineArgument[Any],
+ item_type: _TypeEngineArgument[typing_Any],
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
"""
- def contains(self, other, **kwargs):
+ def contains(
+ self, other: typing_Any, **kwargs: typing_Any
+ ) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
- def contained_by(self, other):
+ def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
- def overlap(self, other):
+ def overlap(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
comparator_factory = Comparator
- @property
- def hashable(self):
- return self.as_tuple
-
- @property
- def python_type(self):
- return list
-
- def compare_values(self, x, y):
- return x == y
-
@util.memoized_property
- def _against_native_enum(self):
+ def _against_native_enum(self) -> bool:
return (
isinstance(self.item_type, sqltypes.Enum)
and self.item_type.native_enum
)
- def literal_processor(self, dialect):
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> Optional[_LiteralProcessorType[_T]]:
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
dialect
)
if item_proc is None:
return None
- def to_str(elements):
+ def to_str(elements: Iterable[typing_Any]) -> str:
return f"ARRAY[{', '.join(elements)}]"
- def process(value):
+ def process(value: Sequence[typing_Any]) -> str:
inner = self._apply_item_processor(
value, item_proc, self.dimensions, to_str
)
return process
- def bind_processor(self, dialect):
+ def bind_processor(
+ self, dialect: Dialect
+ ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
)
- def process(value):
+ def process(
+ value: Optional[Sequence[typing_Any]],
+ ) -> Optional[list[typing_Any]]:
if value is None:
return value
else:
return process
- def result_processor(self, dialect, coltype):
+ def result_processor(
+ self, dialect: Dialect, coltype: object
+ ) -> _ResultProcessorType[Sequence[typing_Any]]:
item_proc = self.item_type.dialect_impl(dialect).result_processor(
dialect, coltype
)
- def process(value):
+ def process(
+ value: Sequence[typing_Any],
+ ) -> Optional[Sequence[typing_Any]]:
if value is None:
return value
else:
super_rp = process
pattern = re.compile(r"^{(.*)}$")
- def handle_raw_string(value):
- inner = pattern.match(value).group(1)
+ def handle_raw_string(value: str) -> list[str]:
+ inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501
return _split_enum_values(inner)
- def process(value):
+ def process(
+ value: Sequence[typing_Any],
+ ) -> Optional[Sequence[typing_Any]]:
if value is None:
return value
# isinstance(value, str) is required to handle
return process
-def _split_enum_values(array_string):
+def _split_enum_values(array_string: str) -> list[str]:
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []