From: Mike Bayer Date: Wed, 18 Feb 2026 15:12:52 +0000 (-0500) Subject: allow JSON, JSONB, etc. to be parameterized, type HSTORE X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b88d9e0969129dfc3efd5bcad617bd4872e6ea9c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git allow JSON, JSONB, etc. to be parameterized, type HSTORE Improved typing of :class:`_sqltypes.JSON` as well as dialect specific variants like :class:`_postgresql.JSON` to include generic capabilities, so that the types may be parameterized to indicate any specific type of contents expected, e.g. ``JSONB[list[str]]()``. Also types HSTORE Fixes: #13131 Change-Id: Ia089ba4e3cebf6339a5420b2923cd267c4e6891a --- diff --git a/doc/build/changelog/unreleased_21/13131.rst b/doc/build/changelog/unreleased_21/13131.rst new file mode 100644 index 0000000000..978bc06b0c --- /dev/null +++ b/doc/build/changelog/unreleased_21/13131.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, typing + :tickets: 13131 + + Improved typing of :class:`_sqltypes.JSON` as well as dialect specific + variants like :class:`_postgresql.JSON` to include generic capabilities, so + that the types may be parameterized to indicate any specific type of + contents expected, e.g. ``JSONB[list[str]]()``. + diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index b594e3c665..b74bae21af 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -204,6 +204,7 @@ from .sql.expression import TableClause as TableClause from .sql.expression import TableSample as TableSample from .sql.expression import tablesample as tablesample from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import TableValuedColumn as TableValuedColumn from .sql.expression import text as text from .sql.expression import TextAsFrom as TextAsFrom from .sql.expression import TextClause as TextClause diff --git a/lib/sqlalchemy/dialects/mssql/json.py b/lib/sqlalchemy/dialects/mssql/json.py index cb071884ac..4c128326a4 100644 --- a/lib/sqlalchemy/dialects/mssql/json.py +++ b/lib/sqlalchemy/dialects/mssql/json.py @@ -4,9 +4,18 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING from ... import types as sqltypes +from ...sql.sqltypes import _T_JSON + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType # technically, all the dialect-specific datatypes that don't have any special # behaviors would be private with names like _MSJson. However, we haven't been @@ -16,7 +25,7 @@ from ... import types as sqltypes # package-private at once. -class JSON(sqltypes.JSON): +class JSON(sqltypes.JSON[_T_JSON]): """MSSQL JSON type. MSSQL supports JSON-formatted data as of SQL Server 2016. @@ -82,13 +91,13 @@ class JSON(sqltypes.JSON): # these are not generalizable to all JSON implementations, remain separately # implemented for each dialect. class _FormatTypeMixin: - def _format_value(self, value): + def _format_value(self, value: Any) -> str: raise NotImplementedError() - def bind_processor(self, dialect): - super_proc = self.string_bind_processor(dialect) + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> Any: value = self._format_value(value) if super_proc: value = super_proc(value) @@ -96,29 +105,31 @@ class _FormatTypeMixin: return process - def literal_processor(self, dialect): - super_proc = self.string_literal_processor(dialect) + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> str: value = self._format_value(value) if super_proc: value = super_proc(value) - return value + return value # type: ignore[no-any-return] return process class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: if isinstance(value, int): - value = "$[%s]" % value + formatted_value = "$[%s]" % value else: - value = '$."%s"' % value - return value + formatted_value = '$."%s"' % value + return formatted_value class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: return "$%s" % ( "".join( [ diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 75ec79baac..9cfd643610 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -2631,7 +2631,7 @@ class MySQLTypeCompiler( def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str: return "VARBINARY(%d)" % type_.length # type: ignore[str-format] - def visit_JSON(self, type_: JSON, **kw: Any) -> str: + def visit_JSON(self, type_: JSON[Any], **kw: Any) -> str: return "JSON" def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str: diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 7e2606ccf9..5c564d73b9 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -10,6 +10,7 @@ from typing import Any from typing import TYPE_CHECKING from ... import types as sqltypes +from ...sql.sqltypes import _T_JSON if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -17,7 +18,7 @@ if TYPE_CHECKING: from ...sql.type_api import _LiteralProcessorType -class JSON(sqltypes.JSON): +class JSON(sqltypes.JSON[_T_JSON]): """MySQL JSON type. MySQL supports JSON as of version 5.7. diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 4c9db15d82..91666e71ce 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -4,10 +4,12 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations import re +from typing import Any +from typing import Optional from .array import ARRAY from .operators import CONTAINED_BY @@ -20,10 +22,17 @@ from ... import types as sqltypes from ...sql import functions as sqlfunc from ...types import OperatorClass + __all__ = ("HSTORE", "hstore") +_HSTORE_VAL = dict[str, str | None] + -class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): +class HSTORE( + sqltypes.Indexable, + sqltypes.Concatenable, + sqltypes.TypeEngine[_HSTORE_VAL], +): """Represent the PostgreSQL HSTORE type. The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: @@ -112,7 +121,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): | OperatorClass.CONCATENABLE ) - def __init__(self, text_type=None): + def __init__(self, text_type: Optional[Any] = None) -> None: """Construct a new :class:`.HSTORE`. :param text_type: the type that should be used for indexed values. @@ -123,25 +132,26 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): self.text_type = text_type class Comparator( - sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + sqltypes.Indexable.Comparator[_HSTORE_VAL], + sqltypes.Concatenable.Comparator[_HSTORE_VAL], ): """Define comparison operations for :class:`.HSTORE`.""" - def has_key(self, other): + def has_key(self, other: Any) -> Any: """Boolean expression. Test for presence of a key. Note that the key may be a SQLA expression. """ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) - def has_all(self, other): + def has_all(self, other: Any) -> Any: """Boolean expression. Test for presence of all keys in jsonb""" return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) - def has_any(self, other): + def has_any(self, other: Any) -> Any: """Boolean expression. Test for presence of any key in jsonb""" return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) - def contains(self, other, **kwargs): + def contains(self, other: Any, **kwargs: Any) -> Any: """Boolean expression. Test if keys (or array) are a superset of/contained the keys of the argument jsonb expression. @@ -150,7 +160,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other): + def contained_by(self, other: Any) -> Any: """Boolean expression. Test if keys are a proper subset of the keys of the argument jsonb expression. """ @@ -158,16 +168,16 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def _setup_getitem(self, index): - return GETITEM, index, self.type.text_type + def _setup_getitem(self, index: Any) -> Any: + return GETITEM, index, self.type.text_type # type: ignore - def defined(self, key): + def defined(self, key: Any) -> Any: """Boolean expression. Test for presence of a non-NULL value for the key. Note that the key may be a SQLA expression. """ return _HStoreDefinedFunction(self.expr, key) - def delete(self, key): + def delete(self, key: Any) -> Any: """HStore expression. Returns the contents of this hstore with the given key deleted. Note that the key may be a SQLA expression. """ @@ -175,37 +185,37 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): key = _serialize_hstore(key) return _HStoreDeleteFunction(self.expr, key) - def slice(self, array): + def slice(self, array: Any) -> Any: """HStore expression. Returns a subset of an hstore defined by array of keys. """ return _HStoreSliceFunction(self.expr, array) - def keys(self): + def keys(self) -> Any: """Text array expression. Returns array of keys.""" return _HStoreKeysFunction(self.expr) - def vals(self): + def vals(self) -> Any: """Text array expression. Returns array of values.""" return _HStoreValsFunction(self.expr) - def array(self): + def array(self) -> Any: """Text array expression. Returns array of alternating keys and values. """ return _HStoreArrayFunction(self.expr) - def matrix(self): + def matrix(self) -> Any: """Text array expression. Returns array of [key, value] pairs.""" return _HStoreMatrixFunction(self.expr) comparator_factory = Comparator - def bind_processor(self, dialect): + def bind_processor(self, dialect: Any) -> Any: # note that dialect-specific types like that of psycopg and # psycopg2 will override this method to allow driver-level conversion # instead, see _PsycopgHStore - def process(value): + def process(value: Any) -> Any: if isinstance(value, dict): return _serialize_hstore(value) else: @@ -213,11 +223,11 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return process - def result_processor(self, dialect, coltype): + def result_processor(self, dialect: Any, coltype: Any) -> Any: # note that dialect-specific types like that of psycopg and # psycopg2 will override this method to allow driver-level conversion # instead, see _PsycopgHStore - def process(value): + def process(value: Any) -> Any: if value is not None: return _parse_hstore(value) else: @@ -226,7 +236,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return process -class hstore(sqlfunc.GenericFunction): +class hstore(sqlfunc.GenericFunction[_HSTORE_VAL]): """Construct an hstore value within a SQL expression using the PostgreSQL ``hstore()`` function. @@ -252,48 +262,48 @@ class hstore(sqlfunc.GenericFunction): """ - type = HSTORE + type = HSTORE() name = "hstore" inherit_cache = True -class _HStoreDefinedFunction(sqlfunc.GenericFunction): - type = sqltypes.Boolean +class _HStoreDefinedFunction(sqlfunc.GenericFunction[bool]): + type = sqltypes.Boolean() name = "defined" inherit_cache = True -class _HStoreDeleteFunction(sqlfunc.GenericFunction): - type = HSTORE +class _HStoreDeleteFunction(sqlfunc.GenericFunction[_HSTORE_VAL]): + type = HSTORE() name = "delete" inherit_cache = True -class _HStoreSliceFunction(sqlfunc.GenericFunction): - type = HSTORE +class _HStoreSliceFunction(sqlfunc.GenericFunction[_HSTORE_VAL]): + type = HSTORE() name = "slice" inherit_cache = True -class _HStoreKeysFunction(sqlfunc.GenericFunction): +class _HStoreKeysFunction(sqlfunc.GenericFunction[Any]): type = ARRAY(sqltypes.Text) name = "akeys" inherit_cache = True -class _HStoreValsFunction(sqlfunc.GenericFunction): +class _HStoreValsFunction(sqlfunc.GenericFunction[Any]): type = ARRAY(sqltypes.Text) name = "avals" inherit_cache = True -class _HStoreArrayFunction(sqlfunc.GenericFunction): +class _HStoreArrayFunction(sqlfunc.GenericFunction[Any]): type = ARRAY(sqltypes.Text) name = "hstore_to_array" inherit_cache = True -class _HStoreMatrixFunction(sqlfunc.GenericFunction): +class _HStoreMatrixFunction(sqlfunc.GenericFunction[Any]): type = ARRAY(sqltypes.Text) name = "hstore_to_matrix" inherit_cache = True @@ -329,7 +339,7 @@ HSTORE_DELIMITER_RE = re.compile( ) -def _parse_error(hstore_str, pos): +def _parse_error(hstore_str: str, pos: int) -> str: """format an unmarshalling error.""" ctx = 20 @@ -350,7 +360,7 @@ def _parse_error(hstore_str, pos): ) -def _parse_hstore(hstore_str): +def _parse_hstore(hstore_str: str) -> _HSTORE_VAL: """Parse an hstore from its literal string representation. Attempts to approximate PG's hstore input parsing rules as closely as @@ -362,7 +372,7 @@ def _parse_hstore(hstore_str): """ - result = {} + result: _HSTORE_VAL = {} pos = 0 pair_match = HSTORE_PAIR_RE.match(hstore_str) @@ -392,13 +402,13 @@ def _parse_hstore(hstore_str): return result -def _serialize_hstore(val): +def _serialize_hstore(val: _HSTORE_VAL) -> str: """Serialize a dictionary into an hstore literal. Keys and values must both be strings (except None for values). """ - def esc(s, position): + def esc(s: Optional[str], position: str) -> str: if position == "value" and s is None: return "NULL" elif isinstance(s, str): diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index e50b1f3364..54b33fc65a 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -28,8 +28,9 @@ from .operators import PATH_EXISTS from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast -from ...sql._typing import _T from ...sql.operators import OperatorClass +from ...sql.sqltypes import _CT_JSON +from ...sql.sqltypes import _T_JSON if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -90,7 +91,7 @@ class JSONPATH(JSONPathType): __visit_name__ = "JSONPATH" -class JSON(sqltypes.JSON): +class JSON(sqltypes.JSON[_T_JSON]): """Represent the PostgreSQL JSON type. :class:`_postgresql.JSON` is used automatically whenever the base @@ -200,10 +201,10 @@ class JSON(sqltypes.JSON): if astext_type is not None: self.astext_type = astext_type - class Comparator(sqltypes.JSON.Comparator[_T]): + class Comparator(sqltypes.JSON.Comparator[_CT_JSON]): """Define comparison operations for :class:`_types.JSON`.""" - type: JSON + type: JSON[_CT_JSON] @property def astext(self) -> ColumnElement[str]: @@ -233,7 +234,7 @@ class JSON(sqltypes.JSON): comparator_factory = Comparator -class JSONB(JSON): +class JSONB(JSON[_T_JSON]): """Represent the PostgreSQL JSONB type. The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, @@ -323,10 +324,10 @@ class JSONB(JSON): else: return super().coerce_compared_value(op, value) - class Comparator(JSON.Comparator[_T]): + class Comparator(JSON.Comparator[_CT_JSON]): """Define comparison operations for :class:`_types.JSON`.""" - type: JSONB + type: JSONB[_CT_JSON] def has_key(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Test for presence of a key (equivalent of @@ -367,7 +368,7 @@ class JSONB(JSON): def delete_path( self, array: Union[List[str], _pg_array[str]] - ) -> ColumnElement[JSONB]: + ) -> ColumnElement[_CT_JSON]: """JSONB expression. Deletes field or array element specified in the argument array (equivalent of the ``#-`` operator). diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py index 1a1ee049c6..ac705d661d 100644 --- a/lib/sqlalchemy/dialects/sqlite/json.py +++ b/lib/sqlalchemy/dialects/sqlite/json.py @@ -4,12 +4,21 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING from ... import types as sqltypes +from ...sql.sqltypes import _T_JSON + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _BindProcessorType + from ...sql.type_api import _LiteralProcessorType -class JSON(sqltypes.JSON): +class JSON(sqltypes.JSON[_T_JSON]): """SQLite JSON type. SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note @@ -42,13 +51,13 @@ class JSON(sqltypes.JSON): # these are not generalizable to all JSON implementations, remain separately # implemented for each dialect. class _FormatTypeMixin: - def _format_value(self, value): + def _format_value(self, value: Any) -> str: raise NotImplementedError() - def bind_processor(self, dialect): - super_proc = self.string_bind_processor(dialect) + def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: + super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> Any: value = self._format_value(value) if super_proc: value = super_proc(value) @@ -56,29 +65,31 @@ class _FormatTypeMixin: return process - def literal_processor(self, dialect): - super_proc = self.string_literal_processor(dialect) + def literal_processor( + self, dialect: Dialect + ) -> _LiteralProcessorType[Any]: + super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501 - def process(value): + def process(value: Any) -> str: value = self._format_value(value) if super_proc: value = super_proc(value) - return value + return value # type: ignore[no-any-return] return process class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: if isinstance(value, int): - value = "$[%s]" % value + formatted_value = "$[%s]" % value else: - value = '$."%s"' % value - return value + formatted_value = '$."%s"' % value + return formatted_value class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): - def _format_value(self, value): + def _format_value(self, value: Any) -> str: return "$%s" % ( "".join( [ diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index ae4d32e36a..ade97f1be7 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -97,6 +97,8 @@ from .expression import table as table from .expression import TableClause as TableClause from .expression import TableSample as TableSample from .expression import tablesample as tablesample +from .expression import TableValuedAlias as TableValuedAlias +from .expression import TableValuedColumn as TableValuedColumn from .expression import text as text from .expression import TextClause as TextClause from .expression import true as true diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 37bb6383a5..06e20c1a4e 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -106,6 +106,7 @@ from .elements import ReleaseSavepointClause as ReleaseSavepointClause from .elements import RollbackToSavepointClause as RollbackToSavepointClause from .elements import SavepointClause as SavepointClause from .elements import SQLColumnExpression as SQLColumnExpression +from .elements import TableValuedColumn as TableValuedColumn from .elements import TextClause as TextClause from .elements import True_ as True_ from .elements import TryCast as TryCast diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 22fcfb0a69..21ce3ae1f0 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -32,7 +32,7 @@ from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING -from typing import TypeVar +from typing import TypeAlias from typing import Union from uuid import UUID as _python_UUID @@ -70,6 +70,7 @@ from ..util import warn_deprecated from ..util.typing import is_literal from ..util.typing import is_pep695 from ..util.typing import TupleAny +from ..util.typing import TypeVar if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -85,9 +86,25 @@ if TYPE_CHECKING: from ..engine.interfaces import Dialect from ..util.typing import _MatchedOnType -_T = TypeVar("_T", bound="Any") +_T = TypeVar("_T", bound=Any) + +_JSON_VALUE: TypeAlias = ( + str | int | bool | None | dict[str, "_JSON_VALUE"] | list["_JSON_VALUE"] +) + +_T_JSON = TypeVar( + "_T_JSON", + bound=Any, + default=_JSON_VALUE, +) +_CT_JSON = TypeVar( + "_CT_JSON", + bound=Any, + default=_JSON_VALUE, +) + _CT = TypeVar("_CT", bound=Any) -_TE = TypeVar("_TE", bound="TypeEngine[Any]") +_TE = TypeVar("_TE", bound=TypeEngine[Any]) _P = TypeVar("_P") @@ -2361,7 +2378,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): return process -class JSON(Indexable, TypeEngine[Any]): +class JSON(Indexable, TypeEngine[_T_JSON]): """Represent a SQL JSON type. .. note:: :class:`_types.JSON` @@ -2713,12 +2730,14 @@ class JSON(Indexable, TypeEngine[Any]): __visit_name__ = "json_path" - class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): + class Comparator( + Indexable.Comparator[_CT_JSON], Concatenable.Comparator[_CT_JSON] + ): """Define comparison operations for :class:`_types.JSON`.""" __slots__ = () - type: JSON + type: JSON[_CT_JSON] def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 6a0d2ed85c..01bf0a7b3a 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -35,7 +35,6 @@ from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeGuard -from typing import TypeVar from typing import Union import typing_extensions @@ -53,6 +52,7 @@ if True: # zimports removes the tailing comments from typing_extensions import Unpack as Unpack # 3.11 from typing_extensions import Never as Never # 3.11 from typing_extensions import LiteralString as LiteralString # 3.11 + from typing_extensions import TypeVar as TypeVar # 3.13 for default _T = TypeVar("_T", bound=Any) diff --git a/test/typing/plain_files/dialects/mysql/mysql_stuff.py b/test/typing/plain_files/dialects/mysql/mysql_stuff.py index 3fcdc75a97..222becb267 100644 --- a/test/typing/plain_files/dialects/mysql/mysql_stuff.py +++ b/test/typing/plain_files/dialects/mysql/mysql_stuff.py @@ -1,21 +1,69 @@ +from typing import Any +from typing import assert_type +from typing import Dict +from typing import List +from typing import Tuple + +from sqlalchemy import Column +from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy.dialects.mysql import insert +from sqlalchemy.dialects.mysql import JSON from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.functions import Function class Base(DeclarativeBase): pass -class Test(Base): - __tablename__ = "test_table_json" +def test_insert_on_duplicate_key_update() -> None: + """Test INSERT with ON DUPLICATE KEY UPDATE.""" + + class Test(Base): + __tablename__ = "test_table_json" + id = mapped_column(Integer, primary_key=True) + data: Mapped[str] = mapped_column() + + insert(Test).on_duplicate_key_update( + {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo" + ).inserted.foo.desc() + + +def test_json_column_types() -> None: + """Test JSON column type inference with type parameters.""" + c_json: Column[Any] = Column(JSON()) + assert_type(c_json, Column[Any]) + + c_json_dict = Column(JSON[Dict[str, Any]]()) + assert_type(c_json_dict, Column[Dict[str, Any]]) + + +def test_json_orm_mapping() -> None: + """Test JSON type in ORM mapped columns.""" + + class JSONTest(Base): + __tablename__ = "test_json" + id = mapped_column(Integer, primary_key=True) + json_data: Mapped[Dict[str, Any]] = mapped_column(JSON) + json_list: Mapped[List[Any]] = mapped_column(JSON) + + json_obj = JSONTest() + assert_type(json_obj.json_data, Dict[str, Any]) + assert_type(json_obj.json_list, List[Any]) + - id = mapped_column(Integer, primary_key=True) - data: Mapped[str] = mapped_column() +def test_json_func_with_type_param() -> None: + """Test func with parameterized JSON types (issue #13131).""" + class JSONTest(Base): + __tablename__ = "test_json_func" + id = mapped_column(Integer, primary_key=True) + data: Mapped[Dict[str, Any]] = mapped_column(JSON) -insert(Test).on_duplicate_key_update( - {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo" -).inserted.foo.desc() + json_func_result = func.json_extract( + JSONTest.data, "$.items", type_=JSON[List[Tuple[int, int]]] + ) + assert_type(json_func_result, Function[List[Tuple[int, int]]]) diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 1499251103..088cb64414 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -3,16 +3,17 @@ from datetime import datetime from typing import Any from typing import assert_type from typing import Dict +from typing import List from typing import Sequence +from typing import Tuple from uuid import UUID as _py_uuid -from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import func from sqlalchemy import Integer -from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy import TableValuedColumn from sqlalchemy import Text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import aggregate_order_by @@ -20,9 +21,11 @@ from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import ExcludeConstraint +from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import INT4RANGE from sqlalchemy.dialects.postgresql import INT8MULTIRANGE +from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import Range from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE @@ -30,118 +33,256 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.sqltypes import _JSON_VALUE -# test #6402 -c1 = Column(UUID()) +class Base(DeclarativeBase): + pass -assert_type(c1, Column[_py_uuid]) -c2 = Column(UUID(as_uuid=False)) +def test_uuid_column_types() -> None: + """Test UUID column type inference with as_uuid parameter.""" + c1 = Column(UUID()) + assert_type(c1, Column[_py_uuid]) -assert_type(c2, Column[str]) + c2 = Column(UUID(as_uuid=False)) + assert_type(c2, Column[str]) -class Base(DeclarativeBase): - pass +def test_range_column_types() -> None: + """Test Range and MultiRange column type inference.""" + assert_type(Column(INT4RANGE()), Column[Range[int]]) + assert_type(Column("foo", DATERANGE()), Column[Range[date]]) + assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]]) + assert_type( + Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]] + ) + + +def test_range_in_select() -> None: + """Test Range types in SELECT statements.""" + range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) + assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]]) + + +def test_array_type_inference() -> None: + array_from_ints = array(range(2)) + assert_type(array_from_ints, array[int]) + + array_of_strings = array([], type_=Text) + assert_type(array_of_strings, array[str]) + + array_of_ints = array([0], type_=Integer) + assert_type(array_of_ints, array[int]) + + # EXPECTED_MYPY_RE: Cannot infer .* of "array" + array([0], type_=Text) -class Test(Base): - __tablename__ = "test_table_json" +def test_array_column_types() -> None: + """Test ARRAY column type inference.""" + assert_type(ARRAY(Text), ARRAY[str]) + assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]]) - id = mapped_column(Integer, primary_key=True) - data: Mapped[Dict[str, Any]] = mapped_column(JSONB) - ident: Mapped[_py_uuid] = mapped_column(UUID()) +def test_array_agg_functions() -> None: + """Test array_agg function type inference.""" + stmt_array_agg = select(func.array_agg(Column("num", type_=Integer))) + assert_type(stmt_array_agg, Select[Sequence[int]]) - ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False)) - __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),) +def test_array_agg_with_aggregate_order_by() -> None: + """Test array_agg with aggregate_order_by.""" + class Test(Base): + __tablename__ = "test_array_agg" + id = mapped_column(Integer, primary_key=True) + ident_str: Mapped[str] = mapped_column() + ident: Mapped[_py_uuid] = mapped_column(UUID()) -elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem") + assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]]) -stmt = select(Test).where( - or_( - cast("example code", ARRAY(Text)).contained_by( - array([select(elem["code"].astext).scalar_subquery()]) - ), - cast("stefan", ARRAY(Text)).contained_by( - array([select(elem["code"]["new_value"].astext).scalar_subquery()]) - ), + stmt_array_agg_order_by_1 = select( + func.array_agg( + aggregate_order_by( + Column("title", type_=Text), + Column("date", type_=DATERANGE).desc(), + Column("id", type_=Integer), + ), + ) ) -) -print(stmt) + assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]]) + stmt_array_agg_order_by_2 = select( + func.array_agg( + aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident), + ) + ) + assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]]) -t1 = Test() -assert_type(t1.data, dict[str, Any]) +def test_json_parameterization() -> None: -assert_type(t1.ident, _py_uuid) + # test default type + x: JSON = JSON() -unique = UniqueConstraint(name="my_constraint") -insert(Test).on_conflict_do_nothing( - "foo", [Test.id], Test.id > 0 -).on_conflict_do_update( - unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 -).excluded.foo.desc() + assert_type(x, JSON[_JSON_VALUE]) -s1 = insert(Test) -s1.on_conflict_do_update(set_=s1.excluded) + # test column values + s1 = select(Column(JSON())) -assert_type(Column(INT4RANGE()), Column[Range[int]]) -assert_type(Column("foo", DATERANGE()), Column[Range[date]]) -assert_type(Column(INT8MULTIRANGE()), Column[Sequence[Range[int]]]) -assert_type(Column("foo", TSTZMULTIRANGE()), Column[Sequence[Range[datetime]]]) + assert_type(s1, Select[_JSON_VALUE]) + c1: Column[list[int]] = Column(JSON()) + s2 = select(c1) -range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) + assert_type(s2, Select[list[int]]) -assert_type(range_col_stmt, Select[Range[int], Sequence[Range[int]]]) -array_from_ints = array(range(2)) +def test_jsonb_parameterization() -> None: -assert_type(array_from_ints, array[int]) + # test default type + x: JSONB = JSONB() -array_of_strings = array([], type_=Text) + assert_type(x, JSONB[_JSON_VALUE]) -assert_type(array_of_strings, array[str]) + # test column values -array_of_ints = array([0], type_=Integer) + s1 = select(Column(JSONB())) -assert_type(array_of_ints, array[int]) + assert_type(s1, Select[_JSON_VALUE]) -# EXPECTED_MYPY_RE: Cannot infer .* of "array" -array([0], type_=Text) + c1: Column[list[int]] = Column(JSONB()) + s2 = select(c1) -assert_type(ARRAY(Text), ARRAY[str]) + assert_type(s2, Select[list[int]]) -assert_type(Column(type_=ARRAY(Integer)), Column[Sequence[int]]) -stmt_array_agg = select(func.array_agg(Column("num", type_=Integer))) +def test_json_column_types() -> None: + """Test JSON column type inference with type parameters.""" -assert_type(stmt_array_agg, Select[Sequence[int]]) + c_json: Column[Any] = Column(JSON()) + assert_type(c_json, Column[Any]) -assert_type(select(func.array_agg(Test.ident_str)), Select[Sequence[str]]) + c_json_dict = Column(JSON[Dict[str, Any]]()) + assert_type(c_json_dict, Column[Dict[str, Any]]) -stmt_array_agg_order_by_1 = select( - func.array_agg( - aggregate_order_by( - Column("title", type_=Text), - Column("date", type_=DATERANGE).desc(), - Column("id", type_=Integer), - ), - ) -) -assert_type(stmt_array_agg_order_by_1, Select[Sequence[str]]) +def test_json_orm_mapping() -> None: + """Test JSON type in ORM mapped columns.""" + + class JSONTest(Base): + __tablename__ = "test_json" + id = mapped_column(Integer, primary_key=True) + json_data: Mapped[Dict[str, Any]] = mapped_column(JSON) + json_list: Mapped[List[Any]] = mapped_column(JSON) + + json_obj = JSONTest() + assert_type(json_obj.json_data, Dict[str, Any]) + assert_type(json_obj.json_list, List[Any]) + + +def test_jsonb_column_types() -> None: + """Test JSONB column type inference with type parameters.""" + c_jsonb: Column[Any] = Column(JSONB()) + assert_type(c_jsonb, Column[Any]) + + c_jsonb_dict = Column(JSONB[Dict[str, Any]]()) + assert_type(c_jsonb_dict, Column[Dict[str, Any]]) + + +def test_jsonb_orm_mapping() -> None: + """Test JSONB type in ORM mapped columns.""" + + class JSONBTest(Base): + __tablename__ = "test_jsonb" + id = mapped_column(Integer, primary_key=True) + jsonb_data: Mapped[Dict[str, Any]] = mapped_column(JSONB) + + jsonb_obj = JSONBTest() + assert_type(jsonb_obj.jsonb_data, Dict[str, Any]) -stmt_array_agg_order_by_2 = select( - func.array_agg( - aggregate_order_by(Test.ident_str, Test.id.desc(), Test.ident), + +def test_jsonb_func_with_type_param() -> None: + """Test func with parameterized JSONB types (issue #13131).""" + + class JSONBTest(Base): + __tablename__ = "test_jsonb_func" + id = mapped_column(Integer, primary_key=True) + data: Mapped[Dict[str, Any]] = mapped_column(JSONB) + + json_func_result = func.jsonb_path_query_array( + JSONBTest.data, "$.items", type_=JSONB[List[Tuple[int, int]]] ) -) + assert_type(json_func_result, Function[List[Tuple[int, int]]]) + + +def test_hstore_column_types() -> None: + """Test HSTORE column type inference.""" + c_hstore: Column[dict[str, str | None]] = Column(HSTORE()) + assert_type(c_hstore, Column[dict[str, str | None]]) + + +def test_hstore_orm_mapping() -> None: + """Test HSTORE type in ORM mapped columns.""" + + class HSTORETest(Base): + __tablename__ = "test_hstore" + id = mapped_column(Integer, primary_key=True) + hstore_data: Mapped[dict[str, str | None]] = mapped_column(HSTORE) + + hstore_obj = HSTORETest() + assert_type(hstore_obj.hstore_data, dict[str, str | None]) + + +def test_hstore_func() -> None: + + my_func = func.foobar(type_=HSTORE) + + stmt = select(my_func) + assert_type(stmt, Select[dict[str, str | None]]) + + +def test_insert_on_conflict() -> None: + """Test INSERT with ON CONFLICT clauses.""" + + class Test(Base): + __tablename__ = "test_dml" + id = mapped_column(Integer, primary_key=True) + data: Mapped[Dict[str, Any]] = mapped_column(JSONB) + ident: Mapped[_py_uuid] = mapped_column(UUID()) + ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False)) + __table_args__ = (ExcludeConstraint((Column("ident_str"), "=")),) + + unique = UniqueConstraint(name="my_constraint") + insert(Test).on_conflict_do_nothing( + "foo", [Test.id], Test.id > 0 + ).on_conflict_do_update( + unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 + ).excluded.foo.desc() + + s1 = insert(Test) + s1.on_conflict_do_update(set_=s1.excluded) + + +def test_complex_jsonb_query() -> None: + """Test complex query with JSONB array elements.""" + + class Test(Base): + __tablename__ = "test_complex" + id = mapped_column(Integer, primary_key=True) + data: Mapped[Dict[str, Any]] = mapped_column(JSONB) + ident: Mapped[_py_uuid] = mapped_column(UUID()) + ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False)) + + elem = func.jsonb_array_elements( + Test.data, type_=JSONB[list[str]] + ).column_valued("elem") + + assert_type(elem, TableValuedColumn[list[str]]) -assert_type(stmt_array_agg_order_by_2, Select[Sequence[str]]) + t1 = Test() + assert_type(t1.data, dict[str, Any]) + assert_type(t1.ident, _py_uuid) diff --git a/test/typing/plain_files/sql/sqltypes.py b/test/typing/plain_files/sql/sqltypes.py index 0b5cc1bc92..7240b7c0d5 100644 --- a/test/typing/plain_files/sql/sqltypes.py +++ b/test/typing/plain_files/sql/sqltypes.py @@ -1,11 +1,63 @@ from decimal import Decimal from typing import assert_type +from sqlalchemy import Column from sqlalchemy import Float +from sqlalchemy import JSON from sqlalchemy import Numeric +from sqlalchemy import Select +from sqlalchemy import select +from sqlalchemy.sql.sqltypes import _JSON_VALUE + assert_type(Float(), Float[float]) assert_type(Float(asdecimal=True), Float[Decimal]) assert_type(Numeric(), Numeric[Decimal]) assert_type(Numeric(asdecimal=False), Numeric[float]) + + +def test_json_value_type() -> None: + + j1: _JSON_VALUE = { + "foo": "bar", + "bat": {"value1": True}, + "hoho": [1, 2, 3], + } + j2: _JSON_VALUE = "foo" + j3: _JSON_VALUE = 5 + j4: _JSON_VALUE = False + j5: _JSON_VALUE = None + j6: _JSON_VALUE = [None, 5, "foo", False] + j7: _JSON_VALUE = { # noqa: F841 + "j1": j1, + "j2": j2, + "j3": j3, + "j4": j4, + "j5": j5, + "j6": j6, + } + + +def test_json_parameterization() -> None: + + # test default type + x: JSON = JSON() + + assert_type(x, JSON[_JSON_VALUE]) + + # test column values + + s1 = select(Column(JSON())) + + assert_type(s1, Select[_JSON_VALUE]) + + c1: Column[list[int]] = Column(JSON()) + s2 = select(c1) + + assert_type(s2, Select[list[int]]) + + c2 = Column(JSON[list[int]]()) + s3 = select(c2) + + assert_type(s3, Select[list[int]])