]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
allow JSON, JSONB, etc. to be parameterized, type HSTORE
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Feb 2026 15:12:52 +0000 (10:12 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 6 Mar 2026 19:25:03 +0000 (14:25 -0500)
Improved typing of :class:`_sqltypes.JSON` as well as dialect specific
variants like :class:`_postgresql.JSON` to include generic capabilities, so
that the types may be parameterized to indicate any specific type of
contents expected, e.g. ``JSONB[list[str]]()``.

Also types HSTORE

Fixes: #13131
Change-Id: Ia089ba4e3cebf6339a5420b2923cd267c4e6891a

15 files changed:
doc/build/changelog/unreleased_21/13131.rst [new file with mode: 0644]
lib/sqlalchemy/__init__.py
lib/sqlalchemy/dialects/mssql/json.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/json.py
lib/sqlalchemy/dialects/postgresql/hstore.py
lib/sqlalchemy/dialects/postgresql/json.py
lib/sqlalchemy/dialects/sqlite/json.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/util/typing.py
test/typing/plain_files/dialects/mysql/mysql_stuff.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py
test/typing/plain_files/sql/sqltypes.py

diff --git a/doc/build/changelog/unreleased_21/13131.rst b/doc/build/changelog/unreleased_21/13131.rst
new file mode 100644 (file)
index 0000000..978bc06
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 13131
+
+    Improved typing of :class:`_sqltypes.JSON` as well as dialect specific
+    variants like :class:`_postgresql.JSON` to include generic capabilities, so
+    that the types may be parameterized to indicate any specific type of
+    contents expected, e.g. ``JSONB[list[str]]()``.
+
index b594e3c6652f370189dcfb70d1cbdef38770560e..b74bae21af86d3c319c137a191974149c93c221e 100644 (file)
@@ -204,6 +204,7 @@ from .sql.expression import TableClause as TableClause
 from .sql.expression import TableSample as TableSample
 from .sql.expression import tablesample as tablesample
 from .sql.expression import TableValuedAlias as TableValuedAlias
+from .sql.expression import TableValuedColumn as TableValuedColumn
 from .sql.expression import text as text
 from .sql.expression import TextAsFrom as TextAsFrom
 from .sql.expression import TextClause as TextClause
index cb071884ace13c2bd1f0a35d4ca2d2f95f083b43..4c128326a404e6acd1b8963af7a4b2c24a36d94b 100644 (file)
@@ -4,9 +4,18 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
 
 from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
+
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
 
 # technically, all the dialect-specific datatypes that don't have any special
 # behaviors would be private with names like _MSJson. However, we haven't been
@@ -16,7 +25,7 @@ from ... import types as sqltypes
 # package-private at once.
 
 
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
     """MSSQL JSON type.
 
     MSSQL supports JSON-formatted data as of SQL Server 2016.
@@ -82,13 +91,13 @@ class JSON(sqltypes.JSON):
 # these are not generalizable to all JSON implementations, remain separately
 # implemented for each dialect.
 class _FormatTypeMixin:
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         raise NotImplementedError()
 
-    def bind_processor(self, dialect):
-        super_proc = self.string_bind_processor(dialect)
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        super_proc = self.string_bind_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> Any:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
@@ -96,29 +105,31 @@ class _FormatTypeMixin:
 
         return process
 
-    def literal_processor(self, dialect):
-        super_proc = self.string_literal_processor(dialect)
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        super_proc = self.string_literal_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> str:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
-            return value
+            return value  # type: ignore[no-any-return]
 
         return process
 
 
 class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         if isinstance(value, int):
-            value = "$[%s]" % value
+            formatted_value = "$[%s]" % value
         else:
-            value = '$."%s"' % value
-        return value
+            formatted_value = '$."%s"' % value
+        return formatted_value
 
 
 class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         return "$%s" % (
             "".join(
                 [
index 75ec79baac18f44caefb9e019ff3ef7f4da24165..9cfd643610c9d396ccc8e402d87f8d5300cf86a0 100644 (file)
@@ -2631,7 +2631,7 @@ class MySQLTypeCompiler(
     def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str:
         return "VARBINARY(%d)" % type_.length  # type: ignore[str-format]
 
-    def visit_JSON(self, type_: JSON, **kw: Any) -> str:
+    def visit_JSON(self, type_: JSON[Any], **kw: Any) -> str:
         return "JSON"
 
     def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str:
index 7e2606ccf96bd9285d10d1731021ff28df56f91a..5c564d73b9e2c84c71b26235649f9ccd51769a3f 100644 (file)
@@ -10,6 +10,7 @@ from typing import Any
 from typing import TYPE_CHECKING
 
 from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -17,7 +18,7 @@ if TYPE_CHECKING:
     from ...sql.type_api import _LiteralProcessorType
 
 
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
     """MySQL JSON type.
 
     MySQL supports JSON as of version 5.7.
index 4c9db15d82492e49bae29686112672e465ac03d4..91666e71ceaed4c30d8e36e4c37827f1ce20c9a2 100644 (file)
@@ -4,10 +4,12 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
+from __future__ import annotations
 
 import re
+from typing import Any
+from typing import Optional
 
 from .array import ARRAY
 from .operators import CONTAINED_BY
@@ -20,10 +22,17 @@ from ... import types as sqltypes
 from ...sql import functions as sqlfunc
 from ...types import OperatorClass
 
+
 __all__ = ("HSTORE", "hstore")
 
+_HSTORE_VAL = dict[str, str | None]
+
 
-class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
+class HSTORE(
+    sqltypes.Indexable,
+    sqltypes.Concatenable,
+    sqltypes.TypeEngine[_HSTORE_VAL],
+):
     """Represent the PostgreSQL HSTORE type.
 
     The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
@@ -112,7 +121,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
         | OperatorClass.CONCATENABLE
     )
 
-    def __init__(self, text_type=None):
+    def __init__(self, text_type: Optional[Any] = None) -> None:
         """Construct a new :class:`.HSTORE`.
 
         :param text_type: the type that should be used for indexed values.
@@ -123,25 +132,26 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
             self.text_type = text_type
 
     class Comparator(
-        sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
+        sqltypes.Indexable.Comparator[_HSTORE_VAL],
+        sqltypes.Concatenable.Comparator[_HSTORE_VAL],
     ):
         """Define comparison operations for :class:`.HSTORE`."""
 
-        def has_key(self, other):
+        def has_key(self, other: Any) -> Any:
             """Boolean expression.  Test for presence of a key.  Note that the
             key may be a SQLA expression.
             """
             return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
 
-        def has_all(self, other):
+        def has_all(self, other: Any) -> Any:
             """Boolean expression.  Test for presence of all keys in jsonb"""
             return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
 
-        def has_any(self, other):
+        def has_any(self, other: Any) -> Any:
             """Boolean expression.  Test for presence of any key in jsonb"""
             return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
 
-        def contains(self, other, **kwargs):
+        def contains(self, other: Any, **kwargs: Any) -> Any:
             """Boolean expression.  Test if keys (or array) are a superset
             of/contained the keys of the argument jsonb expression.
 
@@ -150,7 +160,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
             """
             return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
 
-        def contained_by(self, other):
+        def contained_by(self, other: Any) -> Any:
             """Boolean expression.  Test if keys are a proper subset of the
             keys of the argument jsonb expression.
             """
@@ -158,16 +168,16 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
                 CONTAINED_BY, other, result_type=sqltypes.Boolean
             )
 
