]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure anon_map is passed for most annotated traversals
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Nov 2022 22:01:58 +0000 (17:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2022 20:25:35 +0000 (15:25 -0500)
We can cache the annotated cache key for Table, but
for selectables it's not safe, as it fails to pass the
anon_map along and creates many redudant structures in
observed test scenario.  It is likely safe for a
Column that's mapped to a Table also, however this is
not implemented here.   Will have to see if that part
needs adjusting.

Fixed critical memory issue identified in cache key generation, where for
very large and complex ORM statements that make use of lots of ORM aliases
with subqueries, cache key generation could produce excessively large keys
that were orders of magnitude bigger than the statement itself. Much thanks
to Rollo Konig Brock for their very patient, long term help in finally
identifying this issue.

Also within TypeEngine objects, when we generate elements
for instance variables, skip the None elements at least.
this also saves on tuple complexity.

Fixes: #8790
Change-Id: I448ddbfb45ae0a648815be8dad4faad7d1977427

12 files changed:
doc/build/changelog/unreleased_14/8790.rst [new file with mode: 0644]
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/util.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/orm/test_cache_key.py

diff --git a/doc/build/changelog/unreleased_14/8790.rst b/doc/build/changelog/unreleased_14/8790.rst
new file mode 100644 (file)
index 0000000..a321480
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8790
+
+    Fixed critical memory issue identified in cache key generation, where for
+    very large and complex ORM statements that make use of lots of ORM aliases
+    with subqueries, cache key generation could produce excessively large keys
+    that were orders of magnitude bigger than the statement itself. Much thanks
+    to Rollo Konig Brock for their very patient, long term help in finally
+    identifying this issue.
index 262048bd1d5ccf623247a8f6a5ef0e575d10e21c..43ca84abb3d0c2bf3ff351415c7e2dc892eedbe3 100644 (file)
@@ -94,12 +94,18 @@ class SupportsAnnotations(ExternallyTraversible):
     @util.memoized_property
     def _annotations_cache_key(self) -> Tuple[Any, ...]:
         anon_map_ = anon_map()
+
+        return self._gen_annotations_cache_key(anon_map_)
+
+    def _gen_annotations_cache_key(
+        self, anon_map: anon_map
+    ) -> Tuple[Any, ...]:
         return (
             "_annotations",
             tuple(
                 (
                     key,
-                    value._gen_cache_key(anon_map_, [])
+                    value._gen_cache_key(anon_map, [])
                     if isinstance(value, HasCacheKey)
                     else value,
                 )
index 88148285c58df055352592bf8f8d059fe1d90c72..39d09d3ab53b66d95ff3ce3a088ebc720b37215a 100644 (file)
@@ -297,12 +297,17 @@ class HasCacheKey:
                             else None,
                         )
                     elif meth is InternalTraversal.dp_annotations_key:
-                        # obj is here is the _annotations dict.   however, we
-                        # want to use the memoized cache key version of it. for
-                        # Columns, this should be long lived.   For select()
-                        # statements, not so much, but they usually won't have
-                        # annotations.
-                        result += self._annotations_cache_key  # type: ignore
+                        # obj is here is the _annotations dict.  Table uses
+                        # a memoized version of it.  however in other cases,
+                        # we generate it given anon_map as we may be from a
+                        # Join, Aliased, etc.
+                        # see #8790
+
+                        if self._gen_static_annotations_cache_key:  # type: ignore  # noqa: E501
+                            result += self._annotations_cache_key  # type: ignore  # noqa: E501
+                        else:
+                            result += self._gen_annotations_cache_key(anon_map)  # type: ignore  # noqa: E501
+
                     elif (
                         meth is InternalTraversal.dp_clauseelement_list
                         or meth is InternalTraversal.dp_clauseelement_tuple
index 3b70e8d4e99491a74113f16e971b0662e319478a..6a5aa7db928d37fe76b0666d9e45ad4a2a4833db 100644 (file)
@@ -334,6 +334,7 @@ class ClauseElement(
     _is_column_element = False
     _is_keyed_column_element = False
     _is_table = False
+    _gen_static_annotations_cache_key = False
     _is_textual = False
     _is_from_clause = False
     _is_returns_rows = False
@@ -3224,7 +3225,7 @@ class Cast(WrapsColumnExpression[_T]):
 
     _traverse_internals: _TraverseInternalsType = [
         ("clause", InternalTraversal.dp_clauseelement),
-        ("typeclause", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
     ]
 
     clause: ColumnElement[Any]
@@ -3631,7 +3632,20 @@ class BinaryExpression(OperatorExpression[_T]):
         (
             "type",
             InternalTraversal.dp_type,
-        ),  # affects JSON CAST operators
+        ),
+    ]
+
+    _cache_key_traversal = [
+        ("left", InternalTraversal.dp_clauseelement),
+        ("right", InternalTraversal.dp_clauseelement),
+        ("operator", InternalTraversal.dp_operator),
+        ("modifiers", InternalTraversal.dp_plain_dict),
+        # "type" affects JSON CAST operators, so while redundant in most cases,
+        # is needed for that one
+        (
+            "type",
+            InternalTraversal.dp_type,
+        ),
     ]
 
     _is_implicitly_boolean = True
@@ -3816,6 +3830,10 @@ class Grouping(GroupedElement, ColumnElement[_T]):
         ("type", InternalTraversal.dp_type),
     ]
 
+    _cache_key_traversal = [
+        ("element", InternalTraversal.dp_clauseelement),
+    ]
+
     element: Union[TextClause, ClauseList, ColumnElement[_T]]
 
     def __init__(
@@ -4322,6 +4340,11 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
         ("_element", InternalTraversal.dp_clauseelement),
     ]
 
+    _cache_key_traversal = [
+        ("name", InternalTraversal.dp_anon_name),
+        ("_element", InternalTraversal.dp_clauseelement),
+    ]
+
     _element: ColumnElement[_T]
     name: str
 
index 36c33868af40c72b95c52785be6da2fd19191869..dde5cd37299573fc5b57fb33c6f51fcab738e062 100644 (file)
@@ -2023,6 +2023,17 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
 
     identity: Optional[Identity]
 
+    @util.memoized_property
+    def _gen_static_annotations_cache_key(self) -> bool:  # type: ignore
+        """special attribute used by cache key gen, if true, we will
+        use a static cache key for the annotations dictionary, else we
+        will generate a new cache key for annotations each time.
+
+        Added for #8790
+
+        """
+        return self.table is not None and self.table._is_table
+
     def _extra_kwargs(self, **kwargs: Any) -> None:
         self._validate_dialect_kwargs(kwargs)
 
index 90320701ea6d64b42e23455a07600032a0385efc..cd57ee3b6436ad3957c41a2c0cdd1a30499baef3 100644 (file)
@@ -940,7 +940,9 @@ class TypeEngine(Visitable, Generic[_T]):
                 else self.__dict__[k],
             )
             for k in names
-            if k in self.__dict__ and not k.startswith("_")
+            if k in self.__dict__
+            and not k.startswith("_")
+            and self.__dict__[k] is not None
         )
 
     @overload
