]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Pass DDLCompiler IdentifierPreparer to visit_ENUM
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Feb 2020 20:21:59 +0000 (15:21 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Feb 2020 20:36:16 +0000 (15:36 -0500)
Fixed issue where the "schema_translate_map" feature would not work with a
PostgreSQL native enumeration type (i.e. :class:`.Enum`,
:class:`.postgresql.ENUM`) in that while the "CREATE TYPE" statement would
be emitted with the correct schema, the schema would not be rendered in
the CREATE TABLE statement at the point at which the enumeration was
referenced.

Fixes: #5158
Change-Id: I41529785de2e736c70a142c2ae5705060bfed73e
(cherry picked from commit 89b8c343ed6247a562e0bcd53ef3fc180d0d4e46)

doc/build/changelog/unreleased_13/5158.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/5158.rst b/doc/build/changelog/unreleased_13/5158.rst
new file mode 100644 (file)
index 0000000..adab86d
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 5158
+
+    Fixed issue where the "schema_translate_map" feature would not work with a
+    PostgreSQL native enumeration type (i.e. :class:`.Enum`,
+    :class:`.postgresql.ENUM`) in that while the "CREATE TYPE" statement would
+    be emitted with the correct schema, the schema would not be rendered in
+    the CREATE TABLE statement at the point at which the enumeration was
+    referenced.
+
index 0bd2403ef76cea6359682533ec22c341875f11b2..11b93f391c76be3b847e27f9d471aec6c20c9c8e 100644 (file)
@@ -1890,7 +1890,9 @@ class PGDDLCompiler(compiler.DDLCompiler):
                 colspec += " SERIAL"
         else:
             colspec += " " + self.dialect.type_compiler.process(
-                column.type, type_expression=column
+                column.type,
+                type_expression=column,
+                identifier_preparer=self.preparer,
             )
             default = self.get_column_default_string(column)
             if default is not None:
@@ -2152,8 +2154,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
         else:
             return self.visit_ENUM(type_, **kw)
 
-    def visit_ENUM(self, type_, **kw):
-        return self.dialect.identifier_preparer.format_type(type_)
+    def visit_ENUM(self, type_, identifier_preparer=None, **kw):
+        if identifier_preparer is None:
+            identifier_preparer = self.dialect.identifier_preparer
+
+        return identifier_preparer.format_type(type_)
 
     def visit_TIMESTAMP(self, type_, **kw):
         return "TIMESTAMP%s %s" % (
index 4c4c43281e0f822fa22591d0a8d6cc3a28c04c50..aabbc3ac3b0aebbeac79b85e4b1201e1702c410f 100644 (file)
@@ -237,6 +237,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             schema_translate_map=schema_translate_map,
         )
 
+    def test_create_table_with_schema_type_schema_translate(self):
+        e1 = Enum("x", "y", "z", name="somename")
+        e2 = Enum("x", "y", "z", name="somename", schema="someschema")
+        schema_translate_map = {None: "foo", "someschema": "bar"}
+
+        table = Table(
+            "some_table", MetaData(), Column("q", e1), Column("p", e2)
+        )
+        from sqlalchemy.schema import CreateTable
+
+        self.assert_compile(
+            CreateTable(table),
+            "CREATE TABLE foo.some_table (q foo.somename, p bar.somename)",
+            schema_translate_map=schema_translate_map,
+        )
+
     def test_create_table_with_tablespace(self):
         m = MetaData()
         tbl = Table(
index 6bcfb17364c3a07675fb6a647482f34ec9a6d269..61fb9d67940e4e86e33acd3c46956535552a6040 100644 (file)
@@ -190,6 +190,58 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
                 [(1, "two"), (2, "three"), (3, "three")],
             )
 
+    @testing.combinations(None, "foo")
+    def test_create_table_schema_translate_map(self, symbol_name):
+        # note we can't use the fixture here because it will not drop
+        # from the correct schema
+        metadata = MetaData()
+
+        t1 = Table(
+            "table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "value",
+                Enum(
+                    "one",
+                    "two",
+                    "three",
+                    name="schema_enum",
+                    schema=symbol_name,
+                ),
+            ),
+            schema=symbol_name,
+        )
+        with testing.db.connect() as conn:
+            conn = conn.execution_options(
+                schema_translate_map={symbol_name: testing.config.test_schema}
+            )
+            t1.create(conn)
+            assert "schema_enum" in [
+                e["name"]
+                for e in inspect(conn).get_enums(
+                    schema=testing.config.test_schema
+                )
+            ]
+            t1.create(conn, checkfirst=True)
+
+            conn.execute(t1.insert(), value="two")
+            conn.execute(t1.insert(), value="three")
+            conn.execute(t1.insert(), value="three")
+            eq_(
+                conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+                [(1, "two"), (2, "three"), (3, "three")],
+            )
+
+            t1.drop(conn)
+            assert "schema_enum" not in [
+                e["name"]
+                for e in inspect(conn).get_enums(
+                    schema=testing.config.test_schema
+                )
+            ]
+            t1.drop(conn, checkfirst=True)
+
     def test_name_required(self):
         metadata = MetaData(testing.db)
         etype = Enum("four", "five", "six", metadata=metadata)