]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure anon_map is passed for most annotated traversals
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Nov 2022 22:01:58 +0000 (17:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Nov 2022 20:24:23 +0000 (15:24 -0500)
We can cache the annotated cache key for Table, but
for selectables it's not safe, as it fails to pass the
anon_map along and creates many redudant structures in
observed test scenario.  It is likely safe for a
Column that's mapped to a Table also, however this is
not implemented here.   Will have to see if that part
needs adjusting.

Fixed critical memory issue identified in cache key generation, where for
very large and complex ORM statements that make use of lots of ORM aliases
with subqueries, cache key generation could produce excessively large keys
that were orders of magnitude bigger than the statement itself. Much thanks
to Rollo Konig Brock for their very patient, long term help in finally
identifying this issue.

Also within TypeEngine objects, when we generate elements
for instance variables, skip the None elements at least.
this also saves on tuple complexity.

Fixes: #8790
Change-Id: I448ddbfb45ae0a648815be8dad4faad7d1977427
(cherry picked from commit 88c240d907a9ae3b5caf766009edd196a30cece3)

12 files changed:
doc/build/changelog/unreleased_14/8790.rst [new file with mode: 0644]
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/util.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/orm/test_cache_key.py

diff --git a/doc/build/changelog/unreleased_14/8790.rst b/doc/build/changelog/unreleased_14/8790.rst
new file mode 100644 (file)
index 0000000..a321480
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8790
+
+    Fixed critical memory issue identified in cache key generation, where for
+    very large and complex ORM statements that make use of lots of ORM aliases
+    with subqueries, cache key generation could produce excessively large keys
+    that were orders of magnitude bigger than the statement itself. Much thanks
+    to Rollo Konig Brock for their very patient, long term help in finally
+    identifying this issue.
index 5c000ed6c3fdc61d9fa9b756d3659f8d165a1ea3..01b5a53a6e372eb6ab9117bc7674ff4c35720228 100644 (file)
@@ -26,12 +26,16 @@ class SupportsAnnotations(object):
     @util.memoized_property
     def _annotations_cache_key(self):
         anon_map_ = anon_map()
+
+        return self._gen_annotations_cache_key(anon_map_)
+
+    def _gen_annotations_cache_key(self, anon_map):
         return (
             "_annotations",
             tuple(
                 (
                     key,
-                    value._gen_cache_key(anon_map_, [])
+                    value._gen_cache_key(anon_map, [])
                     if isinstance(value, HasCacheKey)
                     else value,
                 )
index eb5bc5a0087fb35120d07c2782822299f00b50a7..72486e749ab558161489a8885e8b8f53b2e0e7c3 100644 (file)
@@ -203,7 +203,8 @@ class ClauseElement(
 
     is_clause_element = True
     is_selectable = False
-
+    _gen_static_annotations_cache_key = False
+    _is_table = False
     _is_textual = False
     _is_from_clause = False
     _is_returns_rows = False
@@ -3079,7 +3080,7 @@ class Cast(WrapsColumnExpression, ColumnElement):
 
     _traverse_internals = [
         ("clause", InternalTraversal.dp_clauseelement),
-        ("typeclause", InternalTraversal.dp_clauseelement),
+        ("type", InternalTraversal.dp_type),
     ]
 
     def __init__(self, expression, type_):
@@ -3880,7 +3881,20 @@ class BinaryExpression(ColumnElement):
         (
             "type",
             InternalTraversal.dp_type,
-        ),  # affects JSON CAST operators
+        ),
+    ]
+
+    _cache_key_traversal = [
+        ("left", InternalTraversal.dp_clauseelement),
+        ("right", InternalTraversal.dp_clauseelement),
+        ("operator", InternalTraversal.dp_operator),
+        ("modifiers", InternalTraversal.dp_plain_dict),
+        # "type" affects JSON CAST operators, so while redundant in most cases,
+        # is needed for that one
+        (
+            "type",
+            InternalTraversal.dp_type,
+        ),
     ]
 
     _is_implicitly_boolean = True
@@ -4016,6 +4030,10 @@ class Grouping(GroupedElement, ColumnElement):
         ("type", InternalTraversal.dp_type),
     ]
 
+    _cache_key_traversal = [
+        ("element", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(self, element):
         self.element = element
         self.type = getattr(element, "type", type_api.NULLTYPE)
@@ -4516,6 +4534,11 @@ class Label(roles.LabeledColumnExprRole, ColumnElement):
         ("_element", InternalTraversal.dp_clauseelement),
     ]
 
+    _cache_key_traversal = [
+        ("name", InternalTraversal.dp_anon_name),
+        ("_element", InternalTraversal.dp_clauseelement),
+    ]
+
     def __init__(self, name, element, type_=None):
         """Return a :class:`Label` object for the
         given :class:`_expression.ColumnElement`.
index dde665cbde73becf46cf00503992c838c671f289..8198a829839a26706c2a0e37e36456e1b7d76f9b 100644 (file)
@@ -544,6 +544,8 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
         ("schema", InternalTraversal.dp_string)
     ]
 
+    _is_table = True
+
     def _gen_cache_key(self, anon_map, bindparams):
         if self._annotations:
             return (self,) + self._annotations_cache_key
@@ -1810,6 +1812,17 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
 
     """
 
+    @util.memoized_property
+    def _gen_static_annotations_cache_key(self):
+        """special attribute used by cache key gen, if true, we will
+        use a static cache key for the annotations dictionary, else we
+        will generate a new cache key for annotations each time.
+
+        Added for #8790
+
+        """
+        return self.table is not None and self.table._is_table
+
     def _extra_kwargs(self, **kwargs):
         self._validate_dialect_kwargs(kwargs)
 
index 9da61ab28cb69d75a15a710e054b4eaee4784ff9..21aa17a0a640d1fbb6ff8e23e406141c6dd44228 100644 (file)
@@ -246,12 +246,16 @@ class HasCacheKey(object):
                             else None,
                         )
                     elif meth is InternalTraversal.dp_annotations_key:
-                        # obj is here is the _annotations dict.   however, we
-                        # want to use the memoized cache key version of it. for
-                        # Columns, this should be long lived.   For select()
-                        # statements, not so much, but they usually won't have
-                        # annotations.
-                        result += self._annotations_cache_key
+                        # obj is here is the _annotations dict.  Table uses
+                        # a memoized version of it.  however in other cases,
+                        # we generate it given anon_map as we may be from a
+                        # Join, Aliased, etc.
+                        # see #8790
+
+                        if self._gen_static_annotations_cache_key:  # type: ignore  # noqa: E501
+                            result += self._annotations_cache_key  # type: ignore  # noqa: E501
+                        else:
+                            result += self._gen_annotations_cache_key(anon_map)  # type: ignore  # noqa: E501
                     elif (
                         meth is InternalTraversal.dp_clauseelement_list
                         or meth is InternalTraversal.dp_clauseelement_tuple
index 29dc74971c847684d3bc33350c9e1dbb753f9756..30fc4189bba184a40ace5e3e60d984f6d9cdee4e 100644 (file)
@@ -745,7 +745,9 @@ class TypeEngine(Traversible):
                 else self.__dict__[k],
             )
             for k in names
-            if k in self.__dict__ and not k.startswith("_")
+            if k in self.__dict__
+            and not k.startswith("_")
+            and self.__dict__[k] is not None
         )
 
     def adapt(self, cls, **kw):
