--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 6016
+
+ Fixed bug where the "percent escaping" feature that occurs with dialects
+ that use the "format" or "pyformat" bound parameter styles was not enabled
+ for the :meth:`.Operations.op` and :meth:`.Operations.custom_op` methods,
+ for custom operators that use percent signs. The percent sign will now be
+ automatically doubled based on the paramstyle as necessary.
+
+
def visit_custom_op_binary(self, element, operator, **kw):
kw["eager_grouping"] = operator.eager_grouping
return self._generate_generic_binary(
- element, " " + operator.opstring + " ", **kw
+ element,
+ " " + self.escape_literal_column(operator.opstring) + " ",
+ **kw
)
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
- element, operator.opstring + " ", **kw
+ element, self.escape_literal_column(operator.opstring) + " ", **kw
)
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
- element, " " + operator.opstring, **kw
+ element, " " + self.escape_literal_column(operator.opstring), **kw
)
def _generate_generic_binary(
return MyInteger
+ @testing.fixture
+ def modulus(self):
+ class MyInteger(Integer):
+ class comparator_factory(Integer.Comparator):
+ def modulus(self):
+ return UnaryExpression(
+ self.expr,
+ modifier=operators.custom_op("%"),
+ type_=MyInteger,
+ )
+
+ def modulus_prefix(self):
+ return UnaryExpression(
+ self.expr,
+ operator=operators.custom_op("%"),
+ type_=MyInteger,
+ )
+
+ return MyInteger
+
+ @testing.combinations(
+ ("format",),
+ ("qmark",),
+ ("named",),
+ ("pyformat",),
+ argnames="paramstyle",
+ )
+ def test_modulus(self, modulus, paramstyle):
+ col = column("somecol", modulus())
+ self.assert_compile(
+ col.modulus(),
+ "somecol %%"
+ if paramstyle in ("format", "pyformat")
+ else "somecol %",
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
+ @testing.combinations(
+ ("format",),
+ ("qmark",),
+ ("named",),
+ ("pyformat",),
+ argnames="paramstyle",
+ )
+ def test_modulus_prefix(self, modulus, paramstyle):
+ col = column("somecol", modulus())
+ self.assert_compile(
+ col.modulus_prefix(),
+ "%% somecol"
+ if paramstyle in ("format", "pyformat")
+ else "% somecol",
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
def test_factorial(self, factorial):
col = column("somecol", factorial())
self.assert_compile(col.factorial(), "somecol !")
):
self.assert_compile(py_op(lhs, rhs), res % sql_op)
+ @testing.combinations(
+ ("format", "mytable.myid %% %s"),
+ ("qmark", "mytable.myid % ?"),
+ ("named", "mytable.myid % :myid_1"),
+ ("pyformat", "mytable.myid %% %(myid_1)s"),
+ )
+ def test_custom_op_percent_escaping(self, paramstyle, expected):
+ expr = self.table1.c.myid.op("%")(5)
+
+ self.assert_compile(
+ expr,
+ expected,
+ dialect=default.DefaultDialect(paramstyle=paramstyle),
+ )
+
class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "default"
from sqlalchemy.testing import is_
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
+from sqlalchemy.testing.util import resolve_lambda
class QueryTest(fixtures.TablesTest):
select(tuple_(users.c.user_id, users.c.user_name)),
)
- def test_like_ops(self, connection):
+ @testing.combinations(
+ (
+ lambda users: select(users.c.user_id).where(
+ users.c.user_name.startswith("apple")
+ ),
+ [(1,)],
+ ),
+ (
+ lambda users: select(users.c.user_id).where(
+ users.c.user_name.contains("i % t")
+ ),
+ [(5,)],
+ ),
+ (
+ lambda users: select(users.c.user_id).where(
+ users.c.user_name.endswith("anas")
+ ),
+ [(3,)],
+ ),
+ (
+ lambda users: select(users.c.user_id).where(
+ users.c.user_name.contains("i % t", escape="&")
+ ),
+ [(5,)],
+ ),
+ argnames="expr,result",
+ )
+ def test_like_ops(self, connection, expr, result):
users = self.tables.users
connection.execute(
users.insert(),
],
)
- for expr, result in (
- (
- select(users.c.user_id).where(
- users.c.user_name.startswith("apple")
- ),
- [(1,)],
- ),
- (
- select(users.c.user_id).where(
- users.c.user_name.contains("i % t")
- ),
- [(5,)],
- ),
- (
- select(users.c.user_id).where(
- users.c.user_name.endswith("anas")
- ),
- [(3,)],
- ),
- (
- select(users.c.user_id).where(
- users.c.user_name.contains("i % t", escape="&")
- ),
- [(5,)],
- ),
- ):
- eq_(connection.execute(expr).fetchall(), result)
+ expr = resolve_lambda(expr, users=users)
+ eq_(connection.execute(expr).fetchall(), result)
@testing.requires.mod_operator_as_percent_sign
@testing.emits_warning(".*now automatically escapes.*")