]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
do not allow non-cache-key entity objects in annotations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Apr 2023 17:56:33 +0000 (13:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 30 Apr 2023 18:53:31 +0000 (14:53 -0400)
Fixed critical caching issue where combination of :func:`_orm.aliased()`
:func:`_sql.case` and :func:`_hybrid.hybrid_property` expressions would
cause a cache key mismatch, leading to cache keys that held onto the actual
:func:`_orm.aliased` object while also not matching each other, filling up
the cache.

Fixes: #9728
Change-Id: I700645b5629a81a0104cf923db72a7421fa43ff4
(cherry picked from commit 4d69d83530666f9aaf3fb327d8c63110ef5e7ff5)

doc/build/changelog/unreleased_14/9728.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
test/orm/test_cache_key.py

diff --git a/doc/build/changelog/unreleased_14/9728.rst b/doc/build/changelog/unreleased_14/9728.rst
new file mode 100644 (file)
index 0000000..a8bced3
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9728
+    :versions: 2.0.12
+
+    Fixed critical caching issue where combination of :func:`_orm.aliased()`
+    :func:`_sql.case` and :func:`_hybrid.hybrid_property` expressions would
+    cause a cache key mismatch, leading to cache keys that held onto the actual
+    :func:`_orm.aliased` object while also not matching each other, filling up
+    the cache.
index c6be3e6d0cf953ec64aa52ff330c8e2bee8e1b78..2e82851a23fedcbe808d3d69070b126b75558e02 100644 (file)
@@ -54,6 +54,8 @@ from ..sql import base as sql_base
 from ..sql import roles
 from ..sql import traversals
 from ..sql import visitors
+from ..sql.traversals import HasCacheKey
+from ..sql.visitors import InternalTraversal
 
 
 class NoKey(str):
@@ -223,13 +225,16 @@ class QueryableAttribute(
         subclass representing a column expression.
 
         """
+        entity_namespace = self._entity_namespace
+        assert isinstance(entity_namespace, HasCacheKey)
+
         if self.key is NO_KEY:
-            annotations = {"entity_namespace": self._entity_namespace}
+            annotations = {"entity_namespace": entity_namespace}
         else:
             annotations = {
                 "proxy_key": self.key,
                 "proxy_owner": self._parententity,
-                "entity_namespace": self._entity_namespace,
+                "entity_namespace": entity_namespace,
             }
 
         ce = self.comparator.__clause_element__()
@@ -482,10 +487,22 @@ class InstrumentedAttribute(Mapped):
             return self.impl.get(state, dict_)
 
 
-HasEntityNamespace = util.namedtuple(
-    "HasEntityNamespace", ["entity_namespace"]
-)
-HasEntityNamespace.is_mapper = HasEntityNamespace.is_aliased_class = False
+class HasEntityNamespace(HasCacheKey):
+    __slots__ = ("_entity_namespace",)
+
+    is_mapper = False
+    is_aliased_class = False
+
+    _traverse_internals = [
+        ("_entity_namespace", InternalTraversal.dp_has_cache_key),
+    ]
+
+    def __init__(self, ent):
+        self._entity_namespace = ent
+
+    @property
+    def entity_namespace(self):
+        return self._entity_namespace.entity_namespace
 
 
 def create_proxied_attribute(descriptor):
@@ -550,7 +567,7 @@ def create_proxied_attribute(descriptor):
             else:
                 # used by hybrid attributes which try to remain
                 # agnostic of any ORM concepts like mappers
-                return HasEntityNamespace(self.class_)
+                return HasEntityNamespace(self._parententity)
 
         @property
         def property(self):
index 6720baf024fce4db41a67a04fdf797dea81f04e1..93d980e00a5534afa2b4dcad75c6657b6562be2c 100644 (file)
@@ -2,6 +2,7 @@ import random
 
 import sqlalchemy as sa
 from sqlalchemy import Column
+from sqlalchemy import column
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
@@ -16,6 +17,7 @@ from sqlalchemy import true
 from sqlalchemy import update
 from sqlalchemy import util
 from sqlalchemy.ext.declarative import ConcreteBase
+from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import defaultload
@@ -785,6 +787,55 @@ class PolyCacheKeyTest(fixtures.CacheKeyFixture, _poly_fixtures._Polymorphic):
             compare_values=True,
         )
 
+    @testing.variation(
+        "exprtype", ["plain_column", "self_standing_case", "case_w_columns"]
+    )
+    def test_hybrid_w_case_ac(self, decl_base, exprtype):
+        """test #9728"""
+
+        class Employees(decl_base):
+            __tablename__ = "employees"
+            id = Column(String(128), primary_key=True)
+            first_name = Column(String(length=64))
+
+            @hybrid_property
+            def name(self):
+                return self.first_name
+
+            @name.expression
+            def name(
+                cls,
+            ):
+                if exprtype.plain_column:
+                    return cls.first_name
+                elif exprtype.self_standing_case:
+                    return case(
+                        (column("x") == 1, column("q")),
+                        else_=column("q"),
+                    )
+                elif exprtype.case_w_columns:
+                    return case(
+                        (column("x") == 1, column("q")),
+                        else_=cls.first_name,
+                    )
+                else:
+                    exprtype.fail()
+
+        def go1():
+            employees_2 = aliased(Employees, name="employees_2")
+            stmt = select(employees_2.name)
+            return stmt
+
+        def go2():
+            employees_2 = aliased(Employees, name="employees_2")
+            stmt = select(employees_2)
+            return stmt
+
+        self._run_cache_key_fixture(
+            lambda: stmt_20(go1(), go2()),
+            compare_values=True,
+        )
+
 
 class RoundTripTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"