From: Mike Bayer Date: Sat, 18 Mar 2023 15:43:47 +0000 (-0400) Subject: implement content hashing for custom_op, not identity X-Git-Tag: rel_2_0_7~2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0a0c7c73729152b7606509b6e750371106dfdd46;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement content hashing for custom_op, not identity Fixed critical SQL caching issue where use of the :meth:`_sql.Operators.op` custom operator function would not produce an appropriate cache key, leading to reduce the effectiveness of the SQL cache. Fixes: #9506 Change-Id: I3eab1ddb5e09a811ad717161a59df0884cdf70ed --- diff --git a/doc/build/changelog/unreleased_14/9506.rst b/doc/build/changelog/unreleased_14/9506.rst new file mode 100644 index 0000000000..2533a986b1 --- /dev/null +++ b/doc/build/changelog/unreleased_14/9506.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 9506 + + Fixed critical SQL caching issue where use of the + :meth:`_sql.Operators.op` custom operator function would not produce an appropriate + cache key, leading to reduce the effectiveness of the SQL cache. + diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index c973126ca4..ab9ddf85c7 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -41,6 +41,7 @@ from typing import Dict from typing import Generic from typing import Optional from typing import Set +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -52,6 +53,7 @@ from ..util.typing import Literal from ..util.typing import Protocol if typing.TYPE_CHECKING: + from .cache_key import CacheConst from .type_api import TypeEngine _T = TypeVar("_T", bound=Any) @@ -415,10 +417,24 @@ class custom_op(OperatorType, Generic[_T]): self.python_impl = python_impl def __eq__(self, other: Any) -> bool: - return isinstance(other, custom_op) and other.opstring == self.opstring + return ( + isinstance(other, custom_op) + and other._hash_key() == self._hash_key() + ) def __hash__(self) -> int: - return id(self) + return hash(self._hash_key()) + + def _hash_key(self) -> Union[CacheConst, Tuple[Any, ...]]: + return ( + self.__class__, + self.opstring, + self.precedence, + self.is_comparison, + self.natural_self_precedent, + self.eager_grouping, + self.return_type._static_cache_key if self.return_type else None, + ) def __call__( self, diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 4b55560ec9..96758a7ad4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -767,7 +767,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): def visit_operator( self, attrname, left_parent, left, right_parent, right, **kw ): - return left is right + return left == right def visit_type( self, attrname, left_parent, left, right_parent, right, **kw diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 87710fdd9e..187d797297 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -13,6 +13,7 @@ from sqlalchemy import exists from sqlalchemy import extract from sqlalchemy import Float from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ @@ -203,6 +204,15 @@ class CoreFixtures: bindparam("bar", type_=String) ), ), + lambda: ( + literal(1).op("+")(literal(1)), + literal(1).op("-")(literal(1)), + column("q").op("-")(literal(1)), + UnaryExpression(table_a.c.b, modifier=operators.neg), + UnaryExpression(table_a.c.b, modifier=operators.desc_op), + UnaryExpression(table_a.c.b, modifier=operators.custom_op("!")), + UnaryExpression(table_a.c.b, modifier=operators.custom_op("~")), + ), lambda: ( column("q") == column("x"), column("q") == column("y"), @@ -210,6 +220,9 @@ class CoreFixtures: (column("z") == column("x")).self_group(), (column("q") == column("x")).self_group(), column("z") + column("x"), + column("z").op("foo")(column("x")), + column("z").op("foo")(literal(1)), + column("z").op("bar")(column("x")), column("z") - column("x"), column("x") - column("z"), column("z") > column("x"),