]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement content hashing for custom_op, not identity
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Mar 2023 15:43:47 +0000 (11:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Mar 2023 15:43:47 +0000 (11:43 -0400)
Fixed critical SQL caching issue where use of the :meth:`_sql.Operators.op`
custom operator function would not produce an appropriate cache key,
leading to reduce the effectiveness of the SQL cache.

Fixes: #9506
Change-Id: I3eab1ddb5e09a811ad717161a59df0884cdf70ed

doc/build/changelog/unreleased_14/9506.rst [new file with mode: 0644]
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/traversals.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_14/9506.rst b/doc/build/changelog/unreleased_14/9506.rst
new file mode 100644 (file)
index 0000000..2533a98
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 9506
+
+    Fixed critical SQL caching issue where use of the
+    :meth:`_sql.Operators.op` custom operator function would not produce an appropriate
+    cache key, leading to reduce the effectiveness of the SQL cache.
+
index c973126ca48aab0f6569fd8bf01a9c18897aa2d4..ab9ddf85c797f1066695313e6f0e7269aaf51c29 100644 (file)
@@ -41,6 +41,7 @@ from typing import Dict
 from typing import Generic
 from typing import Optional
 from typing import Set
+from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
@@ -52,6 +53,7 @@ from ..util.typing import Literal
 from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
+    from .cache_key import CacheConst
     from .type_api import TypeEngine
 
 _T = TypeVar("_T", bound=Any)
@@ -415,10 +417,24 @@ class custom_op(OperatorType, Generic[_T]):
         self.python_impl = python_impl
 
     def __eq__(self, other: Any) -> bool:
-        return isinstance(other, custom_op) and other.opstring == self.opstring
+        return (
+            isinstance(other, custom_op)
+            and other._hash_key() == self._hash_key()
+        )
 
     def __hash__(self) -> int:
-        return id(self)
+        return hash(self._hash_key())
+
+    def _hash_key(self) -> Union[CacheConst, Tuple[Any, ...]]:
+        return (
+            self.__class__,
+            self.opstring,
+            self.precedence,
+            self.is_comparison,
+            self.natural_self_precedent,
+            self.eager_grouping,
+            self.return_type._static_cache_key if self.return_type else None,
+        )
 
     def __call__(
         self,
index 4b55560ec92145c837a4c0d348c9734c08717fc5..96758a7ad463df92036b60c230de0af347346c11 100644 (file)
@@ -767,7 +767,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
     def visit_operator(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
-        return left is right
+        return left == right
 
     def visit_type(
         self, attrname, left_parent, left, right_parent, right, **kw
index 87710fdd9ed40d984ee40b1acb9700504d60ad94..187d797297346c1f65232ba082c88659d4cf8203 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import exists
 from sqlalchemy import extract
 from sqlalchemy import Float
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
@@ -203,6 +204,15 @@ class CoreFixtures:
                 bindparam("bar", type_=String)
             ),
         ),
+        lambda: (
+            literal(1).op("+")(literal(1)),
+            literal(1).op("-")(literal(1)),
+            column("q").op("-")(literal(1)),
+            UnaryExpression(table_a.c.b, modifier=operators.neg),
+            UnaryExpression(table_a.c.b, modifier=operators.desc_op),
+            UnaryExpression(table_a.c.b, modifier=operators.custom_op("!")),
+            UnaryExpression(table_a.c.b, modifier=operators.custom_op("~")),
+        ),
         lambda: (
             column("q") == column("x"),
             column("q") == column("y"),
@@ -210,6 +220,9 @@ class CoreFixtures:
             (column("z") == column("x")).self_group(),
             (column("q") == column("x")).self_group(),
             column("z") + column("x"),
+            column("z").op("foo")(column("x")),
+            column("z").op("foo")(literal(1)),
+            column("z").op("bar")(column("x")),
             column("z") - column("x"),
             column("x") - column("z"),
             column("z") > column("x"),