]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply basic escaping to anon_labels unconditionally
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Oct 2022 13:28:02 +0000 (09:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Oct 2022 14:25:49 +0000 (10:25 -0400)
Fixed issue which prevented the :func:`_sql.literal_column` construct from
working properly within the context of a :class:`.Select` construct as well
as other potential places where "anonymized labels" might be generated, if
the literal expression contained characters which could interfere with
format strings, such as open parenthesis, due to an implementation detail
of the "anonymous label" structure.

Fixes: #8724
Change-Id: I3089124fbd055a011c8a245964258503b717d941
(cherry picked from commit caa9f0ff98d44359f5162bca8e7fe7bcaa2989a7)

doc/build/changelog/unreleased_14/8724.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_labels.py

diff --git a/doc/build/changelog/unreleased_14/8724.rst b/doc/build/changelog/unreleased_14/8724.rst
new file mode 100644 (file)
index 0000000..8329697
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8724
+
+    Fixed issue which prevented the :func:`_sql.literal_column` construct from
+    working properly within the context of a :class:`.Select` construct as well
+    as other potential places where "anonymized labels" might be generated, if
+    the literal expression contained characters which could interfere with
+    format strings, such as open parenthesis, due to an implementation detail
+    of the "anonymous label" structure.
+
index ace43b3a1d4ca7451de8526c6b5cc9669264d05e..eb5bc5a0087fb35120d07c2782822299f00b50a7 100644 (file)
@@ -5371,8 +5371,13 @@ class _anonymous_label(_truncated_label):
         cls, seed, body, enclosing_label=None, sanitize_key=False
     ):
 
+        # need to escape chars that interfere with format
+        # strings in any case, issue #8724
+        body = re.sub(r"[%\(\) \$]+", "_", body)
+
         if sanitize_key:
-            body = re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+            # sanitize_key is then an extra step used by BindParameter
+            body = body.strip("_")
 
         label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
         if enclosing_label:
index d385b9e8d14993245c4978a0ed926b0251269daf..a82b0372eaad6772b720eed26b3e8db0a6d4784f 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import Boolean
 from sqlalchemy import cast
 from sqlalchemy import exc as exceptions
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import select
@@ -16,6 +17,7 @@ from sqlalchemy.sql import column
 from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
 from sqlalchemy.sql import roles
 from sqlalchemy.sql import table
+from sqlalchemy.sql.base import prefix_anon_map
 from sqlalchemy.sql.elements import _truncated_label
 from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.sql.elements import WrapsColumnExpression
@@ -1038,3 +1040,35 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL):
             "SOME_COL_THING(some_table.value) "
             "AS some_table_value FROM some_table",
         )
+
+    @testing.combinations(
+        # the resulting strings are completely arbitrary and are not
+        # exposed in SQL with current implementations.  we want to
+        # only assert that the operation doesn't fail.  It's safe to
+        # change the assertion cases for this test if the label escaping
+        # format changes
+        (literal_column("'(1,2]'"), "'_1,2]'_1"),
+        (literal_column("))"), "__1"),
+        (literal_column("'%('"), "'_'_1"),
+    )
+    def test_labels_w_strformat_chars_in_isolation(self, test_case, expected):
+        """test #8724"""
+
+        pa = prefix_anon_map()
+        eq_(test_case._anon_key_label % pa, expected)
+
+    @testing.combinations(
+        (
+            select(literal_column("'(1,2]'"), literal_column("'(1,2]'")),
+            "SELECT '(1,2]', '(1,2]'",
+        ),
+        (select(literal_column("))"), literal_column("))")), "SELECT )), ))"),
+        (
+            select(literal_column("'%('"), literal_column("'%('")),
+            "SELECT '%(', '%('",
+        ),
+    )
+    def test_labels_w_strformat_chars_in_statements(self, test_case, expected):
+        """test #8724"""
+
+        self.assert_compile(test_case, expected)