-        def _setup_getitem(self, index):
-            return GETITEM, index, self.type.text_type
+        def _setup_getitem(self, index: Any) -> Any:
+            return GETITEM, index, self.type.text_type  # type: ignore
 
-        def defined(self, key):
+        def defined(self, key: Any) -> Any:
             """Boolean expression.  Test for presence of a non-NULL value for
             the key.  Note that the key may be a SQLA expression.
             """
             return _HStoreDefinedFunction(self.expr, key)
 
-        def delete(self, key):
+        def delete(self, key: Any) -> Any:
             """HStore expression.  Returns the contents of this hstore with the
             given key deleted.  Note that the key may be a SQLA expression.
             """
@@ -175,37 +185,37 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
                 key = _serialize_hstore(key)
             return _HStoreDeleteFunction(self.expr, key)
 
-        def slice(self, array):
+        def slice(self, array: Any) -> Any:
             """HStore expression.  Returns a subset of an hstore defined by
             array of keys.
             """
             return _HStoreSliceFunction(self.expr, array)
 
-        def keys(self):
+        def keys(self) -> Any:
             """Text array expression.  Returns array of keys."""
             return _HStoreKeysFunction(self.expr)
 
-        def vals(self):
+        def vals(self) -> Any:
             """Text array expression.  Returns array of values."""
             return _HStoreValsFunction(self.expr)
 
-        def array(self):
+        def array(self) -> Any:
             """Text array expression.  Returns array of alternating keys and
             values.
             """
             return _HStoreArrayFunction(self.expr)
 
-        def matrix(self):
+        def matrix(self) -> Any:
             """Text array expression.  Returns array of [key, value] pairs."""
             return _HStoreMatrixFunction(self.expr)
 
     comparator_factory = Comparator
 
-    def bind_processor(self, dialect):
+    def bind_processor(self, dialect: Any) -> Any:
         # note that dialect-specific types like that of psycopg and
         # psycopg2 will override this method to allow driver-level conversion
         # instead, see _PsycopgHStore
-        def process(value):
+        def process(value: Any) -> Any:
             if isinstance(value, dict):
                 return _serialize_hstore(value)
             else:
