]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore parameter escaping for public methods
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Jun 2022 01:35:02 +0000 (21:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Jun 2022 14:15:25 +0000 (10:15 -0400)
Adjusted the fix made for :ticket:`8056` which adjusted the escaping of
bound parameter names with special characters such that the escaped names
were translated after the SQL compilation step, which broke a published
recipe on the FAQ illustrating how to merge parameter names into the string
output of a compiled SQL string. The change restores the escaped names that
come from ``compiled.params`` and adds a conditional parameter to
:meth:`.SQLCompiler.construct_params` named ``escape_names`` that defaults
to ``True``, restoring the old behavior by default.

Fixes: #8113
Change-Id: I9cbedb1080bc06d51f287fd2cbf26aaab1c74653
(cherry picked from commit 105cd180856309cf5abf24f59b782a1bcd8210d6)

doc/build/changelog/unreleased_14/8113.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/8113.rst b/doc/build/changelog/unreleased_14/8113.rst
new file mode 100644 (file)
index 0000000..100f9a7
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8113
+
+    Adjusted the fix made for :ticket:`8056` which adjusted the escaping of
+    bound parameter names with special characters such that the escaped names
+    were translated after the SQL compilation step, which broke a published
+    recipe on the FAQ illustrating how to merge parameter names into the string
+    output of a compiled SQL string. The change restores the escaped names that
+    come from ``compiled.params`` and adds a conditional parameter to
+    :meth:`.SQLCompiler.construct_params` named ``escape_names`` that defaults
+    to ``True``, restoring the old behavior by default.
index cc0844e1c3fa95cd850be87eb2663cef954386c3..028c4b0713ad8ebec850f41ca15b97360583c459 100644 (file)
@@ -988,13 +988,15 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         if not parameters:
             self.compiled_parameters = [
                 compiled.construct_params(
-                    extracted_parameters=extracted_parameters
+                    extracted_parameters=extracted_parameters,
+                    escape_names=False,
                 )
             ]
         else:
             self.compiled_parameters = [
                 compiled.construct_params(
                     m,
+                    escape_names=False,
                     _group_number=grp,
                     extracted_parameters=extracted_parameters,
                 )
index 2f3033d7058880850efeaa97300588c10236b745..477c199c1759539bb7a055cfebbae84c3c1245d6 100644 (file)
@@ -490,7 +490,9 @@ class Compiled(object):
 
         return self.string or ""
 
-    def construct_params(self, params=None, extracted_parameters=None):
+    def construct_params(
+        self, params=None, extracted_parameters=None, escape_names=True
+    ):
         """Return the bind params for this compiled object.
 
         :param params: a dict of string/object pairs whose values will
@@ -932,9 +934,12 @@ class SQLCompiler(Compiled):
         _group_number=None,
         _check=True,
         extracted_parameters=None,
+        escape_names=True,
     ):
         """return a dictionary of bind parameter keys and values"""
 
+        has_escaped_names = escape_names and bool(self.escaped_bind_names)
+
         if extracted_parameters:
             # related the bound parameters collected in the original cache key
             # to those collected in the incoming cache key.  They will not have
@@ -965,10 +970,16 @@ class SQLCompiler(Compiled):
         if params:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if bindparam.key in params:
-                    pd[name] = params[bindparam.key]
+                    pd[escaped_name] = params[bindparam.key]
                 elif name in params:
-                    pd[name] = params[name]
+                    pd[escaped_name] = params[name]
 
                 elif _check and bindparam.required:
                     if _group_number:
@@ -993,13 +1004,19 @@ class SQLCompiler(Compiled):
                         value_param = bindparam
 
                     if bindparam.callable:
-                        pd[name] = value_param.effective_value
+                        pd[escaped_name] = value_param.effective_value
                     else:
-                        pd[name] = value_param.value
+                        pd[escaped_name] = value_param.value
             return pd
         else:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
@@ -1021,9 +1038,9 @@ class SQLCompiler(Compiled):
                     value_param = bindparam
 
                 if bindparam.callable:
-                    pd[name] = value_param.effective_value
+                    pd[escaped_name] = value_param.effective_value
                 else:
-                    pd[name] = value_param.value
+                    pd[escaped_name] = value_param.value
             return pd
 
     @util.memoized_instancemethod
@@ -1123,7 +1140,7 @@ class SQLCompiler(Compiled):
         """
 
         if parameters is None:
-            parameters = self.construct_params()
+            parameters = self.construct_params(escape_names=False)
 
         expanded_parameters = {}
         if self.positional:
@@ -4317,7 +4334,9 @@ class DDLCompiler(Compiled):
     def type_compiler(self):
         return self.dialect.type_compiler
 
-    def construct_params(self, params=None, extracted_parameters=None):
+    def construct_params(
+        self, params=None, extracted_parameters=None, escape_names=True
+    ):
         return None
 
     def visit_ddl(self, ddl, **kwargs):
index 4db5f3df9d26f382b855a19945c603d40870a2d3..5953c6449e40451630f460869317d6ea6bccfed6 100644 (file)
@@ -3662,10 +3662,14 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
         """general bind param escape unit tests added as a result of
         #8053.
 
-        However, note that the final application of an escaped param name
+        The final application of an escaped param name
         was moved out of compiler and into DefaultExecutionContext in
         related issue #8056.
 
+        However in #8113 we made this conditional to suit usage recipes
+        posted in the FAQ.
+
+
         """
 
         SomeEnum = pep435_enum("SomeEnum")
@@ -3698,14 +3702,33 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
         compiled = t.insert().compile(
             dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
         )
-        params = compiled.construct_params({"_id": 1, "_data": one})
 
+        # not escaped
+        params = compiled.construct_params(
+            {"_id": 1, "_data": one}, escape_names=False
+        )
         eq_(params, {"_id": 1, "_data": one})
+
+        # escaped by default
+        params = compiled.construct_params({"_id": 1, "_data": one})
+        eq_(params, {'"_id"': 1, '"_data"': one})
+
+        # escaped here as well
+        eq_(compiled.params, {'"_data"': None, '"_id"': None})
+
+        # bind processors aren't part of this
         eq_(compiled._bind_processors, {"_data": mock.ANY})
 
-        # previously, this was:
-        # eq_(params, {'"_id"': 1, '"_data"': one})
-        # eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+        dialect.paramstyle = "pyformat"
+        compiled = t.insert().compile(
+            dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
+        )
+
+        # FAQ recipe works
+        eq_(
+            compiled.string % compiled.params,
+            "INSERT INTO t (_id, _data) VALUES (None, None)",
+        )
 
     def test_expanding_non_expanding_conflict(self):
         """test #8018"""