From: Mike Bayer Date: Thu, 27 Oct 2022 13:28:02 +0000 (-0400) Subject: apply basic escaping to anon_labels unconditionally X-Git-Tag: rel_2_0_0b3~23^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=caa9f0ff98d44359f5162bca8e7fe7bcaa2989a7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply basic escaping to anon_labels unconditionally Fixed issue which prevented the :func:`_sql.literal_column` construct from working properly within the context of a :class:`.Select` construct as well as other potential places where "anonymized labels" might be generated, if the literal expression contained characters which could interfere with format strings, such as open parenthesis, due to an implementation detail of the "anonymous label" structure. Fixes: #8724 Change-Id: I3089124fbd055a011c8a245964258503b717d941 --- diff --git a/doc/build/changelog/unreleased_14/8724.rst b/doc/build/changelog/unreleased_14/8724.rst new file mode 100644 index 0000000000..8329697cee --- /dev/null +++ b/doc/build/changelog/unreleased_14/8724.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 8724 + + Fixed issue which prevented the :func:`_sql.literal_column` construct from + working properly within the context of a :class:`.Select` construct as well + as other potential places where "anonymized labels" might be generated, if + the literal expression contained characters which could interfere with + format strings, such as open parenthesis, due to an implementation detail + of the "anonymous label" structure. + diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8167dc7e45..3f4381c1a8 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -5063,8 +5063,13 @@ class _anonymous_label(_truncated_label): sanitize_key: bool = False, ) -> _anonymous_label: + # need to escape chars that interfere with format + # strings in any case, issue #8724 + body = re.sub(r"[%\(\) \$]+", "_", body) + if sanitize_key: - body = re.sub(r"[%\(\) \$]+", "_", body).strip("_") + # sanitize_key is then an extra step used by BindParameter + body = body.strip("_") label = "%%(%d %s)s" % (seed, body.replace("%", "%%")) if enclosing_label: diff --git a/test/sql/test_labels.py b/test/sql/test_labels.py index d385b9e8d1..42d9c5f003 100644 --- a/test/sql/test_labels.py +++ b/test/sql/test_labels.py @@ -3,6 +3,7 @@ from sqlalchemy import Boolean from sqlalchemy import cast from sqlalchemy import exc as exceptions from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import select @@ -20,6 +21,7 @@ from sqlalchemy.sql.elements import _truncated_label from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import WrapsColumnExpression from sqlalchemy.sql.selectable import LABEL_STYLE_NONE +from sqlalchemy.sql.visitors import prefix_anon_map from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -1038,3 +1040,35 @@ class ColExprLabelTest(fixtures.TestBase, AssertsCompiledSQL): "SOME_COL_THING(some_table.value) " "AS some_table_value FROM some_table", ) + + @testing.combinations( + # the resulting strings are completely arbitrary and are not + # exposed in SQL with current implementations. we want to + # only assert that the operation doesn't fail. It's safe to + # change the assertion cases for this test if the label escaping + # format changes + (literal_column("'(1,2]'"), "'_1,2]'_1"), + (literal_column("))"), "__1"), + (literal_column("'%('"), "'_'_1"), + ) + def test_labels_w_strformat_chars_in_isolation(self, test_case, expected): + """test #8724""" + + pa = prefix_anon_map() + eq_(test_case._anon_key_label % pa, expected) + + @testing.combinations( + ( + select(literal_column("'(1,2]'"), literal_column("'(1,2]'")), + "SELECT '(1,2]', '(1,2]'", + ), + (select(literal_column("))"), literal_column("))")), "SELECT )), ))"), + ( + select(literal_column("'%('"), literal_column("'%('")), + "SELECT '%(', '%('", + ), + ) + def test_labels_w_strformat_chars_in_statements(self, test_case, expected): + """test #8724""" + + self.assert_compile(test_case, expected)