@@ -213,11 +223,11 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
 
         return process
 
-    def result_processor(self, dialect, coltype):
+    def result_processor(self, dialect: Any, coltype: Any) -> Any:
         # note that dialect-specific types like that of psycopg and
         # psycopg2 will override this method to allow driver-level conversion
         # instead, see _PsycopgHStore
-        def process(value):
+        def process(value: Any) -> Any:
             if value is not None:
                 return _parse_hstore(value)
             else:
@@ -226,7 +236,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
         return process
 
 
-class hstore(sqlfunc.GenericFunction):
+class hstore(sqlfunc.GenericFunction[_HSTORE_VAL]):
     """Construct an hstore value within a SQL expression using the
     PostgreSQL ``hstore()`` function.
 
@@ -252,48 +262,48 @@ class hstore(sqlfunc.GenericFunction):
 
     """
 
-    type = HSTORE
+    type = HSTORE()
     name = "hstore"
     inherit_cache = True
 
 
-class _HStoreDefinedFunction(sqlfunc.GenericFunction):
-    type = sqltypes.Boolean
+class _HStoreDefinedFunction(sqlfunc.GenericFunction[bool]):
+    type = sqltypes.Boolean()
     name = "defined"
     inherit_cache = True
 
 
-class _HStoreDeleteFunction(sqlfunc.GenericFunction):
-    type = HSTORE
+class _HStoreDeleteFunction(sqlfunc.GenericFunction[_HSTORE_VAL]):
+    type = HSTORE()
     name = "delete"
     inherit_cache = True
 
 
-class _HStoreSliceFunction(sqlfunc.GenericFunction):
-    type = HSTORE
+class _HStoreSliceFunction(sqlfunc.GenericFunction[_HSTORE_VAL]):
+    type = HSTORE()
     name = "slice"
     inherit_cache = True
 
 
-class _HStoreKeysFunction(sqlfunc.GenericFunction):
+class _HStoreKeysFunction(sqlfunc.GenericFunction[Any]):
     type = ARRAY(sqltypes.Text)
     name = "akeys"
     inherit_cache = True
 
 
-class _HStoreValsFunction(sqlfunc.GenericFunction):
+class _HStoreValsFunction(sqlfunc.GenericFunction[Any]):
     type = ARRAY(sqltypes.Text)
     name = "avals"
     inherit_cache = True
 
 
-class _HStoreArrayFunction(sqlfunc.GenericFunction):
+class _HStoreArrayFunction(sqlfunc.GenericFunction[Any]):
     type = ARRAY(sqltypes.Text)
     name = "hstore_to_array"
     inherit_cache = True
 
 
-class _HStoreMatrixFunction(sqlfunc.GenericFunction):
+class _HStoreMatrixFunction(sqlfunc.GenericFunction[Any]):
     type = ARRAY(sqltypes.Text)
     name = "hstore_to_matrix"
     inherit_cache = True
@@ -329,7 +339,7 @@ HSTORE_DELIMITER_RE = re.compile(
 )
 
 
-def _parse_error(hstore_str, pos):
+def _parse_error(hstore_str: str, pos: int) -> str:
     """format an unmarshalling error."""
 
     ctx = 20
@@ -350,7 +360,7 @@ def _parse_error(hstore_str, pos):
     )
 
 
-def _parse_hstore(hstore_str):
+def _parse_hstore(hstore_str: str) -> _HSTORE_VAL:
     """Parse an hstore from its literal string representation.
 
     Attempts to approximate PG's hstore input parsing rules as closely as
@@ -362,7 +372,7 @@ def _parse_hstore(hstore_str):
 
 
     """
-    result = {}
+    result: _HSTORE_VAL = {}
     pos = 0
     pair_match = HSTORE_PAIR_RE.match(hstore_str)
 
@@ -392,13 +402,13 @@ def _parse_hstore(hstore_str):
     return result
 
 
-def _serialize_hstore(val):
+def _serialize_hstore(val: _HSTORE_VAL) -> str:
     """Serialize a dictionary into an hstore literal.  Keys and values must
     both be strings (except None for values).
 
     """
 
-    def esc(s, position):
+    def esc(s: Optional[str], position: str) -> str:
         if position == "value" and s is None:
             return "NULL"
         elif isinstance(s, str):
index e50b1f3364252dfec35a71261b24c8d5b520873c..54b33fc65a5c6a4086f47d02b9dd5c487c9263ce 100644 (file)
@@ -28,8 +28,9 @@ from .operators import PATH_EXISTS
 from .operators import PATH_MATCH
 from ... import types as sqltypes
 from ...sql import cast
-from ...sql._typing import _T
 from ...sql.operators import OperatorClass
+from ...sql.sqltypes import _CT_JSON
+from ...sql.sqltypes import _T_JSON
 
 if TYPE_CHECKING:
     from ...engine.interfaces import Dialect
