From: Mike Bayer Date: Wed, 25 May 2022 12:47:29 +0000 (-0400) Subject: apply bindparam escape name to processors dictionary X-Git-Tag: rel_1_4_37~9^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=853d6be759ad79d0d3e1d6a52fc7c9c32c0146ec;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply bindparam escape name to processors dictionary Fixed SQL compiler issue where the "bind processing" function for a bound parameter would not be correctly applied to a bound value if the bound parameter's name were "escaped". Concretely, this applies, among other cases, to Oracle when a :class:`.Column` has a name that itself requires quoting, such that the quoting-required name is then used for the bound parameters generated within DML statements, and the datatype in use requires bind processing, such as the :class:`.Enum` datatype. Fixes: #8053 Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99 (cherry picked from commit 4d58ca05e83048e999059a8c2c2e67cb77abf976) --- diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst new file mode 100644 index 0000000000..316b638594 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8053.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, oracle + :tickets: 8053 + + Fixed SQL compiler issue where the "bind processing" function for a bound + parameter would not be correctly applied to a bound value if the bound + parameter's name were "escaped". Concretely, this applies, among other + cases, to Oracle when a :class:`.Column` has a name that itself requires + quoting, such that the quoting-required name is then used for the bound + parameters generated within DML statements, and the datatype in use + requires bind processing, such as the :class:`.Enum` datatype. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index bc2d657fb5..fa158863da 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -898,8 +898,16 @@ class SQLCompiler(Compiled): @util.memoized_property def _bind_processors(self): + _escaped_bind_names = self.escaped_bind_names + has_escaped_names = bool(_escaped_bind_names) + return dict( - (key, value) + ( + _escaped_bind_names.get(key, key) + if has_escaped_names + else key, + value, + ) for key, value in ( ( self.bind_names[bindparam], diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index d65a6d2b53..f494b59aef 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -5,6 +5,7 @@ import re from sqlalchemy import bindparam from sqlalchemy import Computed from sqlalchemy import create_engine +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func @@ -32,6 +33,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table from sqlalchemy.testing.suite import test_select from sqlalchemy.util import u @@ -564,6 +566,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase): 4, ) + def test_param_w_processors(self, metadata, connection): + """test #8053""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + metadata, + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + t.create(connection) + connection.execute(t.insert(), {"_id": 1, "_data": one}) + eq_(connection.scalar(select(t.c._data)), one) + def test_numeric_bind_in_crud(self, metadata, connection): t = Table("asfd", metadata, Column("100K", Integer)) t.create(connection) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 33f84142bc..99addb986d 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -25,6 +25,7 @@ from sqlalchemy import Column from sqlalchemy import Date from sqlalchemy import desc from sqlalchemy import distinct +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import except_ from sqlalchemy import exists @@ -96,6 +97,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.util import u table1 = table( @@ -3655,6 +3657,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): s, ) + def test_bind_param_escaping(self): + """general bind param escape unit tests added as a result of + #8053 + # + #""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + MetaData(), + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + + class MyCompiler(compiler.SQLCompiler): + def bindparam_string(self, name, **kw): + kw["escaped_from"] = name + return super(MyCompiler, self).bindparam_string( + '"%s"' % name, **kw + ) + + dialect = default.DefaultDialect() + dialect.statement_compiler = MyCompiler + + self.assert_compile( + t.insert(), + 'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")', + dialect=dialect, + ) + + compiled = t.insert().compile( + dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data")) + ) + params = compiled.construct_params({"_id": 1, "_data": one}) + eq_(params, {'"_id"': 1, '"_data"': one}) + eq_(compiled._bind_processors, {'"_data"': mock.ANY}) + def test_expanding_non_expanding_conflict(self): """test #8018"""