]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure bindparam key escaping applied in all cases
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Apr 2021 18:33:50 +0000 (14:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 12 Apr 2021 18:33:50 +0000 (14:33 -0400)
Fixed regression where the :class:`_sql.BindParameter` object would not
properly render for an IN expression (i.e. using the "post compile" feature
in 1.4) if the object were copied from either an internal cloning
operation, or from a pickle operation, and the parameter name contained
spaces or other special characters.

Fixes: #6249
Change-Id: Icd0d4096c8fa4eb1a1d4c20f8a96d8b1ae439f0a

doc/build/changelog/unreleased_14/6249.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/6249.rst b/doc/build/changelog/unreleased_14/6249.rst
new file mode 100644 (file)
index 0000000..7ac94c3
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, regression, sql
+    :tickets: 6249
+
+    Fixed regression where the :class:`_sql.BindParameter` object would not
+    properly render for an IN expression (i.e. using the "post compile" feature
+    in 1.4) if the object were copied from either an internal cloning
+    operation, or from a pickle operation, and the parameter name contained
+    spaces or other special characters.
index 9e1b690888abd744de55edd45ddf734793eb10c1..e97ed252ebab77d270253610806b96ec5c744f7d 100644 (file)
@@ -1363,9 +1363,10 @@ class BindParameter(roles.InElementRole, ColumnElement):
         if unique:
             self.key = _anonymous_label.safe_construct(
                 id(self),
-                re.sub(r"[%\(\) \$]+", "_", key).strip("_")
+                key
                 if key is not None and not isinstance(key, _anonymous_label)
                 else "param",
+                sanitize_key=True,
             )
             self._key_is_anon = True
         elif key:
@@ -1479,7 +1480,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
         c = ClauseElement._clone(self, **kw)
         if not maintain_key and self.unique:
             c.key = _anonymous_label.safe_construct(
-                id(c), c._orig_key or "param"
+                id(c), c._orig_key or "param", sanitize_key=True
             )
         return c
 
@@ -1514,7 +1515,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
         if not self.unique:
             self.unique = True
             self.key = _anonymous_label.safe_construct(
-                id(self), self._orig_key or "param"
+                id(self), self._orig_key or "param", sanitize_key=True
             )
 
     def __getstate__(self):
@@ -1531,7 +1532,7 @@ class BindParameter(roles.InElementRole, ColumnElement):
     def __setstate__(self, state):
         if state.get("unique", False):
             state["key"] = _anonymous_label.safe_construct(
-                id(self), state.get("_orig_key", "param")
+                id(self), state.get("_orig_key", "param"), sanitize_key=True
             )
         self.__dict__.update(state)
 
@@ -5048,9 +5049,14 @@ class _anonymous_label(_truncated_label):
     __slots__ = ()
 
     @classmethod
-    def safe_construct(cls, seed, body, enclosing_label=None):
+    def safe_construct(
+        cls, seed, body, enclosing_label=None, sanitize_key=False
+    ):
         # type: (int, str, Optional[_anonymous_label]) -> _anonymous_label
 
+        if sanitize_key:
+            body = re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+
         label = "%%(%d %s)s" % (seed, body.replace("%", "%%"))
         if enclosing_label:
             label = "%s%s" % (enclosing_label, label)
index 21b5b2d27b52e1a84d684707f82a84bf01a14c9a..e7c6cccca570b75c0790d604b2e09823c05c4032 100644 (file)
@@ -1,3 +1,5 @@
+import re
+
 from sqlalchemy import and_
 from sqlalchemy import bindparam
 from sqlalchemy import case
@@ -38,11 +40,14 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
+from sqlalchemy.util import pickle
 
 A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
 
 
-class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
+class TraversalTest(
+    fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
+):
 
     """test ClauseVisitor's traversal, particularly its
     ability to copy and modify a ClauseElement in place."""
@@ -175,6 +180,49 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
         s2 = vis.traverse(s1)
         eq_(list(s2.selected_columns)[0].anon_label, c1.anon_label)
 
+    @testing.combinations(
+        ("clone",), ("pickle",), ("conv_to_unique"), ("none"), argnames="meth"
+    )
+    @testing.combinations(
+        ("name with space",), ("name with [brackets]",), argnames="name"
+    )
+    def test_bindparam_key_proc_for_copies(self, meth, name):
+        r"""test :ticket:`6249`.
+
+        The key of the bindparam needs spaces and other characters
+        escaped out for the POSTCOMPILE regex to work correctly.
+
+
+        Currently, the bind key reg is::
+
+            re.sub(r"[%\(\) \$]+", "_", body).strip("_")
+
+        and the compiler postcompile reg is::
+
+            re.sub(r"\[POSTCOMPILE_(\S+)\]", process_expanding, self.string)
+
+        Interestingly, brackets in the name seems to work out.
+
+        """
+        expr = column(name).in_([1, 2, 3])
+
+        if meth == "clone":
+            expr = visitors.cloned_traverse(expr, {}, {})
+        elif meth == "pickle":
+            expr = pickle.loads(pickle.dumps(expr))
+        elif meth == "conv_to_unique":
+            expr.right.unique = False
+            expr.right._convert_to_unique()
+
+        token = re.sub(r"[%\(\) \$]+", "_", name).strip("_")
+        self.assert_compile(
+            expr,
+            '"%(name)s" IN (:%(token)s_1_1, '
+            ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token},
+            render_postcompile=True,
+            dialect="default",
+        )
+
     def test_change_in_place(self):
         struct = B(
             A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3")