]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add identifier_preparer per-execution context for schema translates
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Feb 2021 16:58:15 +0000 (11:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Feb 2021 18:14:52 +0000 (13:14 -0500)
Fixed bug where the "schema_translate_map" feature failed to be taken into
account for the use case of direct execution of
:class:`_schema.DefaultGenerator` objects such as sequences, which included
the case where they were "pre-executed" in order to generate primary key
values when implicit_returning was disabled.

Fixes: #5929
Change-Id: I3fed1d0af28be5ce9c9bb572524dcc8411633f60

doc/build/changelog/unreleased_13/5929.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/suite/test_sequence.py

diff --git a/doc/build/changelog/unreleased_13/5929.rst b/doc/build/changelog/unreleased_13/5929.rst
new file mode 100644 (file)
index 0000000..9b9b621
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 5929
+
+    Fixed bug where the "schema_translate_map" feature failed to be taken into
+    account for the use case of direct execution of
+    :class:`_schema.DefaultGenerator` objects such as sequences, which included
+    the case where they were "pre-executed" in order to generate primary key
+    values when implicit_returning was disabled.
index 82861e30fd76f94a7ceaee86a5559cbe9bcda318..7fc914f1b091c6769a4ae392666797ee584b17fd 100644 (file)
@@ -614,7 +614,7 @@ class FBExecutionContext(default.DefaultExecutionContext):
 
         return self._execute_scalar(
             "SELECT gen_id(%s, 1) FROM rdb$database"
-            % self.dialect.identifier_preparer.format_sequence(seq),
+            % self.identifier_preparer.format_sequence(seq),
             type_,
         )
 
index 9d0e5d3222a0b7596c8435b648492ea66cd45524..674d5417949f39fcbe00797afb246d6f2beacc2b 100644 (file)
@@ -17,7 +17,7 @@ External Dialects
 In addition to the above DBAPI layers with native SQLAlchemy support, there
 are third-party dialects for other DBAPI layers that are compatible
 with SQL Server. See the "External Dialects" list on the
-:ref:`dialect_toplevel` page. 
+:ref:`dialect_toplevel` page.
 
 .. _mssql_identity:
 
@@ -1560,7 +1560,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                     self.cursor,
                     self._opt_encode(
                         "SET IDENTITY_INSERT %s ON"
-                        % self.dialect.identifier_preparer.format_table(tbl)
+                        % self.identifier_preparer.format_table(tbl)
                     ),
                     (),
                     self,
@@ -1606,7 +1606,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 self.cursor,
                 self._opt_encode(
                     "SET IDENTITY_INSERT %s OFF"
-                    % self.dialect.identifier_preparer.format_table(
+                    % self.identifier_preparer.format_table(
                         self.compiled.statement.table
                     )
                 ),
@@ -1630,7 +1630,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 self.cursor.execute(
                     self._opt_encode(
                         "SET IDENTITY_INSERT %s OFF"
-                        % self.dialect.identifier_preparer.format_table(
+                        % self.identifier_preparer.format_table(
                             self.compiled.statement.table
                         )
                     )
@@ -1650,7 +1650,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
         return self._execute_scalar(
             (
                 "SELECT NEXT VALUE FOR %s"
-                % self.dialect.identifier_preparer.format_sequence(seq)
+                % self.identifier_preparer.format_sequence(seq)
             ),
             type_,
         )
index 063f750faf988aa7c759e020054cf98ccd3bf5ef..c80ff3f19a44c9d48788b97fa4822de5cf2fea60 100644 (file)
@@ -988,9 +988,13 @@ from ...types import BLOB
 from ...types import BOOLEAN
 from ...types import DATE
 from ...types import VARBINARY
+from ...util import compat
 from ...util import topological
 
 
+if compat.TYPE_CHECKING:
+    from typing import Any
+
 RESERVED_WORDS = set(
     [
         "accessible",
@@ -1394,7 +1398,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
         return self._execute_scalar(
             (
                 "select nextval(%s)"
-                % self.dialect.identifier_preparer.format_sequence(seq)
+                % self.identifier_preparer.format_sequence(seq)
             ),
             type_,
         )
@@ -3263,6 +3267,7 @@ class MySQLDialect(default.DefaultDialect):
         return parser.parse(sql, charset)
 
     def _detect_charset(self, connection):
+        # type: (Any) -> str
         raise NotImplementedError()
 
     def _detect_casing(self, connection):
index 9344abeeefe475060d2b8cd4751533c70eecb848..f9805abeb2bd70fbcd3f9d36fb1f47e6501a8bb4 100644 (file)
@@ -1012,9 +1012,7 @@ class OracleCompiler(compiler.SQLCompiler):
         return self.process(vc.column, **kw) + "(+)"
 
     def visit_sequence(self, seq, **kw):
-        return (
-            self.dialect.identifier_preparer.format_sequence(seq) + ".nextval"
-        )
+        return self.preparer.format_sequence(seq) + ".nextval"
 
     def get_render_as_alias_suffix(self, alias_name_text):
         """Oracle doesn't like ``FROM table AS alias``"""
@@ -1441,7 +1439,7 @@ class OracleExecutionContext(default.DefaultExecutionContext):
     def fire_sequence(self, seq, type_):
         return self._execute_scalar(
             "SELECT "
-            + self.dialect.identifier_preparer.format_sequence(seq)
+            + self.identifier_preparer.format_sequence(seq)
             + ".nextval FROM DUAL",
             type_,
         )
index f067e6537e3efe3966948dc408217a33938301c3..7e821acde308ea094871c6d85f04bf20aa3cf2b4 100644 (file)
@@ -2936,7 +2936,7 @@ class PGExecutionContext(default.DefaultExecutionContext):
         return self._execute_scalar(
             (
                 "select nextval('%s')"
-                % self.dialect.identifier_preparer.format_sequence(seq)
+                % self.identifier_preparer.format_sequence(seq)
             ),
             type_,
         )
index 7fddf2814acac8353162f1ffc3feb57f3b67e75b..0c48fcba3fa89e5615aab1ee04e52233553662e4 100644 (file)
@@ -1140,6 +1140,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         else:
             return "unknown"
 
+    @util.memoized_property
+    def identifier_preparer(self):
+        if self.compiled:
+            return self.compiled.preparer
+        elif "schema_translate_map" in self.execution_options:
+            return self.dialect.identifier_preparer._with_schema_translate(
+                self.execution_options["schema_translate_map"]
+            )
+        else:
+            return self.dialect.identifier_preparer
+
     @util.memoized_property
     def engine(self):
         return self.root_connection.engine
@@ -1197,6 +1208,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         ):
             stmt = self.dialect._encoder(stmt)[0]
 
+        if "schema_translate_map" in self.execution_options:
+            schema_translate_map = self.execution_options.get(
+                "schema_translate_map", {}
+            )
+
+            rst = self.identifier_preparer._render_schema_translates
+            stmt = rst(stmt, schema_translate_map)
+
         if not parameters:
             if self.dialect.positional:
                 parameters = self.dialect.execute_sequence_format()
index 2fade1c32d3524748a8d19ad235853d4882b6c0f..a976abee02945306a7de980aeb5dc5b191efe8da 100644 (file)
@@ -262,7 +262,6 @@ def drop_all_schema_objects(cfg, eng):
                     )
 
     util.drop_all_tables(eng, inspector)
-
     if config.requirements.schemas.enabled_for_config(cfg):
         util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
         util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
@@ -273,6 +272,16 @@ def drop_all_schema_objects(cfg, eng):
         with eng.begin() as conn:
             for seq in inspector.get_sequence_names():
                 conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+            if config.requirements.schemas.enabled_for_config(cfg):
+                for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+                    for seq in inspector.get_sequence_names(
+                        schema=schema_name
+                    ):
+                        conn.execute(
+                            ddl.DropSequence(
+                                schema.Sequence(seq, schema=schema_name)
+                            )
+                        )
 
 
 @register.init
index 7445ade00c85c94deedb9306bc0775db26e10ea5..d6747d2538651556cceca5089a2f96dd80cfd13a 100644 (file)
@@ -45,6 +45,34 @@ class SequenceTest(fixtures.TablesTest):
             Column("data", String(50)),
         )
 
+        Table(
+            "seq_no_returning",
+            metadata,
+            Column(
+                "id",
+                Integer,
+                Sequence("noret_id_seq"),
+                primary_key=True,
+            ),
+            Column("data", String(50)),
+            implicit_returning=False,
+        )
+
+        if testing.requires.schemas.enabled:
+            Table(
+                "seq_no_returning_sch",
+                metadata,
+                Column(
+                    "id",
+                    Integer,
+                    Sequence("noret_sch_id_seq", schema=config.test_schema),
+                    primary_key=True,
+                ),
+                Column("data", String(50)),
+                implicit_returning=False,
+                schema=config.test_schema,
+            )
+
     def test_insert_roundtrip(self, connection):
         connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
         self._assert_round_trip(self.tables.seq_pk, connection)
@@ -72,6 +100,46 @@ class SequenceTest(fixtures.TablesTest):
         row = conn.execute(table.select()).first()
         eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
 
+    def test_insert_roundtrip_no_implicit_returning(self, connection):
+        connection.execute(
+            self.tables.seq_no_returning.insert(), dict(data="some data")
+        )
+        self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+    @testing.combinations((True,), (False,), argnames="implicit_returning")
+    @testing.requires.schemas
+    def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+        seq_no_returning = Table(
+            "seq_no_returning_sch",
+            MetaData(),
+            Column(
+                "id",
+                Integer,
+                Sequence("noret_sch_id_seq", schema="alt_schema"),
+                primary_key=True,
+            ),
+            Column("data", String(50)),
+            implicit_returning=implicit_returning,
+            schema="alt_schema",
+        )
+
+        connection = connection.execution_options(
+            schema_translate_map={"alt_schema": config.test_schema}
+        )
+        connection.execute(seq_no_returning.insert(), dict(data="some data"))
+        self._assert_round_trip(seq_no_returning, connection)
+
+    @testing.requires.schemas
+    def test_nextval_direct_schema_translate(self, connection):
+        seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+        connection = connection.execution_options(
+            schema_translate_map={"alt_schema": config.test_schema}
+        )
+
+        r = connection.execute(seq)
+        eq_(r, testing.db.dialect.default_sequence_base)
+
 
 class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
     __requires__ = ("sequences",)