]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
do not allow non-cache-key entity objects in annotations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Apr 2023 17:56:33 +0000 (13:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Apr 2023 18:53:08 +0000 (14:53 -0400)
Fixed critical caching issue where combination of :func:`_orm.aliased()`
:func:`_sql.case` and :func:`_hybrid.hybrid_property` expressions would
cause a cache key mismatch, leading to cache keys that held onto the actual
:func:`_orm.aliased` object while also not matching each other, filling up
the cache.

Fixes: #9728
Change-Id: I700645b5629a81a0104cf923db72a7421fa43ff4

doc/build/changelog/unreleased_14/9728.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
test/orm/test_cache_key.py

diff --git a/doc/build/changelog/unreleased_14/9728.rst b/doc/build/changelog/unreleased_14/9728.rst
new file mode 100644 (file)
index 0000000..a8bced3
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9728
+    :versions: 2.0.12
+
+    Fixed critical caching issue where combination of :func:`_orm.aliased()`
+    :func:`_sql.case` and :func:`_hybrid.hybrid_property` expressions would
+    cause a cache key mismatch, leading to cache keys that held onto the actual
+    :func:`_orm.aliased` object while also not matching each other, filling up
+    the cache.
index 69ddd33889b0f071eccbc691391ffe47e74ddeaf..6a979219caa9899c9ea483d4549ee49bc69bfd31 100644 (file)
@@ -84,6 +84,9 @@ from ..sql import cache_key
 from ..sql import coercions
 from ..sql import roles
 from ..sql import visitors
+from ..sql.cache_key import HasCacheKey
+from ..sql.visitors import _TraverseInternalsType
+from ..sql.visitors import InternalTraversal
 from ..util.typing import Literal
 from ..util.typing import Self
 from ..util.typing import TypeGuard
@@ -326,13 +329,16 @@ class QueryableAttribute(
         # non-string keys.
         # ideally Proxy() would have a separate set of methods to deal
         # with this case.
+        entity_namespace = self._entity_namespace
+        assert isinstance(entity_namespace, HasCacheKey)
+
         if self.key is _UNKNOWN_ATTR_KEY:  # type: ignore[comparison-overlap]
-            annotations = {"entity_namespace": self._entity_namespace}
+            annotations = {"entity_namespace": entity_namespace}
         else:
             annotations = {
                 "proxy_key": self.key,
                 "proxy_owner": self._parententity,
-                "entity_namespace": self._entity_namespace,
+                "entity_namespace": entity_namespace,
             }
 
         ce = self.comparator.__clause_element__()
@@ -558,13 +564,21 @@ class InstrumentedAttribute(QueryableAttribute[_T]):
 
 
 @dataclasses.dataclass(frozen=True)
-class AdHocHasEntityNamespace:
+class AdHocHasEntityNamespace(HasCacheKey):
+    _traverse_internals: ClassVar[_TraverseInternalsType] = [
+        ("_entity_namespace", InternalTraversal.dp_has_cache_key),
+    ]
+
     # py37 compat, no slots=True on dataclass
-    __slots__ = ("entity_namespace",)
-    entity_namespace: _ExternalEntityType[Any]
+    __slots__ = ("_entity_namespace",)
+    _entity_namespace: _InternalEntityType[Any]
     is_mapper: ClassVar[bool] = False
     is_aliased_class: ClassVar[bool] = False
 
+    @property
+    def entity_namespace(self):
+        return self._entity_namespace.entity_namespace
+
 
 def create_proxied_attribute(
     descriptor: Any,
@@ -638,7 +652,7 @@ def create_proxied_attribute(
             else:
                 # used by hybrid attributes which try to remain
                 # agnostic of any ORM concepts like mappers
-                return AdHocHasEntityNamespace(self.class_)
+                return AdHocHasEntityNamespace(self._parententity)
 
         @property
         def property(self):
index f1c3e6a54e0d5a1b6173b628350d795b5293a4eb..884baed62ba1bc6239678d785fd799de355eb98d 100644 (file)
@@ -2,6 +2,7 @@ import random
 
 import sqlalchemy as sa
 from sqlalchemy import Column
+from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
@@ -16,6 +17,7 @@ from sqlalchemy import true
 from sqlalchemy import update
 from sqlalchemy import util
 from sqlalchemy.ext.declarative import ConcreteBase
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import defaultload
@@ -785,6 +787,55 @@ class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
             compare_values=True,
         )
 
+    @testing.variation(
+        "exprtype", ["plain_column", "self_standing_case", "case_w_columns"]
+    )
+    def test_hybrid_w_case_ac(self, decl_base, exprtype):
+        """test #9728"""
+
+        class Employees(decl_base):
+            __tablename__ = "employees"
+            id = Column(String(128), primary_key=True)
+            first_name = Column(String(length=64))
+
+            @hybrid_property
+            def name(self):
+                return self.first_name
+
+            @name.expression
+            def name(
+                cls,
+            ):
+                if exprtype.plain_column:
+                    return cls.first_name
+                elif exprtype.self_standing_case:
+                    return case(
+                        (column("x") == 1, column("q")),
+                        else_=column("q"),
+                    )
+                elif exprtype.case_w_columns:
+                    return case(
+                        (column("x") == 1, column("q")),
+                        else_=cls.first_name,
+                    )
+                else:
+                    exprtype.fail()
+
+        def go1():
+            employees_2 = aliased(Employees, name="employees_2")
+            stmt = select(employees_2.name)
+            return stmt
+
+        def go2():
+            employees_2 = aliased(Employees, name="employees_2")
+            stmt = select(employees_2)
+            return stmt
+
+        self._run_cache_key_fixture(
+            lambda: stmt_20(go1(), go2()),
+            compare_values=True,
+        )
+
 
 class RoundTripTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"