]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use concat() directly for contains, startswith, endswith
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2022 15:32:27 +0000 (11:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2022 15:35:19 +0000 (11:35 -0400)
Adjusted the SQL compilation for string containment functions
``.contains()``, ``.startswith()``, ``.endswith()`` to force the use of the
string concatenation operator, rather than relying upon the overload of the
addition operator, so that non-standard use of these operators with for
example bytestrings still produces string concatenation operators.

To accommodate this, needed to add a new _rconcat operator function,
which is private, as well as a fallback in concat_op() that works
similarly to Python builtin ops.

Fixes: #8253
Change-Id: I2b7f56492f765742d88cb2a7834ded6a2892bd7e
(cherry picked from commit 85a88df13ab8d217331cf98392544a888b4d7df3)

doc/build/changelog/unreleased_14/8253.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/operators.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_14/8253.rst b/doc/build/changelog/unreleased_14/8253.rst
new file mode 100644 (file)
index 0000000..7496ae9
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8253
+
+    Adjusted the SQL compilation for string containment functions
+    ``.contains()``, ``.startswith()``, ``.endswith()`` to force the use of the
+    string concatenation operator, rather than relying upon the overload of the
+    addition operator, so that non-standard use of these operators with for
+    example bytestrings still produces string concatenation operators.
+
index 667dd7d3de5fc1bc06232d5192337b33ef888089..330f3c3bc86672f19f6a2cb9917af1b0f9644e05 100644 (file)
@@ -2295,37 +2295,37 @@ class SQLCompiler(Compiled):
     def visit_contains_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right).__add__(percent)
+        binary.right = percent.concat(binary.right).concat(percent)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_contains_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right).__add__(percent)
+        binary.right = percent.concat(binary.right).concat(percent)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_startswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__radd__(binary.right)
+        binary.right = percent._rconcat(binary.right)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_startswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__radd__(binary.right)
+        binary.right = percent._rconcat(binary.right)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_endswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right)
+        binary.right = percent.concat(binary.right)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_endswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right)
+        binary.right = percent.concat(binary.right)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_like_op_binary(self, binary, operator, **kw):
index 826b3129384e15aae15aad99049f9ff1fc7a41d2..1da50322967d580d961f9ed69b4e45b48112c6c0 100644 (file)
@@ -466,6 +466,16 @@ class ColumnOperators(Operators):
         """
         return self.operate(concat_op, other)
 
+    def _rconcat(self, other):
+        """Implement an 'rconcat' operator.
+
+        this is for internal use at the moment
+
+        .. versionadded:: 1.4.40
+
+        """
+        return self.reverse_operate(concat_op, other)
+
     def like(self, other, escape=None):
         r"""Implement the ``like`` operator.
 
@@ -1512,7 +1522,12 @@ def filter_op(a, b):
 
 
 def concat_op(a, b):
-    return a.concat(b)
+    try:
+        concat = a.concat
+    except AttributeError:
+        return b._rconcat(a)
+    else:
+        return concat(b)
 
 
 def desc_op(a):
index 116d6b792321bf2592f9bc34674015f83c56948f..62f33c2ec24d04ba63f0d81d4ee3eeffaed51eb8 100644 (file)
@@ -2841,6 +2841,36 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "y"},
         )
 
+    def test_contains_encoded(self):
+        self.assert_compile(
+            column("x").contains(b"y"),
+            "x LIKE '%' || :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_not_contains_encoded(self):
+        self.assert_compile(
+            ~column("x").contains(b"y"),
+            "x NOT LIKE '%' || :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_contains_encoded_mysql(self):
+        self.assert_compile(
+            column("x").contains(b"y"),
+            "x LIKE concat(concat('%%', %s), '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_contains_encoded_mysql(self):
+        self.assert_compile(
+            ~column("x").contains(b"y"),
+            "x NOT LIKE concat(concat('%%', %s), '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
     def test_contains_escape(self):
         self.assert_compile(
             column("x").contains("a%b_c", escape="\\"),
@@ -3004,6 +3034,36 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "a^%b^_c/d^^e"},
         )
 
+    def test_startswith_encoded(self):
+        self.assert_compile(
+            column("x").startswith(b"y"),
+            "x LIKE :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_startswith_encoded_mysql(self):
+        self.assert_compile(
+            column("x").startswith(b"y"),
+            "x LIKE concat(%s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_startswith_encoded(self):
+        self.assert_compile(
+            ~column("x").startswith(b"y"),
+            "x NOT LIKE :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_not_startswith_encoded_mysql(self):
+        self.assert_compile(
+            ~column("x").startswith(b"y"),
+            "x NOT LIKE concat(%s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
     def test_not_startswith(self):
         self.assert_compile(
             ~column("x").startswith("y"),
@@ -3094,6 +3154,28 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "y"},
         )
 
+    def test_endswith_encoded(self):
+        self.assert_compile(
+            column("x").endswith(b"y"),
+            "x LIKE '%' || :x_1",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_endswith_encoded_mysql(self):
+        self.assert_compile(
+            column("x").endswith(b"y"),
+            "x LIKE concat('%%', %s)",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_endswith_encoded(self):
+        self.assert_compile(
+            ~column("x").endswith(b"y"),
+            "x NOT LIKE '%' || :x_1",
+            checkparams={"x_1": b"y"},
+        )
+
     def test_endswith_escape(self):
         self.assert_compile(
             column("x").endswith("a%b_c", escape="\\"),