]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore parameter escaping for public methods
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Jun 2022 01:35:02 +0000 (21:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 9 Jun 2022 14:14:13 +0000 (10:14 -0400)
Adjusted the fix made for :ticket:`8056` which adjusted the escaping of
bound parameter names with special characters such that the escaped names
were translated after the SQL compilation step, which broke a published
recipe on the FAQ illustrating how to merge parameter names into the string
output of a compiled SQL string. The change restores the escaped names that
come from ``compiled.params`` and adds a conditional parameter to
:meth:`.SQLCompiler.construct_params` named ``escape_names`` that defaults
to ``True``, restoring the old behavior by default.

Fixes: #8113
Change-Id: I9cbedb1080bc06d51f287fd2cbf26aaab1c74653

doc/build/changelog/unreleased_14/8113.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/8113.rst b/doc/build/changelog/unreleased_14/8113.rst
new file mode 100644 (file)
index 0000000..100f9a7
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8113
+
+    Adjusted the fix made for :ticket:`8056` which adjusted the escaping of
+    bound parameter names with special characters such that the escaped names
+    were translated after the SQL compilation step, which broke a published
+    recipe on the FAQ illustrating how to merge parameter names into the string
+    output of a compiled SQL string. The change restores the escaped names that
+    come from ``compiled.params`` and adds a conditional parameter to
+    :meth:`.SQLCompiler.construct_params` named ``escape_names`` that defaults
+    to ``True``, restoring the old behavior by default.
index 6b76601ffe9dcf325e110bfd714b9def1ed42c6f..df35e7128c3d4a1b2ab9c528e3892d6a90df1cac 100644 (file)
@@ -943,13 +943,15 @@ class DefaultExecutionContext(ExecutionContext):
         if not parameters:
             self.compiled_parameters = [
                 compiled.construct_params(
-                    extracted_parameters=extracted_parameters
+                    extracted_parameters=extracted_parameters,
+                    escape_names=False,
                 )
             ]
         else:
             self.compiled_parameters = [
                 compiled.construct_params(
                     m,
+                    escape_names=False,
                     _group_number=grp,
                     extracted_parameters=extracted_parameters,
                 )
index 78c6af38bad8bf46364c24c5e4735f7c5989a2fe..8ce0c65e42ce865ad72ad06c9ac1367c2be9ea42 100644 (file)
@@ -633,6 +633,7 @@ class Compiled:
         self,
         params: Optional[_CoreSingleExecuteParams] = None,
         extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+        escape_names: bool = True,
     ) -> Optional[_MutableCoreSingleExecuteParams]:
         """Return the bind params for this compiled object.
 
@@ -1176,11 +1177,14 @@ class SQLCompiler(Compiled):
         self,
         params: Optional[_CoreSingleExecuteParams] = None,
         extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+        escape_names: bool = True,
         _group_number: Optional[int] = None,
         _check: bool = True,
     ) -> _MutableCoreSingleExecuteParams:
         """return a dictionary of bind parameter keys and values"""
 
+        has_escaped_names = escape_names and bool(self.escaped_bind_names)
+
         if extracted_parameters:
             # related the bound parameters collected in the original cache key
             # to those collected in the incoming cache key.  They will not have
@@ -1210,10 +1214,16 @@ class SQLCompiler(Compiled):
         if params:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if bindparam.key in params:
-                    pd[name] = params[bindparam.key]
+                    pd[escaped_name] = params[bindparam.key]
                 elif name in params:
-                    pd[name] = params[name]
+                    pd[escaped_name] = params[name]
 
                 elif _check and bindparam.required:
                     if _group_number:
@@ -1238,13 +1248,19 @@ class SQLCompiler(Compiled):
                         value_param = bindparam
 
                     if bindparam.callable:
-                        pd[name] = value_param.effective_value
+                        pd[escaped_name] = value_param.effective_value
                     else:
-                        pd[name] = value_param.value
+                        pd[escaped_name] = value_param.value
             return pd
         else:
             pd = {}
             for bindparam, name in self.bind_names.items():
+                escaped_name = (
+                    self.escaped_bind_names.get(name, name)
+                    if has_escaped_names
+                    else name
+                )
+
                 if _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
@@ -1266,9 +1282,9 @@ class SQLCompiler(Compiled):
                     value_param = bindparam
 
                 if bindparam.callable:
-                    pd[name] = value_param.effective_value
+                    pd[escaped_name] = value_param.effective_value
                 else:
-                    pd[name] = value_param.value
+                    pd[escaped_name] = value_param.value
             return pd
 
     @util.memoized_instancemethod
@@ -1342,7 +1358,7 @@ class SQLCompiler(Compiled):
         """
 
         if parameters is None:
-            parameters = self.construct_params()
+            parameters = self.construct_params(escape_names=False)
 
         expanded_parameters = {}
         positiontup: Optional[List[str]]
@@ -4895,6 +4911,7 @@ class DDLCompiler(Compiled):
         self,
         params: Optional[_CoreSingleExecuteParams] = None,
         extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
+        escape_names: bool = True,
     ) -> Optional[_MutableCoreSingleExecuteParams]:
         return None
 
index 94c38548f7a69798ce35ea065fedc98c442f74f2..930f32b7bf43cf8aff6d2b3d3655f2886f0e8342 100644 (file)
@@ -3752,10 +3752,14 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
         """general bind param escape unit tests added as a result of
         #8053.
 
-        However, note that the final application of an escaped param name
+        The final application of an escaped param name
         was moved out of compiler and into DefaultExecutionContext in
         related issue #8056.
 
+        However in #8113 we made this conditional to suit usage recipes
+        posted in the FAQ.
+
+
         """
 
         SomeEnum = pep435_enum("SomeEnum")
@@ -3788,14 +3792,33 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
         compiled = t.insert().compile(
             dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
         )
-        params = compiled.construct_params({"_id": 1, "_data": one})
 
+        # not escaped
+        params = compiled.construct_params(
+            {"_id": 1, "_data": one}, escape_names=False
+        )
         eq_(params, {"_id": 1, "_data": one})
+
+        # escaped by default
+        params = compiled.construct_params({"_id": 1, "_data": one})
+        eq_(params, {'"_id"': 1, '"_data"': one})
+
+        # escaped here as well
+        eq_(compiled.params, {'"_data"': None, '"_id"': None})
+
+        # bind processors aren't part of this
         eq_(compiled._bind_processors, {"_data": mock.ANY})
 
-        # previously, this was:
-        # eq_(params, {'"_id"': 1, '"_data"': one})
-        # eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+        dialect.paramstyle = "pyformat"
+        compiled = t.insert().compile(
+            dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
+        )
+
+        # FAQ recipe works
+        eq_(
+            compiled.string % compiled.params,
+            "INSERT INTO t (_id, _data) VALUES (None, None)",
+        )
 
     def test_expanding_non_expanding_conflict(self):
         """test #8018"""