From 79bde753e47bd86f0199c4aa6a5c2ead1e4aec95 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 9 Mar 2021 13:36:34 -0500 Subject: [PATCH] Apply percent sign escaping to op(), custom_op() Fixed bug where the "percent escaping" feature that occurs with dialects that use the "format" or "pyformat" bound parameter styles was not enabled for the :meth:`.Operations.op` and :meth:`.Operations.custom_op` methods, for custom operators that use percent signs. The percent sign will now be automatically doubled based on the paramstyle as necessary. Fixes: #6016 Change-Id: I285c5fc082481c2ee989edf1b02a83a6087ea26a --- doc/build/changelog/unreleased_14/6016.rst | 11 ++++ lib/sqlalchemy/sql/compiler.py | 8 ++- test/sql/test_operators.py | 69 ++++++++++++++++++++++ test/sql/test_query.py | 59 +++++++++--------- 4 files changed, 116 insertions(+), 31 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6016.rst diff --git a/doc/build/changelog/unreleased_14/6016.rst b/doc/build/changelog/unreleased_14/6016.rst new file mode 100644 index 0000000000..7dd7db7134 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6016.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 6016 + + Fixed bug where the "percent escaping" feature that occurs with dialects + that use the "format" or "pyformat" bound parameter styles was not enabled + for the :meth:`.Operations.op` and :meth:`.Operations.custom_op` methods, + for custom operators that use percent signs. The percent sign will now be + automatically doubled based on the paramstyle as necessary. + + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8f046a8029..0ea251fb4b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2055,17 +2055,19 @@ class SQLCompiler(Compiled): def visit_custom_op_binary(self, element, operator, **kw): kw["eager_grouping"] = operator.eager_grouping return self._generate_generic_binary( - element, " " + operator.opstring + " ", **kw + element, + " " + self.escape_literal_column(operator.opstring) + " ", + **kw ) def visit_custom_op_unary_operator(self, element, operator, **kw): return self._generate_generic_unary_operator( - element, operator.opstring + " ", **kw + element, self.escape_literal_column(operator.opstring) + " ", **kw ) def visit_custom_op_unary_modifier(self, element, operator, **kw): return self._generate_generic_unary_modifier( - element, " " + operator.opstring, **kw + element, " " + self.escape_literal_column(operator.opstring), **kw ) def _generate_generic_binary( diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index a19eb20bc0..f6a13f8ca0 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -326,6 +326,60 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): return MyInteger + @testing.fixture + def modulus(self): + class MyInteger(Integer): + class comparator_factory(Integer.Comparator): + def modulus(self): + return UnaryExpression( + self.expr, + modifier=operators.custom_op("%"), + type_=MyInteger, + ) + + def modulus_prefix(self): + return UnaryExpression( + self.expr, + operator=operators.custom_op("%"), + type_=MyInteger, + ) + + return MyInteger + + @testing.combinations( + ("format",), + ("qmark",), + ("named",), + ("pyformat",), + argnames="paramstyle", + ) + def test_modulus(self, modulus, paramstyle): + col = column("somecol", modulus()) + self.assert_compile( + col.modulus(), + "somecol %%" + if paramstyle in ("format", "pyformat") + else "somecol %", + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + + @testing.combinations( + ("format",), + ("qmark",), + ("named",), + ("pyformat",), + argnames="paramstyle", + ) + def test_modulus_prefix(self, modulus, paramstyle): + col = column("somecol", modulus()) + self.assert_compile( + col.modulus_prefix(), + "%% somecol" + if paramstyle in ("format", "pyformat") + else "% somecol", + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + def test_factorial(self, factorial): col = column("somecol", factorial()) self.assert_compile(col.factorial(), "somecol !") @@ -1950,6 +2004,21 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): ): self.assert_compile(py_op(lhs, rhs), res % sql_op) + @testing.combinations( + ("format", "mytable.myid %% %s"), + ("qmark", "mytable.myid % ?"), + ("named", "mytable.myid % :myid_1"), + ("pyformat", "mytable.myid %% %(myid_1)s"), + ) + def test_custom_op_percent_escaping(self, paramstyle, expected): + expr = self.table1.c.myid.op("%")(5) + + self.assert_compile( + expr, + expected, + dialect=default.DefaultDialect(paramstyle=paramstyle), + ) + class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 913b7f4d1f..33245bfbce 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -34,6 +34,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from sqlalchemy.testing.util import resolve_lambda class QueryTest(fixtures.TablesTest): @@ -173,7 +174,34 @@ class QueryTest(fixtures.TablesTest): select(tuple_(users.c.user_id, users.c.user_name)), ) - def test_like_ops(self, connection): + @testing.combinations( + ( + lambda users: select(users.c.user_id).where( + users.c.user_name.startswith("apple") + ), + [(1,)], + ), + ( + lambda users: select(users.c.user_id).where( + users.c.user_name.contains("i % t") + ), + [(5,)], + ), + ( + lambda users: select(users.c.user_id).where( + users.c.user_name.endswith("anas") + ), + [(3,)], + ), + ( + lambda users: select(users.c.user_id).where( + users.c.user_name.contains("i % t", escape="&") + ), + [(5,)], + ), + argnames="expr,result", + ) + def test_like_ops(self, connection, expr, result): users = self.tables.users connection.execute( users.insert(), @@ -186,33 +214,8 @@ class QueryTest(fixtures.TablesTest): ], ) - for expr, result in ( - ( - select(users.c.user_id).where( - users.c.user_name.startswith("apple") - ), - [(1,)], - ), - ( - select(users.c.user_id).where( - users.c.user_name.contains("i % t") - ), - [(5,)], - ), - ( - select(users.c.user_id).where( - users.c.user_name.endswith("anas") - ), - [(3,)], - ), - ( - select(users.c.user_id).where( - users.c.user_name.contains("i % t", escape="&") - ), - [(5,)], - ), - ): - eq_(connection.execute(expr).fetchall(), result) + expr = resolve_lambda(expr, users=users) + eq_(connection.execute(expr).fetchall(), result) @testing.requires.mod_operator_as_percent_sign @testing.emits_warning(".*now automatically escapes.*") -- 2.47.2