]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix with_expression() cache leak; don't adapt singletons
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Apr 2021 22:53:25 +0000 (18:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Apr 2021 23:41:02 +0000 (19:41 -0400)
Fixed a cache leak involving the :func:`_orm.with_expression` loader
option, where the given SQL expression would not be correctly considered as
part of the cache key.

Additionally, fixed regression involving the corresponding
:func:`_orm.query_expression` feature. While the bug technically exists in
1.3 as well, it was not exposed until 1.4. The "default expr" value of
``null()`` would be rendered when not needed, and additionally was also not
adapted correctly when the ORM rewrites statements such as when using
joined eager loading. The fix ensures "singleton" expressions like ``NULL``
and ``true`` aren't "adapted" to refer to columns in ORM statements, and
additionally ensures that a :func:`_orm.query_expression` with no default
expression doesn't render in the statement if a
:func:`_orm.with_expression` isn't used.

Fixes: #6259
Change-Id: I5a70bc12dadad125bbc4324b64048c8d4a18916c

doc/build/changelog/unreleased_14/6259.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/util.py
test/orm/test_cache_key.py
test/orm/test_core_compilation.py
test/orm/test_deferred.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/6259.rst b/doc/build/changelog/unreleased_14/6259.rst
new file mode 100644 (file)
index 0000000..d827a74
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, regression, orm
+    :tickets: 6259
+
+    Fixed a cache leak involving the :func:`_orm.with_expression` loader
+    option, where the given SQL expression would not be correctly considered as
+    part of the cache key.
+
+    Additionally, fixed regression involving the corresponding
+    :func:`_orm.query_expression` feature. While the bug technically exists in
+    1.3 as well, it was not exposed until 1.4. The "default expr" value of
+    ``null()`` would be rendered when not needed, and additionally was also not
+    adapted correctly when the ORM rewrites statements such as when using
+    joined eager loading. The fix ensures "singleton" expressions like ``NULL``
+    and ``true`` aren't "adapted" to refer to columns in ORM statements, and
+    additionally ensures that a :func:`_orm.query_expression` with no default
+    expression doesn't render in the statement if a
+    :func:`_orm.with_expression` isn't used.
index efaf77a4f1b343b19ad44a0000ab44d2e4b144fd..4936049d4f9704c41290200eee3f9e39872145fb 100644 (file)
@@ -264,7 +264,10 @@ class ExpressionColumnLoader(ColumnLoader):
     def __init__(self, parent, strategy_key):
         super(ExpressionColumnLoader, self).__init__(parent, strategy_key)
 
-        null = sql.null()
+        # compare to the "default" expression that is mapped in
+        # the column.   If it's sql.null, we don't need to render
+        # unless an expr is passed in the options.
+        null = sql.null().label(None)
         self._have_default_expression = any(
             not c.compare(null) for c in self.parent_property.columns
         )
index 4827b375243ce0fd5fa45a4e2c7621bfbd5cf986..2cab0d0f0fa0000627d05b9421e14a0c2e5b427b 100644 (file)
@@ -90,7 +90,10 @@ class Load(Generative, LoaderOption):
             "_context_cache_key",
             visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
         ),
-        ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+        (
+            "local_opts",
+            visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+        ),
     ]
 
     def __init__(self, entity):
@@ -601,7 +604,10 @@ class _UnboundLoad(Load):
         ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
         ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
         ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
-        ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
+        (
+            "local_opts",
+            visitors.ExtendedInternalTraversal.dp_string_multi_dict,
+        ),
     ]
 
     _is_chain_link = False
index 81685dfe0c10dfc66d2ce668e07c325ae5f397f5..d9f05e823bed56f7e583e310eb3f20284c97839c 100644 (file)
@@ -54,6 +54,8 @@ class Immutable(object):
 
 
 class SingletonConstant(Immutable):
+    """Represent SQL constants like NULL, TRUE, FALSE"""
+
     def __new__(cls, *arg, **kw):
         return cls._singleton
 
@@ -63,6 +65,13 @@ class SingletonConstant(Immutable):
         obj.__init__()
         cls._singleton = obj
 