index 80d344faf1eff483b2e8f48c384587eba58cc10c..73b43f04bd43e6e2fcc716edf62a93ade3d7a90a 100644 (file)
@@ -28,6 +28,7 @@ from .assertions import expect_raises
 from .assertions import expect_raises_message
 from .assertions import expect_warnings
 from .assertions import in_
+from .assertions import int_within_variance
 from .assertions import is_
 from .assertions import is_false
 from .assertions import is_instance_of
@@ -48,6 +49,7 @@ from .config import combinations_list
 from .config import db
 from .config import fixture
 from .config import requirements as requires
+from .config import skip_test
 from .exclusions import _is_excluded
 from .exclusions import _server_version
 from .exclusions import against as _against
index 9a3c06b029039b9f01946758256a4972bccb7f4e..ba6ee14c3b58ad41dfe68fd33714a0e65f0e0e81 100644 (file)
@@ -243,6 +243,17 @@ def _assert_no_stray_pool_connections():
     engines.testing_reaper.assert_all_closed()
 
 
+def int_within_variance(expected, received, variance):
+    deviance = int(expected * variance)
+    assert (
+        abs(received - expected) < deviance
+    ), "Given int value %s is not within %d%% of expected value %s" % (
+        received,
+        variance * 100,
+        expected,
+    )
+
+
 def eq_regex(a, b, msg=None):
     assert re.match(b, a), msg or "%r !~ %r" % (a, b)
 
index be89bc6e4488512038ce591e05501962df981af1..9baf1014b0efdfdd4b932d2d8708a70042fe0032 100644 (file)
@@ -5,10 +5,13 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from collections import deque
 import decimal
 import gc
+from itertools import chain
 import random
 import sys
+from sys import getsizeof
 import types
 
 from . import config
@@ -456,3 +459,63 @@ def teardown_events(event_cls):
             event_cls._clear()
 
     return decorate
