]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement DDL visitor for PG ENUM with schema translate support
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Aug 2020 02:53:09 +0000 (22:53 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Aug 2020 14:43:38 +0000 (10:43 -0400)
Fixed issue where the :class:`_postgresql.ENUM` type would not consult the
schema translate map when emitting a CREATE TYPE or DROP TYPE during the
test to see if the type exists or not.  Additionally, repaired an issue
where if the same enum were encountered multiple times in a single DDL
sequence, the "check" query would run repeatedly rather than relying upon a
cached value.

Fixes: #5520
Change-Id: I79f46e29ac0168e873ff178c242f8d78f6679aeb

doc/build/changelog/unreleased_13/5520.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/testing/assertsql.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/5520.rst b/doc/build/changelog/unreleased_13/5520.rst
new file mode 100644 (file)
index 0000000..5dd7477
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 5520
+
+    Fixed issue where the :class:`_postgresql.ENUM` type would not consult the
+    schema translate map when emitting a CREATE TYPE or DROP TYPE during the
+    test to see if the type exists or not.  Additionally, repaired an issue
+    where if the same enum were encountered multiple times in a single DDL
+    sequence, the "check" query would run repeatedly rather than relying upon a
+    cached value.
+
index 7717a2526bf0b428d383c3b97f01b0b027731966..543746687e320dafdb424b221907c36f9ce923aa 100644 (file)
@@ -1011,6 +1011,7 @@ from ...sql import expression
 from ...sql import roles
 from ...sql import sqltypes
 from ...sql import util as sql_util
+from ...sql.ddl import DDLBase
 from ...types import BIGINT
 from ...types import BOOLEAN
 from ...types import CHAR
@@ -1499,10 +1500,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         if not bind.dialect.supports_native_enum:
             return
 
-        if not checkfirst or not bind.dialect.has_type(
-            bind, self.name, schema=self.schema
-        ):
-            bind.execute(CreateEnumType(self))
+        bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
 
     def drop(self, bind=None, checkfirst=True):
         """Emit ``DROP TYPE`` for this
@@ -1522,10 +1520,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         if not bind.dialect.supports_native_enum:
             return
 
-        if not checkfirst or bind.dialect.has_type(
-            bind, self.name, schema=self.schema
-        ):
-            bind.execute(DropEnumType(self))
+        bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+    class EnumGenerator(DDLBase):
+        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+            super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+            self.checkfirst = checkfirst
+
+        def _can_create_enum(self, enum):
+            if not self.checkfirst:
+                return True
+
+            effective_schema = self.connection.schema_for_object(enum)
+
+            return not self.connection.dialect.has_type(
+                self.connection, enum.name, schema=effective_schema
+            )
+
+        def visit_enum(self, enum):
+            if not self._can_create_enum(enum):
+                return
+
+            self.connection.execute(CreateEnumType(enum))
+
+    class EnumDropper(DDLBase):
+        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+            super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+            self.checkfirst = checkfirst
+
+        def _can_drop_enum(self, enum):
+            if not self.checkfirst:
+                return True
+
+            effective_schema = self.connection.schema_for_object(enum)
+
+            return self.connection.dialect.has_type(
+                self.connection, enum.name, schema=effective_schema
+            )
+
+        def visit_enum(self, enum):
+            if not self._can_drop_enum(enum):
+                return
+
+            self.connection.execute(DropEnumType(enum))
 
     def _check_for_name_in_memos(self, checkfirst, kw):
         """Look in the 'ddl runner' for 'memos', then
@@ -1551,14 +1588,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             return False
 
     def _on_table_create(self, target, bind, checkfirst=False, **kw):
+
         if (
             checkfirst
             or (
                 not self.metadata
                 and not kw.get("_is_metadata_operation", False)
             )
-            and not self._check_for_name_in_memos(checkfirst, kw)
-        ):
+        ) and not self._check_for_name_in_memos(checkfirst, kw):
             self.create(bind=bind, checkfirst=checkfirst)
 
     def _on_table_drop(self, target, bind, checkfirst=False, **kw):