+    # don't proxy singletons.   this means that a SingletonConstant
+    # will never be a "corresponding column" in a statement; the constant
+    # can be named directly and as it is often/usually compared against using
+    # "IS", it can't be adapted to a subquery column in any case.
+    # see :ticket:`6259`.
+    proxy_set = frozenset()
+
 
 def _from_objects(*elements):
     return itertools.chain.from_iterable(
index 4dec30a80cc3167d5f83defe9dc30953b42c0566..85b20a5682189c0ee9a922f2470cc9e4c74d19ca 100644 (file)
@@ -829,6 +829,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
     def _corresponding_column(
         self, col, require_embedded, _seen=util.EMPTY_SET
     ):
+
         newcol = self.selectable.corresponding_column(
             col, require_embedded=require_embedded
         )
index cd06ce56a12ba3d8e272550e8a7075bcb5703ff0..d120b05c05dd7f209e314d00fb87ad19c7302be0 100644 (file)
@@ -1,9 +1,12 @@
 import random
 
+from sqlalchemy import func
 from sqlalchemy import inspect
+from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import defaultload
 from sqlalchemy.orm import defer
@@ -16,6 +19,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import subqueryload
+from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.sql.base import CacheableOptions
@@ -72,6 +76,23 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
             compare_values=True,
         )
 
+    def test_query_expr(self):
+        (User,) = self.classes("User")
+
+        self._run_cache_key_fixture(
+            lambda: (
+                with_expression(User.name, true()),
+                with_expression(User.name, null()),
+                with_expression(User.name, func.foobar()),
+                with_expression(User.name, User.name == "test"),
+                Load(User).with_expression(User.name, true()),
+                Load(User).with_expression(User.name, null()),
+                Load(User).with_expression(User.name, func.foobar()),
+                Load(User).with_expression(User.name, User.name == "test"),
+            ),
+            compare_values=True,
+        )
+
     def test_loader_criteria(self):
         User, Address = self.classes("User", "Address")
 
index a53b15bcbf6b6723f85484d7ffac12da7f1ab264..cbc069a85c99ff77505dacff0e443dafeff1bd2d 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import insert
 from sqlalchemy import literal_column
+from sqlalchemy import null
 from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import testing
@@ -17,6 +18,7 @@ from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import with_expression
 from sqlalchemy.orm import with_polymorphic
+from sqlalchemy.sql import and_
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql.selectable import Join as core_join
 from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
@@ -405,12 +407,50 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
             self.tables.users,
             self.classes.User,
         )
+        addresses, Address = (self.tables.addresses, self.classes.Address)
 
         mapper(
             User,
             users,
-            properties=util.OrderedDict([("value", query_expression())]),
+            properties=util.OrderedDict(
+                [
+                    ("value", query_expression()),
+                ]
+            ),
         )
+        mapper(Address, addresses)
+
+        return User
+
+    @testing.fixture
+    def query_expression_w_joinedload_fixture(self):
+        users, User = (
+            self.tables.users,
+            self.classes.User,
+        )
+        addresses, Address = (self.tables.addresses, self.classes.Address)
+
+        mapper(
+            User,
+            users,
+            properties=util.OrderedDict(
+                [
+                    ("value", query_expression()),
+                    (
+                        "addresses",
+                        relationship(
+                            Address,
+                            primaryjoin=and_(
+                                addresses.c.user_id == users.c.id,
+                                addresses.c.email_address != None,
+                            ),
+                        ),
+                    ),
+                ]
+            ),
+        )
+        mapper(Address, addresses)
+
         return User
 
     @testing.fixture
@@ -528,6 +568,49 @@ class ExtraColsTest(QueryTest, AssertsCompiledSQL):
             "users.name || :name_1 AS foo FROM users) AS anon_1",
         )
 
