]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Apply percent sign escaping to op(), custom_op()
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Mar 2021 18:36:34 +0000 (13:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 9 Mar 2021 18:36:34 +0000 (13:36 -0500)
Fixed bug where the "percent escaping" feature that occurs with dialects
that use the "format" or "pyformat" bound parameter styles was not enabled
for the :meth:`.Operations.op` and :meth:`.Operations.custom_op` methods,
for custom operators that use percent signs. The percent sign will now be
automatically doubled based on the paramstyle as necessary.

Fixes: #6016
Change-Id: I285c5fc082481c2ee989edf1b02a83a6087ea26a

doc/build/changelog/unreleased_14/6016.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_operators.py
test/sql/test_query.py

diff --git a/doc/build/changelog/unreleased_14/6016.rst b/doc/build/changelog/unreleased_14/6016.rst
new file mode 100644 (file)
index 0000000..7dd7db7
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6016
+
+    Fixed bug where the "percent escaping" feature that occurs with dialects
+    that use the "format" or "pyformat" bound parameter styles was not enabled
+    for the :meth:`.Operations.op` and :meth:`.Operations.custom_op` methods,
+    for custom operators that use percent signs. The percent sign will now be
+    automatically doubled based on the paramstyle as necessary.
+
+
index 8f046a8029d8dd5c5df7212a9a82c1c35f1d139f..0ea251fb4b217a28d86b418a76670b4b2deed8b0 100644 (file)
@@ -2055,17 +2055,19 @@ class SQLCompiler(Compiled):
     def visit_custom_op_binary(self, element, operator, **kw):
         kw["eager_grouping"] = operator.eager_grouping
         return self._generate_generic_binary(
-            element, " " + operator.opstring + " ", **kw
+            element,
+            " " + self.escape_literal_column(operator.opstring) + " ",
+            **kw
         )
 
     def visit_custom_op_unary_operator(self, element, operator, **kw):
         return self._generate_generic_unary_operator(
-            element, operator.opstring + " ", **kw
+            element, self.escape_literal_column(operator.opstring) + " ", **kw
         )
 
     def visit_custom_op_unary_modifier(self, element, operator, **kw):
         return self._generate_generic_unary_modifier(
-            element, " " + operator.opstring, **kw
+            element, " " + self.escape_literal_column(operator.opstring), **kw
         )
 
     def _generate_generic_binary(
index a19eb20bc0c09365f719654d9528dbf4a12fc333..f6a13f8ca0f010e64c93034da2f37bb9f2dfca94 100644 (file)
@@ -326,6 +326,60 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         return MyInteger
 
+    @testing.fixture
+    def modulus(self):
+        class MyInteger(Integer):
+            class comparator_factory(Integer.Comparator):
+                def modulus(self):
+                    return UnaryExpression(
+                        self.expr,
+                        modifier=operators.custom_op("%"),
+                        type_=MyInteger,
+                    )
+
+                def modulus_prefix(self):
+                    return UnaryExpression(
+                        self.expr,
+                        operator=operators.custom_op("%"),
+                        type_=MyInteger,
+                    )
+
+        return MyInteger
+
+    @testing.combinations(
+        ("format",),
+        ("qmark",),
+        ("named",),
+        ("pyformat",),
+        argnames="paramstyle",
+    )
+    def test_modulus(self, modulus, paramstyle):
+        col = column("somecol", modulus())
+        self.assert_compile(
+            col.modulus(),
+            "somecol %%"
+            if paramstyle in ("format", "pyformat")
+            else "somecol %",
+            dialect=default.DefaultDialect(paramstyle=paramstyle),
+        )
+
+    @testing.combinations(
+        ("format",),
+        ("qmark",),
+        ("named",),
+        ("pyformat",),
+        argnames="paramstyle",
+    )
+    def test_modulus_prefix(self, modulus, paramstyle):
+        col = column("somecol", modulus())
+        self.assert_compile(
+            col.modulus_prefix(),
+            "%% somecol"
+            if paramstyle in ("format", "pyformat")
+            else "% somecol",
+            dialect=default.DefaultDialect(paramstyle=paramstyle),
+        )
+
     def test_factorial(self, factorial):
         col = column("somecol", factorial())
         self.assert_compile(col.factorial(), "somecol !")
@@ -1950,6 +2004,21 @@ class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ):
             self.assert_compile(py_op(lhs, rhs), res % sql_op)
 
+    @testing.combinations(
+        ("format", "mytable.myid %% %s"),
+        ("qmark", "mytable.myid % ?"),
+        ("named", "mytable.myid % :myid_1"),
+        ("pyformat", "mytable.myid %% %(myid_1)s"),
+    )
+    def test_custom_op_percent_escaping(self, paramstyle, expected):
+        expr = self.table1.c.myid.op("%")(5)
+
+        self.assert_compile(
+            expr,
+            expected,
+            dialect=default.DefaultDialect(paramstyle=paramstyle),
+        )
+
 
 class ComparisonOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
index 913b7f4d1f4925c7929bd61ed7fe109c6e3b4554..33245bfbcedba33dfa4c1ff46f75a4bef370dee1 100644 (file)
@@ -34,6 +34,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from sqlalchemy.testing.util import resolve_lambda
 
 
 class QueryTest(fixtures.TablesTest):
@@ -173,7 +174,34 @@ class QueryTest(fixtures.TablesTest):
             select(tuple_(users.c.user_id, users.c.user_name)),
         )
 
-    def test_like_ops(self, connection):
+    @testing.combinations(
+        (
+            lambda users: select(users.c.user_id).where(
+                users.c.user_name.startswith("apple")
+            ),
+            [(1,)],
+        ),
+        (
+            lambda users: select(users.c.user_id).where(
+                users.c.user_name.contains("i % t")
+            ),
+            [(5,)],
+        ),
+        (
+            lambda users: select(users.c.user_id).where(
+                users.c.user_name.endswith("anas")
+            ),
+            [(3,)],
+        ),
+        (
+            lambda users: select(users.c.user_id).where(
+                users.c.user_name.contains("i % t", escape="&")
+            ),
+            [(5,)],
+        ),
+        argnames="expr,result",
+    )
+    def test_like_ops(self, connection, expr, result):
         users = self.tables.users
         connection.execute(
             users.insert(),
@@ -186,33 +214,8 @@ class QueryTest(fixtures.TablesTest):
             ],
         )
 
-        for expr, result in (
-            (
-                select(users.c.user_id).where(
-                    users.c.user_name.startswith("apple")
-                ),
-                [(1,)],
-            ),
-            (
-                select(users.c.user_id).where(
-                    users.c.user_name.contains("i % t")
-                ),
-                [(5,)],
-            ),
-            (
-                select(users.c.user_id).where(
-                    users.c.user_name.endswith("anas")
-                ),
-                [(3,)],
-            ),
-            (
-                select(users.c.user_id).where(
-                    users.c.user_name.contains("i % t", escape="&")
-                ),
-                [(5,)],
-            ),
-        ):
-            eq_(connection.execute(expr).fetchall(), result)
+        expr = resolve_lambda(expr, users=users)
+        eq_(connection.execute(expr).fetchall(), result)
 
     @testing.requires.mod_operator_as_percent_sign
     @testing.emits_warning(".*now automatically escapes.*")