From cac8de9ab2d5fe04954947d96b78ee34522f3a2a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 29 May 2022 12:07:46 -0400 Subject: [PATCH] move bindparam quote application from compiler to default in 296c84313ab29bf9599634f3 for #5653 we generalized Oracle's parameter escaping feature into the compiler, so that it could also work for PostgreSQL. The compiler used quoted names within parameter dictionaries, which then led to the complexity that all functions which interpreted keys from the compiled_params dict had to also quote the param names to use the dictionary. This extra complexity was not added to the ORM peristence.py however, which led to the versioning id feature being broken as well as other areas where persistence.py relies on naming schemes present in context.compiled_params. It also was not added to the "processors" lookup which led to #8053, that added this escaping to that part of the compiler. To both solve the whole problem as well as simplify the compiler quite a bit, move the actual application of the escaped names to be as late as possible, when default.py builds the final list of parameters. This is more similar to how it worked previously where OracleExecutionContext would be late-applying these escaped names. This re-establishes context.compiled_params as deterministically named regardless of dialect in use and moves out the complexity of the quoted param names to be only at the cursor.execute stage. Fixed bug, likely a regression from 1.3, where usage of column names that require bound parameter escaping, more concretely when using Oracle with column names that require quoting such as those that start with an underscore, or in less common cases with some PostgreSQL drivers when using column names that contain percent signs, would cause the ORM versioning feature to not work correctly if the versioning column itself had such a name, as the ORM assumes certain bound parameter naming conventions that were being interfered with via the quotes. This issue is related to :ticket:`8053` and essentially revises the approach towards fixing this, revising the original issue :ticket:`5653` that created the initial implementation for generalized bound-parameter name quoting. Fixes: #8056 Change-Id: I57b064e8f0d070e328b65789c30076f6a0ca0fef --- doc/build/changelog/unreleased_14/8056.rst | 15 +++++++ lib/sqlalchemy/engine/default.py | 31 ++++++++++--- lib/sqlalchemy/sql/compiler.py | 51 +++++++-------------- lib/sqlalchemy/testing/fixtures.py | 4 ++ test/orm/test_versioning.py | 52 ++++++++++++++++++++++ test/sql/test_compiler.py | 19 +++++--- test/sql/test_external_traversal.py | 2 + 7 files changed, 126 insertions(+), 48 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/8056.rst diff --git a/doc/build/changelog/unreleased_14/8056.rst b/doc/build/changelog/unreleased_14/8056.rst new file mode 100644 index 0000000000..a5a61fa321 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8056.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: bug, orm, oracle, postgresql + :tickets: 8056 + + Fixed bug, likely a regression from 1.3, where usage of column names that + require bound parameter escaping, more concretely when using Oracle with + column names that require quoting such as those that start with an + underscore, or in less common cases with some PostgreSQL drivers when using + column names that contain percent signs, would cause the ORM versioning + feature to not work correctly if the versioning column itself had such a + name, as the ORM assumes certain bound parameter naming conventions that + were being interfered with via the quotes. This issue is related to + :ticket:`8053` and essentially revises the approach towards fixing this, + revising the original issue :ticket:`5653` that created the initial + implementation for generalized bound-parameter name quoting. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index fc114efa3a..c188e155c7 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1001,14 +1001,31 @@ class DefaultExecutionContext(ExecutionContext): self.parameters = core_positional_parameters else: core_dict_parameters: MutableSequence[Dict[str, Any]] = [] + escaped_names = compiled.escaped_bind_names + + # note that currently, "expanded" parameters will be present + # in self.compiled_parameters in their quoted form. This is + # slightly inconsistent with the approach taken as of + # #8056 where self.compiled_parameters is meant to contain unquoted + # param names. + d_param: Dict[str, Any] for compiled_params in self.compiled_parameters: - - d_param: Dict[str, Any] = { - key: flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] - for key in compiled_params - } + if escaped_names: + d_param = { + escaped_names.get(key, key): flattened_processors[key]( + compiled_params[key] + ) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } + else: + d_param = { + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + for key in compiled_params + } core_dict_parameters.append(d_param) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 63ed45a969..12a5987177 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1143,16 +1143,11 @@ class SQLCompiler(Compiled): str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] ]: - _escaped_bind_names = self.escaped_bind_names - has_escaped_names = bool(_escaped_bind_names) - # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve return dict( ( - _escaped_bind_names.get(key, key) - if has_escaped_names - else key, + key, value, ) # type: ignore for key, value in ( @@ -1186,8 +1181,6 @@ class SQLCompiler(Compiled): ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" - has_escaped_names = bool(self.escaped_bind_names) - if extracted_parameters: # related the bound parameters collected in the original cache key # to those collected in the incoming cache key. They will not have @@ -1217,16 +1210,10 @@ class SQLCompiler(Compiled): if params: pd = {} for bindparam, name in self.bind_names.items(): - escaped_name = ( - self.escaped_bind_names.get(name, name) - if has_escaped_names - else name - ) - if bindparam.key in params: - pd[escaped_name] = params[bindparam.key] + pd[name] = params[bindparam.key] elif name in params: - pd[escaped_name] = params[name] + pd[name] = params[name] elif _check and bindparam.required: if _group_number: @@ -1251,19 +1238,13 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[escaped_name] = value_param.effective_value + pd[name] = value_param.effective_value else: - pd[escaped_name] = value_param.value + pd[name] = value_param.value return pd else: pd = {} for bindparam, name in self.bind_names.items(): - escaped_name = ( - self.escaped_bind_names.get(name, name) - if has_escaped_names - else name - ) - if _check and bindparam.required: if _group_number: raise exc.InvalidRequestError( @@ -1285,9 +1266,9 @@ class SQLCompiler(Compiled): value_param = bindparam if bindparam.callable: - pd[escaped_name] = value_param.effective_value + pd[name] = value_param.effective_value else: - pd[escaped_name] = value_param.value + pd[name] = value_param.value return pd @util.memoized_instancemethod @@ -1359,6 +1340,7 @@ class SQLCompiler(Compiled): N as a bound parameter. """ + if parameters is None: parameters = self.construct_params() @@ -1435,7 +1417,12 @@ class SQLCompiler(Compiled): # process it. the single name is being replaced with # individual numbered parameters for each value in the # param. - values = parameters.pop(escaped_name) + # + # note we are also inserting *escaped* parameter names + # into the given dictionary. default dialect will + # use these param names directly as they will not be + # in the escaped_bind_names dictionary. + values = parameters.pop(name) leep = self._literal_execute_expanding_parameter to_update, replacement_expr = leep( @@ -1541,15 +1528,7 @@ class SQLCompiler(Compiled): @util.memoized_property def _within_exec_param_key_getter(self) -> Callable[[Any], str]: getter = self._get_bind_name_for_col - if self.escaped_bind_names: - - def _get(obj): - key = getter(obj) - return self.escaped_bind_names.get(key, key) - - return _get - else: - return getter + return getter @util.memoized_property @util.preload_module("sqlalchemy.engine.result") diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 4b53661860..ae7a42488e 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -178,6 +178,10 @@ class TestBase: return go + @config.fixture + def fixture_session(self): + return fixture_session() + @config.fixture() def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" diff --git a/test/orm/test_versioning.py b/test/orm/test_versioning.py index 9f23073d41..4898cb1228 100644 --- a/test/orm/test_versioning.py +++ b/test/orm/test_versioning.py @@ -2006,3 +2006,55 @@ class VersioningMappedSelectTest(fixtures.MappedTest): f1.value = "f2" f1.version_id = 2 s1.flush() + + +class QuotedBindVersioningTest(fixtures.MappedTest): + """test for #8056""" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "version_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + # will need parameter quoting for Oracle and PostgreSQL + # dont use 'key' to make sure no the awkward name is definitely + # in the params + Column("_version%id", Integer, nullable=False), + Column("value", String(40), nullable=False), + ) + + @classmethod + def setup_classes(cls): + class Foo(cls.Basic): + pass + + @classmethod + def setup_mappers(cls): + Foo = cls.classes.Foo + vt = cls.tables.version_table + cls.mapper_registry.map_imperatively( + Foo, + vt, + version_id_col=vt.c["_version%id"], + properties={"version": vt.c["_version%id"]}, + ) + + def test_round_trip(self, fixture_session): + Foo = self.classes.Foo + + f1 = Foo(value="v1") + fixture_session.add(f1) + fixture_session.commit() + + f1.value = "v2" + with conditional_sane_rowcount_warnings( + update=True, only_returning=True + ): + fixture_session.commit() + + eq_(f1.version, 2) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4e40ae0a22..6ad2aa2c14 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3749,9 +3749,13 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): def test_bind_param_escaping(self): """general bind param escape unit tests added as a result of - #8053 - # - #""" + #8053. + + However, note that the final application of an escaped param name + was moved out of compiler and into DefaultExecutionContext in + related issue #8056. + + """ SomeEnum = pep435_enum("SomeEnum") one = SomeEnum("one", 1) @@ -3784,8 +3788,13 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data")) ) params = compiled.construct_params({"_id": 1, "_data": one}) - eq_(params, {'"_id"': 1, '"_data"': one}) - eq_(compiled._bind_processors, {'"_data"': mock.ANY}) + + eq_(params, {"_id": 1, "_data": one}) + eq_(compiled._bind_processors, {"_data": mock.ANY}) + + # previously, this was: + # eq_(params, {'"_id"': 1, '"_data"': one}) + # eq_(compiled._bind_processors, {'"_data"': mock.ANY}) def test_expanding_non_expanding_conflict(self): """test #8018""" diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 13116225cb..5e46808b3a 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -197,6 +197,8 @@ class TraversalTest( def test_bindparam_key_proc_for_copies(self, meth, name): r"""test :ticket:`6249`. + Revised for :ticket:`8056`. + The key of the bindparam needs spaces and other characters escaped out for the POSTCOMPILE regex to work correctly. -- 2.47.2