@@ -90,7 +91,7 @@ class JSONPATH(JSONPathType):
     __visit_name__ = "JSONPATH"
 
 
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
     """Represent the PostgreSQL JSON type.
 
     :class:`_postgresql.JSON` is used automatically whenever the base
@@ -200,10 +201,10 @@ class JSON(sqltypes.JSON):
         if astext_type is not None:
             self.astext_type = astext_type
 
-    class Comparator(sqltypes.JSON.Comparator[_T]):
+    class Comparator(sqltypes.JSON.Comparator[_CT_JSON]):
         """Define comparison operations for :class:`_types.JSON`."""
 
-        type: JSON
+        type: JSON[_CT_JSON]
 
         @property
         def astext(self) -> ColumnElement[str]:
@@ -233,7 +234,7 @@ class JSON(sqltypes.JSON):
     comparator_factory = Comparator
 
 
-class JSONB(JSON):
+class JSONB(JSON[_T_JSON]):
     """Represent the PostgreSQL JSONB type.
 
     The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
@@ -323,10 +324,10 @@ class JSONB(JSON):
         else:
             return super().coerce_compared_value(op, value)
 
-    class Comparator(JSON.Comparator[_T]):
+    class Comparator(JSON.Comparator[_CT_JSON]):
         """Define comparison operations for :class:`_types.JSON`."""
 
