]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to `postgresql.json`
authorDenis Laxalde <denis@laxalde.org>
Tue, 4 Mar 2025 20:28:47 +0000 (15:28 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 5 Mar 2025 19:45:09 +0000 (20:45 +0100)
(Same as https://github.com/sqlalchemy/sqlalchemy/pull/12384, but for `json`.)

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Related to #6810

**Have a nice day!**

Closes: #12391
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12391
Pull-request-sha: 0a43724f1737a4519629a13e2d6bf33f7aecb9ac

Change-Id: I2a0e88effccf351de7fa72389ee646532ce9cf69
(cherry picked from commit c7f4e8b9370487135777677eaf4d8992825c24aa)

lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py

index 2f26b39e31e5aa6ae92bfdcc830a98839eec773f..663be8b7a2b8830f45da58d2202b7f01953dd377 100644 (file)
@@ -4,8 +4,15 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .array import ARRAY
 from .array import array as _pg_array
@@ -21,13 +28,23 @@ from .operators import PATH_EXISTS
 from .operators import PATH_MATCH
 from ... import types as sqltypes
 from ...sql import cast
+from ...sql._typing import _T
+
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.elements import ColumnElement
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+    from ...sql.type_api import TypeEngine
 
 __all__ = ("JSON", "JSONB")
 
 
 class JSONPathType(sqltypes.JSON.JSONPathType):
-    def _processor(self, dialect, super_proc):
-        def process(value):
+    def _processor(
+        self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
+    ) -> Callable[[Any], Any]:
+        def process(value: Any) -> Any:
             if isinstance(value, str):
                 # If it's already a string assume that it's in json path
                 # format. This allows using cast with json paths literals
@@ -44,11 +61,13 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
 
         return process
 
-    def bind_processor(self, dialect):
-        return self._processor(dialect, self.string_bind_processor(dialect))
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        return self._processor(dialect, self.string_bind_processor(dialect))  # type: ignore[return-value]  # noqa: E501
 
-    def literal_processor(self, dialect):
-        return self._processor(dialect, self.string_literal_processor(dialect))
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        return self._processor(dialect, self.string_literal_processor(dialect))  # type: ignore[return-value]  # noqa: E501
 
 
 class JSONPATH(JSONPathType):
@@ -148,9 +167,13 @@ class JSON(sqltypes.JSON):
     """  # noqa
 
     render_bind_cast = True
-    astext_type = sqltypes.Text()
+    astext_type: TypeEngine[str] = sqltypes.Text()
 
-    def __init__(self, none_as_null=False, astext_type=None):
+    def __init__(
+        self,
+        none_as_null: bool = False,
+        astext_type: Optional[TypeEngine[str]] = None,
+    ):
         """Construct a :class:`_types.JSON` type.
 
         :param none_as_null: if True, persist the value ``None`` as a
@@ -175,11 +198,13 @@ class JSON(sqltypes.JSON):
         if astext_type is not None:
             self.astext_type = astext_type
 
-    class Comparator(sqltypes.JSON.Comparator):
+    class Comparator(sqltypes.JSON.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
+        type: JSON
+
         @property
-        def astext(self):
+        def astext(self) -> ColumnElement[str]:
             """On an indexed expression, use the "astext" (e.g. "->>")
             conversion when rendered in SQL.
 
@@ -193,13 +218,13 @@ class JSON(sqltypes.JSON):
 
             """
             if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
-                return self.expr.left.operate(
+                return self.expr.left.operate(  # type: ignore[no-any-return]
                     JSONPATH_ASTEXT,
                     self.expr.right,
                     result_type=self.type.astext_type,
                 )
             else:
-                return self.expr.left.operate(
+                return self.expr.left.operate(  # type: ignore[no-any-return]
                     ASTEXT, self.expr.right, result_type=self.type.astext_type
                 )
 
@@ -258,28 +283,30 @@ class JSONB(JSON):
 
     __visit_name__ = "JSONB"
 
-    class Comparator(JSON.Comparator):
+    class Comparator(JSON.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
-        def has_key(self, other):
+        type: JSONB
+
+        def has_key(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of a key (equivalent of
             the ``?`` operator).  Note that the key may be a SQLA expression.
             """
             return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
 
-        def has_all(self, other):
+        def has_all(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of all keys in jsonb
             (equivalent of the ``?&`` operator)
             """
             return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
 
-        def has_any(self, other):
+        def has_any(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of any key in jsonb
             (equivalent of the ``?|`` operator)
             """
             return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
 
-        def contains(self, other, **kwargs):
+        def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if keys (or array) are a superset
             of/contained the keys of the argument jsonb expression
             (equivalent of the ``@>`` operator).
@@ -289,7 +316,7 @@ class JSONB(JSON):
             """
             return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
 
-        def contained_by(self, other):
+        def contained_by(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if keys are a proper subset of the
             keys of the argument jsonb expression
             (equivalent of the ``<@`` operator).
@@ -298,7 +325,9 @@ class JSONB(JSON):
                 CONTAINED_BY, other, result_type=sqltypes.Boolean
             )
 
-        def delete_path(self, array):
+        def delete_path(
+            self, array: Union[List[str], _pg_array[str]]
+        ) -> ColumnElement[JSONB]:
             """JSONB expression. Deletes field or array element specified in
             the argument array (equivalent of the ``#-`` operator).
 
@@ -308,11 +337,11 @@ class JSONB(JSON):
             .. versionadded:: 2.0
             """
             if not isinstance(array, _pg_array):
-                array = _pg_array(array)
+                array = _pg_array(array)  # type: ignore[no-untyped-call]
             right_side = cast(array, ARRAY(sqltypes.TEXT))
             return self.operate(DELETE_PATH, right_side, result_type=JSONB)
 
-        def path_exists(self, other):
+        def path_exists(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Test for presence of item given by the
             argument JSONPath expression (equivalent of the ``@?`` operator).
 
@@ -322,7 +351,7 @@ class JSONB(JSON):
                 PATH_EXISTS, other, result_type=sqltypes.Boolean
             )
 
-        def path_match(self, other):
+        def path_match(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Test if JSONPath predicate given by the
             argument JSONPath expression matches
             (equivalent of the ``@@`` operator).
index ee471a6c4ecdfc9f1da1b8a4d12240bec01aa889..ad220356f046e17167088c321fcc7356e4b5275c 100644 (file)
@@ -72,6 +72,7 @@ if TYPE_CHECKING:
     from .schema import MetaData
     from .type_api import _BindProcessorType
     from .type_api import _ComparatorFactory
+    from .type_api import _LiteralProcessorType
     from .type_api import _MatchedOnType
     from .type_api import _ResultProcessorType
     from ..engine.interfaces import Dialect
@@ -2465,17 +2466,21 @@ class JSON(Indexable, TypeEngine[Any]):
         _integer = Integer()
         _string = String()
 
-        def string_bind_processor(self, dialect):
+        def string_bind_processor(
+            self, dialect: Dialect
+        ) -> Optional[_BindProcessorType[str]]:
             return self._string._cached_bind_processor(dialect)
 
-        def string_literal_processor(self, dialect):
+        def string_literal_processor(
+            self, dialect: Dialect
+        ) -> Optional[_LiteralProcessorType[str]]:
             return self._string._cached_literal_processor(dialect)
 
-        def bind_processor(self, dialect):
+        def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
             int_processor = self._integer._cached_bind_processor(dialect)
             string_processor = self.string_bind_processor(dialect)
 
-            def process(value):
+            def process(value: Optional[Any]) -> Any:
                 if int_processor and isinstance(value, int):
                     value = int_processor(value)
                 elif string_processor and isinstance(value, str):
@@ -2484,11 +2489,13 @@ class JSON(Indexable, TypeEngine[Any]):
 
             return process
 
-        def literal_processor(self, dialect):
+        def literal_processor(
+            self, dialect: Dialect
+        ) -> _LiteralProcessorType[Any]:
             int_processor = self._integer._cached_literal_processor(dialect)
             string_processor = self.string_literal_processor(dialect)
 
-            def process(value):
+            def process(value: Optional[Any]) -> Any:
                 if int_processor and isinstance(value, int):
                     value = int_processor(value)
                 elif string_processor and isinstance(value, str):
@@ -2539,6 +2546,8 @@ class JSON(Indexable, TypeEngine[Any]):
 
         __slots__ = ()
 
+        type: JSON
+
         def _setup_getitem(self, index):
             if not isinstance(index, str) and isinstance(
                 index, collections_abc.Sequence
index aeb804d3f9b6992d83a25177adba43c64254f097..8cdb323b2a62c70f21e4f1343a1ced0d281637e5 100644 (file)
@@ -67,6 +67,7 @@ _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
 _O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
+_RT = TypeVar("_RT", bound=Any)
 
 _MatchedOnType = Union[
     "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
@@ -186,10 +187,24 @@ class TypeEngine(Visitable, Generic[_T]):
         def __reduce__(self) -> Any:
             return self.__class__, (self.expr,)
 
+        @overload
+        def operate(
+            self,
+            op: OperatorType,
+            *other: Any,
+            result_type: Type[TypeEngine[_RT]],
+            **kwargs: Any,
+        ) -> ColumnElement[_RT]: ...
+
+        @overload
+        def operate(
+            self, op: OperatorType, *other: Any, **kwargs: Any
+        ) -> ColumnElement[_CT]: ...
+
         @util.preload_module("sqlalchemy.sql.default_comparator")
         def operate(
             self, op: OperatorType, *other: Any, **kwargs: Any
-        ) -> ColumnElement[_CT]:
+        ) -> ColumnElement[Any]:
             default_comparator = util.preloaded.sql_default_comparator
             op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__]
             if kwargs: