From: Mike Bayer Date: Tue, 8 Feb 2022 15:12:33 +0000 (-0500) Subject: Accommodate escaped_bind_names for defaults/insert params X-Git-Tag: rel_2_0_0b1~494 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c2aa6374f3965c28aa2d56cbddf6dab3e1de18a2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Accommodate escaped_bind_names for defaults/insert params Fixed issue in Oracle dialect where using a column name that requires quoting when written as a bound parameter, such as ``"_id"``, would not correctly track a Python generated default value due to the bound-parameter rewriting missing this value, causing an Oracle error to be raised. Fixes: #7676 Change-Id: I5a54426d24f2f9b336e3597d5595fb3e031aad97 --- diff --git a/doc/build/changelog/unreleased_14/7676.rst b/doc/build/changelog/unreleased_14/7676.rst new file mode 100644 index 0000000000..ec6275fb40 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7676.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, oracle + :tickets: 7676 + + Fixed issue in Oracle dialect where using a column name that requires + quoting when written as a bound parameter, such as ``"_id"``, would not + correctly track a Python generated default value due to the bound-parameter + rewriting missing this value, causing an Oracle error to be raised. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 539af2507b..4861214c4a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1389,7 +1389,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext): def _setup_ins_pk_from_empty(self): getter = self.compiled._inserted_primary_key_from_lastrowid_getter - return [getter(None, param) for param in self.compiled_parameters] def _setup_ins_pk_from_implicit_returning(self, result, rows): @@ -1664,7 +1663,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): return self._exec_default(column, column.onupdate, column.type) def _process_executemany_defaults(self): - key_getter = self.compiled._key_getters_for_crud_column[2] + key_getter = self.compiled._within_exec_param_key_getter scalar_defaults = {} @@ -1702,7 +1701,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): del self.current_parameters def _process_executesingle_defaults(self): - key_getter = self.compiled._key_getters_for_crud_column[2] + key_getter = self.compiled._within_exec_param_key_getter self.current_parameters = ( compiled_parameters ) = self.compiled_parameters[0] diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8a3f264255..9cf4d83974 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1254,16 +1254,29 @@ class SQLCompiler(Compiled): self._result_columns ) + @util.memoized_property + def _within_exec_param_key_getter(self): + getter = self._key_getters_for_crud_column[2] + if self.escaped_bind_names: + + def _get(obj): + key = getter(obj) + return self.escaped_bind_names.get(key, key) + + return _get + else: + return getter + @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_lastrowid_getter(self): result = util.preloaded.engine_result - key_getter = self._key_getters_for_crud_column[2] + param_key_getter = self._within_exec_param_key_getter table = self.statement.table getters = [ - (operator.methodcaller("get", key_getter(col), None), col) + (operator.methodcaller("get", param_key_getter(col), None), col) for col in table.primary_key ] @@ -1279,6 +1292,12 @@ class SQLCompiler(Compiled): row_fn = result.result_tuple([col.key for col in table.primary_key]) def get(lastrowid, parameters): + """given cursor.lastrowid value and the parameters used for INSERT, + return a "row" that represents the primary key, either by + using the "lastrowid" or by extracting values from the parameters + that were sent along with the INSERT. + + """ if proc is not None: lastrowid = proc(lastrowid) @@ -1297,7 +1316,7 @@ class SQLCompiler(Compiled): def _inserted_primary_key_from_returning_getter(self): result = util.preloaded.engine_result - key_getter = self._key_getters_for_crud_column[2] + param_key_getter = self._within_exec_param_key_getter table = self.statement.table ret = {col: idx for idx, col in enumerate(self.returning)} @@ -1305,7 +1324,10 @@ class SQLCompiler(Compiled): getters = [ (operator.itemgetter(ret[col]), True) if col in ret - else (operator.methodcaller("get", key_getter(col), None), False) + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) for col in table.primary_key ] diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index c06baace03..5383ffc0c8 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -490,6 +490,35 @@ class QuotedBindRoundTripTest(fixtures.TestBase): dict(uid=[1, 2, 3]), ) + @testing.combinations(True, False, argnames="executemany") + def test_python_side_default(self, metadata, connection, executemany): + """test #7676""" + + ids = ["a", "b", "c"] + + def gen_id(): + return ids.pop(0) + + t = Table( + "has_id", + metadata, + Column("_id", String(50), default=gen_id, primary_key=True), + Column("_data", Integer), + ) + metadata.create_all(connection) + + if executemany: + result = connection.execute( + t.insert(), [{"_data": 27}, {"_data": 28}, {"_data": 29}] + ) + eq_( + connection.execute(t.select().order_by(t.c._id)).all(), + [("a", 27), ("b", 28), ("c", 29)], + ) + else: + result = connection.execute(t.insert(), {"_data": 27}) + eq_(result.inserted_primary_key, ("a",)) + class CompatFlagsTest(fixtures.TestBase, AssertsCompiledSQL): def _dialect(self, server_version, **kw):