-        type: JSONB
+        type: JSONB[_CT_JSON]
 
         def has_key(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression.  Test for presence of a key (equivalent of
@@ -367,7 +368,7 @@ class JSONB(JSON):
 
         def delete_path(
             self, array: Union[List[str], _pg_array[str]]
-        ) -> ColumnElement[JSONB]:
+        ) -> ColumnElement[_CT_JSON]:
             """JSONB expression. Deletes field or array element specified in
             the argument array (equivalent of the ``#-`` operator).
 
index 1a1ee049c6d16bd84ad57e52bbc0c8d5944e299f..ac705d661d5b71c6f46123e70ee734d8c98ee899 100644 (file)
@@ -4,12 +4,21 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
 
 from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
+
+if TYPE_CHECKING:
+    from ...engine.interfaces import Dialect
+    from ...sql.type_api import _BindProcessorType
+    from ...sql.type_api import _LiteralProcessorType
 
 
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
     """SQLite JSON type.
 
     SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
@@ -42,13 +51,13 @@ class JSON(sqltypes.JSON):
 # these are not generalizable to all JSON implementations, remain separately
 # implemented for each dialect.
 class _FormatTypeMixin:
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         raise NotImplementedError()
 
-    def bind_processor(self, dialect):
-        super_proc = self.string_bind_processor(dialect)
+    def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+        super_proc = self.string_bind_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> Any:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
@@ -56,29 +65,31 @@ class _FormatTypeMixin:
 
         return process
 
-    def literal_processor(self, dialect):
-        super_proc = self.string_literal_processor(dialect)
+    def literal_processor(
+        self, dialect: Dialect
+    ) -> _LiteralProcessorType[Any]:
+        super_proc = self.string_literal_processor(dialect)  # type: ignore[attr-defined]  # noqa: E501
 
-        def process(value):
+        def process(value: Any) -> str:
             value = self._format_value(value)
             if super_proc:
                 value = super_proc(value)
-            return value
+            return value  # type: ignore[no-any-return]
 
         return process
 
 
 class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         if isinstance(value, int):
-            value = "$[%s]" % value
+            formatted_value = "$[%s]" % value
         else:
-            value = '$."%s"' % value
-        return value
+            formatted_value = '$."%s"' % value
+        return formatted_value
 
 
 class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
-    def _format_value(self, value):
+    def _format_value(self, value: Any) -> str:
         return "$%s" % (
             "".join(
                 [
index ae4d32e36ac9495076be9efb9f961d5acdaac691..ade97f1be79681635ec7c8a1d0ec28204bb67e75 100644 (file)
@@ -97,6 +97,8 @@ from .expression import table as table
 from .expression import TableClause as TableClause
 from .expression import TableSample as TableSample
 from .expression import tablesample as tablesample
+from .expression import TableValuedAlias as TableValuedAlias
+from .expression import TableValuedColumn as TableValuedColumn
 from .expression import text as text
 from .expression import TextClause as TextClause
 from .expression import true as true
index 37bb6383a5c63bf4ed98fa9fe00c7b6f997f404e..06e20c1a4e5db144feca8e7a5bdbe9413422c0c7 100644 (file)
@@ -106,6 +106,7 @@ from .elements import ReleaseSavepointClause as ReleaseSavepointClause
 from .elements import RollbackToSavepointClause as RollbackToSavepointClause
 from .elements import SavepointClause as SavepointClause
 from .elements import SQLColumnExpression as SQLColumnExpression
+from .elements import TableValuedColumn as TableValuedColumn
 from .elements import TextClause as TextClause
 from .elements import True_ as True_
 from .elements import TryCast as TryCast
index 22fcfb0a697022c9747c5b52ea3119018a226ed2..21ce3ae1f066edd2722798d15280a39f8df91752 100644 (file)
@@ -32,7 +32,7 @@ from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
-from typing import TypeVar
+from typing import TypeAlias
 from typing import Union
 from uuid import UUID as _python_UUID
 
@@ -70,6 +70,7 @@ from ..util import warn_deprecated
 from ..util.typing import is_literal
 from ..util.typing import is_pep695
 from ..util.typing import TupleAny
+from ..util.typing import TypeVar
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
@@ -85,9 +86,25 @@ if TYPE_CHECKING:
     from ..engine.interfaces import Dialect
     from ..util.typing import _MatchedOnType
 
-_T = TypeVar("_T", bound="Any")
+_T = TypeVar("_T", bound=Any)
+
+_JSON_VALUE: TypeAlias = (
+    str | int | bool | None | dict[str, "_JSON_VALUE"] | list["_JSON_VALUE"]
+)
+
+_T_JSON = TypeVar(
+    "_T_JSON",
+    bound=Any,
+    default=_JSON_VALUE,
+)
+_CT_JSON = TypeVar(
+    "_CT_JSON",
+    bound=Any,
+    default=_JSON_VALUE,
+)
+
 _CT = TypeVar("_CT", bound=Any)
-_TE = TypeVar("_TE", bound="TypeEngine[Any]")
+_TE = TypeVar("_TE", bound=TypeEngine[Any])
 _P = TypeVar("_P")
 
 
@@ -2361,7 +2378,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
         return process
 
 
-class JSON(Indexable, TypeEngine[Any]):
+class JSON(Indexable, TypeEngine[_T_JSON]):
     """Represent a SQL JSON type.
 
     .. note::  :class:`_types.JSON`
@@ -2713,12 +2730,14 @@ class JSON(Indexable, TypeEngine[Any]):
 
         __visit_name__ = "json_path"
 
-    class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]):
+    class Comparator(
+        Indexable.Comparator[_CT_JSON], Concatenable.Comparator[_CT_JSON]
+    ):
         """Define comparison operations for :class:`_types.JSON`."""
 
         __slots__ = ()
 
-        type: JSON
+        type: JSON[_CT_JSON]
 
         def _setup_getitem(self, index):
             if not isinstance(index, str) and isinstance(
index 6a0d2ed85c750dd034e948c120ad85e7ebd1f41f..01bf0a7b3a49acfb0141ae20502935319890ab1d 100644 (file)
@@ -35,7 +35,6 @@ from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeGuard
-from typing import TypeVar
 from typing import Union
 
 import typing_extensions
@@ -53,6 +52,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import Unpack as Unpack  # 3.11
     from typing_extensions import Never as Never  # 3.11
     from typing_extensions import LiteralString as LiteralString  # 3.11
+    from typing_extensions import TypeVar as TypeVar  # 3.13 for default
 
 
 _T = TypeVar("_T", bound=Any)
index 3fcdc75a9715ffadc40dd22a9f72afd6924b3178..222becb2674098d6d99ffaa0a91742c01b97b72d 100644 (file)
@@ -1,21 +1,69 @@
+from typing import Any
+from typing import assert_type
+from typing import Dict
+from typing import List
+from typing import Tuple
+
+from sqlalchemy import Column
+from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy.dialects.mysql import insert
+from sqlalchemy.dialects.mysql import JSON
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.functions import Function
 
 
 class Base(DeclarativeBase):
     pass
 
 
-class Test(Base):
-    __tablename__ = "test_table_json"
+def test_insert_on_duplicate_key_update() -> None:
+    """Test INSERT with ON DUPLICATE KEY UPDATE."""
+
+    class Test(Base):
+        __tablename__ = "test_table_json"
+        id = mapped_column(Integer, primary_key=True)
+        data: Mapped[str] = mapped_column()
+
+    insert(Test).on_duplicate_key_update(
+        {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo"
+    ).inserted.foo.desc()
+
+
+def test_json_column_types() -> None:
+    """Test JSON column type inference with type parameters."""
+    c_json: Column[Any] = Column(JSON())
+    assert_type(c_json, Column[Any])
+
+    c_json_dict = Column(JSON[Dict[str, Any]]())
+    assert_type(c_json_dict, Column[Dict[str, Any]])
+
+
+def test_json_orm_mapping() -> None:
+    """Test JSON type in ORM mapped columns."""
+
+    class JSONTest(Base):
+        __tablename__ = "test_json"
+        id = mapped_column(Integer, primary_key=True)
+        json_data: Mapped[Dict[str, Any]] = mapped_column(JSON)
+        json_list: Mapped[List[Any]] = mapped_column(JSON)
+
+    json_obj = JSONTest()
+    assert_type(json_obj.json_data, Dict[str, Any])
+    assert_type(json_obj.json_list, List[Any])
+
 
-    id = mapped_column(Integer, primary_key=True)
-    data: Mapped[str] = mapped_column()
+def test_json_func_with_type_param() -> None:
+    """Test func with parameterized JSON types (issue #13131)."""
 
+    class JSONTest(Base):
+        __tablename__ = "test_json_func"
+        id = mapped_column(Integer, primary_key=True)
+        data: Mapped[Dict[str, Any]] = mapped_column(JSON)
 
-insert(Test).on_duplicate_key_update(
-    {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo"
-).inserted.foo.desc()
+    json_func_result = func.json_extract(
+        JSONTest.data, "$.items", type_=JSON[List[Tuple[int, int]]]
+    )
+    assert_type(json_func_result, Function[List[Tuple[int, int]]])
index 14992511037485c1d1ab35c125d815c1d4aa313d..088cb644146aaf544e8edfbb3053bbb7181f798f 100644 (file)
@@ -3,16 +3,17 @@ from datetime import datetime
 from typing import Any
 from typing import assert_type
 from typing import Dict
+from typing import List
 from typing import Sequence
+from typing import Tuple
 from uuid import UUID as _py_uuid
 
-from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import func
 from sqlalchemy import Integer
-from sqlalchemy import or_
 from sqlalchemy import Select
 from sqlalchemy import select
+from sqlalchemy import TableValuedColumn
 from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import aggregate_order_by
@@ -20,9 +21,11 @@ from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATERANGE
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
+from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import INT4RANGE
 from sqlalchemy.dialects.postgresql import INT8MULTIRANGE
+from sqlalchemy.dialects.postgresql import JSON
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
@@ -30,118 +33,256 @@ from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.functions import Function
+from sqlalchemy.sql.sqltypes import _JSON_VALUE
 
-# test #6402
 
-c1 = Column(UUID())
+class Base(DeclarativeBase):
+    pass
 
-assert_type(c1, Column[_py_uuid])
 
-c2 = Column(UUID(as_uuid=False))
+def test_uuid_column_types() -> None:
+    """Test UUID column type inference with as_uuid parameter."""
+    c1 = Column(UUID())
+    assert_type(c1, Column[_py_uuid])
 
-assert_type(c2, Column[str])
+    c2 = Column(UUID(as_uuid=False))
+    assert_type(c2, Column[str])
 
 
-class Base(DeclarativeBase):
-    pass
+def test_range_column_types() -> None:
+    """Test Range and MultiRange column type inference."""
+    assert_type(Column(INT4RANGE()), Column[Range[int]])
+    assert_type(Column("foo", DATERANGE()), Column[Range[date]])
+    assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]])
+    assert_type(
+        Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]]
+    )
+
+
+def test_range_in_select() -> None:
+    """Test Range types in SELECT statements."""
+    range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
+    assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]])
+
+
+def test_array_type_inference() -> None:
+    array_from_ints = array(range(2))
+    assert_type(array_from_ints, array[int])
+
+    array_of_strings = array([], type_=Text)
+    assert_type(array_of_strings, array[str])
+
+    array_of_ints = array([0], type_=Integer)
+    assert_type(array_of_ints, array[int])
+
+    # EXPECTED_MYPY_RE: Cannot infer .* of "array"
+    array([0], type_=Text)
 
 
-class Test(Base):
-    __tablename__ = "test_table_json"
+def test_array_column_types() -> None:
+    """Test ARRAY column type inference."""
+    assert_type(ARRAY(Text), ARRAY[str])
+    assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]])
 