+
+
+def total_size(o):
+    """Returns the approximate memory footprint an object and all of its
+    contents.
+
+    source: https://code.activestate.com/recipes/577504/
+
+
+    """
+
+    def dict_handler(d):
+        return chain.from_iterable(d.items())
+
+    all_handlers = {
+        tuple: iter,
+        list: iter,
+        deque: iter,
+        dict: dict_handler,
+        set: iter,
+        frozenset: iter,
+    }
+    seen = set()  # track which object id's have already been seen
+    default_size = getsizeof(0)  # estimate sizeof object without __sizeof__
+
+    def sizeof(o):
+        if id(o) in seen:  # do not double count the same object
+            return 0
+        seen.add(id(o))
+        s = getsizeof(o, default_size)
+
+        for typ, handler in all_handlers.items():
+            if isinstance(o, typ):
+                s += sum(map(sizeof, handler(o)))
+                break
+        return s
+
+    return sizeof(o)
+
+
+def count_cache_key_tuples(tup):
+    """given a cache key tuple, counts how many instances of actual
+    tuples are found.
+
+    used to alert large jumps in cache key complexity.
+
+    """
+    stack = [tup]
+
+    sentinel = object()
+    num_elements = 0
+
+    while stack:
+        elem = stack.pop(0)
+        if elem is sentinel:
+            num_elements += 1
+        elif isinstance(elem, tuple):
+            if elem:
+                stack = list(elem) + [sentinel] + stack
+    return num_elements
index 33427e3b504faae34d7b775712a12969f2302734..d6ce649034414a57b31132fb3660e818ea394233 100644 (file)
@@ -77,7 +77,9 @@ from .compat import perf_counter
 from .compat import pickle
 from .compat import print_
 from .compat import py2k
+from .compat import py310
 from .compat import py311
+from .compat import py312
 from .compat import py37
 from .compat import py38
 from .compat import py39
index 21a9491f8e671d52377240cb3b02a824f8fd3d93..2c2a1a77ae838b0389d2610f48f73e33f39ac81b 100644 (file)
@@ -14,7 +14,9 @@ import operator
 import platform
 import sys
 
+py312 = sys.version_info >= (3, 12)
 py311 = sys.version_info >= (3, 11)
+py310 = sys.version_info >= (3, 10)
 py39 = sys.version_info >= (3, 9)
 py38 = sys.version_info >= (3, 8)
 py37 = sys.version_info >= (3, 7)
index 23fec61d2a09da9c3b5d37ab7e97f6c3bc411d15..169df909ec512e1e441bf4349c3615fa4fdf279f 100644 (file)
@@ -8,11 +8,14 @@ from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import null
 from sqlalchemy import select
+from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
 from sqlalchemy import update
+from sqlalchemy import util
+from sqlalchemy.ext.declarative import ConcreteBase
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import defaultload
@@ -37,10 +40,15 @@ from sqlalchemy.sql.visitors import InternalTraversal
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import int_within_variance
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import DeclarativeMappedTest
 from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.testing.util import count_cache_key_tuples
+from sqlalchemy.testing.util import total_size
 from test.orm import _fixtures
 from .inheritance import _poly_fixtures
+from .test_events import _RemoveListeners
 from .test_query import QueryTest
 
 
@@ -1032,3 +1040,80 @@ class CompositeTest(fixtures.MappedTest):
         )
 
         eq_(stmt._generate_cache_key(), stmt2._generate_cache_key())
+
+
+class EmbeddedSubqTest(_RemoveListeners, DeclarativeMappedTest):
+    """test #8790.
+
+    it's expected that cache key structures will change, this test is here
+    testing something fairly similar to the issue we had (though vastly
+    smaller scale) so we mostly want to look for surprise jumps here.
+
+    """
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Employee(ConcreteBase, Base):
+            __tablename__ = "employee"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "employee",
+                "concrete": True,
+            }
+
+        class Manager(Employee):
+            __tablename__ = "manager"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            manager_data = Column(String(40))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "manager",
+                "concrete": True,
+            }
+
+        class Engineer(Employee):
+            __tablename__ = "engineer"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            engineer_info = Column(String(40))
+
+            __mapper_args__ = {
+                "polymorphic_identity": "engineer",
+                "concrete": True,
+            }
+
+    @testing.combinations("tuples", "memory", argnames="assert_on")
+    def test_cache_key_gen(self, assert_on):
+        Employee = self.classes.Employee
+
+        e1 = aliased(Employee)
+        e2 = aliased(Employee)
+
+        subq = select(e1).union_all(select(e2)).subquery()
+
+        anno = aliased(Employee, subq)
+
+        stmt = select(anno)
+
+        ck = stmt._generate_cache_key()
+
+        if assert_on == "tuples":
+            # before the fix for #8790 this was 700
+            int_within_variance(142, count_cache_key_tuples(ck), 0.05)
+
+        elif assert_on == "memory":
+            # before the fix for #8790 this was 55154
+
+            if util.py312:
+                testing.skip_test("python platform not available")
+            elif util.py311:
+                int_within_variance(39996, total_size(ck), 0.05)
+            elif util.py310:
+                int_within_variance(29796, total_size(ck), 0.05)
+            else:
+                testing.skip_test("python platform not available")