From: Mike Bayer Date: Fri, 6 May 2022 14:13:03 +0000 (-0400) Subject: accept for literal coercions X-Git-Tag: rel_2_0_0b1~321^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=457902ccc3a3aa2c435887ffd308b4d86b522df2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git accept for literal coercions this may be needed in many more places but cast() is a prominent one. Change-Id: I5331edd2d34c54910e4ca16b0553f64fc9167af7 --- diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 605f75ec4f..0eaeae66ed 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -47,8 +47,8 @@ from .functions import FunctionElement from ..util.typing import Literal if typing.TYPE_CHECKING: - from . import sqltypes from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _TypeEngineArgument from .elements import BinaryExpression from .functions import FunctionElement @@ -289,7 +289,7 @@ def collate( def between( - expr: _ColumnExpressionArgument[_T], + expr: _ColumnExpressionOrLiteralArgument[_T], lower_bound: Any, upper_bound: Any, symmetric: bool = False, @@ -782,7 +782,7 @@ def case( def cast( - expression: _ColumnExpressionArgument[Any], + expression: _ColumnExpressionOrLiteralArgument[Any], type_: _TypeEngineArgument[_T], ) -> Cast[_T]: r"""Produce a ``CAST`` expression. @@ -1544,7 +1544,7 @@ def tuple_( def type_coerce( - expression: _ColumnExpressionArgument[Any], + expression: _ColumnExpressionOrLiteralArgument[Any], type_: _TypeEngineArgument[_T], ) -> TypeCoerce[_T]: r"""Associate a SQL expression with a particular type, without rendering diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 1df530dbd6..f49a6d3ec5 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -162,6 +162,9 @@ overall which brings in the TextClause object also. """ +_ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] + + _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" diff --git a/test/ext/mypy/plain_files/pg_stuff.py b/test/ext/mypy/plain_files/pg_stuff.py new file mode 100644 index 0000000000..ce02723972 --- /dev/null +++ b/test/ext/mypy/plain_files/pg_stuff.py @@ -0,0 +1,37 @@ +from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import func +from sqlalchemy import Integer +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy import Text +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass + + +class Test(Base): + __tablename__ = "test_table_json" + + id = Column(Integer, primary_key=True) + data = Column(JSONB) + + +elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem") + +stmt = select(Test).where( + or_( + cast("example code", ARRAY(Text)).contained_by( + array([select(elem["code"].astext).scalar_subquery()]) + ), + cast("stefan", ARRAY(Text)).contained_by( + array([select(elem["code"]["new_value"].astext).scalar_subquery()]) + ), + ) +) +print(stmt)