-    id = mapped_column(Integer, primary_key=True)
-    data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
 
-    ident: Mapped[_py_uuid] = mapped_column(UUID())
+def test_array_agg_functions() -> None:
+    """Test array_agg function type inference."""
+    stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+    assert_type(stmt_array_agg, Select[Sequence[int]])
 
-    ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
 
-    __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),)
+def test_array_agg_with_aggregate_order_by() -> None:
+    """Test array_agg with aggregate_order_by."""
 
+    class Test(Base):
+        __tablename__ = "test_array_agg"
+        id = mapped_column(Integer, primary_key=True)
+        ident_str: Mapped[str] = mapped_column()
+        ident: Mapped[_py_uuid] = mapped_column(UUID())
 
-elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem")
+    assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]])
 
-stmt = select(Test).where(
-    or_(
-        cast("example code", ARRAY(Text)).contained_by(
-            array([select(elem["code"].astext).scalar_subquery()])
-        ),
-        cast("stefan", ARRAY(Text)).contained_by(
-            array([select(elem["code"]["new_value"].astext).scalar_subquery()])
-        ),
+    stmt_array_agg_order_by_1 = select(
+        func.array_agg(
+            aggregate_order_by(
+                Column("title", type_=Text),
+                Column("date", type_=DATERANGE).desc(),
+                Column("id", type_=Integer),
+            ),
+        )
     )
-)
-print(stmt)
+    assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]])
 
