]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use concat() directly for contains, startswith, endswith
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2022 15:32:27 +0000 (11:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2022 15:32:27 +0000 (11:32 -0400)
Adjusted the SQL compilation for string containment functions
``.contains()``, ``.startswith()``, ``.endswith()`` to force the use of the
string concatenation operator, rather than relying upon the overload of the
addition operator, so that non-standard use of these operators with for
example bytestrings still produces string concatenation operators.

To accommodate this, needed to add a new _rconcat operator function,
which is private, as well as a fallback in concat_op() that works
similarly to Python builtin ops.

Fixes: #8253
Change-Id: I2b7f56492f765742d88cb2a7834ded6a2892bd7e

doc/build/changelog/unreleased_14/8253.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/operators.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_14/8253.rst b/doc/build/changelog/unreleased_14/8253.rst
new file mode 100644 (file)
index 0000000..7496ae9
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8253
+
+    Adjusted the SQL compilation for string containment functions
+    ``.contains()``, ``.startswith()``, ``.endswith()`` to force the use of the
+    string concatenation operator, rather than relying upon the overload of the
+    addition operator, so that non-standard use of these operators with for
+    example bytestrings still produces string concatenation operators.
+
index a11d83b11cd7e9bf98dfe41fe9c70b0f733a8779..87d031cc269640bcb353beea1a886713f4c7008b 100644 (file)
@@ -2691,37 +2691,37 @@ class SQLCompiler(Compiled):
     def visit_contains_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right).__add__(percent)
+        binary.right = percent.concat(binary.right).concat(percent)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_contains_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right).__add__(percent)
+        binary.right = percent.concat(binary.right).concat(percent)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_startswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__radd__(binary.right)
+        binary.right = percent._rconcat(binary.right)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_startswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__radd__(binary.right)
+        binary.right = percent._rconcat(binary.right)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_endswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right)
+        binary.right = percent.concat(binary.right)
         return self.visit_like_op_binary(binary, operator, **kw)
 
     def visit_not_endswith_op_binary(self, binary, operator, **kw):
         binary = binary._clone()
         percent = self._like_percent_literal
-        binary.right = percent.__add__(binary.right)
+        binary.right = percent.concat(binary.right)
         return self.visit_not_like_op_binary(binary, operator, **kw)
 
     def visit_like_op_binary(self, binary, operator, **kw):
index 2b888769a1486d32403cde9603ec3e82e096c960..44d63b3987556aa01c27c3edf2cb3b441076f1ec 100644 (file)
@@ -615,6 +615,16 @@ class ColumnOperators(Operators):
         """
         return self.operate(concat_op, other)
 
+    def _rconcat(self, other: Any) -> ColumnOperators:
+        """Implement an 'rconcat' operator.
+
+        this is for internal use at the moment
+
+        .. versionadded:: 1.4.40
+
+        """
+        return self.reverse_operate(concat_op, other)
+
     def like(
         self, other: Any, escape: Optional[str] = None
     ) -> ColumnOperators:
@@ -1764,7 +1774,12 @@ def filter_op(a: Any, b: Any) -> Any:
 
 @_operator_fn
 def concat_op(a: Any, b: Any) -> Any:
-    return a.concat(b)
+    try:
+        concat = a.concat
+    except AttributeError:
+        return b._rconcat(a)
+    else:
+        return concat(b)
 
 
 @_operator_fn
index 2411fb0a336fcc71aa0fb1bc713e2f344c4196c1..830a5eb0f1203171865037a9016ceff4ca61b1bc 100644 (file)
@@ -3087,6 +3087,36 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "y"},
         )
 
+    def test_contains_encoded(self):
+        self.assert_compile(
+            column("x").contains(b"y"),
+            "x LIKE '%' || :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_not_contains_encoded(self):
+        self.assert_compile(
+            ~column("x").contains(b"y"),
+            "x NOT LIKE '%' || :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_contains_encoded_mysql(self):
+        self.assert_compile(
+            column("x").contains(b"y"),
+            "x LIKE concat('%%', %s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_contains_encoded_mysql(self):
+        self.assert_compile(
+            ~column("x").contains(b"y"),
+            "x NOT LIKE concat('%%', %s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
     def test_contains_escape(self):
         self.assert_compile(
             column("x").contains("a%b_c", escape="\\"),
@@ -3250,6 +3280,36 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "a^%b^_c/d^^e"},
         )
 
+    def test_startswith_encoded(self):
+        self.assert_compile(
+            column("x").startswith(b"y"),
+            "x LIKE :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_startswith_encoded_mysql(self):
+        self.assert_compile(
+            column("x").startswith(b"y"),
+            "x LIKE concat(%s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_startswith_encoded(self):
+        self.assert_compile(
+            ~column("x").startswith(b"y"),
+            "x NOT LIKE :x_1 || '%'",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_not_startswith_encoded_mysql(self):
+        self.assert_compile(
+            ~column("x").startswith(b"y"),
+            "x NOT LIKE concat(%s, '%%')",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
     def test_not_startswith(self):
         self.assert_compile(
             ~column("x").startswith("y"),
@@ -3324,6 +3384,28 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             checkparams={"x_1": "y"},
         )
 
+    def test_endswith_encoded(self):
+        self.assert_compile(
+            column("x").endswith(b"y"),
+            "x LIKE '%' || :x_1",
+            checkparams={"x_1": b"y"},
+        )
+
+    def test_endswith_encoded_mysql(self):
+        self.assert_compile(
+            column("x").endswith(b"y"),
+            "x LIKE concat('%%', %s)",
+            checkparams={"x_1": b"y"},
+            dialect="mysql",
+        )
+
+    def test_not_endswith_encoded(self):
+        self.assert_compile(
+            ~column("x").endswith(b"y"),
+            "x NOT LIKE '%' || :x_1",
+            checkparams={"x_1": b"y"},
+        )
+
     def test_endswith_escape(self):
         self.assert_compile(
             column("x").endswith("a%b_c", escape="\\"),