+    def test_with_expr_three(self, query_expression_w_joinedload_fixture):
+        """test :ticket:`6259`"""
+        User = query_expression_w_joinedload_fixture
+
+        stmt = select(User).options(joinedload(User.addresses)).limit(1)
+
+        # test that the outer IS NULL is rendered
+        # test that the inner query does not include a NULL default
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.id, anon_1.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address FROM "
+            "(SELECT users.id AS id, users.name AS name FROM users "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER "
+            "JOIN addresses AS addresses_1 ON addresses_1.user_id = anon_1.id "
+            "AND addresses_1.email_address IS NOT NULL",
+        )
+
+    def test_with_expr_four(self, query_expression_w_joinedload_fixture):
+        """test :ticket:`6259`"""
+        User = query_expression_w_joinedload_fixture
+
+        stmt = (
+            select(User)
+            .options(
+                with_expression(User.value, null()), joinedload(User.addresses)
+            )
+            .limit(1)
+        )
+
+        # test that the outer IS NULL is rendered, not adapted
+        # test that the inner query includes the NULL we asked for
+        self.assert_compile(
+            stmt,
+            "SELECT anon_2.anon_1, anon_2.id, anon_2.name, "
+            "addresses_1.id AS id_1, addresses_1.user_id, "
+            "addresses_1.email_address FROM (SELECT NULL AS anon_1, "
+            "users.id AS id, users.name AS name FROM users LIMIT :param_1) "
+            "AS anon_2 LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON addresses_1.user_id = anon_2.id "
+            "AND addresses_1.email_address IS NOT NULL",
+        )
+
     def test_joinedload_outermost(self, plain_fixture):
         User, Address = plain_fixture
 
index d528cb9355457e6b3a29b09fd55d5f9c7bb3b2cc..decb456c61a4765f63d29d4889ef8375aea129a7 100644 (file)
@@ -2,6 +2,7 @@ import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -1824,7 +1825,7 @@ class WithExpressionTest(fixtures.DeclarativeMappedTest):
         A = self.classes.A
 
         s = fixture_session()
-        a1 = s.query(A).first()
+        a1 = s.query(A).options(with_expression(A.my_expr, null())).first()
 
         def go():
             eq_(a1.my_expr, None)
index e7c6cccca570b75c0790d604b2e09823c05c4032..e1490adfd395bf22a2ae04703a65ca6b6e49a83d 100644 (file)
@@ -12,11 +12,13 @@ from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
+from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import tuple_
 from sqlalchemy import union
 from sqlalchemy.sql import ClauseElement
@@ -402,6 +404,75 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             select(f), "SELECT t1_1.col1 * :col1_1 AS anon_1 FROM t1 AS t1_1"
         )
 
+    @testing.combinations((null(),), (true(),))
+    def test_dont_adapt_singleton_elements(self, elem):
+        """test :ticket:`6259`"""
+        t1 = table("t1", column("c1"))
+
+        stmt = select(t1.c.c1, elem)
+
+        wherecond = t1.c.c1.is_(elem)
+
+        subq = stmt.subquery()
+
+        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
+        stmt = select(subq).where(adapted_wherecond)
+
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.c1, anon_1.anon_2 FROM (SELECT t1.c1 AS c1, "
+            "%s AS anon_2 FROM t1) AS anon_1 WHERE anon_1.c1 IS %s"
+            % (str(elem), str(elem)),
+            dialect="default_enhanced",
+        )
+
+    def test_adapt_funcs_etc_on_identity_one(self):
+        """Adapting to a function etc. will adapt if its on identity"""
+        t1 = table("t1", column("c1"))
+
+        elem = func.foobar()
+
+        stmt = select(t1.c.c1, elem)
+
+        wherecond = t1.c.c1 == elem
+
+        subq = stmt.subquery()
+
+        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
+        stmt = select(subq).where(adapted_wherecond)
+
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.c1, anon_1.foobar_1 FROM (SELECT t1.c1 AS c1, "
+            "foobar() AS foobar_1 FROM t1) AS anon_1 "
+            "WHERE anon_1.c1 = anon_1.foobar_1",
+            dialect="default_enhanced",
+        )
+
+    def test_adapt_funcs_etc_on_identity_two(self):
+        """Adapting to a function etc. will not adapt if they are different"""
+        t1 = table("t1", column("c1"))
+
+        elem = func.foobar()
+        elem2 = func.foobar()
+
+        stmt = select(t1.c.c1, elem)
+
+        wherecond = t1.c.c1 == elem2
+
+        subq = stmt.subquery()
+
+        adapted_wherecond = sql_util.ClauseAdapter(subq).traverse(wherecond)
+        stmt = select(subq).where(adapted_wherecond)
+
+        self.assert_compile(
+            stmt,
+            "SELECT anon_1.c1, anon_1.foobar_1 FROM (SELECT t1.c1 AS c1, "
+            "foobar() AS foobar_1 FROM t1) AS anon_1 "
+            "WHERE anon_1.c1 = foobar()",
+            dialect="default_enhanced",
+        )
+
     def test_join(self):
         clause = t1.join(t2, t1.c.col2 == t2.c.col2)
         c1 = str(clause)