]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add type annotations to postgresql.json module 12391/head
authorDenis Laxalde <denis@laxalde.org>
Mon, 3 Mar 2025 07:22:54 +0000 (08:22 +0100)
committerDenis Laxalde <denis@laxalde.org>
Tue, 4 Mar 2025 10:13:45 +0000 (11:13 +0100)
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/sql/sqltypes.py

index 2f26b39e31e5aa6ae92bfdcc830a98839eec773f..663be8b7a2b8830f45da58d2202b7f01953dd377 100644 (file)
@@ -4,8 +4,15 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
+from __future__ import annotations
+
+from typing import Any
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import Union
 
 from .array import ARRAY
 from .array import array as _pg_array
@@ -21,13 +28,23 @@ from .operators import PATH_EXISTS
 from .operators import PATH_MATCH
 from ... import types as sqltypes
 from ...sql import cast
+from ...sql._typing import _T
+
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.elements import ColumnElement
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
+    from ...sql.type_api import TypeEngine
 
 __all__ = ("JSON", "JSONB")
 
 
 class JSONPathType(sqltypes.JSON.JSONPathType):
-    def _processor(self, dialect, super_proc):
-        def process(value):
+    def _processor(
+        self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
+    ) -> Callable[[Any], Any]:
+        def process(value: Any) -> Any:
             if isinstance(value, str):
                 # If it's already a string assume that it's in json path
                 # format. This allows using cast with json paths literals
@@ -44,11 +61,13 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
 
         return process
 
-    def bind_processor(self, dialect):
-        return self._processor(dialect, self.string_bind_processor(dialect))
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        return self._processor(dialect, self.string_bind_processor(dialect))  # type: ignore[return-value]  # noqa: E501
 
-    def literal_processor(self, dialect):
-        return self._processor(dialect, self.string_literal_processor(dialect))
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        return self._processor(dialect, self.string_literal_processor(dialect))  # type: ignore[return-value]  # noqa: E501
 
 
 class JSONPATH(JSONPathType):
@@ -148,9 +167,13 @@ class JSON(sqltypes.JSON):
     """  # noqa
 
     render_bind_cast = True
-    astext_type = sqltypes.Text()
+    astext_type: TypeEngine[str] = sqltypes.Text()
 
-    def __init__(self, none_as_null=False, astext_type=None):
+    def __init__(
+        self,
+        none_as_null: bool = False,
+        astext_type: Optional[TypeEngine[str]] = None,
+    ):
         """Construct a :class:`_types.JSON` type.
 
         :param none_as_null: if True, persist the value ``None`` as a
@@ -175,11 +198,13 @@ class JSON(sqltypes.JSON):
         if astext_type is not None:
             self.astext_type = astext_type
 
-    class Comparator(sqltypes.JSON.Comparator):
+    class Comparator(sqltypes.JSON.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
+        type: JSON
+
         @property
-        def astext(self):
+        def astext(self) -> ColumnElement[str]:
             """On an indexed expression, use the "astext" (e.g. "->>")
             conversion when rendered in SQL.
 
@@ -193,13 +218,13 @@ class JSON(sqltypes.JSON):
 
             """
             if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
-                return self.expr.left.operate(
+                return self.expr.left.operate(  # type: ignore[no-any-return]
                     JSONPATH_ASTEXT,
                     self.expr.right,
                     result_type=self.type.astext_type,
                 )
             else:
-                return self.expr.left.operate(
+                return self.expr.left.operate(  # type: ignore[no-any-return]
                     ASTEXT, self.expr.right, result_type=self.type.astext_type
                 )
 
@@ -258,28 +283,30 @@ class JSONB(JSON):
 
     __visit_name__ = "JSONB"
 
-    class Comparator(JSON.Comparator):
+    class Comparator(JSON.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
-        def has_key(self, other):
+        type: JSONB
+
+        def has_key(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of a key (equivalent of
             the ``?`` operator).  Note that the key may be a SQLA expression.
             """
             return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
 
-        def has_all(self, other):
+        def has_all(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of all keys in jsonb
             (equivalent of the ``?&`` operator)
             """
             return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
 
-        def has_any(self, other):
+        def has_any(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of any key in jsonb
             (equivalent of the ``?|`` operator)
             """
             return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
 
-        def contains(self, other, **kwargs):
+        def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if keys (or array) are a superset
             of/contained the keys of the argument jsonb expression
             (equivalent of the ``@>`` operator).
@@ -289,7 +316,7 @@ class JSONB(JSON):
             """
             return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
 
-        def contained_by(self, other):
+        def contained_by(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test if keys are a proper subset of the
             keys of the argument jsonb expression
             (equivalent of the ``<@`` operator).
@@ -298,7 +325,9 @@ class JSONB(JSON):
                 CONTAINED_BY, other, result_type=sqltypes.Boolean
             )
 
-        def delete_path(self, array):
+        def delete_path(
+            self, array: Union[List[str], _pg_array[str]]
+        ) -> ColumnElement[JSONB]:
             """JSONB expression. Deletes field or array element specified in
             the argument array (equivalent of the ``#-`` operator).
 
@@ -308,11 +337,11 @@ class JSONB(JSON):
             .. versionadded:: 2.0
             """
             if not isinstance(array, _pg_array):
-                array = _pg_array(array)
+                array = _pg_array(array)  # type: ignore[no-untyped-call]
             right_side = cast(array, ARRAY(sqltypes.TEXT))
             return self.operate(DELETE_PATH, right_side, result_type=JSONB)
 
-        def path_exists(self, other):
+        def path_exists(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Test for presence of item given by the
             argument JSONPath expression (equivalent of the ``@?`` operator).
 
@@ -322,7 +351,7 @@ class JSONB(JSON):
                 PATH_EXISTS, other, result_type=sqltypes.Boolean
             )
 
-        def path_match(self, other):
+        def path_match(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Test if JSONPath predicate given by the
             argument JSONPath expression matches
             (equivalent of the ``@@`` operator).
index b141da188db355c7def432c727089e0c0f4a695a..3fcf22ee6865ecdf5cd21b26d43a913635582346 100644 (file)
@@ -2591,6 +2591,8 @@ class JSON(Indexable, TypeEngine[Any]):
 
         __slots__ = ()
 
+        type: JSON
+
         def _setup_getitem(self, index):
             if not isinstance(index, str) and isinstance(
                 index, collections_abc.Sequence