]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
handle quoted_name instances separately in engine.reflection.cache (Fixes: #11687) 11688/head
authorMasterchen09 <13187726+Masterchen09@users.noreply.github.com>
Sun, 4 Aug 2024 12:08:56 +0000 (14:08 +0200)
committerMasterchen09 <13187726+Masterchen09@users.noreply.github.com.>
Tue, 6 Aug 2024 17:37:04 +0000 (19:37 +0200)
lib/sqlalchemy/engine/reflection.py
test/engine/test_reflection.py

index 02a757379a888c8a8eb8e4382e6bb85f7c7a4fa7..58e3aa390fc8ab576221979ccc93b1a370289699 100644 (file)
@@ -56,6 +56,7 @@ from .. import util
 from ..sql import operators
 from ..sql import schema as sa_schema
 from ..sql.cache_key import _ad_hoc_cache_key_from_args
+from ..sql.elements import quoted_name
 from ..sql.elements import TextClause
 from ..sql.type_api import TypeEngine
 from ..sql.visitors import InternalTraversal
@@ -89,8 +90,16 @@ def cache(
     exclude = {"info_cache", "unreflectable"}
     key = (
         fn.__name__,
-        tuple(a for a in args if isinstance(a, str)),
-        tuple((k, v) for k, v in kw.items() if k not in exclude),
+        tuple(
+            (str(a), a.quote) if isinstance(a, quoted_name) else a
+            for a in args
+            if isinstance(a, str)
+        ),
+        tuple(
+            (k, (str(v), v.quote) if isinstance(v, quoted_name) else v)
+            for k, v in kw.items()
+            if k not in exclude
+        ),
     )
     ret: _R = info_cache.get(key)
     if ret is None:
index 003b457a51a8a232b8b7c36d471663e7ce66eb09..adb4037065512c6f6d3637ffdb568e74e43941ad 100644 (file)
@@ -1,3 +1,4 @@
+import itertools
 import unicodedata
 
 import sqlalchemy as sa
@@ -19,6 +20,8 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.engine import Inspector
+from sqlalchemy.engine.reflection import cache
+from sqlalchemy.sql.elements import quoted_name
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -2494,3 +2497,162 @@ class IncludeColsFksTest(AssertsCompiledSQL, fixtures.TestBase):
             "SELECT b_1.x, b_1.q, b_1.p, b_1.r, b_1.s, b_1.t "
             "FROM b AS b_1 JOIN a ON a.x = b_1.r",
         )
+
+
+class ReflectionCacheTest(fixtures.TestBase):
+    @testing.fixture(params=["arg", "kwarg"])
+    def cache(self, connection, request):
+        dialect = connection.dialect
+        info_cache = {}
+        counter = itertools.count(1)
+
+        @cache
+        def get_cached_name(self, connection, *args, **kw):
+            return next(counter)
+
+        def get_cached_name_via_arg(name):
+            return get_cached_name(
+                dialect, connection, name, info_cache=info_cache
+            )
+
+        def get_cached_name_via_kwarg(name):
+            return get_cached_name(
+                dialect, connection, name=name, info_cache=info_cache
+            )
+
+        if request.param == "arg":
+            yield get_cached_name_via_arg
+        elif request.param == "kwarg":
+            yield get_cached_name_via_kwarg
+        else:
+            assert False
+
+    @testing.fixture(params=[False, True])
+    def quote(self, request):
+        yield request.param
+
+    def test_single_string(self, cache):
+        # new value
+        eq_(cache("name1"), 1)
+
+        # same value, counter not incremented
+        eq_(cache("name1"), 1)
+
+    def test_multiple_string(self, cache):
+        # new value
+        eq_(cache("name1"), 1)
+        eq_(cache("name2"), 2)
+
+        # same values, counter not incremented
+        eq_(cache("name1"), 1)
+        eq_(cache("name2"), 2)
+
+    def test_single_quoted_name(self, cache, quote):
+        # new value
+        eq_(cache(quoted_name("name1", quote=quote)), 1)
+
+        # same value, counter not incremented
+        eq_(cache(quoted_name("name1", quote=quote)), 1)
+
+    def test_multiple_quoted_name(self, cache, quote):
+        # new value
+        eq_(cache(quoted_name("name1", quote=quote)), 1)
+        eq_(cache(quoted_name("name2", quote=quote)), 2)
+
+        # same values, counter not incremented
+        eq_(cache(quoted_name("name1", quote=quote)), 1)
+        eq_(cache(quoted_name("name2", quote=quote)), 2)
+
+    def test_single_quoted_name_and_string(self, cache, quote):
+        # new values
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache("n1"), 2)
+
+        # same values, counter not incremented
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache("n1"), 2)
+
+    def test_multiple_quoted_name_and_string(self, cache, quote):
+        # new values
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n2", quote=quote)), 2)
+        eq_(cache("n1"), 3)
+        eq_(cache("n2"), 4)
+
+        # same values, counter not incremented
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n2", quote=quote)), 2)
+        eq_(cache("n1"), 3)
+        eq_(cache("n2"), 4)
+
+    def test_single_quoted_name_false_true_and_string(self, cache, quote):
+        # new values
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n1", quote=not quote)), 2)
+        eq_(cache("n1"), 3)
+
+        # same values, counter not incremented
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n1", quote=not quote)), 2)
+        eq_(cache("n1"), 3)
+
+    def test_multiple_quoted_name_false_true_and_string(self, cache, quote):
+        # new values
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n2", quote=quote)), 2)
+        eq_(cache(quoted_name("n1", quote=not quote)), 3)
+        eq_(cache(quoted_name("n2", quote=not quote)), 4)
+        eq_(cache("n1"), 5)
+        eq_(cache("n2"), 6)
+
+        # same values, counter not incremented
+        eq_(cache(quoted_name("n1", quote=quote)), 1)
+        eq_(cache(quoted_name("n2", quote=quote)), 2)
+        eq_(cache(quoted_name("n1", quote=not quote)), 3)
+        eq_(cache(quoted_name("n2", quote=not quote)), 4)
+        eq_(cache("n1"), 5)
+        eq_(cache("n2"), 6)
+
+    def test_multiple_quoted_name_false_true_and_string_arg_and_kwarg(
+        self, connection, quote
+    ):
+        dialect = connection.dialect
+        info_cache = {}
+        counter = itertools.count(1)
+
+        @cache
+        def get_cached_name(self, connection, *args, **kw):
+            return next(counter)
+
+        def cache_(*args, **kw):
+            return get_cached_name(
+                dialect, connection, *args, **kw, info_cache=info_cache
+            )
+
+        # new values
+        eq_(cache_(quoted_name("n1", quote=quote)), 1)
+        eq_(cache_(name=quoted_name("n1", quote=quote)), 2)
+        eq_(cache_(quoted_name("n2", quote=quote)), 3)
+        eq_(cache_(name=quoted_name("n2", quote=quote)), 4)
+        eq_(cache_(quoted_name("n1", quote=not quote)), 5)
+        eq_(cache_(name=quoted_name("n1", quote=not quote)), 6)
+        eq_(cache_(quoted_name("n2", quote=not quote)), 7)
+        eq_(cache_(name=quoted_name("n2", quote=not quote)), 8)
+        eq_(cache_("n1"), 9)
+        eq_(cache_(name="n1"), 10)
+        eq_(cache_("n2"), 11)
+        eq_(cache_(name="n2"), 12)
+
+        # same values, counter not incremented
+        eq_(cache_(quoted_name("n1", quote=quote)), 1)
+        eq_(cache_(name=quoted_name("n1", quote=quote)), 2)
+        eq_(cache_(quoted_name("n2", quote=quote)), 3)
+        eq_(cache_(name=quoted_name("n2", quote=quote)), 4)
+        eq_(cache_(quoted_name("n1", quote=not quote)), 5)
+        eq_(cache_(name=quoted_name("n1", quote=not quote)), 6)
+        eq_(cache_(quoted_name("n2", quote=not quote)), 7)
+        eq_(cache_(name=quoted_name("n2", quote=not quote)), 8)
+        eq_(cache_("n1"), 9)
+        eq_(cache_(name="n1"), 10)
+        eq_(cache_("n2"), 11)
+        eq_(cache_(name="n2"), 12)