From 43d94273a5b13a89226e60de4b958d5b4ac7ff78 Mon Sep 17 00:00:00 2001 From: Masterchen09 <13187726+Masterchen09@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:08:56 +0200 Subject: [PATCH] handle quoted_name instances separately in engine.reflection.cache (Fixes: #11687) --- lib/sqlalchemy/engine/reflection.py | 13 ++- test/engine/test_reflection.py | 162 ++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 02a757379a..58e3aa390f 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -56,6 +56,7 @@ from .. import util from ..sql import operators from ..sql import schema as sa_schema from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import quoted_name from ..sql.elements import TextClause from ..sql.type_api import TypeEngine from ..sql.visitors import InternalTraversal @@ -89,8 +90,16 @@ def cache( exclude = {"info_cache", "unreflectable"} key = ( fn.__name__, - tuple(a for a in args if isinstance(a, str)), - tuple((k, v) for k, v in kw.items() if k not in exclude), + tuple( + (str(a), a.quote) if isinstance(a, quoted_name) else a + for a in args + if isinstance(a, str) + ), + tuple( + (k, (str(v), v.quote) if isinstance(v, quoted_name) else v) + for k, v in kw.items() + if k not in exclude + ), ) ret: _R = info_cache.get(key) if ret is None: diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 003b457a51..adb4037065 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -1,3 +1,4 @@ +import itertools import unicodedata import sqlalchemy as sa @@ -19,6 +20,8 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import UniqueConstraint from sqlalchemy.engine import Inspector +from sqlalchemy.engine.reflection import cache +from sqlalchemy.sql.elements import quoted_name from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -2494,3 +2497,162 @@ class IncludeColsFksTest(AssertsCompiledSQL, fixtures.TestBase): "SELECT b_1.x, b_1.q, b_1.p, b_1.r, b_1.s, b_1.t " "FROM b AS b_1 JOIN a ON a.x = b_1.r", ) + + +class ReflectionCacheTest(fixtures.TestBase): + @testing.fixture(params=["arg", "kwarg"]) + def cache(self, connection, request): + dialect = connection.dialect + info_cache = {} + counter = itertools.count(1) + + @cache + def get_cached_name(self, connection, *args, **kw): + return next(counter) + + def get_cached_name_via_arg(name): + return get_cached_name( + dialect, connection, name, info_cache=info_cache + ) + + def get_cached_name_via_kwarg(name): + return get_cached_name( + dialect, connection, name=name, info_cache=info_cache + ) + + if request.param == "arg": + yield get_cached_name_via_arg + elif request.param == "kwarg": + yield get_cached_name_via_kwarg + else: + assert False + + @testing.fixture(params=[False, True]) + def quote(self, request): + yield request.param + + def test_single_string(self, cache): + # new value + eq_(cache("name1"), 1) + + # same value, counter not incremented + eq_(cache("name1"), 1) + + def test_multiple_string(self, cache): + # new value + eq_(cache("name1"), 1) + eq_(cache("name2"), 2) + + # same values, counter not incremented + eq_(cache("name1"), 1) + eq_(cache("name2"), 2) + + def test_single_quoted_name(self, cache, quote): + # new value + eq_(cache(quoted_name("name1", quote=quote)), 1) + + # same value, counter not incremented + eq_(cache(quoted_name("name1", quote=quote)), 1) + + def test_multiple_quoted_name(self, cache, quote): + # new value + eq_(cache(quoted_name("name1", quote=quote)), 1) + eq_(cache(quoted_name("name2", quote=quote)), 2) + + # same values, counter not incremented + eq_(cache(quoted_name("name1", quote=quote)), 1) + eq_(cache(quoted_name("name2", quote=quote)), 2) + + def test_single_quoted_name_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache("n1"), 2) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache("n1"), 2) + + def test_multiple_quoted_name_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache("n1"), 3) + eq_(cache("n2"), 4) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache("n1"), 3) + eq_(cache("n2"), 4) + + def test_single_quoted_name_false_true_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n1", quote=not quote)), 2) + eq_(cache("n1"), 3) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n1", quote=not quote)), 2) + eq_(cache("n1"), 3) + + def test_multiple_quoted_name_false_true_and_string(self, cache, quote): + # new values + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache(quoted_name("n1", quote=not quote)), 3) + eq_(cache(quoted_name("n2", quote=not quote)), 4) + eq_(cache("n1"), 5) + eq_(cache("n2"), 6) + + # same values, counter not incremented + eq_(cache(quoted_name("n1", quote=quote)), 1) + eq_(cache(quoted_name("n2", quote=quote)), 2) + eq_(cache(quoted_name("n1", quote=not quote)), 3) + eq_(cache(quoted_name("n2", quote=not quote)), 4) + eq_(cache("n1"), 5) + eq_(cache("n2"), 6) + + def test_multiple_quoted_name_false_true_and_string_arg_and_kwarg( + self, connection, quote + ): + dialect = connection.dialect + info_cache = {} + counter = itertools.count(1) + + @cache + def get_cached_name(self, connection, *args, **kw): + return next(counter) + + def cache_(*args, **kw): + return get_cached_name( + dialect, connection, *args, **kw, info_cache=info_cache + ) + + # new values + eq_(cache_(quoted_name("n1", quote=quote)), 1) + eq_(cache_(name=quoted_name("n1", quote=quote)), 2) + eq_(cache_(quoted_name("n2", quote=quote)), 3) + eq_(cache_(name=quoted_name("n2", quote=quote)), 4) + eq_(cache_(quoted_name("n1", quote=not quote)), 5) + eq_(cache_(name=quoted_name("n1", quote=not quote)), 6) + eq_(cache_(quoted_name("n2", quote=not quote)), 7) + eq_(cache_(name=quoted_name("n2", quote=not quote)), 8) + eq_(cache_("n1"), 9) + eq_(cache_(name="n1"), 10) + eq_(cache_("n2"), 11) + eq_(cache_(name="n2"), 12) + + # same values, counter not incremented + eq_(cache_(quoted_name("n1", quote=quote)), 1) + eq_(cache_(name=quoted_name("n1", quote=quote)), 2) + eq_(cache_(quoted_name("n2", quote=quote)), 3) + eq_(cache_(name=quoted_name("n2", quote=quote)), 4) + eq_(cache_(quoted_name("n1", quote=not quote)), 5) + eq_(cache_(name=quoted_name("n1", quote=not quote)), 6) + eq_(cache_(quoted_name("n2", quote=not quote)), 7) + eq_(cache_(name=quoted_name("n2", quote=not quote)), 8) + eq_(cache_("n1"), 9) + eq_(cache_(name="n1"), 10) + eq_(cache_("n2"), 11) + eq_(cache_(name="n2"), 12) -- 2.47.2