index 0c83cb469afe63c1af0e887facf1b4879e8414aa..76445a44425dec17171efc1de6767539b82cdb1c 100644 (file)
@@ -31,6 +31,7 @@ from .assertions import expect_raises
 from .assertions import expect_raises_message
 from .assertions import expect_warnings
 from .assertions import in_
+from .assertions import int_within_variance
 from .assertions import is_
 from .assertions import is_false
 from .assertions import is_instance_of
index 44e7e892f8c66bb1eca4a9def1396c7a39353087..321c05b4465821d38f4a3f93456e6fb51dcb5a6e 100644 (file)
@@ -236,6 +236,17 @@ def _assert_no_stray_pool_connections():
     engines.testing_reaper.assert_all_closed()
 
 
+def int_within_variance(expected, received, variance):
+    deviance = int(expected * variance)
+    assert (
+        abs(received - expected) < deviance
+    ), "Given int value %s is not within %d%% of expected value %s" % (
+        received,
+        variance * 100,
+        expected,
+    )
+
+
 def eq_regex(a, b, msg=None):
     assert re.match(b, a), msg or "%r !~ %r" % (a, b)
 
index 6fd42af702c1a70a13965812ed8a88611e078f9d..74b1ca99289f327cade5f68388305ce16fd43299 100644 (file)
@@ -9,10 +9,13 @@
 
 from __future__ import annotations
 
+from collections import deque
 import decimal
 import gc
+from itertools import chain
 import random
 import sys
+from sys import getsizeof
 import types
 
 from . import config
@@ -459,3 +462,63 @@ def teardown_events(event_cls):
             event_cls._clear()
 
     return decorate
