--- /dev/null
+.. change::
+ :tags: bug, typing
+ :tickets: 13131
+
+ Improved typing of :class:`_sqltypes.JSON` as well as dialect specific
+ variants like :class:`_postgresql.JSON` to include generic capabilities, so
+ that the types may be parameterized to indicate any specific type of
+ contents expected, e.g. ``JSONB[list[str]]()``.
+
from .sql.expression import TableSample as TableSample
from .sql.expression import tablesample as tablesample
from .sql.expression import TableValuedAlias as TableValuedAlias
+from .sql.expression import TableValuedColumn as TableValuedColumn
from .sql.expression import text as text
from .sql.expression import TextAsFrom as TextAsFrom
from .sql.expression import TextClause as TextClause
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
+
+if TYPE_CHECKING:
+ from ...engine.interfaces import Dialect
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _LiteralProcessorType
# technically, all the dialect-specific datatypes that don't have any special
# behaviors would be private with names like _MSJson. However, we haven't been
# package-private at once.
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
"""MSSQL JSON type.
MSSQL supports JSON-formatted data as of SQL Server 2016.
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin:
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
raise NotImplementedError()
- def bind_processor(self, dialect):
- super_proc = self.string_bind_processor(dialect)
+ def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+ super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> Any:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return process
- def literal_processor(self, dialect):
- super_proc = self.string_literal_processor(dialect)
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> _LiteralProcessorType[Any]:
+ super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> str:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
- return value
+ return value # type: ignore[no-any-return]
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
if isinstance(value, int):
- value = "$[%s]" % value
+ formatted_value = "$[%s]" % value
else:
- value = '$."%s"' % value
- return value
+ formatted_value = '$."%s"' % value
+ return formatted_value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
return "$%s" % (
"".join(
[
def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str:
return "VARBINARY(%d)" % type_.length # type: ignore[str-format]
- def visit_JSON(self, type_: JSON, **kw: Any) -> str:
+ def visit_JSON(self, type_: JSON[Any], **kw: Any) -> str:
return "JSON"
def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str:
from typing import TYPE_CHECKING
from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _LiteralProcessorType
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
"""MySQL JSON type.
MySQL supports JSON as of version 5.7.
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
import re
+from typing import Any
+from typing import Optional
from .array import ARRAY
from .operators import CONTAINED_BY
from ...sql import functions as sqlfunc
from ...types import OperatorClass
+
__all__ = ("HSTORE", "hstore")
+_HSTORE_VAL = dict[str, str | None]
+
-class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
+class HSTORE(
+ sqltypes.Indexable,
+ sqltypes.Concatenable,
+ sqltypes.TypeEngine[_HSTORE_VAL],
+):
"""Represent the PostgreSQL HSTORE type.
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
| OperatorClass.CONCATENABLE
)
- def __init__(self, text_type=None):
+ def __init__(self, text_type: Optional[Any] = None) -> None:
"""Construct a new :class:`.HSTORE`.
:param text_type: the type that should be used for indexed values.
self.text_type = text_type
class Comparator(
- sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
+ sqltypes.Indexable.Comparator[_HSTORE_VAL],
+ sqltypes.Concatenable.Comparator[_HSTORE_VAL],
):
"""Define comparison operations for :class:`.HSTORE`."""
- def has_key(self, other):
+ def has_key(self, other: Any) -> Any:
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
- def has_all(self, other):
+ def has_all(self, other: Any) -> Any:
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
- def has_any(self, other):
+ def has_any(self, other: Any) -> Any:
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
- def contains(self, other, **kwargs):
+ def contains(self, other: Any, **kwargs: Any) -> Any:
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
- def contained_by(self, other):
+ def contained_by(self, other: Any) -> Any:
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
- def _setup_getitem(self, index):
- return GETITEM, index, self.type.text_type
+ def _setup_getitem(self, index: Any) -> Any:
+ return GETITEM, index, self.type.text_type # type: ignore
- def defined(self, key):
+ def defined(self, key: Any) -> Any:
"""Boolean expression. Test for presence of a non-NULL value for
the key. Note that the key may be a SQLA expression.
"""
return _HStoreDefinedFunction(self.expr, key)
- def delete(self, key):
+ def delete(self, key: Any) -> Any:
"""HStore expression. Returns the contents of this hstore with the
given key deleted. Note that the key may be a SQLA expression.
"""
key = _serialize_hstore(key)
return _HStoreDeleteFunction(self.expr, key)
- def slice(self, array):
+ def slice(self, array: Any) -> Any:
"""HStore expression. Returns a subset of an hstore defined by
array of keys.
"""
return _HStoreSliceFunction(self.expr, array)
- def keys(self):
+ def keys(self) -> Any:
"""Text array expression. Returns array of keys."""
return _HStoreKeysFunction(self.expr)
- def vals(self):
+ def vals(self) -> Any:
"""Text array expression. Returns array of values."""
return _HStoreValsFunction(self.expr)
- def array(self):
+ def array(self) -> Any:
"""Text array expression. Returns array of alternating keys and
values.
"""
return _HStoreArrayFunction(self.expr)
- def matrix(self):
+ def matrix(self) -> Any:
"""Text array expression. Returns array of [key, value] pairs."""
return _HStoreMatrixFunction(self.expr)
comparator_factory = Comparator
- def bind_processor(self, dialect):
+ def bind_processor(self, dialect: Any) -> Any:
# note that dialect-specific types like that of psycopg and
# psycopg2 will override this method to allow driver-level conversion
# instead, see _PsycopgHStore
- def process(value):
+ def process(value: Any) -> Any:
if isinstance(value, dict):
return _serialize_hstore(value)
else:
return process
- def result_processor(self, dialect, coltype):
+ def result_processor(self, dialect: Any, coltype: Any) -> Any:
# note that dialect-specific types like that of psycopg and
# psycopg2 will override this method to allow driver-level conversion
# instead, see _PsycopgHStore
- def process(value):
+ def process(value: Any) -> Any:
if value is not None:
return _parse_hstore(value)
else:
return process
-class hstore(sqlfunc.GenericFunction):
+class hstore(sqlfunc.GenericFunction[_HSTORE_VAL]):
"""Construct an hstore value within a SQL expression using the
PostgreSQL ``hstore()`` function.
"""
- type = HSTORE
+ type = HSTORE()
name = "hstore"
inherit_cache = True
-class _HStoreDefinedFunction(sqlfunc.GenericFunction):
- type = sqltypes.Boolean
+class _HStoreDefinedFunction(sqlfunc.GenericFunction[bool]):
+ type = sqltypes.Boolean()
name = "defined"
inherit_cache = True
-class _HStoreDeleteFunction(sqlfunc.GenericFunction):
- type = HSTORE
+class _HStoreDeleteFunction(sqlfunc.GenericFunction[_HSTORE_VAL]):
+ type = HSTORE()
name = "delete"
inherit_cache = True
-class _HStoreSliceFunction(sqlfunc.GenericFunction):
- type = HSTORE
+class _HStoreSliceFunction(sqlfunc.GenericFunction[_HSTORE_VAL]):
+ type = HSTORE()
name = "slice"
inherit_cache = True
-class _HStoreKeysFunction(sqlfunc.GenericFunction):
+class _HStoreKeysFunction(sqlfunc.GenericFunction[Any]):
type = ARRAY(sqltypes.Text)
name = "akeys"
inherit_cache = True
-class _HStoreValsFunction(sqlfunc.GenericFunction):
+class _HStoreValsFunction(sqlfunc.GenericFunction[Any]):
type = ARRAY(sqltypes.Text)
name = "avals"
inherit_cache = True
-class _HStoreArrayFunction(sqlfunc.GenericFunction):
+class _HStoreArrayFunction(sqlfunc.GenericFunction[Any]):
type = ARRAY(sqltypes.Text)
name = "hstore_to_array"
inherit_cache = True
-class _HStoreMatrixFunction(sqlfunc.GenericFunction):
+class _HStoreMatrixFunction(sqlfunc.GenericFunction[Any]):
type = ARRAY(sqltypes.Text)
name = "hstore_to_matrix"
inherit_cache = True
)
-def _parse_error(hstore_str, pos):
+def _parse_error(hstore_str: str, pos: int) -> str:
"""format an unmarshalling error."""
ctx = 20
)
-def _parse_hstore(hstore_str):
+def _parse_hstore(hstore_str: str) -> _HSTORE_VAL:
"""Parse an hstore from its literal string representation.
Attempts to approximate PG's hstore input parsing rules as closely as
"""
- result = {}
+ result: _HSTORE_VAL = {}
pos = 0
pair_match = HSTORE_PAIR_RE.match(hstore_str)
return result
-def _serialize_hstore(val):
+def _serialize_hstore(val: _HSTORE_VAL) -> str:
"""Serialize a dictionary into an hstore literal. Keys and values must
both be strings (except None for values).
"""
- def esc(s, position):
+ def esc(s: Optional[str], position: str) -> str:
if position == "value" and s is None:
return "NULL"
elif isinstance(s, str):
from .operators import PATH_MATCH
from ... import types as sqltypes
from ...sql import cast
-from ...sql._typing import _T
from ...sql.operators import OperatorClass
+from ...sql.sqltypes import _CT_JSON
+from ...sql.sqltypes import _T_JSON
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
__visit_name__ = "JSONPATH"
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
"""Represent the PostgreSQL JSON type.
:class:`_postgresql.JSON` is used automatically whenever the base
if astext_type is not None:
self.astext_type = astext_type
- class Comparator(sqltypes.JSON.Comparator[_T]):
+ class Comparator(sqltypes.JSON.Comparator[_CT_JSON]):
"""Define comparison operations for :class:`_types.JSON`."""
- type: JSON
+ type: JSON[_CT_JSON]
@property
def astext(self) -> ColumnElement[str]:
comparator_factory = Comparator
-class JSONB(JSON):
+class JSONB(JSON[_T_JSON]):
"""Represent the PostgreSQL JSONB type.
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
else:
return super().coerce_compared_value(op, value)
- class Comparator(JSON.Comparator[_T]):
+ class Comparator(JSON.Comparator[_CT_JSON]):
"""Define comparison operations for :class:`_types.JSON`."""
- type: JSONB
+ type: JSONB[_CT_JSON]
def has_key(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of a key (equivalent of
def delete_path(
self, array: Union[List[str], _pg_array[str]]
- ) -> ColumnElement[JSONB]:
+ ) -> ColumnElement[_CT_JSON]:
"""JSONB expression. Deletes field or array element specified in
the argument array (equivalent of the ``#-`` operator).
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
from ... import types as sqltypes
+from ...sql.sqltypes import _T_JSON
+
+if TYPE_CHECKING:
+ from ...engine.interfaces import Dialect
+ from ...sql.type_api import _BindProcessorType
+ from ...sql.type_api import _LiteralProcessorType
-class JSON(sqltypes.JSON):
+class JSON(sqltypes.JSON[_T_JSON]):
"""SQLite JSON type.
SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin:
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
raise NotImplementedError()
- def bind_processor(self, dialect):
- super_proc = self.string_bind_processor(dialect)
+ def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
+ super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> Any:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return process
- def literal_processor(self, dialect):
- super_proc = self.string_literal_processor(dialect)
+ def literal_processor(
+ self, dialect: Dialect
+ ) -> _LiteralProcessorType[Any]:
+ super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501
- def process(value):
+ def process(value: Any) -> str:
value = self._format_value(value)
if super_proc:
value = super_proc(value)
- return value
+ return value # type: ignore[no-any-return]
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
if isinstance(value, int):
- value = "$[%s]" % value
+ formatted_value = "$[%s]" % value
else:
- value = '$."%s"' % value
- return value
+ formatted_value = '$."%s"' % value
+ return formatted_value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
- def _format_value(self, value):
+ def _format_value(self, value: Any) -> str:
return "$%s" % (
"".join(
[
from .expression import TableClause as TableClause
from .expression import TableSample as TableSample
from .expression import tablesample as tablesample
+from .expression import TableValuedAlias as TableValuedAlias
+from .expression import TableValuedColumn as TableValuedColumn
from .expression import text as text
from .expression import TextClause as TextClause
from .expression import true as true
from .elements import RollbackToSavepointClause as RollbackToSavepointClause
from .elements import SavepointClause as SavepointClause
from .elements import SQLColumnExpression as SQLColumnExpression
+from .elements import TableValuedColumn as TableValuedColumn
from .elements import TextClause as TextClause
from .elements import True_ as True_
from .elements import TryCast as TryCast
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
-from typing import TypeVar
+from typing import TypeAlias
from typing import Union
from uuid import UUID as _python_UUID
from ..util.typing import is_literal
from ..util.typing import is_pep695
from ..util.typing import TupleAny
+from ..util.typing import TypeVar
if TYPE_CHECKING:
from ._typing import _ColumnExpressionArgument
from ..engine.interfaces import Dialect
from ..util.typing import _MatchedOnType
-_T = TypeVar("_T", bound="Any")
+_T = TypeVar("_T", bound=Any)
+
+_JSON_VALUE: TypeAlias = (
+ str | int | bool | None | dict[str, "_JSON_VALUE"] | list["_JSON_VALUE"]
+)
+
+_T_JSON = TypeVar(
+ "_T_JSON",
+ bound=Any,
+ default=_JSON_VALUE,
+)
+_CT_JSON = TypeVar(
+ "_CT_JSON",
+ bound=Any,
+ default=_JSON_VALUE,
+)
+
_CT = TypeVar("_CT", bound=Any)
-_TE = TypeVar("_TE", bound="TypeEngine[Any]")
+_TE = TypeVar("_TE", bound=TypeEngine[Any])
_P = TypeVar("_P")
return process
-class JSON(Indexable, TypeEngine[Any]):
+class JSON(Indexable, TypeEngine[_T_JSON]):
"""Represent a SQL JSON type.
.. note:: :class:`_types.JSON`
__visit_name__ = "json_path"
- class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]):
+ class Comparator(
+ Indexable.Comparator[_CT_JSON], Concatenable.Comparator[_CT_JSON]
+ ):
"""Define comparison operations for :class:`_types.JSON`."""
__slots__ = ()
- type: JSON
+ type: JSON[_CT_JSON]
def _setup_getitem(self, index):
if not isinstance(index, str) and isinstance(
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeGuard
-from typing import TypeVar
from typing import Union
import typing_extensions
from typing_extensions import Unpack as Unpack # 3.11
from typing_extensions import Never as Never # 3.11
from typing_extensions import LiteralString as LiteralString # 3.11
+ from typing_extensions import TypeVar as TypeVar # 3.13 for default
_T = TypeVar("_T", bound=Any)
+from typing import Any
+from typing import assert_type
+from typing import Dict
+from typing import List
+from typing import Tuple
+
+from sqlalchemy import Column
+from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy.dialects.mysql import insert
+from sqlalchemy.dialects.mysql import JSON
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.functions import Function
class Base(DeclarativeBase):
pass
-class Test(Base):
- __tablename__ = "test_table_json"
+def test_insert_on_duplicate_key_update() -> None:
+ """Test INSERT with ON DUPLICATE KEY UPDATE."""
+
+ class Test(Base):
+ __tablename__ = "test_table_json"
+ id = mapped_column(Integer, primary_key=True)
+ data: Mapped[str] = mapped_column()
+
+ insert(Test).on_duplicate_key_update(
+ {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo"
+ ).inserted.foo.desc()
+
+
+def test_json_column_types() -> None:
+ """Test JSON column type inference with type parameters."""
+ c_json: Column[Any] = Column(JSON())
+ assert_type(c_json, Column[Any])
+
+ c_json_dict = Column(JSON[Dict[str, Any]]())
+ assert_type(c_json_dict, Column[Dict[str, Any]])
+
+
+def test_json_orm_mapping() -> None:
+ """Test JSON type in ORM mapped columns."""
+
+ class JSONTest(Base):
+ __tablename__ = "test_json"
+ id = mapped_column(Integer, primary_key=True)
+ json_data: Mapped[Dict[str, Any]] = mapped_column(JSON)
+ json_list: Mapped[List[Any]] = mapped_column(JSON)
+
+ json_obj = JSONTest()
+ assert_type(json_obj.json_data, Dict[str, Any])
+ assert_type(json_obj.json_list, List[Any])
+
- id = mapped_column(Integer, primary_key=True)
- data: Mapped[str] = mapped_column()
+def test_json_func_with_type_param() -> None:
+ """Test func with parameterized JSON types (issue #13131)."""
+ class JSONTest(Base):
+ __tablename__ = "test_json_func"
+ id = mapped_column(Integer, primary_key=True)
+ data: Mapped[Dict[str, Any]] = mapped_column(JSON)
-insert(Test).on_duplicate_key_update(
- {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo"
-).inserted.foo.desc()
+ json_func_result = func.json_extract(
+ JSONTest.data, "$.items", type_=JSON[List[Tuple[int, int]]]
+ )
+ assert_type(json_func_result, Function[List[Tuple[int, int]]])
from typing import Any
from typing import assert_type
from typing import Dict
+from typing import List
from typing import Sequence
+from typing import Tuple
from uuid import UUID as _py_uuid
-from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import func
from sqlalchemy import Integer
-from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
+from sqlalchemy import TableValuedColumn
from sqlalchemy import Text
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import aggregate_order_by
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import DATERANGE
from sqlalchemy.dialects.postgresql import ExcludeConstraint
+from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import INT4RANGE
from sqlalchemy.dialects.postgresql import INT8MULTIRANGE
+from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import Range
from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
+from sqlalchemy.sql.functions import Function
+from sqlalchemy.sql.sqltypes import _JSON_VALUE
-# test #6402
-c1 = Column(UUID())
+class Base(DeclarativeBase):
+ pass
-assert_type(c1, Column[_py_uuid])
-c2 = Column(UUID(as_uuid=False))
+def test_uuid_column_types() -> None:
+ """Test UUID column type inference with as_uuid parameter."""
+ c1 = Column(UUID())
+ assert_type(c1, Column[_py_uuid])
-assert_type(c2, Column[str])
+ c2 = Column(UUID(as_uuid=False))
+ assert_type(c2, Column[str])
-class Base(DeclarativeBase):
- pass
+def test_range_column_types() -> None:
+ """Test Range and MultiRange column type inference."""
+ assert_type(Column(INT4RANGE()), Column[Range[int]])
+ assert_type(Column("foo", DATERANGE()), Column[Range[date]])
+ assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]])
+ assert_type(
+ Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]]
+ )
+
+
+def test_range_in_select() -> None:
+ """Test Range types in SELECT statements."""
+ range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
+ assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]])
+
+
+def test_array_type_inference() -> None:
+ array_from_ints = array(range(2))
+ assert_type(array_from_ints, array[int])
+
+ array_of_strings = array([], type_=Text)
+ assert_type(array_of_strings, array[str])
+
+ array_of_ints = array([0], type_=Integer)
+ assert_type(array_of_ints, array[int])
+
+ # EXPECTED_MYPY_RE: Cannot infer .* of "array"
+ array([0], type_=Text)
-class Test(Base):
- __tablename__ = "test_table_json"
+def test_array_column_types() -> None:
+ """Test ARRAY column type inference."""
+ assert_type(ARRAY(Text), ARRAY[str])
+ assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]])
- id = mapped_column(Integer, primary_key=True)
- data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
- ident: Mapped[_py_uuid] = mapped_column(UUID())
+def test_array_agg_functions() -> None:
+ """Test array_agg function type inference."""
+ stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+ assert_type(stmt_array_agg, Select[Sequence[int]])
- ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
- __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),)
+def test_array_agg_with_aggregate_order_by() -> None:
+ """Test array_agg with aggregate_order_by."""
+ class Test(Base):
+ __tablename__ = "test_array_agg"
+ id = mapped_column(Integer, primary_key=True)
+ ident_str: Mapped[str] = mapped_column()
+ ident: Mapped[_py_uuid] = mapped_column(UUID())
-elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem")
+ assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]])
-stmt = select(Test).where(
- or_(
- cast("example code", ARRAY(Text)).contained_by(
- array([select(elem["code"].astext).scalar_subquery()])
- ),
- cast("stefan", ARRAY(Text)).contained_by(
- array([select(elem["code"]["new_value"].astext).scalar_subquery()])
- ),
+ stmt_array_agg_order_by_1 = select(
+ func.array_agg(
+ aggregate_order_by(
+ Column("title", type_=Text),
+ Column("date", type_=DATERANGE).desc(),
+ Column("id", type_=Integer),
+ ),
+ )
)
-)
-print(stmt)
+ assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]])
+ stmt_array_agg_order_by_2 = select(
+ func.array_agg(
+ aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+ )
+ )
+ assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]])
-t1 = Test()
-assert_type(t1.data, dict[str, Any])
+def test_json_parameterization() -> None:
-assert_type(t1.ident, _py_uuid)
+ # test default type
+ x: JSON = JSON()
-unique = UniqueConstraint(name="my_constraint")
-insert(Test).on_conflict_do_nothing(
- "foo", [Test.id], Test.id > 0
-).on_conflict_do_update(
- unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
-).excluded.foo.desc()
+ assert_type(x, JSON[_JSON_VALUE])
-s1 = insert(Test)
-s1.on_conflict_do_update(set_=s1.excluded)
+ # test column values
+ s1 = select(Column(JSON()))
-assert_type(Column(INT4RANGE()), Column[Range[int]])
-assert_type(Column("foo", DATERANGE()), Column[Range[date]])
-assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]])
-assert_type(Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]])
+ assert_type(s1, Select[_JSON_VALUE])
+ c1: Column[list[int]] = Column(JSON())
+ s2 = select(c1)
-range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
+ assert_type(s2, Select[list[int]])
-assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]])
-array_from_ints = array(range(2))
+def test_jsonb_parameterization() -> None:
-assert_type(array_from_ints, array[int])
+ # test default type
+ x: JSONB = JSONB()
-array_of_strings = array([], type_=Text)
+ assert_type(x, JSONB[_JSON_VALUE])
-assert_type(array_of_strings, array[str])
+ # test column values
-array_of_ints = array([0], type_=Integer)
+ s1 = select(Column(JSONB()))
-assert_type(array_of_ints, array[int])
+ assert_type(s1, Select[_JSON_VALUE])
-# EXPECTED_MYPY_RE: Cannot infer .* of "array"
-array([0], type_=Text)
+ c1: Column[list[int]] = Column(JSONB())
+ s2 = select(c1)
-assert_type(ARRAY(Text), ARRAY[str])
+ assert_type(s2, Select[list[int]])
-assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]])
-stmt_array_agg = select(func.array_agg(Column("num", type_=Integer)))
+def test_json_column_types() -> None:
+ """Test JSON column type inference with type parameters."""
-assert_type(stmt_array_agg, Select[Sequence[int]])
+ c_json: Column[Any] = Column(JSON())
+ assert_type(c_json, Column[Any])
-assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]])
+ c_json_dict = Column(JSON[Dict[str, Any]]())
+ assert_type(c_json_dict, Column[Dict[str, Any]])
-stmt_array_agg_order_by_1 = select(
- func.array_agg(
- aggregate_order_by(
- Column("title", type_=Text),
- Column("date", type_=DATERANGE).desc(),
- Column("id", type_=Integer),
- ),
- )
-)
-assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]])
+def test_json_orm_mapping() -> None:
+ """Test JSON type in ORM mapped columns."""
+
+ class JSONTest(Base):
+ __tablename__ = "test_json"
+ id = mapped_column(Integer, primary_key=True)
+ json_data: Mapped[Dict[str, Any]] = mapped_column(JSON)
+ json_list: Mapped[List[Any]] = mapped_column(JSON)
+
+ json_obj = JSONTest()
+ assert_type(json_obj.json_data, Dict[str, Any])
+ assert_type(json_obj.json_list, List[Any])
+
+
+def test_jsonb_column_types() -> None:
+ """Test JSONB column type inference with type parameters."""
+ c_jsonb: Column[Any] = Column(JSONB())
+ assert_type(c_jsonb, Column[Any])
+
+ c_jsonb_dict = Column(JSONB[Dict[str, Any]]())
+ assert_type(c_jsonb_dict, Column[Dict[str, Any]])
+
+
+def test_jsonb_orm_mapping() -> None:
+ """Test JSONB type in ORM mapped columns."""
+
+ class JSONBTest(Base):
+ __tablename__ = "test_jsonb"
+ id = mapped_column(Integer, primary_key=True)
+ jsonb_data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+
+ jsonb_obj = JSONBTest()
+ assert_type(jsonb_obj.jsonb_data, Dict[str, Any])
-stmt_array_agg_order_by_2 = select(
- func.array_agg(
- aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident),
+
+def test_jsonb_func_with_type_param() -> None:
+ """Test func with parameterized JSONB types (issue #13131)."""
+
+ class JSONBTest(Base):
+ __tablename__ = "test_jsonb_func"
+ id = mapped_column(Integer, primary_key=True)
+ data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+
+ json_func_result = func.jsonb_path_query_array(
+ JSONBTest.data, "$.items", type_=JSONB[List[Tuple[int, int]]]
)
-)
+ assert_type(json_func_result, Function[List[Tuple[int, int]]])
+
+
+def test_hstore_column_types() -> None:
+ """Test HSTORE column type inference."""
+ c_hstore: Column[dict[str, str | None]] = Column(HSTORE())
+ assert_type(c_hstore, Column[dict[str, str | None]])
+
+
+def test_hstore_orm_mapping() -> None:
+ """Test HSTORE type in ORM mapped columns."""
+
+ class HSTORETest(Base):
+ __tablename__ = "test_hstore"
+ id = mapped_column(Integer, primary_key=True)
+ hstore_data: Mapped[dict[str, str | None]] = mapped_column(HSTORE)
+
+ hstore_obj = HSTORETest()
+ assert_type(hstore_obj.hstore_data, dict[str, str | None])
+
+
+def test_hstore_func() -> None:
+
+ my_func = func.foobar(type_=HSTORE)
+
+ stmt = select(my_func)
+ assert_type(stmt, Select[dict[str, str | None]])
+
+
+def test_insert_on_conflict() -> None:
+ """Test INSERT with ON CONFLICT clauses."""
+
+ class Test(Base):
+ __tablename__ = "test_dml"
+ id = mapped_column(Integer, primary_key=True)
+ data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+ ident: Mapped[_py_uuid] = mapped_column(UUID())
+ ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
+ __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),)
+
+ unique = UniqueConstraint(name="my_constraint")
+ insert(Test).on_conflict_do_nothing(
+ "foo", [Test.id], Test.id > 0
+ ).on_conflict_do_update(
+ unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
+ ).excluded.foo.desc()
+
+ s1 = insert(Test)
+ s1.on_conflict_do_update(set_=s1.excluded)
+
+
+def test_complex_jsonb_query() -> None:
+ """Test complex query with JSONB array elements."""
+
+ class Test(Base):
+ __tablename__ = "test_complex"
+ id = mapped_column(Integer, primary_key=True)
+ data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+ ident: Mapped[_py_uuid] = mapped_column(UUID())
+ ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
+
+ elem = func.jsonb_array_elements(
+ Test.data, type_=JSONB[list[str]]
+ ).column_valued("elem")
+
+ assert_type(elem, TableValuedColumn[list[str]])
-assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]])
+ t1 = Test()
+ assert_type(t1.data, dict[str, Any])
+ assert_type(t1.ident, _py_uuid)
from decimal import Decimal
from typing import assert_type
+from sqlalchemy import Column
from sqlalchemy import Float
+from sqlalchemy import JSON
from sqlalchemy import Numeric
+from sqlalchemy import Select
+from sqlalchemy import select
+from sqlalchemy.sql.sqltypes import _JSON_VALUE
+
assert_type(Float(), Float[float])
assert_type(Float(asdecimal=True), Float[Decimal])
assert_type(Numeric(), Numeric[Decimal])
assert_type(Numeric(asdecimal=False), Numeric[float])
+
+
+def test_json_value_type() -> None:
+
+ j1: _JSON_VALUE = {
+ "foo": "bar",
+ "bat": {"value1": True},
+ "hoho": [1, 2, 3],
+ }
+ j2: _JSON_VALUE = "foo"
+ j3: _JSON_VALUE = 5
+ j4: _JSON_VALUE = False
+ j5: _JSON_VALUE = None
+ j6: _JSON_VALUE = [None, 5, "foo", False]
+ j7: _JSON_VALUE = { # noqa: F841
+ "j1": j1,
+ "j2": j2,
+ "j3": j3,
+ "j4": j4,
+ "j5": j5,
+ "j6": j6,
+ }
+
+
+def test_json_parameterization() -> None:
+
+ # test default type
+ x: JSON = JSON()
+
+ assert_type(x, JSON[_JSON_VALUE])
+
+ # test column values
+
+ s1 = select(Column(JSON()))
+
+ assert_type(s1, Select[_JSON_VALUE])
+
+ c1: Column[list[int]] = Column(JSON())
+ s2 = select(c1)
+
+ assert_type(s2, Select[list[int]])
+
+ c2 = Column(JSON[list[int]]())
+ s3 = select(c2)
+
+ assert_type(s3, Select[list[int]])