+    stmt_array_agg_order_by_2 = select(
+        func.array_agg(
+            aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+        )
+    )
+    assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]])
 
-t1 = Test()
 
-assert_type(t1.data, dict[str, Any])
+def test_json_parameterization() -> None:
 
-assert_type(t1.ident, _py_uuid)
+    # test default type
+    x: JSON = JSON()
 
-unique = UniqueConstraint(name="my_constraint")
-insert(Test).on_conflict_do_nothing(
-    "foo", [Test.id], Test.id > 0
-).on_conflict_do_update(
-    unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
-).excluded.foo.desc()
+    assert_type(x, JSON[_JSON_VALUE])
 
-s1 = insert(Test)
-s1.on_conflict_do_update(set_=s1.excluded)
+    # test column values
 
+    s1 = select(Column(JSON()))
 
-assert_type(Column(INT4RANGE()), Column[Range[int]])
-assert_type(Column("foo", DATERANGE()), Column[Range[date]])
-assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]])
-assert_type(Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]])
+    assert_type(s1, Select[_JSON_VALUE])
 
+    c1: Column[list[int]] = Column(JSON())
+    s2 = select(c1)
 
-range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
+    assert_type(s2, Select[list[int]])
 
-assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]])
 
-array_from_ints = array(range(2))
+def test_jsonb_parameterization() -> None:
 
-assert_type(array_from_ints, array[int])
+    # test default type
+    x: JSONB = JSONB()
 
-array_of_strings = array([], type_=Text)
+    assert_type(x, JSONB[_JSON_VALUE])
 
-assert_type(array_of_strings, array[str])
+    # test column values
 
-array_of_ints = array([0], type_=Integer)
+    s1 = select(Column(JSONB()))
 
-assert_type(array_of_ints, array[int])
+    assert_type(s1, Select[_JSON_VALUE])
 
-# EXPECTED_MYPY_RE: Cannot infer .* of "array"
-array([0], type_=Text)
+    c1: Column[list[int]] = Column(JSONB())
+    s2 = select(c1)
 
-assert_type(ARRAY(Text), ARRAY[str])
+    assert_type(s2, Select[list[int]])
 
-assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]])
 
-stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+def test_json_column_types() -> None:
+    """Test JSON column type inference with type parameters."""
 
-assert_type(stmt_array_agg, Select[Sequence[int]])
+    c_json: Column[Any] = Column(JSON())
+    assert_type(c_json, Column[Any])
 
-assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]])
+    c_json_dict = Column(JSON[Dict[str, Any]]())
+    assert_type(c_json_dict, Column[Dict[str, Any]])
 
-stmt_array_agg_order_by_1 = select(
-    func.array_agg(
-        aggregate_order_by(
-            Column("title", type_=Text),
-            Column("date", type_=DATERANGE).desc(),
-            Column("id", type_=Integer),
-        ),
-    )
-)
 
-assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]])
+def test_json_orm_mapping() -> None:
+    """Test JSON type in ORM mapped columns."""
+
+    class JSONTest(Base):
+        __tablename__ = "test_json"
+        id = mapped_column(Integer, primary_key=True)
+        json_data: Mapped[Dict[str, Any]] = mapped_column(JSON)
+        json_list: Mapped[List[Any]] = mapped_column(JSON)
+
+    json_obj = JSONTest()
+    assert_type(json_obj.json_data, Dict[str, Any])
+    assert_type(json_obj.json_list, List[Any])
+
+
+def test_jsonb_column_types() -> None:
+    """Test JSONB column type inference with type parameters."""
+    c_jsonb: Column[Any] = Column(JSONB())
+    assert_type(c_jsonb, Column[Any])
+
+    c_jsonb_dict = Column(JSONB[Dict[str, Any]]())
+    assert_type(c_jsonb_dict, Column[Dict[str, Any]])
+
+
+def test_jsonb_orm_mapping() -> None:
+    """Test JSONB type in ORM mapped columns."""
+
+    class JSONBTest(Base):
+        __tablename__ = "test_jsonb"
+        id = mapped_column(Integer, primary_key=True)
+        jsonb_data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+
+    jsonb_obj = JSONBTest()
+    assert_type(jsonb_obj.jsonb_data, Dict[str, Any])
 
