From: Mike Bayer Date: Mon, 8 Feb 2021 16:58:15 +0000 (-0500) Subject: Add identifier_preparer per-execution context for schema translates X-Git-Tag: rel_1_3_24~23 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=b348e82dcf5eca1fc8496c941dc1ac2ffd60eed9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add identifier_preparer per-execution context for schema translates Fixed bug where the "schema_translate_map" feature failed to be taken into account for the use case of direct execution of :class:`_schema.DefaultGenerator` objects such as sequences, which included the case where they were "pre-executed" in order to generate primary key values when implicit_returning was disabled. Fixes: #5929 Change-Id: I3fed1d0af28be5ce9c9bb572524dcc8411633f60 (cherry picked from commit 2385ebb19366efeb35415298166ac18668864c51) --- diff --git a/doc/build/changelog/unreleased_13/5929.rst b/doc/build/changelog/unreleased_13/5929.rst new file mode 100644 index 0000000000..9b9b6214c4 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5929.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, engine + :tickets: 5929 + + Fixed bug where the "schema_translate_map" feature failed to be taken into + account for the use case of direct execution of + :class:`_schema.DefaultGenerator` objects such as sequences, which included + the case where they were "pre-executed" in order to generate primary key + values when implicit_returning was disabled. diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 28fefa5b72..9138a81a96 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -612,7 +612,7 @@ class FBExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( "SELECT gen_id(%s, 1) FROM rdb$database" - % self.dialect.identifier_preparer.format_sequence(seq), + % self.identifier_preparer.format_sequence(seq), type_, ) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 22f3297308..debfa55b17 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1495,7 +1495,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor, self._opt_encode( "SET IDENTITY_INSERT %s ON" - % self.dialect.identifier_preparer.format_table(tbl) + % self.identifier_preparer.format_table(tbl) ), (), self, @@ -1531,7 +1531,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor, self._opt_encode( "SET IDENTITY_INSERT %s OFF" - % self.dialect.identifier_preparer.format_table( + % self.identifier_preparer.format_table( self.compiled.statement.table ) ), @@ -1548,7 +1548,7 @@ class MSExecutionContext(default.DefaultExecutionContext): self.cursor.execute( self._opt_encode( "SET IDENTITY_INSERT %s OFF" - % self.dialect.identifier_preparer.format_table( + % self.identifier_preparer.format_table( self.compiled.statement.table ) ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index c41d6acf7a..47e4dff944 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -901,9 +901,13 @@ from ...types import BLOB from ...types import BOOLEAN from ...types import DATE from ...types import VARBINARY +from ...util import compat from ...util import topological +if compat.TYPE_CHECKING: + from typing import Any + RESERVED_WORDS = set( [ "accessible", @@ -3053,6 +3057,7 @@ class MySQLDialect(default.DefaultDialect): return parser.parse(sql, charset) def _detect_charset(self, connection): + # type: (Any) -> str raise NotImplementedError() def _detect_casing(self, connection): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index c476554bd1..c621165720 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -958,9 +958,7 @@ class OracleCompiler(compiler.SQLCompiler): return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq, **kw): - return ( - self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" - ) + return self.preparer.format_sequence(seq) + ".nextval" def get_render_as_alias_suffix(self, alias_name_text): """Oracle doesn't like ``FROM table AS alias``""" @@ -1281,7 +1279,7 @@ class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( "SELECT " - + self.dialect.identifier_preparer.format_sequence(seq) + + self.identifier_preparer.format_sequence(seq) + ".nextval FROM DUAL", type_, ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 39e11aa616..7dec6d8182 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2487,7 +2487,7 @@ class PGExecutionContext(default.DefaultExecutionContext): return self._execute_scalar( ( "select nextval('%s')" - % self.dialect.identifier_preparer.format_sequence(seq) + % self.identifier_preparer.format_sequence(seq) ), type_, ) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1c0a87b4ac..59eac7e0d5 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1056,6 +1056,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.cursor = self.create_cursor() return self + @util.memoized_property + def identifier_preparer(self): + if self.compiled: + return self.compiled.preparer + elif "schema_translate_map" in self.execution_options: + return self.dialect.identifier_preparer._with_schema_translate( + self.execution_options["schema_translate_map"] + ) + else: + return self.dialect.identifier_preparer + @util.memoized_property def engine(self): return self.root_connection.engine diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index 22ae7d43cf..6c80f94879 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -39,6 +39,34 @@ class SequenceTest(fixtures.TablesTest): Column("data", String(50)), ) + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + Sequence("noret_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema=config.test_schema), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + def test_insert_roundtrip(self): config.db.execute(self.tables.seq_pk.insert(), data="some data") self._assert_round_trip(self.tables.seq_pk, config.db) @@ -62,6 +90,46 @@ class SequenceTest(fixtures.TablesTest): row = conn.execute(table.select()).first() eq_(row, (1, "some data")) + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema="alt_schema"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = Sequence("noret_sch_id_seq", schema="alt_schema") + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.execute(seq) + eq_(r, testing.db.dialect.default_sequence_base) + class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): __requires__ = ("sequences",) diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index aed39366d5..a1d55376dd 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -186,6 +186,8 @@ if py3k: # as the __traceback__ object creates a cycle del exception, replace_context, from_, with_traceback + from typing import TYPE_CHECKING + def u(s): return s @@ -299,6 +301,7 @@ else: " raise exception\n" ) + TYPE_CHECKING = False if py35: