From: Denis Laxalde Date: Tue, 4 Mar 2025 20:28:47 +0000 (-0500) Subject: Add type annotations to `postgresql.json` X-Git-Tag: rel_2_0_39~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1db6ee03c91cdcb618bac3c5119861656ba16521;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add type annotations to `postgresql.json` (Same as https://github.com/sqlalchemy/sqlalchemy/pull/12384, but for `json`.) ### Checklist This pull request is: - [ ] A documentation / typographical / small typing error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [x] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Related to #6810 **Have a nice day!** Closes: #12391 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12391 Pull-request-sha: 0a43724f1737a4519629a13e2d6bf33f7aecb9ac Change-Id: I2a0e88effccf351de7fa72389ee646532ce9cf69 (cherry picked from commit c7f4e8b9370487135777677eaf4d8992825c24aa) --- diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 2f26b39e31..663be8b7a2 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -4,8 +4,15 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import TYPE_CHECKING +from typing import Union from .array import ARRAY from .array import array as _pg_array @@ -21,13 +28,23 @@ from .operators import PATH_EXISTS from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast +from ...sql._typing import _T + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.elements import ColumnElement + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import TypeEngine __all__ = ("JSON", "JSONB") class JSONPathType(sqltypes.JSON.JSONPathType): - def _processor(self, dialect, super_proc): - def process(value): + def _processor( + self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]] + ) -> Callable[[Any], Any]: + def process(value: Any) -> Any: if isinstance(value, str): # If it's already a string assume that it's in json path # format. This allows using cast with json paths literals @@ -44,11 +61,13 @@ class JSONPathType(sqltypes.JSON.JSONPathType): return process - def bind_processor(self, dialect): - return self._processor(dialect, self.string_bind_processor(dialect)) + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501 - def literal_processor(self, dialect): - return self._processor(dialect, self.string_literal_processor(dialect)) + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501 class JSONPATH(JSONPathType): @@ -148,9 +167,13 @@ class JSON(sqltypes.JSON): """ # noqa render_bind_cast = True - astext_type = sqltypes.Text() + astext_type: TypeEngine[str] = sqltypes.Text() - def __init__(self, none_as_null=False, astext_type=None): + def __init__( + self, + none_as_null: bool = False, + astext_type: Optional[TypeEngine[str]] = None, + ): """Construct a :class:`_types.JSON` type. :param none_as_null: if True, persist the value ``None`` as a @@ -175,11 +198,13 @@ class JSON(sqltypes.JSON): if astext_type is not None: self.astext_type = astext_type - class Comparator(sqltypes.JSON.Comparator): + class Comparator(sqltypes.JSON.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" + type: JSON + @property - def astext(self): + def astext(self) -> ColumnElement[str]: """On an indexed expression, use the "astext" (e.g. "->>") conversion when rendered in SQL. @@ -193,13 +218,13 @@ class JSON(sqltypes.JSON): """ if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): - return self.expr.left.operate( + return self.expr.left.operate( # type: ignore[no-any-return] JSONPATH_ASTEXT, self.expr.right, result_type=self.type.astext_type, ) else: - return self.expr.left.operate( + return self.expr.left.operate( # type: ignore[no-any-return] ASTEXT, self.expr.right, result_type=self.type.astext_type ) @@ -258,28 +283,30 @@ class JSONB(JSON): __visit_name__ = "JSONB" - class Comparator(JSON.Comparator): + class Comparator(JSON.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" - def has_key(self, other): + type: JSONB + + def has_key(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of a key (equivalent of the ``?`` operator). Note that the key may be a SQLA expression. """ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) - def has_all(self, other): + def has_all(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of all keys in jsonb (equivalent of the ``?&`` operator) """ return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) - def has_any(self, other): + def has_any(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of any key in jsonb (equivalent of the ``?|`` operator) """ return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) - def contains(self, other, **kwargs): + def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: """Boolean expression. Test if keys (or array) are a superset of/contained the keys of the argument jsonb expression (equivalent of the ``@>`` operator). @@ -289,7 +316,7 @@ class JSONB(JSON): """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other): + def contained_by(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test if keys are a proper subset of the keys of the argument jsonb expression (equivalent of the ``<@`` operator). @@ -298,7 +325,9 @@ class JSONB(JSON): CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def delete_path(self, array): + def delete_path( + self, array: Union[List[str], _pg_array[str]] + ) -> ColumnElement[JSONB]: """JSONB expression. Deletes field or array element specified in the argument array (equivalent of the ``#-`` operator). @@ -308,11 +337,11 @@ class JSONB(JSON): .. versionadded:: 2.0 """ if not isinstance(array, _pg_array): - array = _pg_array(array) + array = _pg_array(array) # type: ignore[no-untyped-call] right_side = cast(array, ARRAY(sqltypes.TEXT)) return self.operate(DELETE_PATH, right_side, result_type=JSONB) - def path_exists(self, other): + def path_exists(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of item given by the argument JSONPath expression (equivalent of the ``@?`` operator). @@ -322,7 +351,7 @@ class JSONB(JSON): PATH_EXISTS, other, result_type=sqltypes.Boolean ) - def path_match(self, other): + def path_match(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test if JSONPath predicate given by the argument JSONPath expression matches (equivalent of the ``@@`` operator). diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index ee471a6c4e..ad220356f0 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -72,6 +72,7 @@ if TYPE_CHECKING: from .schema import MetaData from .type_api import _BindProcessorType from .type_api import _ComparatorFactory + from .type_api import _LiteralProcessorType from .type_api import _MatchedOnType from .type_api import _ResultProcessorType from ..engine.interfaces import Dialect @@ -2465,17 +2466,21 @@ class JSON(Indexable, TypeEngine[Any]): _integer = Integer() _string = String() - def string_bind_processor(self, dialect): + def string_bind_processor( + self, dialect: Dialect + ) -> Optional[_BindProcessorType[str]]: return self._string._cached_bind_processor(dialect) - def string_literal_processor(self, dialect): + def string_literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[str]]: return self._string._cached_literal_processor(dialect) - def bind_processor(self, dialect): + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: int_processor = self._integer._cached_bind_processor(dialect) string_processor = self.string_bind_processor(dialect) - def process(value): + def process(value: Optional[Any]) -> Any: if int_processor and isinstance(value, int): value = int_processor(value) elif string_processor and isinstance(value, str): @@ -2484,11 +2489,13 @@ class JSON(Indexable, TypeEngine[Any]): return process - def literal_processor(self, dialect): + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: int_processor = self._integer._cached_literal_processor(dialect) string_processor = self.string_literal_processor(dialect) - def process(value): + def process(value: Optional[Any]) -> Any: if int_processor and isinstance(value, int): value = int_processor(value) elif string_processor and isinstance(value, str): @@ -2539,6 +2546,8 @@ class JSON(Indexable, TypeEngine[Any]): __slots__ = () + type: JSON + def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( index, collections_abc.Sequence diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index aeb804d3f9..8cdb323b2a 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -67,6 +67,7 @@ _T_con = TypeVar("_T_con", bound=Any, contravariant=True) _O = TypeVar("_O", bound=object) _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) +_RT = TypeVar("_RT", bound=Any) _MatchedOnType = Union[ "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any] @@ -186,10 +187,24 @@ class TypeEngine(Visitable, Generic[_T]): def __reduce__(self) -> Any: return self.__class__, (self.expr,) + @overload + def operate( + self, + op: OperatorType, + *other: Any, + result_type: Type[TypeEngine[_RT]], + **kwargs: Any, + ) -> ColumnElement[_RT]: ... + + @overload + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_CT]: ... + @util.preload_module("sqlalchemy.sql.default_comparator") def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[_CT]: + ) -> ColumnElement[Any]: default_comparator = util.preloaded.sql_default_comparator op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] if kwargs: