]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Accommodate escaped_bind_names for defaults/insert params
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Feb 2022 15:12:33 +0000 (10:12 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Feb 2022 15:12:33 +0000 (10:12 -0500)
Fixed issue in Oracle dialect where using a column name that requires
quoting when written as a bound parameter, such as ``"_id"``, would not
correctly track a Python generated default value due to the bound-parameter
rewriting missing this value, causing an Oracle error to be raised.

Fixes: #7676
Change-Id: I5a54426d24f2f9b336e3597d5595fb3e031aad97

doc/build/changelog/unreleased_14/7676.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/dialect/oracle/test_dialect.py

diff --git a/doc/build/changelog/unreleased_14/7676.rst b/doc/build/changelog/unreleased_14/7676.rst
new file mode 100644 (file)
index 0000000..ec6275f
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 7676
+
+    Fixed issue in Oracle dialect where using a column name that requires
+    quoting when written as a bound parameter, such as ``"_id"``, would not
+    correctly track a Python generated default value due to the bound-parameter
+    rewriting missing this value, causing an Oracle error to be raised.
index 539af2507ba726e2f9ef7bf3c4381dfcb9be1665..4861214c4afc58df5025532dd3d5eede4bab7fc0 100644 (file)
@@ -1389,7 +1389,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
 
     def _setup_ins_pk_from_empty(self):
         getter = self.compiled._inserted_primary_key_from_lastrowid_getter
-
         return [getter(None, param) for param in self.compiled_parameters]
 
     def _setup_ins_pk_from_implicit_returning(self, result, rows):
@@ -1664,7 +1663,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             return self._exec_default(column, column.onupdate, column.type)
 
     def _process_executemany_defaults(self):
-        key_getter = self.compiled._key_getters_for_crud_column[2]
+        key_getter = self.compiled._within_exec_param_key_getter
 
         scalar_defaults = {}
 
@@ -1702,7 +1701,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         del self.current_parameters
 
     def _process_executesingle_defaults(self):
-        key_getter = self.compiled._key_getters_for_crud_column[2]
+        key_getter = self.compiled._within_exec_param_key_getter
         self.current_parameters = (
             compiled_parameters
         ) = self.compiled_parameters[0]
index 8a3f264255f26bbb06f232948db2cb8b69343690..9cf4d8397414f135f5e5f4f5363fcb1d8311cc59 100644 (file)
@@ -1254,16 +1254,29 @@ class SQLCompiler(Compiled):
             self._result_columns
         )
 
+    @util.memoized_property
+    def _within_exec_param_key_getter(self):
+        getter = self._key_getters_for_crud_column[2]
+        if self.escaped_bind_names:
+
+            def _get(obj):
+                key = getter(obj)
+                return self.escaped_bind_names.get(key, key)
+
+            return _get
+        else:
+            return getter
+
     @util.memoized_property
     @util.preload_module("sqlalchemy.engine.result")
     def _inserted_primary_key_from_lastrowid_getter(self):
         result = util.preloaded.engine_result
 
-        key_getter = self._key_getters_for_crud_column[2]
+        param_key_getter = self._within_exec_param_key_getter
         table = self.statement.table
 
         getters = [
-            (operator.methodcaller("get", key_getter(col), None), col)
+            (operator.methodcaller("get", param_key_getter(col), None), col)
             for col in table.primary_key
         ]
 
@@ -1279,6 +1292,12 @@ class SQLCompiler(Compiled):
         row_fn = result.result_tuple([col.key for col in table.primary_key])
 
         def get(lastrowid, parameters):
+            """given cursor.lastrowid value and the parameters used for INSERT,
+            return a "row" that represents the primary key, either by
+            using the "lastrowid" or by extracting values from the parameters
+            that were sent along with the INSERT.
+
+            """
             if proc is not None:
                 lastrowid = proc(lastrowid)
 
@@ -1297,7 +1316,7 @@ class SQLCompiler(Compiled):
     def _inserted_primary_key_from_returning_getter(self):
         result = util.preloaded.engine_result
 
-        key_getter = self._key_getters_for_crud_column[2]
+        param_key_getter = self._within_exec_param_key_getter
         table = self.statement.table
 
         ret = {col: idx for idx, col in enumerate(self.returning)}
@@ -1305,7 +1324,10 @@ class SQLCompiler(Compiled):
         getters = [
             (operator.itemgetter(ret[col]), True)
             if col in ret
-            else (operator.methodcaller("get", key_getter(col), None), False)
+            else (
+                operator.methodcaller("get", param_key_getter(col), None),
+                False,
+            )
             for col in table.primary_key
         ]
 
index c06baace03b6421f6550949aecc959ac51b64e09..5383ffc0c8f83baf1a15326187b3aa19ff0d3824 100644 (file)
@@ -490,6 +490,35 @@ class QuotedBindRoundTripTest(fixtures.TestBase):
             dict(uid=[1, 2, 3]),
         )
 
+    @testing.combinations(True, False, argnames="executemany")
+    def test_python_side_default(self, metadata, connection, executemany):
+        """test #7676"""
+
+        ids = ["a", "b", "c"]
+
+        def gen_id():
+            return ids.pop(0)
+
+        t = Table(
+            "has_id",
+            metadata,
+            Column("_id", String(50), default=gen_id, primary_key=True),
+            Column("_data", Integer),
+        )
+        metadata.create_all(connection)
+
+        if executemany:
+            result = connection.execute(
+                t.insert(), [{"_data": 27}, {"_data": 28}, {"_data": 29}]
+            )
+            eq_(
+                connection.execute(t.select().order_by(t.c._id)).all(),
+                [("a", 27), ("b", 28), ("c", 29)],
+            )
+        else:
+            result = connection.execute(t.insert(), {"_data": 27})
+            eq_(result.inserted_primary_key, ("a",))
+
 
 class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL):
     def _dialect(self, server_version, **kw):