index caf61a80616bfb606180263da4fac6df078c0135..73b062b96ea6dfb036d28506f5dcd7908244d0c3 100644 (file)
@@ -195,12 +195,12 @@ class CompiledSQL(SQLMatchRule):
 
 
 class RegexSQL(CompiledSQL):
-    def __init__(self, regex, params=None):
+    def __init__(self, regex, params=None, dialect="default"):
         SQLMatchRule.__init__(self)
         self.regex = re.compile(regex)
         self.orig_regex = regex
         self.params = params
-        self.dialect = "default"
+        self.dialect = dialect
 
     def _failure_message(self, expected_params):
         return (
index 503477833d2b9704daf819e3172d63d0e9bd59fe..bfefc039aa90428e3f58fb70bf7eb05b0a886ad0 100644 (file)
@@ -58,6 +58,7 @@ from sqlalchemy.testing.assertions import AssertsExecutionResults
 from sqlalchemy.testing.assertions import ComparesTables
 from sqlalchemy.testing.assertions import eq_
 from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.assertsql import RegexSQL
 from sqlalchemy.testing.suite import test_types as suite
 from sqlalchemy.testing.util import round_decimal
 
@@ -371,6 +372,42 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         t1.drop(testing.db)
         e1.drop(bind=testing.db)
 
+    @testing.provide_metadata
+    def test_dont_keep_checking(self, connection):
+        metadata = self.metadata
+
+        e1 = postgresql.ENUM("one", "two", "three", name="myenum")
+
+        Table("t", metadata, Column("a", e1), Column("b", e1), Column("c", e1))
+
+        with self.sql_execution_asserter(connection) as asserter:
+            metadata.create_all(connection)
+
+        asserter.assert_(
+            # check for table
+            RegexSQL(
+                "select relname from pg_class c join pg_namespace.*",
+                dialect="postgresql",
+            ),
+            # check for enum, just once
+            RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"),
+            RegexSQL("CREATE TYPE myenum AS ENUM .*", dialect="postgresql"),
+            RegexSQL(r"CREATE TABLE t .*", dialect="postgresql"),
+        )
+
+        with self.sql_execution_asserter(connection) as asserter:
+            metadata.drop_all(connection)
+
+        asserter.assert_(
+            RegexSQL(
+                "select relname from pg_class c join pg_namespace.*",
+                dialect="postgresql",
+            ),
+            RegexSQL("DROP TABLE t", dialect="postgresql"),
+            RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"),
+            RegexSQL("DROP TYPE myenum", dialect="postgresql"),
+        )
+
     @testing.provide_metadata
     def test_generate_multiple(self):
         """Test that the same enum twice only generates once
@@ -482,6 +519,41 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             e["name"] for e in inspect(testing.db).get_enums()
         ]
 
+    def test_create_drop_schema_translate_map(self, connection):
+
+        conn = connection.execution_options(
+            schema_translate_map={None: testing.config.test_schema}
+        )
+
+        e1 = Enum("one", "two", "three", name="myenum")
+
+        assert "myenum" not in [
+            e["name"]
+            for e in inspect(connection).get_enums(testing.config.test_schema)
+        ]
+
+        e1.create(conn, checkfirst=True)
+        e1.create(conn, checkfirst=True)
+
+        assert "myenum" in [
+            e["name"]
+            for e in inspect(connection).get_enums(testing.config.test_schema)
+        ]
+
+        s1 = conn.begin_nested()
+        assert_raises(exc.ProgrammingError, e1.create, conn, checkfirst=False)
+        s1.rollback()
+
+        e1.drop(conn, checkfirst=True)
+        e1.drop(conn, checkfirst=True)
+
+        assert "myenum" not in [
+            e["name"]
+            for e in inspect(connection).get_enums(testing.config.test_schema)
+        ]
+
+        assert_raises(exc.ProgrammingError, e1.drop, conn, checkfirst=False)
+
     @testing.provide_metadata
     def test_remain_on_table_metadata_wide(self):
         metadata = self.metadata