]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
move bindparam quote application from compiler to default
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 May 2022 16:07:46 +0000 (12:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 May 2022 18:34:25 +0000 (14:34 -0400)
in 296c84313ab29bf9599634f3 for #5653 we generalized Oracle's
parameter escaping feature into the compiler, so that it could also
work for PostgreSQL.  The compiler used quoted names within parameter
dictionaries, which then led to the complexity that all functions
which interpreted keys from the compiled_params dict had to
also quote the param names to use the dictionary.  This
extra complexity was not added to the ORM peristence.py however,
which led to the versioning id feature being broken as well as
other areas where persistence.py relies on naming schemes present
in context.compiled_params.  It also was not added to the
"processors" lookup which led to #8053, that added this escaping
to that part of the compiler.

To both solve the whole problem as well as simplify the compiler
quite a bit, move the actual application of the escaped names
to be as late as possible, when default.py builds the final list
of parameters.  This is more similar to how it worked previously
where OracleExecutionContext would be late-applying these
escaped names.   This re-establishes context.compiled_params as
deterministically named regardless of dialect in use and moves
out the complexity of the quoted param names to be only at the
cursor.execute stage.

Fixed bug, likely a regression from 1.3, where usage of column names that
require bound parameter escaping, more concretely when using Oracle with
column names that require quoting such as those that start with an
underscore, or in less common cases with some PostgreSQL drivers when using
column names that contain percent signs, would cause the ORM versioning
feature to not work correctly if the versioning column itself had such a
name, as the ORM assumes certain bound parameter naming conventions that
were being interfered with via the quotes. This issue is related to
:ticket:`8053` and essentially revises the approach towards fixing this,
revising the original issue :ticket:`5653` that created the initial
implementation for generalized bound-parameter name quoting.

Fixes: #8056
Change-Id: I57b064e8f0d070e328b65789c30076f6a0ca0fef
(cherry picked from commit a48b597d0cafa1dd7fc46be99eb808fd4cb0a347)

doc/build/changelog/unreleased_14/8056.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/fixtures.py
test/orm/test_versioning.py
test/sql/test_compiler.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/8056.rst b/doc/build/changelog/unreleased_14/8056.rst
new file mode 100644 (file)
index 0000000..a5a61fa
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: bug, orm, oracle, postgresql
+    :tickets: 8056
+
+    Fixed bug, likely a regression from 1.3, where usage of column names that
+    require bound parameter escaping, more concretely when using Oracle with
+    column names that require quoting such as those that start with an
+    underscore, or in less common cases with some PostgreSQL drivers when using
+    column names that contain percent signs, would cause the ORM versioning
+    feature to not work correctly if the versioning column itself had such a
+    name, as the ORM assumes certain bound parameter naming conventions that
+    were being interfered with via the quotes. This issue is related to
+    :ticket:`8053` and essentially revises the approach towards fixing this,
+    revising the original issue :ticket:`5653` that created the initial
+    implementation for generalized bound-parameter name quoting.
index 5a1443ecbc1991b7701aeb20e8086913a977e927..cc0844e1c3fa95cd850be87eb2663cef954386c3 100644 (file)
@@ -1079,21 +1079,44 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             if encode:
                 encoder = dialect._encoder
             for compiled_params in self.compiled_parameters:
+                escaped_bind_names = compiled.escaped_bind_names
 
                 if encode:
-                    param = {
-                        encoder(key)[0]: processors[key](compiled_params[key])
-                        if key in processors
-                        else compiled_params[key]
-                        for key in compiled_params
-                    }
+                    if escaped_bind_names:
+                        param = {
+                            encoder(escaped_bind_names.get(key, key))[
+                                0
+                            ]: processors[key](compiled_params[key])
+                            if key in processors
+                            else compiled_params[key]
+                            for key in compiled_params
+                        }
+                    else:
+                        param = {
+                            encoder(key)[0]: processors[key](
+                                compiled_params[key]
+                            )
+                            if key in processors
+                            else compiled_params[key]
+                            for key in compiled_params
+                        }
                 else:
-                    param = {
-                        key: processors[key](compiled_params[key])
-                        if key in processors
-                        else compiled_params[key]
-                        for key in compiled_params
-                    }
+                    if escaped_bind_names:
+                        param = {
+                            escaped_bind_names.get(key, key): processors[key](
+                                compiled_params[key]
+                            )
+                            if key in processors
+                            else compiled_params[key]
+                            for key in compiled_params
+                        }
+                    else:
+                        param = {
+                            key: processors[key](compiled_params[key])
+                            if key in processors
+                            else compiled_params[key]
+                            for key in compiled_params
+                        }
 
                 parameters.append(param)
 
index fa158863da9dd5d6995f836f61769d7bcf62a9ab..2f3033d7058880850efeaa97300588c10236b745 100644 (file)
@@ -898,14 +898,10 @@ class SQLCompiler(Compiled):
 
     @util.memoized_property
     def _bind_processors(self):
-        _escaped_bind_names = self.escaped_bind_names
-        has_escaped_names = bool(_escaped_bind_names)
 
         return dict(
             (
-                _escaped_bind_names.get(key, key)
-                if has_escaped_names
-                else key,
+                key,
                 value,
             )
             for key, value in (
@@ -939,8 +935,6 @@ class SQLCompiler(Compiled):
     ):
         """return a dictionary of bind parameter keys and values"""
 
-        has_escaped_names = bool(self.escaped_bind_names)
-
         if extracted_parameters:
             # related the bound parameters collected in the original cache key
             # to those collected in the incoming cache key.  They will not have
@@ -971,16 +965,10 @@ class SQLCompiler(Compiled):
         if params:
             pd = {}
             for bindparam, name in self.bind_names.items():
-                escaped_name = (
-                    self.escaped_bind_names.get(name, name)
-                    if has_escaped_names
-                    else name
-                )
-
                 if bindparam.key in params:
-                    pd[escaped_name] = params[bindparam.key]
+                    pd[name] = params[bindparam.key]
                 elif name in params:
-                    pd[escaped_name] = params[name]
+                    pd[name] = params[name]
 
                 elif _check and bindparam.required:
                     if _group_number:
@@ -1005,19 +993,13 @@ class SQLCompiler(Compiled):
                         value_param = bindparam
 
                     if bindparam.callable:
-                        pd[escaped_name] = value_param.effective_value
+                        pd[name] = value_param.effective_value
                     else:
-                        pd[escaped_name] = value_param.value
+                        pd[name] = value_param.value
             return pd
         else:
             pd = {}
             for bindparam, name in self.bind_names.items():
-                escaped_name = (
-                    self.escaped_bind_names.get(name, name)
-                    if has_escaped_names
-                    else name
-                )
-
                 if _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
@@ -1039,9 +1021,9 @@ class SQLCompiler(Compiled):
                     value_param = bindparam
 
                 if bindparam.callable:
-                    pd[escaped_name] = value_param.effective_value
+                    pd[name] = value_param.effective_value
                 else:
-                    pd[escaped_name] = value_param.value
+                    pd[name] = value_param.value
             return pd
 
     @util.memoized_instancemethod
@@ -1139,6 +1121,7 @@ class SQLCompiler(Compiled):
           N as a bound parameter.
 
         """
+
         if parameters is None:
             parameters = self.construct_params()
 
@@ -1181,10 +1164,11 @@ class SQLCompiler(Compiled):
                 if self.escaped_bind_names
                 else name
             )
+
             parameter = self.binds[name]
             if parameter in self.literal_execute_params:
                 if escaped_name not in replacement_expressions:
-                    value = parameters.pop(escaped_name)
+                    value = parameters.pop(name)
 
                 replacement_expressions[
                     escaped_name
@@ -1203,7 +1187,12 @@ class SQLCompiler(Compiled):
                     # process it. the single name is being replaced with
                     # individual numbered parameters for each value in the
                     # param.
-                    values = parameters.pop(escaped_name)
+                    #
+                    # note we are also inserting *escaped* parameter names
+                    # into the given dictionary.   default dialect will
+                    # use these param names directly as they will not be
+                    # in the escaped_bind_names dictionary.
+                    values = parameters.pop(name)
 
                     leep = self._literal_execute_expanding_parameter
                     to_update, replacement_expr = leep(
@@ -1301,15 +1290,7 @@ class SQLCompiler(Compiled):
     @util.memoized_property
     def _within_exec_param_key_getter(self):
         getter = self._key_getters_for_crud_column[2]
-        if self.escaped_bind_names:
-
-            def _get(obj):
-                key = getter(obj)
-                return self.escaped_bind_names.get(key, key)
-
-            return _get
-        else:
-            return getter
+        return getter
 
     @util.memoized_property
     @util.preload_module("sqlalchemy.engine.result")
index 8c2e9d8de6caadbc66bf778286fb7f6a58f54737..3f8f749bfe85ae6e45dcd92983bd36cdca3bdfdc 100644 (file)
@@ -138,6 +138,10 @@ class TestBase(object):
 
         return go
 
+    @config.fixture
+    def fixture_session(self):
+        return fixture_session()
+
     @config.fixture()
     def metadata(self, request):
         """Provide bound MetaData for a single test, dropping afterwards."""
index ce01cace7f921bc832d2d292502ef8e8fcbc181f..3073012241020c49444d49a8cca0abee4b8a5a9f 100644 (file)
@@ -2006,3 +2006,55 @@ class VersioningMappedSelectTest(fixtures.MappedTest):
             f1.value = "f2"
             f1.version_id = 2
             s1.flush()
+
+
+class QuotedBindVersioningTest(fixtures.MappedTest):
+    """test for #8056"""
+
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "version_table",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            # will need parameter quoting for Oracle and PostgreSQL
+            # dont use 'key' to make sure no the awkward name is definitely
+            # in the params
+            Column("_version%id", Integer, nullable=False),
+            Column("value", String(40), nullable=False),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Foo(cls.Basic):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        Foo = cls.classes.Foo
+        vt = cls.tables.version_table
+        cls.mapper_registry.map_imperatively(
+            Foo,
+            vt,
+            version_id_col=vt.c["_version%id"],
+            properties={"version": vt.c["_version%id"]},
+        )
+
+    def test_round_trip(self, fixture_session):
+        Foo = self.classes.Foo
+
+        f1 = Foo(value="v1")
+        fixture_session.add(f1)
+        fixture_session.commit()
+
+        f1.value = "v2"
+        with conditional_sane_rowcount_warnings(
+            update=True, only_returning=True
+        ):
+            fixture_session.commit()
+
+        eq_(f1.version, 2)
index 99addb986d3563fd01da12eb337767a847ff2fa0..250e8b30cf99895489d4a8075a2644d2726020e0 100644 (file)
@@ -3659,9 +3659,13 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
 
     def test_bind_param_escaping(self):
         """general bind param escape unit tests added as a result of
-        #8053
-        #
-        #"""
+        #8053.
+
+        However, note that the final application of an escaped param name
+        was moved out of compiler and into DefaultExecutionContext in
+        related issue #8056.
+
+        """
 
         SomeEnum = pep435_enum("SomeEnum")
         one = SomeEnum("one", 1)
@@ -3694,8 +3698,13 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
             dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
         )
         params = compiled.construct_params({"_id": 1, "_data": one})
-        eq_(params, {'"_id"': 1, '"_data"': one})
-        eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+
+        eq_(params, {"_id": 1, "_data": one})
+        eq_(compiled._bind_processors, {"_data": mock.ANY})
+
+        # previously, this was:
+        # eq_(params, {'"_id"': 1, '"_data"': one})
+        # eq_(compiled._bind_processors, {'"_data"': mock.ANY})
 
     def test_expanding_non_expanding_conflict(self):
         """test #8018"""
index c14b8b4c68bd413e76da5d19bab3ae090e5aa524..1695771486a16f9f3010396e0427dfb705487554 100644 (file)
@@ -196,6 +196,8 @@ class TraversalTest(
     def test_bindparam_key_proc_for_copies(self, meth, name):
         r"""test :ticket:`6249`.
 
+        Revised for :ticket:`8056`.
+
         The key of the bindparam needs spaces and other characters
         escaped out for the POSTCOMPILE regex to work correctly.