From: Mike Bayer Date: Wed, 19 Aug 2020 02:53:09 +0000 (-0400) Subject: Implement DDL visitor for PG ENUM with schema translate support X-Git-Tag: rel_1_3_20~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0a234fee4bed27036f7d75b659b6a2ed0512efe6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement DDL visitor for PG ENUM with schema translate support Fixed issue where the :class:`_postgresql.ENUM` type would not consult the schema translate map when emitting a CREATE TYPE or DROP TYPE during the test to see if the type exists or not. Additionally, repaired an issue where if the same enum were encountered multiple times in a single DDL sequence, the "check" query would run repeatedly rather than relying upon a cached value. Fixes: #5520 Change-Id: I79f46e29ac0168e873ff178c242f8d78f6679aeb (cherry picked from commit c290a40a543f8355ee712e2e565698b6ebdb162f) --- diff --git a/doc/build/changelog/unreleased_13/5520.rst b/doc/build/changelog/unreleased_13/5520.rst new file mode 100644 index 0000000000..5dd7477fe2 --- /dev/null +++ b/doc/build/changelog/unreleased_13/5520.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, postgresql + :tickets: 5520 + + Fixed issue where the :class:`_postgresql.ENUM` type would not consult the + schema translate map when emitting a CREATE TYPE or DROP TYPE during the + test to see if the type exists or not. Additionally, repaired an issue + where if the same enum were encountered multiple times in a single DDL + sequence, the "check" query would run repeatedly rather than relying upon a + cached value. + diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 4ad6cb11cf..86a857f138 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1008,6 +1008,7 @@ from ...sql import elements from ...sql import expression from ...sql import sqltypes from ...sql import util as sql_util +from ...sql.ddl import DDLBase from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1502,10 +1503,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or not bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(CreateEnumType(self)) + bind._run_visitor(self.EnumGenerator, self, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Emit ``DROP TYPE`` for this @@ -1525,10 +1523,49 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or bind.dialect.has_type( - bind, self.name, schema=self.schema - ): - bind.execute(DropEnumType(self)) + bind._run_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then @@ -1554,14 +1591,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( checkfirst or ( not self.metadata and not kw.get("_is_metadata_operation", False) ) - and not self._check_for_name_in_memos(checkfirst, kw) - ): + ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) def _on_table_drop(self, target, bind, checkfirst=False, **kw): diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index 560261dfc8..f1353b0e52 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -194,12 +194,12 @@ class CompiledSQL(SQLMatchRule): class RegexSQL(CompiledSQL): - def __init__(self, regex, params=None): + def __init__(self, regex, params=None, dialect="default"): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params - self.dialect = "default" + self.dialect = dialect def _failure_message(self, expected_params): return ( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 3255ca8251..9735540ee4 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -59,6 +59,7 @@ from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import ComparesTables from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.suite import test_types as suite from sqlalchemy.testing.util import round_decimal @@ -377,6 +378,40 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): t1.drop(testing.db) e1.drop(bind=testing.db) + @testing.provide_metadata + def test_dont_keep_checking(self, connection): + metadata = self.metadata + + e1 = postgresql.ENUM("one", "two", "three", name="myenum") + + Table("t", metadata, Column("a", e1), Column("b", e1), Column("c", e1)) + + with self.sql_execution_asserter(connection) as asserter: + metadata.create_all(connection) + + asserter.assert_( + RegexSQL( + "select relname from pg_class c join pg_namespace.*", + dialect="postgresql", + ), + RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL("CREATE TYPE myenum AS ENUM .*", dialect="postgresql"), + RegexSQL(r"CREATE TABLE t .*", dialect="postgresql"), + ) + + with self.sql_execution_asserter(connection) as asserter: + metadata.drop_all(connection) + + asserter.assert_( + RegexSQL( + "select relname from pg_class c join pg_namespace.*", + dialect="postgresql", + ), + RegexSQL("DROP TABLE t", dialect="postgresql"), + RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL("DROP TYPE myenum", dialect="postgresql"), + ) + @testing.provide_metadata def test_generate_multiple(self): """Test that the same enum twice only generates once @@ -488,6 +523,41 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e["name"] for e in inspect(testing.db).get_enums() ] + def test_create_drop_schema_translate_map(self, connection): + + conn = connection.execution_options( + schema_translate_map={None: testing.config.test_schema} + ) + + e1 = Enum("one", "two", "three", name="myenum") + + assert "myenum" not in [ + e["name"] + for e in inspect(connection).get_enums(testing.config.test_schema) + ] + + e1.create(conn, checkfirst=True) + e1.create(conn, checkfirst=True) + + assert "myenum" in [ + e["name"] + for e in inspect(connection).get_enums(testing.config.test_schema) + ] + + s1 = conn.begin_nested() + assert_raises(exc.ProgrammingError, e1.create, conn, checkfirst=False) + s1.rollback() + + e1.drop(conn, checkfirst=True) + e1.drop(conn, checkfirst=True) + + assert "myenum" not in [ + e["name"] + for e in inspect(connection).get_enums(testing.config.test_schema) + ] + + assert_raises(exc.ProgrammingError, e1.drop, conn, checkfirst=False) + @testing.provide_metadata def test_remain_on_table_metadata_wide(self): metadata = self.metadata