]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply bindparam escape name to processors dictionary
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 May 2022 12:47:29 +0000 (08:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 25 May 2022 12:53:47 +0000 (08:53 -0400)
Fixed SQL compiler issue where the "bind processing" function for a bound
parameter would not be correctly applied to a bound value if the bound
parameter's name were "escaped". Concretely, this applies, among other
cases, to Oracle when a :class:`.Column` has a name that itself requires
quoting, such that the quoting-required name is then used for the bound
parameters generated within DML statements, and the datatype in use
requires bind processing, such as the :class:`.Enum` datatype.

Fixes: #8053
Change-Id: I39d060a87e240b4ebcfccaa9c535e971b7255d99
(cherry picked from commit 4d58ca05e83048e999059a8c2c2e67cb77abf976)

doc/build/changelog/unreleased_14/8053.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/dialect/oracle/test_dialect.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/8053.rst b/doc/build/changelog/unreleased_14/8053.rst
new file mode 100644 (file)
index 0000000..316b638
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 8053
+
+    Fixed SQL compiler issue where the "bind processing" function for a bound
+    parameter would not be correctly applied to a bound value if the bound
+    parameter's name were "escaped". Concretely, this applies, among other
+    cases, to Oracle when a :class:`.Column` has a name that itself requires
+    quoting, such that the quoting-required name is then used for the bound
+    parameters generated within DML statements, and the datatype in use
+    requires bind processing, such as the :class:`.Enum` datatype.
index bc2d657fb51ed2b55bd3634511e9ec2039131eab..fa158863da9dd5d6995f836f61769d7bcf62a9ab 100644 (file)
@@ -898,8 +898,16 @@ class SQLCompiler(Compiled):
 
     @util.memoized_property
     def _bind_processors(self):
+        _escaped_bind_names = self.escaped_bind_names
+        has_escaped_names = bool(_escaped_bind_names)
+
         return dict(
-            (key, value)
+            (
+                _escaped_bind_names.get(key, key)
+                if has_escaped_names
+                else key,
+                value,
+            )
             for key, value in (
                 (
                     self.bind_names[bindparam],
index d65a6d2b53aa110b62471a95b49bf81ebfceddef..f494b59aeff371d76e65406b9feeb603ea2405af 100644 (file)
@@ -5,6 +5,7 @@ import re
 from sqlalchemy import bindparam
 from sqlalchemy import Computed
 from sqlalchemy import create_engine
+from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import func
@@ -32,6 +33,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.mock import Mock
 from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.suite import test_select
 from sqlalchemy.util import u
@@ -564,6 +566,23 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
             4,
         )
 
+    def test_param_w_processors(self, metadata, connection):
+        """test #8053"""
+
+        SomeEnum = pep435_enum("SomeEnum")
+        one = SomeEnum("one", 1)
+        SomeEnum("two", 2)
+
+        t = Table(
+            "t",
+            metadata,
+            Column("_id", Integer, primary_key=True),
+            Column("_data", Enum(SomeEnum)),
+        )
+        t.create(connection)
+        connection.execute(t.insert(), {"_id": 1, "_data": one})
+        eq_(connection.scalar(select(t.c._data)), one)
+
     def test_numeric_bind_in_crud(self, metadata, connection):
         t = Table("asfd", metadata, Column("100K", Integer))
         t.create(connection)
index 33f84142bc80bde8204e131f983239af8a404267..99addb986d3563fd01da12eb337767a847ff2fa0 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy import Column
 from sqlalchemy import Date
 from sqlalchemy import desc
 from sqlalchemy import distinct
+from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import except_
 from sqlalchemy import exists
@@ -96,6 +97,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.util import u
 
 table1 = table(
@@ -3655,6 +3657,46 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
             s,
         )
 
+    def test_bind_param_escaping(self):
+        """general bind param escape unit tests added as a result of
+        #8053
+        #
+        #"""
+
+        SomeEnum = pep435_enum("SomeEnum")
+        one = SomeEnum("one", 1)
+        SomeEnum("two", 2)
+
+        t = Table(
+            "t",
+            MetaData(),
+            Column("_id", Integer, primary_key=True),
+            Column("_data", Enum(SomeEnum)),
+        )
+
+        class MyCompiler(compiler.SQLCompiler):
+            def bindparam_string(self, name, **kw):
+                kw["escaped_from"] = name
+                return super(MyCompiler, self).bindparam_string(
+                    '"%s"' % name, **kw
+                )
+
+        dialect = default.DefaultDialect()
+        dialect.statement_compiler = MyCompiler
+
+        self.assert_compile(
+            t.insert(),
+            'INSERT INTO t (_id, _data) VALUES (:"_id", :"_data")',
+            dialect=dialect,
+        )
+
+        compiled = t.insert().compile(
+            dialect=dialect, compile_kwargs=dict(compile_keys=("_id", "_data"))
+        )
+        params = compiled.construct_params({"_id": 1, "_data": one})
+        eq_(params, {'"_id"': 1, '"_data"': one})
+        eq_(compiled._bind_processors, {'"_data"': mock.ANY})
+
     def test_expanding_non_expanding_conflict(self):
         """test #8018"""