]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Complement type annotations for ARRAY
authorDenis Laxalde <denis@laxalde.org>
Wed, 5 Mar 2025 20:59:39 +0000 (15:59 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Wed, 5 Mar 2025 20:59:39 +0000 (15:59 -0500)
### Description

This complements the type annotations of the `ARRAY` class, in preparation of #12384.

### Checklist

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Related to https://github.com/sqlalchemy/sqlalchemy/issues/6810

Closes: #12386
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12386
Pull-request-sha: c9513ce729fa1116b46b02336d4e2cda3d096fee

Change-Id: If9df4708c8e597eedc79ee3990792fa6c72f1afe

lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py

index bd92f6aa854e0ae8c64174c26a5e3174c6f67243..520e4af8662e7fca84a9356cb56654532984013a 100644 (file)
@@ -3798,7 +3798,9 @@ class CollectionAggregate(UnaryExpression[_T]):
     # operate and reverse_operate are hardwired to
     # dispatch onto the type comparator directly, so that we can
     # ensure "reversed" behavior.
-    def operate(self, op, *other, **kwargs):
+    def operate(
+        self, op: OperatorType, *other: Any, **kwargs: Any
+    ) -> ColumnElement[_T]:
         if not operators.is_comparison(op):
             raise exc.ArgumentError(
                 "Only comparison operators may be used with ANY/ALL"
@@ -3806,7 +3808,9 @@ class CollectionAggregate(UnaryExpression[_T]):
         kwargs["reverse"] = True
         return self.comparator.operate(operators.mirror(op), *other, **kwargs)
 
-    def reverse_operate(self, op, other, **kwargs):
+    def reverse_operate(
+        self, op: OperatorType, other: Any, **kwargs: Any
+    ) -> ColumnElement[_T]:
         # comparison operators should never call reverse_operate
         assert not operators.is_comparison(op)
         raise exc.ArgumentError(
index ec382c2f14747e237037f3a78fd9144cf108fc7f..7a40c7ef6f38506b3f0f5097cf1c66e8bdbee0e4 100644 (file)
@@ -22,6 +22,7 @@ from typing import Callable
 from typing import cast
 from typing import Dict
 from typing import Generic
+from typing import Iterable
 from typing import List
 from typing import Optional
 from typing import overload
@@ -69,10 +70,12 @@ from ..util.typing import TupleAny
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
     from ._typing import _TypeEngineArgument
+    from .elements import ColumnElement
     from .operators import OperatorType
     from .schema import MetaData
     from .type_api import _BindProcessorType
     from .type_api import _ComparatorFactory
+    from .type_api import _LiteralProcessorType
     from .type_api import _MatchedOnType
     from .type_api import _ResultProcessorType
     from ..engine.interfaces import Dialect
@@ -80,6 +83,7 @@ if TYPE_CHECKING:
 _T = TypeVar("_T", bound="Any")
 _CT = TypeVar("_CT", bound=Any)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
+_P = TypeVar("_P")
 
 
 class HasExpressionLookup(TypeEngineMixin):
@@ -2987,7 +2991,20 @@ class ARRAY(
 
         type: ARRAY
 
-        def _setup_getitem(self, index):
+        @overload
+        def _setup_getitem(
+            self, index: int
+        ) -> Tuple[OperatorType, int, TypeEngine[Any]]: ...
+
+        @overload
+        def _setup_getitem(
+            self, index: slice
+        ) -> Tuple[OperatorType, Slice, TypeEngine[Any]]: ...
+
+        def _setup_getitem(self, index: Union[int, slice]) -> Union[
+            Tuple[OperatorType, int, TypeEngine[Any]],
+            Tuple[OperatorType, Slice, TypeEngine[Any]],
+        ]:
             arr_type = self.type
 
             return_type: TypeEngine[Any]
@@ -3013,7 +3030,7 @@ class ARRAY(
 
                 return operators.getitem, index, return_type
 
-        def contains(self, *arg, **kw):
+        def contains(self, *arg: Any, **kw: Any) -> ColumnElement[bool]:
             """``ARRAY.contains()`` not implemented for the base ARRAY type.
             Use the dialect-specific ARRAY type.
 
@@ -3027,7 +3044,9 @@ class ARRAY(
             )
 
         @util.preload_module("sqlalchemy.sql.elements")
-        def any(self, other, operator=None):
+        def any(
+            self, other: Any, operator: Optional[OperatorType] = None
+        ) -> ColumnElement[bool]:
             """Return ``other operator ANY (array)`` clause.
 
             .. legacy:: This method is an :class:`_types.ARRAY` - specific
@@ -3074,7 +3093,9 @@ class ARRAY(
             )
 
         @util.preload_module("sqlalchemy.sql.elements")
-        def all(self, other, operator=None):
+        def all(
+            self, other: Any, operator: Optional[OperatorType] = None
+        ) -> ColumnElement[bool]:
             """Return ``other operator ALL (array)`` clause.
 
             .. legacy:: This method is an :class:`_types.ARRAY` - specific
@@ -3123,23 +3144,27 @@ class ARRAY(
     comparator_factory = Comparator
 
     @property
-    def hashable(self):
+    def hashable(self) -> bool:  # type: ignore[override]
         return self.as_tuple
 
     @property
-    def python_type(self):
+    def python_type(self) -> Type[Any]:
         return list
 
-    def compare_values(self, x, y):
-        return x == y
+    def compare_values(self, x: Any, y: Any) -> bool:
+        return x == y  # type: ignore[no-any-return]
 
-    def _set_parent(self, parent, outer=False, **kw):
+    def _set_parent(
+        self, parent: SchemaEventTarget, outer: bool = False, **kw: Any
+    ) -> None:
         """Support SchemaEventTarget"""
 
         if not outer and isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent(parent, **kw)
 
-    def _set_parent_with_dispatch(self, parent, **kw):
+    def _set_parent_with_dispatch(
+        self, parent: SchemaEventTarget, **kw: Any
+    ) -> None:
         """Support SchemaEventTarget"""
 
         super()._set_parent_with_dispatch(parent, outer=True)
@@ -3147,17 +3172,19 @@ class ARRAY(
         if isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent_with_dispatch(parent)
 
-    def literal_processor(self, dialect):
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> Optional[_LiteralProcessorType[_T]]:
         item_proc = self.item_type.dialect_impl(dialect).literal_processor(
             dialect
         )
         if item_proc is None:
             return None
 
-        def to_str(elements):
+        def to_str(elements: Iterable[Any]) -> str:
             return f"[{', '.join(elements)}]"
 
-        def process(value):
+        def process(value: Sequence[Any]) -> str:
             inner = self._apply_item_processor(
                 value, item_proc, self.dimensions, to_str
             )
@@ -3165,7 +3192,13 @@ class ARRAY(
 
         return process
 
-    def _apply_item_processor(self, arr, itemproc, dim, collection_callable):
+    def _apply_item_processor(
+        self,
+        arr: Sequence[Any],
+        itemproc: Optional[Callable[[Any], Any]],
+        dim: Optional[int],
+        collection_callable: Callable[[Iterable[Any]], _P],
+    ) -> _P:
         """Helper method that can be used by bind_processor(),
         literal_processor(), etc. to apply an item processor to elements of
         an array value, taking into account the 'dimensions' for this