-stmt_array_agg_order_by_2 = select(
-    func.array_agg(
-        aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+
+def test_jsonb_func_with_type_param() -> None:
+    """Test func with parameterized JSONB types (issue #13131)."""
+
+    class JSONBTest(Base):
+        __tablename__ = "test_jsonb_func"
+        id = mapped_column(Integer, primary_key=True)
+        data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+
+    json_func_result = func.jsonb_path_query_array(
+        JSONBTest.data, "$.items", type_=JSONB[List[Tuple[int, int]]]
     )
-)
+    assert_type(json_func_result, Function[List[Tuple[int, int]]])
+
+
+def test_hstore_column_types() -> None:
+    """Test HSTORE column type inference."""
+    c_hstore: Column[dict[str, str | None]] = Column(HSTORE())
+    assert_type(c_hstore, Column[dict[str, str | None]])
+
+
+def test_hstore_orm_mapping() -> None:
+    """Test HSTORE type in ORM mapped columns."""
+
+    class HSTORETest(Base):
+        __tablename__ = "test_hstore"
+        id = mapped_column(Integer, primary_key=True)
+        hstore_data: Mapped[dict[str, str | None]] = mapped_column(HSTORE)
+
+    hstore_obj = HSTORETest()
+    assert_type(hstore_obj.hstore_data, dict[str, str | None])
+
+
+def test_hstore_func() -> None:
+
+    my_func = func.foobar(type_=HSTORE)
+
+    stmt = select(my_func)
+    assert_type(stmt, Select[dict[str, str | None]])
+
+
+def test_insert_on_conflict() -> None:
+    """Test INSERT with ON CONFLICT clauses."""
+
+    class Test(Base):
+        __tablename__ = "test_dml"
+        id = mapped_column(Integer, primary_key=True)
+        data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+        ident: Mapped[_py_uuid] = mapped_column(UUID())
+        ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
+        __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),)
+
+    unique = UniqueConstraint(name="my_constraint")
+    insert(Test).on_conflict_do_nothing(
+        "foo", [Test.id], Test.id > 0
+    ).on_conflict_do_update(
+        unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
+    ).excluded.foo.desc()
+
+    s1 = insert(Test)
+    s1.on_conflict_do_update(set_=s1.excluded)
+
+
+def test_complex_jsonb_query() -> None:
+    """Test complex query with JSONB array elements."""
+
+    class Test(Base):
+        __tablename__ = "test_complex"
+        id = mapped_column(Integer, primary_key=True)
+        data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+        ident: Mapped[_py_uuid] = mapped_column(UUID())
+        ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
+
+    elem = func.jsonb_array_elements(
+        Test.data, type_=JSONB[list[str]]
+    ).column_valued("elem")
+
+    assert_type(elem, TableValuedColumn[list[str]])
 
-assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]])
+    t1 = Test()
+    assert_type(t1.data, dict[str, Any])
+    assert_type(t1.ident, _py_uuid)
index 0b5cc1bc92c9be1dff6b7cb847f43575d8f2aabe..7240b7c0d58955c5e3c95ec976d561e42a3543e7 100644 (file)
@@ -1,11 +1,63 @@
 from decimal import Decimal
 from typing import assert_type
 
+from sqlalchemy import Column
 from sqlalchemy import Float
+from sqlalchemy import JSON
 from sqlalchemy import Numeric
+from sqlalchemy import Select
+from sqlalchemy import select
+from sqlalchemy.sql.sqltypes import _JSON_VALUE
+
 
 assert_type(Float(), Float[float])
 assert_type(Float(asdecimal=True), Float[Decimal])
 
 assert_type(Numeric(), Numeric[Decimal])
 assert_type(Numeric(asdecimal=False), Numeric[float])
+
+
+def test_json_value_type() -> None:
+
+    j1: _JSON_VALUE = {
+        "foo": "bar",
+        "bat": {"value1": True},
+        "hoho": [1, 2, 3],
+    }
+    j2: _JSON_VALUE = "foo"
+    j3: _JSON_VALUE = 5
+    j4: _JSON_VALUE = False
+    j5: _JSON_VALUE = None
+    j6: _JSON_VALUE = [None, 5, "foo", False]
+    j7: _JSON_VALUE = {  # noqa: F841
+        "j1": j1,
+        "j2": j2,
+        "j3": j3,
+        "j4": j4,
+        "j5": j5,
+        "j6": j6,
+    }
+
+
+def test_json_parameterization() -> None:
+
+    # test default type
+    x: JSON = JSON()
+
+    assert_type(x, JSON[_JSON_VALUE])
+
+    # test column values
+
+    s1 = select(Column(JSON()))
+
+    assert_type(s1, Select[_JSON_VALUE])
+
+    c1: Column[list[int]] = Column(JSON())
+    s2 = select(c1)
+
+    assert_type(s2, Select[list[int]])
+
+    c2 = Column(JSON[list[int]]())
+    s3 = select(c2)
+
+    assert_type(s3, Select[list[int]])