From: Mike Bayer Date: Wed, 25 May 2022 12:47:29 +0000 (-0400) Subject: apply bindparam escape name to processors dictionary X-Git-Tag: rel_2_0_0b1~289^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a5d481eaa5bff958692fc3b0024f0b9b1c4f56c6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply bindparam escape name to processors dictionary Fixed SQL compiler issue where the "bind processing" function for a bound parameter would not be correctly applied to a bound value if the bound parameter's name were "escaped". Concretely, this applies, among other cases, to Oracle when a :class:`.Column` has a name that itself requires quoting, such that the quoting-required name is then used for the bound parameters generated within DML statements, and the datatype in use requires bind processing, such as the :class:`.Enum` datatype. Fixes: #8053 Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99 --- diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst new file mode 100644 index 0000000000..316b638594 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8053.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, oracle + :tickets: 8053 + + Fixed SQL compiler issue where the "bind processing" function for a bound + parameter would not be correctly applied to a bound value if the bound + parameter's name were "escaped". Concretely, this applies, among other + cases, to Oracle when a :class:`.Column` has a name that itself requires + quoting, such that the quoting-required name is then used for the bound + parameters generated within DML statements, and the datatype in use + requires bind processing, such as the :class:`.Enum` datatype. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0eae31a1a4..63ed45a969 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1143,10 +1143,18 @@ class SQLCompiler(Compiled): str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] ]: + _escaped_bind_names = self.escaped_bind_names + has_escaped_names = bool(_escaped_bind_names) + # mypy is not able to see the two value types as the above Union, # it just sees "object". don't know how to resolve return dict( - (key, value) # type: ignore + ( + _escaped_bind_names.get(key, key) + if has_escaped_names + else key, + value, + ) # type: ignore for key, value in ( ( self.bind_names[bindparam], diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 26a29b73e6..8d74c1f489 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -7,6 +7,7 @@ from unittest.mock import Mock from sqlalchemy import bindparam from sqlalchemy import Computed from sqlalchemy import create_engine +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func @@ -33,6 +34,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table from sqlalchemy.testing.suite import test_select @@ -527,6 +529,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase): 4, ) + def test_param_w_processors(self, metadata, connection): + """test #8053""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + metadata, + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + t.create(connection) + connection.execute(t.insert(), {"_id": 1, "_data": one}) + eq_(connection.scalar(select(t.c._data)), one) + def test_numeric_bind_in_crud(self, metadata, connection): t = Table("asfd", metadata, Column("100K", Integer)) t.create(connection) diff --git a/test/ext/mypy/plain_files/sql_operations.py b/test/ext/mypy/plain_files/sql_operations.py index b4d0bd0060..0ed0df661d 100644 --- a/test/ext/mypy/plain_files/sql_operations.py +++ b/test/ext/mypy/plain_files/sql_operations.py @@ -56,7 +56,7 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\] reveal_type(expr2) - # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\] + # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, .*\.Decimal\]\] reveal_type(expr3) # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\] diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4e8e2ac139..4e40ae0a22 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -25,6 +25,7 @@ from sqlalchemy import Column from sqlalchemy import Date from sqlalchemy import desc from sqlalchemy import distinct +from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import except_ from sqlalchemy import exists @@ -96,6 +97,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import ne_ +from sqlalchemy.testing.schema import pep435_enum table1 = table( "mytable", @@ -3745,6 +3747,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): s, ) + def test_bind_param_escaping(self): + """general bind param escape unit tests added as a result of + #8053 + # + #""" + + SomeEnum = pep435_enum("SomeEnum") + one = SomeEnum("one", 1) + SomeEnum("two", 2) + + t = Table( + "t", + MetaData(), + Column("_id", Integer, primary_key=True), + Column("_data", Enum(SomeEnum)), + ) + + class MyCompiler(compiler.SQLCompiler): + def bindparam_string(self, name, **kw): + kw["escaped_from"] = name + return super(MyCompiler, self).bindparam_string( + '"%s"' % name, **kw + ) + + dialect = default.DefaultDialect() + dialect.statement_compiler = MyCompiler + + self.assert_compile( + t.insert(), + 'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")', + dialect=dialect, + ) + + compiled = t.insert().compile( + dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data")) + ) + params = compiled.construct_params({"_id": 1, "_data": one}) + eq_(params, {'"_id"': 1, '"_data"': one}) + eq_(compiled._bind_processors, {'"_data"': mock.ANY}) + def test_expanding_non_expanding_conflict(self): """test #8018"""