]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply bindparam escape name to processors dictionary
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 May 2022 12:47:29 +0000 (08:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 May 2022 14:07:31 +0000 (10:07 -0400)
Fixed SQL compiler issue where the "bind processing" function for a bound
parameter would not be correctly applied to a bound value if the bound
parameter's name were "escaped". Concretely, this applies, among other
cases, to Oracle when a :class:`.Column` has a name that itself requires
quoting, such that the quoting-required name is then used for the bound
parameters generated within DML statements, and the datatype in use
requires bind processing, such as the :class:`.Enum` datatype.

Fixes: #8053
Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99

doc/build/changelog/unreleased_14/8053.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/dialect/oracle/test_dialect.py
test/ext/mypy/plain_files/sql_operations.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst
new file mode 100644 (file)
index 0000000..316b638
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 8053
+
+    Fixed SQL compiler issue where the "bind processing" function for a bound
+    parameter would not be correctly applied to a bound value if the bound
+    parameter's name were "escaped". Concretely, this applies, among other
+    cases, to Oracle when a :class:`.Column` has a name that itself requires
+    quoting, such that the quoting-required name is then used for the bound
+    parameters generated within DML statements, and the datatype in use
+    requires bind processing, such as the :class:`.Enum` datatype.
index 0eae31a1a4c8ce36926ef90cc5bab293b7347956..63ed45a969e2bd0819728ba62b61d8dc84ebccb5 100644 (file)
@@ -1143,10 +1143,18 @@ class SQLCompiler(Compiled):
         str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
     ]:
 
+        _escaped_bind_names = self.escaped_bind_names
+        has_escaped_names = bool(_escaped_bind_names)
+
         # mypy is not able to see the two value types as the above Union,
         # it just sees "object".  don't know how to resolve
         return dict(
-            (key, value)  # type: ignore
+            (
+                _escaped_bind_names.get(key, key)
+                if has_escaped_names
+                else key,
+                value,
+            )  # type: ignore
             for key, value in (
                 (
                     self.bind_names[bindparam],
index 26a29b73e620837428156b43f646e51ac7ae64fc..8d74c1f4890eed07ca8a88e6e445943d48552dc4 100644 (file)
@@ -7,6 +7,7 @@ from unittest.mock import Mock
 from sqlalchemy import bindparam
 from sqlalchemy import Computed
 from sqlalchemy import create_engine
+from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import func
@@ -33,6 +34,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.suite import test_select
 
@@ -527,6 +529,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
             4,
         )
 
+    def test_param_w_processors(self, metadata, connection):
+        """test #8053"""
+
+        SomeEnum = pep435_enum("SomeEnum")
+        one = SomeEnum("one", 1)
+        SomeEnum("two", 2)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("_id", Integer, primary_key=True),
+            Column("_data", Enum(SomeEnum)),
+        )
+        t.create(connection)
+        connection.execute(t.insert(), {"_id": 1, "_data": one})
+        eq_(connection.scalar(select(t.c._data)), one)
+
     def test_numeric_bind_in_crud(self, metadata, connection):
         t = Table("asfd", metadata, Column("100K", Integer))
         t.create(connection)
index b4d0bd0060eeb3e630a9d8d3c6b844f762f0f282..0ed0df661d111f825f8f7522047468b3ab813b8e 100644 (file)
@@ -56,7 +56,7 @@ if typing.TYPE_CHECKING:
     # EXPECTED_RE_TYPE: sqlalchemy..*BinaryExpression\[builtins.bool\]
     reveal_type(expr2)
 
-    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, decimal.Decimal\]\]
+    # EXPECTED_RE_TYPE: sqlalchemy..*ColumnElement\[Union\[builtins.float, .*\.Decimal\]\]
     reveal_type(expr3)
 
     # EXPECTED_RE_TYPE: sqlalchemy..*UnaryExpression\[builtins.int.?\]
index 4e8e2ac139a361d1f5c89052aae338dda9d48a8f..4e40ae0a22f52e527551d639c2be9a2142056df9 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy import Column
 from sqlalchemy import Date
 from sqlalchemy import desc
 from sqlalchemy import distinct
+from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import except_
 from sqlalchemy import exists
@@ -96,6 +97,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.schema import pep435_enum
 
 table1 = table(
     "mytable",
@@ -3745,6 +3747,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
             s,
         )
 
+    def test_bind_param_escaping(self):
+        """general bind param escape unit tests added as a result of
+        #8053
+        #
+        #"""
+
+        SomeEnum = pep435_enum("SomeEnum")
+        one = SomeEnum("one", 1)
+        SomeEnum("two", 2)
+
+        t = Table(
+            "t",
+            MetaData(),
+            Column("_id", Integer, primary_key=True),
+            Column("_data", Enum(SomeEnum)),
+        )
+
+        class MyCompiler(compiler.SQLCompiler):
+            def bindparam_string(self, name, **kw):
+                kw["escaped_from"] = name
+                return super(MyCompiler, self).bindparam_string(
+                    '"%s"' % name, **kw
+                )
+
+        dialect = default.DefaultDialect()
+        dialect.statement_compiler = MyCompiler
+
+        self.assert_compile(
+            t.insert(),
+            'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")',
+            dialect=dialect,
+        )
+
+        compiled = t.insert().compile(
+            dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
+        )
+        params = compiled.construct_params({"_id": 1, "_data": one})
+        eq_(params, {'"_id"': 1, '"_data"': one})
+        eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+
     def test_expanding_non_expanding_conflict(self):
         """test #8018"""