+
+
+def total_size(o):
+    """Returns the approximate memory footprint an object and all of its
+    contents.
+
+    source: https://code.activestate.com/recipes/577504/
+
+
+    """
+
+    def dict_handler(d):
+        return chain.from_iterable(d.items())
+
+    all_handlers = {
+        tuple: iter,
+        list: iter,
+        deque: iter,
+        dict: dict_handler,
+        set: iter,
+        frozenset: iter,
+    }
+    seen = set()  # track which object id's have already been seen
+    default_size = getsizeof(0)  # estimate sizeof object without __sizeof__
+
+    def sizeof(o):
+        if id(o) in seen:  # do not double count the same object
+            return 0
+        seen.add(id(o))
+        s = getsizeof(o, default_size)
+
+        for typ, handler in all_handlers.items():
+            if isinstance(o, typ):
+                s += sum(map(sizeof, handler(o)))
+                break
+        return s
+
+    return sizeof(o)
+
+
+def count_cache_key_tuples(tup):
+    """given a cache key tuple, counts how many instances of actual
+    tuples are found.
+
+    used to alert large jumps in cache key complexity.
+
+    """
+    stack = [tup]
+
+    sentinel = object()
+    num_elements = 0
+
+    while stack:
+        elem = stack.pop(0)
+        if elem is sentinel:
+            num_elements += 1
+        elif isinstance(elem, tuple):
+            if elem:
+                stack = list(elem) + [sentinel] + stack
+    return num_elements
index 4952cb5011e96c35ab99c2b926f56030980856e5..bb4642a4ff66239ce524d9e97ccbccb032effe10 100644 (file)
@@ -62,6 +62,7 @@ from .compat import local_dataclass_fields as local_dataclass_fields
 from .compat import osx as osx
 from .compat import py310 as py310
 from .compat import py311 as py311
+from .compat import py312 as py312
 from .compat import py38 as py38
 from .compat import py39 as py39
 from .compat import pypy as pypy
index cda5ab6c12551b85c2a7f820f5de474092168b22..2899b425867341c57a40f4105cd2dd8d084cafe7 100644 (file)
@@ -30,6 +30,7 @@ from typing import Tuple
 from typing import Type
 
 
+py312 = sys.version_info >= (3, 12)
 py311 = sys.version_info >= (3, 11)
 py310 = sys.version_info >= (3, 10)
 py39 = sys.version_info >= (3, 9)
index 70770089c3faf918db018bf2d00ebbd11a7b0296..3106a71ad76d1039c65318fd89da35593fd7e9d2 100644 (file)
@@ -8,11 +8,14 @@ from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import select
+from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy import update
+from sqlalchemy import util
+from sqlalchemy.ext.declarative import ConcreteBase
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import defaultload
@@ -37,10 +40,15 @@ from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import int_within_variance
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import DeclarativeMappedTest
 from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.testing.util import count_cache_key_tuples
+from sqlalchemy.testing.util import total_size
 from test.orm import _fixtures
 from .inheritance import _poly_fixtures
+from .test_events import _RemoveListeners
 from .test_query import QueryTest
 
 
@@ -1037,3 +1045,80 @@ class CompositeTest(fixtures.MappedTest):
         )
 
         eq_(stmt._generate_cache_key(), stmt2._generate_cache_key())
+
+
+class EmbeddedSubqTest(_RemoveListeners, DeclarativeMappedTest):
+    """test #8790.
+
+    it's expected that cache key structures will change, this test is here
+    testing something fairly similar to the issue we had (though vastly
+    smaller scale) so we mostly want to look for surprise jumps here.
+
+    """
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Employee(ConcreteBase, Base):
+            __tablename__ = "employee"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "employee",
+                "concrete": True,
+            }
+
+        class Manager(Employee):
+            __tablename__ = "manager"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            manager_data = Column(String(40))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "manager",
+                "concrete": True,
+            }
+
+        class Engineer(Employee):
+            __tablename__ = "engineer"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            engineer_info = Column(String(40))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "engineer",
+                "concrete": True,
+            }
+
+    @testing.combinations("tuples", "memory", argnames="assert_on")
+    def test_cache_key_gen(self, assert_on):
+        Employee = self.classes.Employee
+
+        e1 = aliased(Employee)
+        e2 = aliased(Employee)
+
+        subq = select(e1).union_all(select(e2)).subquery()
+
+        anno = aliased(Employee, subq)
+
+        stmt = select(anno)
+
+        ck = stmt._generate_cache_key()
+
+        if assert_on == "tuples":
+            # before the fix for #8790 this was 700
+            int_within_variance(142, count_cache_key_tuples(ck), 0.05)
+
+        elif assert_on == "memory":
+            # before the fix for #8790 this was 55154
+
+            if util.py312:
+                testing.skip_test("python platform not available")
+            elif util.py311:
+                int_within_variance(39996, total_size(ck), 0.05)
+            elif util.py310:
+                int_within_variance(29796, total_size(ck), 0.05)
+            else:
+                testing.skip_test("python platform not available")