From: Mike Bayer Date: Wed, 16 Jul 2025 16:14:27 +0000 (-0400) Subject: add OperatorClasses to gate mismatched operator use X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=34740d33abffc8e5ec11dd1f2ed98bf42ea078f6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add OperatorClasses to gate mismatched operator use Added a new concept of "operator classes" to the SQL operators supported by SQLAlchemy, represented within the enum :class:`.OperatorClass`. The purpose of this structure is to provide an extra layer of validation when a particular kind of SQL operation is used with a particular datatype, to catch early the use of an operator that does not have any relevance to the datatype in use; a simple example is an integer or numeric column used with a "string match" operator. Fixes: #12736 Change-Id: I44f46d7326aef6847dbf0cf7a325833f8e347da6 --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index a1e4d67bdf..278c446f04 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -590,3 +590,98 @@ the existing ``asyncpg.BitString`` type. :ticket:`10556` +.. _change_12736: + +Operator classes added to validate operator usage with datatypes +---------------------------------------------------------------- + +SQLAlchemy 2.1 introduces a new "operator classes" system that provides +validation when SQL operators are used with specific datatypes. This feature +helps catch usage of operators that are not appropriate for a given datatype +during the initial construction of expression objects. A simple example is an +integer or numeric column used with a "string match" operator. When an +incompatible operation is used, a deprecation warning is emitted; in a future +major release this will raise :class:`.InvalidRequestError`. + +The initial motivation for this new system is to revise the use of the +:meth:`.ColumnOperators.contains` method when used with :class:`_types.JSON` columns. +The :meth:`.ColumnOperators.contains` method in the case of the :class:`_types.JSON` +datatype makes use of the string-oriented version of the method, that +assumes string data and uses LIKE to match substrings. This is not compatible +with the same-named method that is defined by the PostgreSQL +:class:`_postgresql.JSONB` type, which uses PostgreSQL's native JSONB containment +operators. Because :class:`_types.JSON` data is normally stored as a plain string, +:meth:`.ColumnOperators.contains` would "work", and even in trivial cases +behave similarly to that of :class:`_postgresql.JSONB`. However, since the two +operations are not actually compatible at all, this mis-use can easily lead to +unexpected inconsistencies. + +Code that uses :meth:`.ColumnOperators.contains` with :class:`_types.JSON` columns will +now emit a deprecation warning:: + + from sqlalchemy import JSON, select, Column + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + + + class Base(DeclarativeBase): + pass + + + class MyTable(Base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + json_column: Mapped[dict] = mapped_column(JSON) + + + # This will now emit a deprecation warning + select(MyTable).filter(MyTable.json_column.contains("some_value")) + +Above, using :meth:`.ColumnOperators.contains` with :class:`_types.JSON` columns +is considered to be inappropriate, since :meth:`.ColumnOperators.contains` +works as a simple string search without any awareness of JSON structuring. +To explicitly indicate that the JSON data should be searched as a string +using LIKE, the +column should first be cast (using either :func:`_sql.cast` for a full CAST, +or :func:`_sql.type_coerce` for a Python-side cast) to :class:`.String`:: + + from sqlalchemy import type_coerce, String + + # Explicit string-based matching + select(MyTable).filter(type_coerce(MyTable.json_column, String).contains("some_value")) + +This change forces code to distinguish between using string-based "contains" +with a :class:`_types.JSON` column and using PostgreSQL's JSONB containment +operator with :class:`_postgresql.JSONB` columns as separate, explicitly-stated operations. + +The operator class system involves a mapping of SQLAlchemy operators listed +out in :mod:`sqlalchemy.sql.operators` to operator class combinations that come +from the :class:`.OperatorClass` enumeration, which are reconciled at +expression construction time with datatypes using the +:attr:`.TypeEngine.operator_classes` attribute. A custom user defined type +may want to set this attribute to indicate the kinds of operators that make +sense:: + + from sqlalchemy.types import UserDefinedType + from sqlalchemy.sql.sqltypes import OperatorClass + + + class ComplexNumber(UserDefinedType): + operator_classes = OperatorClass.MATH + +The above ``ComplexNumber`` datatype would then validate that operators +used are included in the "math" operator class. By default, user defined +types made with :class:`.UserDefinedType` are left open to accept all +operators by default, whereas classes defined with :class:`.TypeDecorator` +will make use of the operator classes declared by the "impl" type. + +.. seealso:: + + :paramref:`.Operators.op.operator_class` - define an operator class when creating custom operators + + :class:`.OperatorClass` + +:ticket:`12736` + + +` diff --git a/doc/build/changelog/unreleased_21/12736.rst b/doc/build/changelog/unreleased_21/12736.rst new file mode 100644 index 0000000000..c16c9c17d3 --- /dev/null +++ b/doc/build/changelog/unreleased_21/12736.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: bug, sql + :tickets: 12736 + + Added a new concept of "operator classes" to the SQL operators supported by + SQLAlchemy, represented within the enum :class:`.OperatorClass`. The + purpose of this structure is to provide an extra layer of validation when a + particular kind of SQL operation is used with a particular datatype, to + catch early the use of an operator that does not have any relevance to the + datatype in use; a simple example is an integer or numeric column used with + a "string match" operator. + + .. seealso:: + + :ref:`change_12736` + + diff --git a/doc/build/core/operators.rst b/doc/build/core/operators.rst index 7fa163d6e6..b21953200e 100644 --- a/doc/build/core/operators.rst +++ b/doc/build/core/operators.rst @@ -1,5 +1,7 @@ .. highlight:: pycon+sql +.. module:: sqlalchemy.sql.operators + Operator Reference =============================== diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index 7e7da36af5..5e8299ab34 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -193,6 +193,10 @@ The classes here are generated using the constructors listed at .. autoclass:: Null :members: +.. autoclass:: OperatorClass + :members: + :undoc-members: + .. autoclass:: Operators :members: :special-members: diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py index 06aeaace2f..4ad624475c 100644 --- a/lib/sqlalchemy/dialects/oracle/types.py +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -13,6 +13,7 @@ from typing import Type from typing import TYPE_CHECKING from ... import exc +from ...sql import operators from ...sql import sqltypes from ...types import NVARCHAR from ...types import VARCHAR @@ -309,6 +310,7 @@ class ROWID(sqltypes.TypeEngine): """ __visit_name__ = "ROWID" + operator_classes = operators.OperatorClass.ANY class _OracleBoolean(sqltypes.Boolean): diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index 0a915b17df..e7cac4cb4d 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -18,7 +18,7 @@ from .operators import HAS_ANY from .operators import HAS_KEY from ... import types as sqltypes from ...sql import functions as sqlfunc - +from ...types import OperatorClass __all__ = ("HSTORE", "hstore") @@ -105,6 +105,13 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): hashable = False text_type = sqltypes.Text() + operator_classes = ( + OperatorClass.BASE + | OperatorClass.CONTAINS + | OperatorClass.INDEXABLE + | OperatorClass.CONCATENABLE + ) + def __init__(self, text_type=None): """Construct a new :class:`.HSTORE`. diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 06f8db5b2a..9aa805a0fc 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -29,6 +29,7 @@ from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast from ...sql._typing import _T +from ...sql.operators import OperatorClass if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -283,6 +284,8 @@ class JSONB(JSON): __visit_name__ = "JSONB" + operator_classes = OperatorClass.JSON | OperatorClass.CONCATENABLE + class Comparator(JSON.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index ea25ed5caf..10d70cc770 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -36,6 +36,7 @@ from .operators import STRICTLY_LEFT_OF from .operators import STRICTLY_RIGHT_OF from ... import types as sqltypes from ...sql import operators +from ...sql.operators import OperatorClass from ...sql.type_api import TypeEngine if TYPE_CHECKING: @@ -711,6 +712,8 @@ class AbstractRange(sqltypes.TypeEngine[_T]): render_bind_cast = True + operator_classes = OperatorClass.NUMERIC + __abstract__ = True @overload diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 96e5644572..49226b94bd 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -19,6 +19,7 @@ from .bitstring import BitString from ...sql import sqltypes from ...sql import type_api from ...sql.type_api import TypeEngine +from ...types import OperatorClass if TYPE_CHECKING: from ...engine.interfaces import Dialect @@ -57,6 +58,7 @@ class BYTEA(sqltypes.LargeBinary): class _NetworkAddressTypeMixin: + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON def coerce_compared_value( self, op: Optional[OperatorType], value: Any @@ -144,6 +146,8 @@ class OID(sqltypes.TypeEngine[int]): __visit_name__ = "OID" + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON + class REGCONFIG(sqltypes.TypeEngine[str]): """Provide the PostgreSQL REGCONFIG type. @@ -154,6 +158,8 @@ class REGCONFIG(sqltypes.TypeEngine[str]): __visit_name__ = "REGCONFIG" + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON + class TSQUERY(sqltypes.TypeEngine[str]): """Provide the PostgreSQL TSQUERY type. @@ -164,12 +170,16 @@ class TSQUERY(sqltypes.TypeEngine[str]): __visit_name__ = "TSQUERY" + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON + class REGCLASS(sqltypes.TypeEngine[str]): """Provide the PostgreSQL REGCLASS type.""" __visit_name__ = "REGCLASS" + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON + class TIMESTAMP(sqltypes.TIMESTAMP): """Provide the PostgreSQL TIMESTAMP type.""" @@ -274,6 +284,10 @@ class BIT(sqltypes.TypeEngine[BitString]): render_bind_cast = True __visit_name__ = "BIT" + operator_classes = ( + OperatorClass.BASE | OperatorClass.COMPARISON | OperatorClass.BITWISE + ) + def __init__( self, length: Optional[int] = None, varying: bool = False ) -> None: @@ -356,6 +370,8 @@ class TSVECTOR(sqltypes.TypeEngine[str]): __visit_name__ = "TSVECTOR" + operator_classes = OperatorClass.STRING + class CITEXT(sqltypes.TEXT): """Provide the PostgreSQL CITEXT type. diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index eba769f892..ae7fb5ab4e 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -12,7 +12,6 @@ from __future__ import annotations import typing from typing import Any from typing import Callable -from typing import Dict from typing import NoReturn from typing import Optional from typing import Tuple @@ -423,145 +422,233 @@ def _regexp_replace_impl( ) -# a mapping of operators with the method they use, along with -# additional keyword arguments to be passed -operator_lookup: Dict[ +operator_lookup: util.immutabledict[ str, Tuple[ - Callable[..., ColumnElement[Any]], + Callable[..., "ColumnElement[Any]"], util.immutabledict[ - str, Union[OperatorType, Callable[..., ColumnElement[Any]]] + str, Union["OperatorType", Callable[..., "ColumnElement[Any]"]] ], ], -] = { - "and_": (_conjunction_operate, util.EMPTY_DICT), - "or_": (_conjunction_operate, util.EMPTY_DICT), - "inv": (_inv_impl, util.EMPTY_DICT), - "add": (_binary_operate, util.EMPTY_DICT), - "mul": (_binary_operate, util.EMPTY_DICT), - "sub": (_binary_operate, util.EMPTY_DICT), - "div": (_binary_operate, util.EMPTY_DICT), - "mod": (_binary_operate, util.EMPTY_DICT), - "bitwise_xor_op": (_binary_operate, util.EMPTY_DICT), - "bitwise_or_op": (_binary_operate, util.EMPTY_DICT), - "bitwise_and_op": (_binary_operate, util.EMPTY_DICT), - "bitwise_not_op": (_bitwise_not_impl, util.EMPTY_DICT), - "bitwise_lshift_op": (_binary_operate, util.EMPTY_DICT), - "bitwise_rshift_op": (_binary_operate, util.EMPTY_DICT), - "truediv": (_binary_operate, util.EMPTY_DICT), - "floordiv": (_binary_operate, util.EMPTY_DICT), - "custom_op": (_custom_op_operate, util.EMPTY_DICT), - "json_path_getitem_op": (_binary_operate, util.EMPTY_DICT), - "json_getitem_op": (_binary_operate, util.EMPTY_DICT), - "concat_op": (_binary_operate, util.EMPTY_DICT), - "any_op": ( - _scalar, - util.immutabledict({"fn": CollectionAggregate._create_any}), - ), - "all_op": ( - _scalar, - util.immutabledict({"fn": CollectionAggregate._create_all}), - ), - "lt": (_boolean_compare, util.immutabledict({"negate_op": operators.ge})), - "le": (_boolean_compare, util.immutabledict({"negate_op": operators.gt})), - "ne": (_boolean_compare, util.immutabledict({"negate_op": operators.eq})), - "gt": (_boolean_compare, util.immutabledict({"negate_op": operators.le})), - "ge": (_boolean_compare, util.immutabledict({"negate_op": operators.lt})), - "eq": (_boolean_compare, util.immutabledict({"negate_op": operators.ne})), - "is_distinct_from": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.is_not_distinct_from}), - ), - "is_not_distinct_from": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.is_distinct_from}), - ), - "like_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_like_op}), - ), - "ilike_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_ilike_op}), - ), - "not_like_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.like_op}), - ), - "not_ilike_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.ilike_op}), - ), - "contains_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_contains_op}), - ), - "icontains_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_icontains_op}), - ), - "startswith_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_startswith_op}), - ), - "istartswith_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_istartswith_op}), - ), - "endswith_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_endswith_op}), - ), - "iendswith_op": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.not_iendswith_op}), - ), - "desc_op": ( - _scalar, - util.immutabledict({"fn": UnaryExpression._create_desc}), - ), - "asc_op": ( - _scalar, - util.immutabledict({"fn": UnaryExpression._create_asc}), - ), - "nulls_first_op": ( - _scalar, - util.immutabledict({"fn": UnaryExpression._create_nulls_first}), - ), - "nulls_last_op": ( - _scalar, - util.immutabledict({"fn": UnaryExpression._create_nulls_last}), - ), - "in_op": ( - _in_impl, - util.immutabledict({"negate_op": operators.not_in_op}), - ), - "not_in_op": ( - _in_impl, - util.immutabledict({"negate_op": operators.in_op}), - ), - "is_": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.is_}), - ), - "is_not": ( - _boolean_compare, - util.immutabledict({"negate_op": operators.is_not}), - ), - "collate": (_collate_impl, util.EMPTY_DICT), - "match_op": (_match_impl, util.EMPTY_DICT), - "not_match_op": (_match_impl, util.EMPTY_DICT), - "distinct_op": (_distinct_impl, util.EMPTY_DICT), - "between_op": (_between_impl, util.EMPTY_DICT), - "not_between_op": (_between_impl, util.EMPTY_DICT), - "neg": (_neg_impl, util.EMPTY_DICT), - "getitem": (_getitem_impl, util.EMPTY_DICT), - "lshift": (_unsupported_impl, util.EMPTY_DICT), - "rshift": (_unsupported_impl, util.EMPTY_DICT), - "matmul": (_unsupported_impl, util.EMPTY_DICT), - "contains": (_unsupported_impl, util.EMPTY_DICT), - "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), - "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), - "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT), - "pow": (_pow_impl, util.EMPTY_DICT), -} +] = util.immutabledict( + { + "any_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_any}), + ), + "all_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_all}), + ), + "lt": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.ge}), + ), + "le": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.gt}), + ), + "ne": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.eq}), + ), + "gt": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.le}), + ), + "ge": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.lt}), + ), + "eq": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.ne}), + ), + "is_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not_distinct_from}), + ), + "is_not_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_distinct_from}), + ), + "in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.not_in_op}), + ), + "not_in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.in_op}), + ), + "is_": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_}), + ), + "is_not": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not}), + ), + "between_op": ( + _between_impl, + util.EMPTY_DICT, + ), + "not_between_op": ( + _between_impl, + util.EMPTY_DICT, + ), + "desc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_desc}), + ), + "asc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_asc}), + ), + "nulls_first_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_first}), + ), + "nulls_last_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_last}), + ), + "distinct_op": ( + _distinct_impl, + util.EMPTY_DICT, + ), + "null_op": (_binary_operate, util.EMPTY_DICT), + "custom_op": (_custom_op_operate, util.EMPTY_DICT), + "and_": ( + _conjunction_operate, + util.EMPTY_DICT, + ), + "or_": ( + _conjunction_operate, + util.EMPTY_DICT, + ), + "inv": ( + _inv_impl, + util.EMPTY_DICT, + ), + "add": ( + _binary_operate, + util.EMPTY_DICT, + ), + "concat_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "getitem": (_getitem_impl, util.EMPTY_DICT), + "contains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_contains_op}), + ), + "icontains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_icontains_op}), + ), + "contains": ( + _unsupported_impl, + util.EMPTY_DICT, + ), + "like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_like_op}), + ), + "ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_ilike_op}), + ), + "not_like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.like_op}), + ), + "not_ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.ilike_op}), + ), + "startswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_startswith_op}), + ), + "istartswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_istartswith_op}), + ), + "endswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_endswith_op}), + ), + "iendswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_iendswith_op}), + ), + "collate": ( + _collate_impl, + util.EMPTY_DICT, + ), + "match_op": (_match_impl, util.EMPTY_DICT), + "not_match_op": ( + _match_impl, + util.EMPTY_DICT, + ), + "regexp_match_op": ( + _regexp_match_impl, + util.EMPTY_DICT, + ), + "not_regexp_match_op": ( + _regexp_match_impl, + util.EMPTY_DICT, + ), + "regexp_replace_op": ( + _regexp_replace_impl, + util.EMPTY_DICT, + ), + "lshift": (_unsupported_impl, util.EMPTY_DICT), + "rshift": (_unsupported_impl, util.EMPTY_DICT), + "bitwise_xor_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "bitwise_or_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "bitwise_and_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "bitwise_not_op": ( + _bitwise_not_impl, + util.EMPTY_DICT, + ), + "bitwise_lshift_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "bitwise_rshift_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "matmul": (_unsupported_impl, util.EMPTY_DICT), + "pow": (_pow_impl, util.EMPTY_DICT), + "neg": (_neg_impl, util.EMPTY_DICT), + "mul": (_binary_operate, util.EMPTY_DICT), + "sub": ( + _binary_operate, + util.EMPTY_DICT, + ), + "div": (_binary_operate, util.EMPTY_DICT), + "mod": (_binary_operate, util.EMPTY_DICT), + "truediv": (_binary_operate, util.EMPTY_DICT), + "floordiv": (_binary_operate, util.EMPTY_DICT), + "json_path_getitem_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + "json_getitem_op": ( + _binary_operate, + util.EMPTY_DICT, + ), + } +) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 2b0fa95875..8f68e520b8 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -67,6 +67,7 @@ from .cache_key import MemoizedHasCacheKey from .cache_key import NO_CACHE from .coercions import _document_text_coercion # noqa from .operators import ColumnOperators +from .operators import OperatorClass from .traversals import HasCopyInternals from .visitors import cloned_traverse from .visitors import ExternallyTraversible @@ -851,6 +852,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): *, return_type: _TypeEngineArgument[_OPT], python_impl: Optional[Callable[..., Any]] = None, + operator_class: OperatorClass = ..., ) -> Callable[[Any], BinaryExpression[_OPT]]: ... @overload @@ -861,6 +863,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): is_comparison: bool = ..., return_type: Optional[_TypeEngineArgument[Any]] = ..., python_impl: Optional[Callable[..., Any]] = ..., + operator_class: OperatorClass = ..., ) -> Callable[[Any], BinaryExpression[Any]]: ... def op( @@ -870,6 +873,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): is_comparison: bool = False, return_type: Optional[_TypeEngineArgument[Any]] = None, python_impl: Optional[Callable[..., Any]] = None, + operator_class: OperatorClass = OperatorClass.BASE, ) -> Callable[[Any], BinaryExpression[Any]]: ... def bool_op( diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 3d0ff7d7ba..f7847bf7e6 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -116,6 +116,7 @@ from .lambdas import LambdaElement as LambdaElement from .lambdas import StatementLambdaElement as StatementLambdaElement from .operators import ColumnOperators as ColumnOperators from .operators import custom_op as custom_op +from .operators import OperatorClass as OperatorClass from .operators import Operators as Operators from .selectable import Alias as Alias from .selectable import AliasedReturnsRows as AliasedReturnsRows diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 7917c9d283..9d4d86a341 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -12,6 +12,8 @@ from __future__ import annotations +from enum import auto +from enum import Flag from enum import IntEnum from operator import add as _uncast_add from operator import and_ as _uncast_and_ @@ -41,6 +43,7 @@ from typing import Callable from typing import cast from typing import Dict from typing import Generic +from typing import Hashable from typing import Literal from typing import Optional from typing import overload @@ -65,7 +68,59 @@ _T = TypeVar("_T", bound=Any) _FN = TypeVar("_FN", bound=Callable[..., Any]) -class OperatorType(Protocol): +class OperatorClass(Flag): + """Describes a class of SQLAlchemy built-in operators that should be + available on a particular type. + + The :class:`.OperatorClass` should be present on the + :attr:`.TypeEngine.operator_classes` attribute of any particular type. + + The enums here can be ORed together to provide sets of operators merged + together. + + .. versionadded:: 2.1 + + """ + + UNSPECIFIED = auto() + BASE = auto() + BOOLEAN_ALGEBRA = auto() + COMPARISON = auto() + INDEXABLE = auto() # noqa: F811 + CONTAINS = auto() + CONCATENABLE = auto() + STRING_MATCH = auto() + MATH = auto() + BITWISE = auto() + DATE_ARITHEMETIC = auto() + JSON_GETITEM = auto() + + STRING = ( + BASE | COMPARISON | STRING_MATCH | CONTAINS | CONCATENABLE | INDEXABLE + ) + INTEGER = BASE | COMPARISON | MATH | BITWISE + NUMERIC = BASE | COMPARISON | MATH | BITWISE + BOOLEAN = BASE | COMPARISON | BOOLEAN_ALGEBRA | COMPARISON + BINARY = BASE | COMPARISON | CONTAINS | CONCATENABLE | INDEXABLE + DATETIME = BASE | COMPARISON | DATE_ARITHEMETIC + JSON = BASE | COMPARISON | INDEXABLE | JSON_GETITEM + ARRAY = BASE | COMPARISON | CONTAINS | CONCATENABLE | INDEXABLE + TUPLE = BASE | COMPARISON | CONTAINS | CONCATENABLE | INDEXABLE + + ANY = ( + STRING + | INTEGER + | NUMERIC + | BOOLEAN + | DATETIME + | BINARY + | JSON + | ARRAY + | TUPLE + ) + + +class OperatorType(Hashable, Protocol): """describe an op() function.""" __slots__ = () @@ -208,6 +263,11 @@ class Operators: """ return self.operate(inv) + def _null_operate(self, other: Any) -> Operators: + """A 'null' operation available on all types, used for testing.""" + + return self.operate(null_op, other) + def op( self, opstring: str, @@ -217,6 +277,7 @@ class Operators: Union[Type[TypeEngine[Any]], TypeEngine[Any]] ] = None, python_impl: Optional[Callable[..., Any]] = None, + operator_class: OperatorClass = OperatorClass.BASE, ) -> Callable[[Any], Operators]: """Produce a generic operator function. @@ -292,6 +353,13 @@ class Operators: .. versionadded:: 2.0 + :param operator_class: optional :class:`.OperatorClass` which will be + applied to the :class:`.custom_op` created, which provides hints + as to which datatypes are appropriate for this operator. Defaults + to :attr:`.OperatorClass.BASE` which is appropriate for all + datatypes. + + .. versionadded:: 2.1 .. seealso:: @@ -304,10 +372,11 @@ class Operators: """ operator = custom_op( opstring, - precedence, - is_comparison, - return_type, + precedence=precedence, + is_comparison=is_comparison, + return_type=return_type, python_impl=python_impl, + operator_class=operator_class, ) def against(other: Any) -> Operators: @@ -418,11 +487,13 @@ class custom_op(OperatorType, Generic[_T]): "eager_grouping", "return_type", "python_impl", + "operator_class", ) def __init__( self, opstring: str, + *, precedence: int = 0, is_comparison: bool = False, return_type: Optional[ @@ -431,6 +502,7 @@ class custom_op(OperatorType, Generic[_T]): natural_self_precedent: bool = False, eager_grouping: bool = False, python_impl: Optional[Callable[..., Any]] = None, + operator_class: OperatorClass = OperatorClass.BASE, ): self.opstring = opstring self.precedence = precedence @@ -441,6 +513,7 @@ class custom_op(OperatorType, Generic[_T]): return_type._to_instance(return_type) if return_type else None ) self.python_impl = python_impl + self.operator_class = operator_class def __eq__(self, other: Any) -> bool: return ( @@ -460,6 +533,7 @@ class custom_op(OperatorType, Generic[_T]): self.natural_self_precedent, self.eager_grouping, self.return_type._static_cache_key if self.return_type else None, + self.operator_class, ) @overload @@ -2544,6 +2618,18 @@ def bitwise_rshift_op(a: Any, b: Any) -> Any: return a.bitwise_rshift(b) +@_operator_fn +def null_op(a: Any, b: Any) -> Any: + """a 'null' operator that provides a boolean operation. + + Does not compile in a SQL context, used for testing operators only. + + .. versionadded:: 2.1 + + """ + return a._null_operate(b) + + def is_comparison(op: OperatorType) -> bool: return op in _comparison or isinstance(op, custom_op) and op.is_comparison @@ -2637,6 +2723,7 @@ _PRECEDENCE: Dict[OperatorType, int] = { bitwise_and_op: 7, bitwise_lshift_op: 7, bitwise_rshift_op: 7, + null_op: 7, filter_op: 6, concat_op: 5, match_op: 5, @@ -2678,6 +2765,99 @@ _PRECEDENCE: Dict[OperatorType, int] = { } +# Mapping of OperatorType objects to their corresponding OperatorClass +# Derived from unified_operator_lookup in default_comparator.py +_OPERATOR_CLASSES: util.immutabledict[OperatorType, OperatorClass] = ( + util.immutabledict( + { + # BASE operators + null_op: OperatorClass.BASE, + # COMPARISON operators + lt: OperatorClass.COMPARISON, + le: OperatorClass.COMPARISON, + ne: OperatorClass.COMPARISON, + gt: OperatorClass.COMPARISON, + ge: OperatorClass.COMPARISON, + eq: OperatorClass.COMPARISON, + is_distinct_from: OperatorClass.COMPARISON, + is_not_distinct_from: OperatorClass.COMPARISON, + in_op: OperatorClass.COMPARISON, + not_in_op: OperatorClass.COMPARISON, + is_: OperatorClass.COMPARISON, + is_not: OperatorClass.COMPARISON, + between_op: OperatorClass.COMPARISON, + not_between_op: OperatorClass.COMPARISON, + desc_op: OperatorClass.COMPARISON, + asc_op: OperatorClass.COMPARISON, + nulls_first_op: OperatorClass.COMPARISON, + nulls_last_op: OperatorClass.COMPARISON, + distinct_op: OperatorClass.COMPARISON, + any_op: OperatorClass.COMPARISON, + all_op: OperatorClass.COMPARISON, + # BOOLEAN_ALGEBRA operators + and_: OperatorClass.BOOLEAN_ALGEBRA, + or_: OperatorClass.BOOLEAN_ALGEBRA, + inv: OperatorClass.BOOLEAN_ALGEBRA | OperatorClass.BITWISE, + # CONCATENABLE | MATH | DATE_ARITHMETIC | BITWISE operators + add: OperatorClass.CONCATENABLE + | OperatorClass.MATH + | OperatorClass.DATE_ARITHEMETIC + | OperatorClass.BITWISE, + # CONCATENABLE | BITWISE operators + concat_op: OperatorClass.CONCATENABLE | OperatorClass.BITWISE, + # INDEXABLE operators + getitem: OperatorClass.INDEXABLE, + # CONTAINS operators + contains_op: OperatorClass.CONTAINS, + icontains_op: OperatorClass.CONTAINS, + contains: OperatorClass.CONTAINS, + not_contains_op: OperatorClass.CONTAINS, + not_icontains_op: OperatorClass.CONTAINS, + # STRING_MATCH operators + like_op: OperatorClass.STRING_MATCH, + ilike_op: OperatorClass.STRING_MATCH, + not_like_op: OperatorClass.STRING_MATCH, + not_ilike_op: OperatorClass.STRING_MATCH, + startswith_op: OperatorClass.STRING_MATCH, + istartswith_op: OperatorClass.STRING_MATCH, + endswith_op: OperatorClass.STRING_MATCH, + iendswith_op: OperatorClass.STRING_MATCH, + not_startswith_op: OperatorClass.STRING_MATCH, + not_istartswith_op: OperatorClass.STRING_MATCH, + not_endswith_op: OperatorClass.STRING_MATCH, + not_iendswith_op: OperatorClass.STRING_MATCH, + collate: OperatorClass.STRING_MATCH, + match_op: OperatorClass.STRING_MATCH, + not_match_op: OperatorClass.STRING_MATCH, + regexp_match_op: OperatorClass.STRING_MATCH, + not_regexp_match_op: OperatorClass.STRING_MATCH, + regexp_replace_op: OperatorClass.STRING_MATCH, + # BITWISE operators + lshift: OperatorClass.BITWISE, + rshift: OperatorClass.BITWISE, + bitwise_xor_op: OperatorClass.BITWISE, + bitwise_or_op: OperatorClass.BITWISE, + bitwise_and_op: OperatorClass.BITWISE, + bitwise_not_op: OperatorClass.BITWISE, + bitwise_lshift_op: OperatorClass.BITWISE, + bitwise_rshift_op: OperatorClass.BITWISE, + # MATH operators + matmul: OperatorClass.MATH, + pow_: OperatorClass.MATH, + neg: OperatorClass.MATH, + mul: OperatorClass.MATH, + sub: OperatorClass.MATH | OperatorClass.DATE_ARITHEMETIC, + truediv: OperatorClass.MATH, + floordiv: OperatorClass.MATH, + mod: OperatorClass.MATH, + # JSON_GETITEM operators + json_path_getitem_op: OperatorClass.JSON_GETITEM, + json_getitem_op: OperatorClass.JSON_GETITEM, + } + ) +) + + def is_precedent( operator: OperatorType, against: Optional[OperatorType] ) -> bool: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 81a2bbf67d..916e6444e5 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -49,6 +49,7 @@ from .cache_key import HasCacheKey from .elements import quoted_name from .elements import Slice from .elements import TypeCoerce as type_coerce # noqa +from .operators import OperatorClass from .type_api import Emulated from .type_api import NativeForEmulated # noqa from .type_api import to_instance as to_instance @@ -188,6 +189,8 @@ class String(Concatenable, TypeEngine[str]): __visit_name__ = "string" + operator_classes = OperatorClass.STRING + def __init__( self, length: Optional[int] = None, @@ -345,6 +348,8 @@ class Integer(HasExpressionLookup, TypeEngine[int]): __visit_name__ = "integer" + operator_classes = OperatorClass.INTEGER + if TYPE_CHECKING: @util.ro_memoized_property @@ -433,6 +438,8 @@ class NumericCommon(HasExpressionLookup, TypeEngineMixin, Generic[_N]): _default_decimal_return_scale = 10 + operator_classes = OperatorClass.NUMERIC + if TYPE_CHECKING: @util.ro_memoized_property @@ -811,6 +818,8 @@ class DateTime( __visit_name__ = "datetime" + operator_classes = OperatorClass.DATETIME + def __init__(self, timezone: bool = False): """Construct a new :class:`.DateTime`. @@ -859,6 +868,8 @@ class Date(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.date]): __visit_name__ = "date" + operator_classes = OperatorClass.DATETIME + def get_dbapi_type(self, dbapi): return dbapi.DATETIME @@ -899,6 +910,8 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): __visit_name__ = "time" + operator_classes = OperatorClass.DATETIME + def __init__(self, timezone: bool = False): self.timezone = timezone @@ -933,6 +946,8 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): class _Binary(TypeEngine[bytes]): """Define base behavior for binary types.""" + operator_classes = OperatorClass.BINARY + length: Optional[int] def __init__(self, length: Optional[int] = None): @@ -2023,6 +2038,8 @@ class Boolean(SchemaType, Emulated, TypeEngine[bool]): __visit_name__ = "boolean" native = True + operator_classes = OperatorClass.BOOLEAN + def __init__( self, create_constraint: bool = False, @@ -2143,6 +2160,8 @@ class Boolean(SchemaType, Emulated, TypeEngine[bool]): class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]): + operator_classes = OperatorClass.DATETIME + @util.memoized_property def _expression_adaptations(self): # Based on @@ -2480,6 +2499,8 @@ class JSON(Indexable, TypeEngine[Any]): __visit_name__ = "JSON" + operator_classes = OperatorClass.JSON + hashable = False NULL = util.symbol("JSON_NULL") """Describe the json value of NULL. @@ -2970,6 +2991,8 @@ class ARRAY( __visit_name__ = "ARRAY" + operator_classes = OperatorClass.ARRAY + _is_array = True zero_indexes = False @@ -3298,6 +3321,8 @@ class TupleType(TypeEngine[TupleAny]): _is_tuple_type = True + operator_classes = OperatorClass.TUPLE + types: List[TypeEngine[Any]] def __init__(self, *types: _TypeEngineArgument[Any]): @@ -3593,6 +3618,8 @@ class NullType(TypeEngine[None]): _isnull = True + operator_classes = OperatorClass.ANY + def literal_processor(self, dialect): return None @@ -3619,6 +3646,8 @@ class TableValueType(HasCacheKey, TypeEngine[Any]): _is_table_value = True + operator_classes = OperatorClass.BASE + _traverse_internals = [ ("_elements", InternalTraversal.dp_clauseelement_list), ] @@ -3699,6 +3728,8 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): __visit_name__ = "uuid" + operator_classes = OperatorClass.BASE | OperatorClass.COMPARISON + length: Optional[int] = None collation: Optional[str] = None diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 8cfc72c88b..2e88542c98 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -14,6 +14,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import ClassVar from typing import Dict from typing import Generic from typing import Mapping @@ -33,7 +34,10 @@ from typing import Union from .base import SchemaEventTarget from .cache_key import CacheConst from .cache_key import NO_CACHE +from .operators import _OPERATOR_CLASSES from .operators import ColumnOperators +from .operators import custom_op +from .operators import OperatorClass from .visitors import Visitable from .. import exc from .. import util @@ -160,6 +164,17 @@ class TypeEngine(Visitable, Generic[_T]): """ + operator_classes: ClassVar[OperatorClass] = OperatorClass.UNSPECIFIED + """Indicate categories of operators that should be available on this type. + + .. versionadded:: 2.1 + + .. seealso:: + + :class:`.OperatorClass` + + """ + class Comparator( ColumnOperators, Generic[_CT], @@ -185,6 +200,56 @@ class TypeEngine(Visitable, Generic[_T]): def __reduce__(self) -> Any: return self.__class__, (self.expr,) + @util.preload_module("sqlalchemy.sql.default_comparator") + def _resolve_operator_lookup(self, op: OperatorType) -> Tuple[ + Callable[..., "ColumnElement[Any]"], + util.immutabledict[ + str, Union["OperatorType", Callable[..., "ColumnElement[Any]"]] + ], + ]: + default_comparator = util.preloaded.sql_default_comparator + + op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] + + if op_fn is default_comparator._custom_op_operate: + if TYPE_CHECKING: + assert isinstance(op, custom_op) + operator_class = op.operator_class + else: + try: + operator_class = _OPERATOR_CLASSES[op] + except KeyError: + operator_class = OperatorClass.UNSPECIFIED + + if not operator_class & self.type.operator_classes: + + if self.type.operator_classes is OperatorClass.UNSPECIFIED: + util.warn_deprecated( + f"Type object {self.type.__class__} does not refer " + "to an OperatorClass in its operator_classes " + "attribute. This attribute will be required in a " + "future release.", + "2.1", + ) + else: + if isinstance(op, custom_op): + op_description = f"custom operator {op.opstring!r}" + else: + op_description = f"operator {op.__name__!r}" + + util.warn_deprecated( + f"Type object {self.type.__class__!r} does not " + "include " + f"{op_description} in its operator classes. " + "Using built-in operators (not including custom or " + "overridden operators) outside of " + "a type's stated operator classes is deprecated and " + "will raise InvalidRequestError in a future release", + "2.1", + ) + + return op_fn, addtl_kw + @overload def operate( self, @@ -199,22 +264,19 @@ class TypeEngine(Visitable, Generic[_T]): self, op: OperatorType, *other: Any, **kwargs: Any ) -> ColumnElement[_CT]: ... - @util.preload_module("sqlalchemy.sql.default_comparator") def operate( self, op: OperatorType, *other: Any, **kwargs: Any ) -> ColumnElement[Any]: - default_comparator = util.preloaded.sql_default_comparator - op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] + op_fn, addtl_kw = self._resolve_operator_lookup(op) if kwargs: addtl_kw = addtl_kw.union(kwargs) return op_fn(self.expr, op, *other, **addtl_kw) - @util.preload_module("sqlalchemy.sql.default_comparator") def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any ) -> ColumnElement[_CT]: - default_comparator = util.preloaded.sql_default_comparator - op_fn, addtl_kw = default_comparator.operator_lookup[op.__name__] + op_fn, addtl_kw = self._resolve_operator_lookup(op) + if kwargs: addtl_kw = addtl_kw.union(kwargs) return op_fn(self.expr, op, other, reverse=True, **addtl_kw) @@ -1381,6 +1443,8 @@ class UserDefinedType( ensure_kwarg = "get_col_spec" + operator_classes = OperatorClass.ANY + def coerce_compared_value( self, op: Optional[OperatorType], value: Any ) -> TypeEngine[Any]: @@ -1719,6 +1783,12 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]): """ + if not TYPE_CHECKING: + + @property + def operator_classes(self) -> OperatorClass: + return self.impl_instance.operator_classes + class Comparator(TypeEngine.Comparator[_CT]): """A :class:`.TypeEngine.Comparator` that is specific to :class:`.TypeDecorator`. diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index c803bc9d91..88bd4aa6c5 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -10,6 +10,7 @@ from __future__ import annotations +from .sql.operators import OperatorClass as OperatorClass from .sql.sqltypes import _Binary as _Binary from .sql.sqltypes import ARRAY as ARRAY from .sql.sqltypes import BIGINT as BIGINT diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index d458449f09..b0f933364c 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -1575,7 +1575,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpCommon(testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", column("myid", String), column("name", String) ) def test_regexp_match(self): diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 625547efb1..b9897b007f 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1764,7 +1764,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", column("myid", String), column("name", String) ) def test_regexp_match(self): diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 5be149cf6a..d1b753d54f 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -2677,29 +2677,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): @testing.combinations( ( lambda col: col["foo"] + " ", - "(x -> %(x_1)s) || %(param_1)s", "x[%(x_1)s] || %(param_1)s", ), ( lambda col: col["foo"] + " " + col["bar"], - "(x -> %(x_1)s) || %(param_1)s || (x -> %(x_2)s)", "x[%(x_1)s] || %(param_1)s || x[%(x_2)s]", ), - argnames="expr, json_expected, jsonb_expected", + argnames="expr, expected", ) - @testing.combinations((JSON(),), (JSONB(),), argnames="type_") - def test_eager_grouping_flag( - self, expr, json_expected, jsonb_expected, type_ - ): + def test_eager_grouping_flag(self, expr, expected): """test #10479""" - col = Column("x", type_) + col = Column("x", JSONB) expr = testing.resolve_lambda(expr, col=col) # Choose expected result based on type - expected = ( - jsonb_expected if isinstance(type_, JSONB) else json_expected - ) self.assert_compile(expr, expected) @testing.variation("pgversion", ["pg14", "pg13"]) @@ -4258,7 +4250,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", column("myid", String), column("name", String) ) def test_regexp_match(self): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index b78c35d359..42b537e8da 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3399,8 +3399,8 @@ class TimestampTest( expr = column("bar", postgresql.INTERVAL) + column("foo", types.Date) eq_(expr.type._type_affinity, types.DateTime) - expr = column("bar", postgresql.INTERVAL) * column( - "foo", types.Numeric + expr = operators.null_op( + column("bar", postgresql.INTERVAL), column("foo", types.Numeric) ) eq_(expr.type._type_affinity, types.Interval) assert isinstance(expr.type, postgresql.INTERVAL) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 17c0eb8d71..05c8ea250d 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -2885,7 +2885,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", column("myid", String), column("name", String) ) @testing.only_on("sqlite >= 3.9") diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index ac4274dd67..97c81fd532 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -724,13 +724,13 @@ class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): @hybrid.hybrid_property def value(self): - return self._value - 5 + return self._value + "18" if assignable: @value.setter def value(self, v): - self._value = v + 5 + self._value = v + "5" return A @@ -750,9 +750,9 @@ class PropertyValueTest(fixtures.TestBase, AssertsCompiledSQL): def test_set_get(self): A = self._fixture(True) - a1 = A(value=5) - eq_(a1.value, 5) - eq_(a1._value, 10) + a1 = A(value="5") + eq_(a1.value, "5518") + eq_(a1._value, "55") class PropertyOverrideTest(fixtures.TestBase, AssertsCompiledSQL): diff --git a/test/orm/dml/test_evaluator.py b/test/orm/dml/test_evaluator.py index 3fc82db694..c1f0b23dbd 100644 --- a/test/orm/dml/test_evaluator.py +++ b/test/orm/dml/test_evaluator.py @@ -15,6 +15,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import evaluator from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship +from sqlalchemy.sql.operators import OperatorClass from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ @@ -48,13 +49,16 @@ def eval_eq(clause, testcases=None): class EvaluateTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): + class LiberalJson(JSON): + operator_classes = JSON.operator_classes | OperatorClass.MATH + Table( "users", metadata, Column("id", Integer, primary_key=True), Column("name", String(64)), Column("othername", String(64)), - Column("json", JSON), + Column("json", LiberalJson), ) @classmethod @@ -368,7 +372,7 @@ class EvaluateTest(fixtures.MappedTest): {"foo": "bar"}, evaluator.UnevaluatableError, r"Cannot evaluate math operator \"add\" for " - r"datatypes JSON, INTEGER", + r"datatypes LiberalJson\(\), INTEGER", ), ( lambda User: User.json + {"bar": "bat"}, @@ -376,7 +380,7 @@ class EvaluateTest(fixtures.MappedTest): {"foo": "bar"}, evaluator.UnevaluatableError, r"Cannot evaluate concatenate operator \"concat_op\" for " - r"datatypes JSON, JSON", + r"datatypes LiberalJson\(\), LiberalJson\(\)", ), ( lambda User: User.json - 12, @@ -384,7 +388,7 @@ class EvaluateTest(fixtures.MappedTest): {"foo": "bar"}, evaluator.UnevaluatableError, r"Cannot evaluate math operator \"sub\" for " - r"datatypes JSON, INTEGER", + r"datatypes LiberalJson\(\), INTEGER", ), ( lambda User: User.json - "foo", @@ -392,7 +396,7 @@ class EvaluateTest(fixtures.MappedTest): {"foo": "bar"}, evaluator.UnevaluatableError, r"Cannot evaluate math operator \"sub\" for " - r"datatypes JSON, VARCHAR", + r"datatypes LiberalJson\(\), VARCHAR", ), ) def test_math_op_type_exclusions( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 42cd4aedd2..b201e4343c 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -1997,9 +1997,13 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): def test_collate(self): User = self.classes.User - self._test(collate(User.id, "utf8_bin"), "users.id COLLATE utf8_bin") + self._test( + collate(User.name, "utf8_bin"), "users.name COLLATE utf8_bin" + ) - self._test(User.id.collate("utf8_bin"), "users.id COLLATE utf8_bin") + self._test( + User.name.collate("utf8_bin"), "users.name COLLATE utf8_bin" + ) def test_selfref_between(self): User = self.classes.User diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 9eb20dd4e5..c357b92dff 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -1523,7 +1523,7 @@ class LambdaElementTest( x = {"foo": "bar"} def mylambda(): - return tt.c.q + x + return tt.c.q._null_operate(x) expr = coercions.expect(roles.WhereHavingRole, mylambda) is_(expr._resolved.right.type._type_affinity, JSON) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index fd1fc64117..7ce305de01 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -28,11 +28,13 @@ from sqlalchemy import SQLColumnExpression from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import type_coerce from sqlalchemy.dialects import mssql from sqlalchemy.dialects import mysql from sqlalchemy.dialects import oracle from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import sqlite +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.engine import default from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData @@ -71,6 +73,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not from sqlalchemy.testing import mock +from sqlalchemy.testing import ne_ from sqlalchemy.testing import resolve_lambda from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.types import ARRAY @@ -81,6 +84,7 @@ from sqlalchemy.types import Indexable from sqlalchemy.types import JSON from sqlalchemy.types import MatchType from sqlalchemy.types import NullType +from sqlalchemy.types import OperatorClass from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeEngine from sqlalchemy.types import UserDefinedType @@ -315,23 +319,23 @@ class DefaultColumnComparatorTest( def test_default_adapt(self): class TypeOne(TypeEngine): - pass + operator_classes = OperatorClass.ANY class TypeTwo(TypeEngine): - pass + operator_classes = OperatorClass.ANY expr = column("x", TypeOne()) - column("y", TypeTwo()) is_(expr.type._type_affinity, TypeOne) def test_concatenable_adapt(self): class TypeOne(Concatenable, TypeEngine): - pass + operator_classes = OperatorClass.ANY class TypeTwo(Concatenable, TypeEngine): - pass + operator_classes = OperatorClass.ANY class TypeThree(TypeEngine): - pass + operator_classes = OperatorClass.ANY expr = column("x", TypeOne()) - column("y", TypeTwo()) is_(expr.type._type_affinity, TypeOne) @@ -352,8 +356,7 @@ class DefaultColumnComparatorTest( def test_contains_override_raises(self): for col in [ Column("x", String), - Column("x", Integer), - Column("x", DateTime), + Column("x", ARRAY(Integer)), ]: assert_raises_message( NotImplementedError, @@ -776,7 +779,14 @@ class _CustomComparatorTests: def test_no_boolean_propagate(self): c1 = Column("foo", self._add_override_factory()) - self._assert_not_add_override(c1 == 56) + + class Nonsensical(Boolean): + operator_classes = OperatorClass.BOOLEAN | OperatorClass.NUMERIC + + expr = c1 == 56 + expr.type = Nonsensical() + self._assert_not_add_override(expr) + self._assert_not_and_override(c1 == 56) def _assert_and_override(self, expr): @@ -931,6 +941,70 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): pass +class OperatorClassTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """test operator classes introduced in #12736""" + + __dialect__ = "default" + + def test_no_class(self): + class MyType(TypeEngine): + pass + + with expect_deprecated( + r"Type object .*.MyType.* does not refer to an OperatorClass" + ): + column("q", MyType()) + 5 + + @testing.variation("json_type", ["plain", "with_variant"]) + def test_json_cant_contains(self, json_type): + """test the original case for #12736""" + + if json_type.plain: + type_ = JSON() + else: + type_ = JSON().with_variant(JSONB(), "postgresql") + + with expect_deprecated( + r"Type object .*.JSON.* does not include operator " + r"'contains_op' in its operator classes." + ): + self.assert_compile( + column("xyz", type_).contains("{'foo': 'bar'}"), + "xyz LIKE '%' || :xyz_1 || '%'", + ) + + def test_invalid_op(self): + with expect_deprecated( + r"Type object .*.Integer.* does not include " + "operator 'like_op' in its operator classes." + ): + expr = column("q", Integer).like("hi") + + self.assert_compile(expr, "q LIKE :q_1", checkparams={"q_1": "hi"}) + + def test_invalid_op_custom(self): + class MyType(Integer): + pass + + with expect_deprecated( + r"Type object .*.MyType.* does not include " + "operator 'like_op' in its operator classes." + ): + expr = column("q", MyType).like("hi") + + self.assert_compile(expr, "q LIKE :q_1", checkparams={"q_1": "hi"}) + + def test_add_in_classes(self): + class MyType(Integer): + operator_classes = ( + Integer.operator_classes | OperatorClass.STRING_MATCH + ) + + expr = column("q", MyType).like("hi") + + self.assert_compile(expr, "q LIKE :q_1", checkparams={"q_1": "hi"}) + + class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -964,6 +1038,35 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): col = Column("x", MyType()) assert not isinstance(col, collections_abc.Iterable) + @testing.combinations( + (operators.lshift, OperatorClass.BITWISE), + (operators.rshift, OperatorClass.BITWISE), + (operators.matmul, OperatorClass.MATH), + (operators.getitem, OperatorClass.INDEXABLE), + ) + def test_not_implemented_operators(self, op, operator_class): + """test operators that are availble but not implemented by default. + + this might be semantically different from the operator not being + present in the operator class though the effect is the same (that is, + we could just not include lshift/rshift/matmul in any operator class, + do away with _unsupported_impl() and the path to implement them would + be the same). So it's not totally clear if we should keep using + _unsupported_impl() long term. However at least for now because we + only emit a deprecation warning in the other case, this is still + appropriately a separate concept. + + """ + + class MyType(TypeEngine): + operator_classes = operator_class + + with expect_raises_message( + NotImplementedError, + f"Operator {op.__name__!r} is not supported on this expression", + ): + op(column("q", MyType()), "test") + def test_lshift(self): class MyType(UserDefinedType): cache_ok = True @@ -1057,7 +1160,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class MyType(JSON): __visit_name__ = "mytype" - pass + operator_classes = OperatorClass.JSON | OperatorClass.MATH self.MyType = MyType self.__dialect__ = MyDialect() @@ -1180,7 +1283,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_cast_ops_unsupported_on_non_json_binary( self, caster, expected_type ): - expr = Column("x", JSON) + {"foo": "bar"} + expr = Column("x", self.MyType) + {"foo": "bar"} meth = getattr(expr, "as_%s" % caster) @@ -1320,6 +1423,8 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class MyOtherType(Indexable, TypeEngine): __visit_name__ = "myothertype" + operator_classes = OperatorClass.ANY + class Comparator(TypeEngine.Comparator): def _adapt_expression(self, op, other_comparator): return special_index_op, MyOtherType() @@ -1752,7 +1857,9 @@ class OperatorPrecedenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_operator_precedence_5(self): self.assert_compile( - self.table2.select().where(5 + self.table2.c.field.in_([5, 6])), + self.table2.select().where( + 5 + type_coerce(self.table2.c.field.in_([5, 6]), Integer) + ), "SELECT op.field FROM op WHERE :param_1 + " "(op.field IN (__[POSTCOMPILE_field_1]))", ) @@ -1832,7 +1939,9 @@ class OperatorPrecedenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_operator_precedence_collate_2(self): self.assert_compile( - (self.table1.c.name == literal("foo")).collate("utf-8"), + type_coerce(self.table1.c.name == literal("foo"), String).collate( + "utf-8" + ), 'mytable.name = :param_1 COLLATE "utf-8"', ) @@ -1843,10 +1952,17 @@ class OperatorPrecedenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_operator_precedence_collate_4(self): + class Nonsensical(Boolean): + operator_classes = OperatorClass.BOOLEAN | OperatorClass.STRING + self.assert_compile( and_( - (self.table1.c.name == literal("foo")).collate("utf-8"), - (self.table2.c.field == literal("bar")).collate("utf-8"), + type_coerce( + self.table1.c.name == literal("foo"), Nonsensical + ).collate("utf-8"), + type_coerce( + self.table2.c.field == literal("bar"), Nonsensical + ).collate("utf-8"), ), 'mytable.name = :param_1 COLLATE "utf-8" ' 'AND op.field = :param_2 COLLATE "utf-8"', @@ -1880,8 +1996,12 @@ class OperatorPrecedenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) def test_commutative_operators(self): + class Nonsensical(String): + operator_classes = OperatorClass.STRING | OperatorClass.NUMERIC + self.assert_compile( - literal("a") + literal("b") * literal("c"), + literal("x", Nonsensical) + + literal("y", Nonsensical) * literal("q", Nonsensical), ":param_1 || :param_2 * :param_3", ) @@ -2705,12 +2825,12 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - table1 = table("mytable", column("myid", Integer)) + table1 = table("mytable", column("myid", String)) def test_pickle_operators_one(self): clause = ( - (self.table1.c.myid == 12) - & self.table1.c.myid.between(15, 20) + (self.table1.c.myid == "12") + & self.table1.c.myid.between("15", "20") & self.table1.c.myid.like("hoho") ) eq_(str(clause), str(pickle.loads(pickle.dumps(clause)))) @@ -2841,7 +2961,8 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): table1 = table("mytable", column("myid", Integer), column("name", String)) @testing.combinations( - (~literal(5), "NOT :param_1"), (~-literal(5), "NOT -:param_1") + (~literal(5, NullType), "NOT :param_1"), + (~-literal(5, NullType), "NOT -:param_1"), ) def test_nonsensical_negates(self, expr, expected): """exercise codepaths in the UnaryExpression._negate() method where the @@ -2853,7 +2974,7 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): for py_op, op in ((operator.neg, "-"), (operator.inv, "NOT ")): for expr, expected in ( (self.table1.c.myid, "mytable.myid"), - (literal("foo"), ":param_1"), + (literal(5, Integer), ":param_1"), ): self.assert_compile(py_op(expr), "%s%s" % (op, expected)) @@ -2897,10 +3018,26 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_negate_operators_5(self): self.assert_compile( self.table1.select().where( - (self.table1.c.myid != 12) & ~self.table1.c.name + (self.table1.c.myid != "12") + & ~and_( + literal("somethingboolean", Boolean), literal("q", Boolean) + ) ), "SELECT mytable.myid, mytable.name FROM " - "mytable WHERE mytable.myid != :myid_1 AND NOT mytable.name", + "mytable WHERE mytable.myid != :myid_1 AND NOT " + "(:param_1 = 1 AND :param_2 = 1)", + ) + + def test_negate_operators_6(self): + self.assert_compile( + self.table1.select().where( + (self.table1.c.myid != "12") + & ~literal("somethingboolean", Boolean) + ), + "SELECT mytable.myid, mytable.name FROM " + "mytable WHERE mytable.myid != :myid_1 AND NOT :param_1", + supports_native_boolean=True, + use_default_dialect=True, ) def test_negate_operator_type(self): @@ -2980,7 +3117,7 @@ class NegationTest(fixtures.TestBase, testing.AssertsCompiledSQL): class LikeTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - table1 = table("mytable", column("myid", Integer), column("name", String)) + table1 = table("mytable", column("myid", String), column("name", String)) def test_like_1(self): self.assert_compile( @@ -3108,7 +3245,7 @@ class BetweenTest(fixtures.TestBase, testing.AssertsCompiledSQL): class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - table1 = table("mytable", column("myid", Integer), column("name", String)) + table1 = table("mytable", column("myid", String), column("name", String)) def test_match_1(self): self.assert_compile( @@ -3179,7 +3316,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", column("myid", String), column("name", String) ) def test_regexp_match(self): @@ -3212,7 +3349,10 @@ class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL): def setup_test(self): self.table = table( - "mytable", column("myid", Integer), column("name", String) + "mytable", + column("myid", String), + column("name", String), + column("myinteger", Integer), ) def test_regexp_match(self): @@ -3335,16 +3475,17 @@ class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL): def test_regexp_precedence_2(self): self.assert_compile( - self.table.c.myid + self.table.c.myid.regexp_match("xx"), - "mytable.myid + (mytable.myid :myid_1)", + self.table.c.myinteger + self.table.c.myid.regexp_match("xx"), + "mytable.myinteger + (mytable.myid :myid_1)", ) self.assert_compile( - self.table.c.myid + ~self.table.c.myid.regexp_match("xx"), - "mytable.myid + (mytable.myid :myid_1)", + self.table.c.myinteger + ~self.table.c.myid.regexp_match("xx"), + "mytable.myinteger + (mytable.myid :myid_1)", ) self.assert_compile( - self.table.c.myid + self.table.c.myid.regexp_replace("xx", "yy"), - "mytable.myid + (" + self.table.c.myinteger + + self.table.c.myid.regexp_replace("xx", "yy"), + "mytable.myinteger + (" "(mytable.myid, :myid_1, :myid_2))", ) @@ -4441,6 +4582,76 @@ class CustomOpTest(fixtures.TestBase): ): op1(3, 5) + def test_operator_class_default(self): + """Test that custom_op defaults to OperatorClass.BASE""" + op = operators.custom_op("++") + eq_(op.operator_class, OperatorClass.BASE) + + def test_operator_class_explicit(self): + """Test that custom_op accepts an explicit operator_class parameter""" + op = operators.custom_op("++", operator_class=OperatorClass.MATH) + eq_(op.operator_class, OperatorClass.MATH) + + def test_operator_class_combined(self): + """Test that custom_op accepts combined operator classes""" + op = operators.custom_op( + "++", operator_class=OperatorClass.MATH | OperatorClass.BITWISE + ) + eq_(op.operator_class, OperatorClass.MATH | OperatorClass.BITWISE) + + def test_operator_class_with_column_op(self): + """Test that operator_class is passed through when using column.op()""" + c = column("x", Integer) + + expr1 = c.op("++")("value") + eq_(expr1.operator.operator_class, OperatorClass.BASE) + + expr2 = c.op("++", operator_class=OperatorClass.MATH)("value") + eq_(expr2.operator.operator_class, OperatorClass.MATH) + + with expect_deprecated( + r"Type object .*Integer.* does not include custom " + r"operator '\+\+' in its operator classes." + ): + expr3 = c.op("++", operator_class=OperatorClass.STRING_MATCH)( + "value" + ) + eq_(expr3.operator.operator_class, OperatorClass.STRING_MATCH) + + def test_operator_class_hash_and_equality(self): + op1 = operators.custom_op("++", operator_class=OperatorClass.MATH) + op2 = operators.custom_op("++", operator_class=OperatorClass.MATH) + op3 = operators.custom_op("++", operator_class=OperatorClass.BITWISE) + + # Same opstring and same operator_class should be equal + eq_(op1, op2) + eq_(hash(op1), hash(op2)) + + # Same opstring but different operator_class should be different + ne_(op1, op3) + ne_(hash(op1), hash(op3)) + + def test_operator_class_warning_unspecified_type(self): + """Test warning when type has UNSPECIFIED operator_classes""" + + # Create a custom type with UNSPECIFIED operator_classes + class UnspecifiedType(TypeEngine): + operator_classes = OperatorClass.UNSPECIFIED + + metadata = MetaData() + test_table = Table( + "test", metadata, Column("value", UnspecifiedType()) + ) + col = test_table.c.value + + # Use a builtin operator that should not be compatible + # This should trigger the first deprecation warning + with expect_deprecated( + "Type object .* does not refer to an OperatorClass in " + "its operator_classes attribute" + ): + col == "test" + class TupleTypingTest(fixtures.TestBase): def _assert_types(self, expr): @@ -4864,7 +5075,7 @@ class BitOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): """test for #12681""" if named.column: - expr = py_op(column("q", String)) + expr = py_op(column("q", Integer)) assert isinstance(expr, UnaryExpression) self.assert_compile( @@ -4873,7 +5084,7 @@ class BitOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) elif named.unnamed: - expr = py_op(literal("x", String)) + expr = py_op(literal("x", Integer)) assert isinstance(expr, UnaryExpression) self.assert_compile( @@ -4881,7 +5092,7 @@ class BitOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): f"SELECT {sql_op}:param_1 AS anon_1", ) elif named.label: - expr = py_op(literal("x", String).label("z")) + expr = py_op(literal("x", Integer).label("z")) if py_op is operators.inv: # special case for operators.inv due to Label._negate() # not sure if this should be changed but still works out in the diff --git a/test/sql/test_types.py b/test/sql/test_types.py index bd147a415a..9d3e0ae38c 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -3879,7 +3879,7 @@ class ExpressionTest( expr = column("bar", types.Interval) + column("foo", types.Date) eq_(expr.type._type_affinity, types.DateTime) - expr = column("bar", types.Interval) * column("foo", types.Numeric) + expr = column("bar", types.Interval) - column("foo", types.Numeric) eq_(expr.type._type_affinity, types.Interval) @testing.combinations(