From fbb623d9ceac8f5499d0d40f8db6f5606b26b068 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 21 Sep 2025 12:13:21 -0400 Subject: [PATCH] add create_type to Enum Added new parameter :paramref:`.Enum.create_type` to the Core :class:`.Enum` class. This parameter is automatically passed to the corresponding :class:`_postgresql.ENUM` native type during DDL operations, allowing control over whether the PostgreSQL ENUM type is implicitly created or dropped within DDL operations that are otherwise targeting tables only. This provides control over the :paramref:`_postgresql.ENUM.create_type` behavior without requiring explicit creation of a :class:`_postgresql.ENUM` object. Fixes: #10604 Change-Id: I450003ec2a2a65c119fe7ca8ff201392ce6b91e1 --- doc/build/changelog/unreleased_21/10604.rst | 12 ++ .../dialects/postgresql/named_types.py | 8 +- lib/sqlalchemy/sql/sqltypes.py | 19 +- test/dialect/postgresql/test_types.py | 170 +++++++++++------- 4 files changed, 138 insertions(+), 71 deletions(-) create mode 100644 doc/build/changelog/unreleased_21/10604.rst diff --git a/doc/build/changelog/unreleased_21/10604.rst b/doc/build/changelog/unreleased_21/10604.rst new file mode 100644 index 0000000000..863affd7da --- /dev/null +++ b/doc/build/changelog/unreleased_21/10604.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 10604 + + Added new parameter :paramref:`.Enum.create_type` to the Core + :class:`.Enum` class. This parameter is automatically passed to the + corresponding :class:`_postgresql.ENUM` native type during DDL operations, + allowing control over whether the PostgreSQL ENUM type is implicitly + created or dropped within DDL operations that are otherwise targeting + tables only. This provides control over the + :paramref:`_postgresql.ENUM.create_type` behavior without requiring + explicit creation of a :class:`_postgresql.ENUM` object. diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index 5807041ead..c47b381856 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -308,9 +308,9 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): "always refers to ENUM. Use sqlalchemy.types.Enum for " "non-native enum." ) - self.create_type = create_type if name is not _NoArg.NO_ARG: kw["name"] = name + kw["create_type"] = create_type super().__init__(*enums, **kw) def coerce_compared_value(self, op, value): @@ -335,6 +335,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): """ kw.setdefault("validate_strings", impl.validate_strings) kw.setdefault("name", impl.name) + kw.setdefault("create_type", impl.create_type) kw.setdefault("schema", impl.schema) kw.setdefault("inherit_schema", impl.inherit_schema) kw.setdefault("metadata", impl.metadata) @@ -342,8 +343,6 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): kw.setdefault("values_callable", impl.values_callable) kw.setdefault("omit_aliases", impl._omit_aliases) kw.setdefault("_adapted_from", impl) - if type_api._is_native_for_emulated(impl.__class__): - kw.setdefault("create_type", impl.create_type) return cls(**kw) @@ -496,8 +495,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType): if check is not None: check = coercions.expect(roles.DDLExpressionRole, check) self.check = check - self.create_type = create_type - super().__init__(name=name, **kw) + super().__init__(name=name, create_type=create_type, **kw) @classmethod def __test_init__(cls): diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 449eadda45..7154dba97c 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1073,6 +1073,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): metadata: Optional[MetaData] = None, inherit_schema: Union[bool, _NoArg] = NO_ARG, quote: Optional[bool] = None, + create_type: bool = True, _create_events: bool = True, _adapted_from: Optional[SchemaType] = None, ): @@ -1082,7 +1083,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): self.name = None self.schema = schema self.metadata = metadata - + self.create_type = create_type if inherit_schema is True and schema is not None: raise exc.ArgumentError( "Ambiguously setting inherit_schema=True while " @@ -1442,6 +1443,21 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): class was used, its name (converted to lower case) is used by default. + :param create_type: Defaults to True. This parameter only applies + to backends such as PostgreSQL which use explicitly named types + that are created and dropped separately from the table(s) they + are used by. Indicates that ``CREATE TYPE`` should be emitted, + after optionally checking for the presence of the type, when the + parent table is being created; and additionally that ``DROP TYPE`` is + called when the table is dropped. This parameter is equivalent to the + parameter of the same name on the PostgreSQL-specific + :class:`_postgresql.ENUM` datatype. + + .. versionadded:: 2.1 - The dialect agnostic :class:`.Enum` class + now includes the same :paramref:`.Enum.create_type` parameter that + was already available on the PostgreSQL native + :class:`_postgresql.ENUM` implementation. + :param native_enum: Use the database's native ENUM type when available. Defaults to True. When False, uses VARCHAR + check constraint for all backends. When False, the VARCHAR length can be @@ -1586,6 +1602,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): SchemaType.__init__( self, name=kw.pop("name", None), + create_type=kw.pop("create_type", True), inherit_schema=kw.pop("inherit_schema", NO_ARG), schema=kw.pop("schema", None), metadata=kw.pop("metadata", None), diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 734e909e09..e57991a30a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -255,6 +255,12 @@ class NamedTypeTest( __only_on__ = "postgresql > 8.3" + def _enum_exists(self, name, connection): + return name in {d["name"] for d in inspect(connection).get_enums()} + + def _domain_exists(self, name, connection): + return name in {d["name"] for d in inspect(connection).get_domains()} + def test_native_enum_warnings(self): """test #6106""" @@ -800,28 +806,109 @@ class NamedTypeTest( connection.execute(t1.insert(), {"bar": "Ü"}) eq_(connection.scalar(select(t1.c.bar)), "Ü") - @testing.combinations( - (ENUM("one", "two", "three", name="mytype", create_type=False),), - ( - DOMAIN( + @testing.variation("datatype", ["enum", "native_enum", "domain"]) + @testing.variation("createtype", [True, False]) + def test_create_type_parameter( + self, metadata, connection, datatype, createtype + ): + + if datatype.enum: + dt = Enum( + "one", + "two", + "three", + name="mytype", + create_type=bool(createtype), + ) + elif datatype.native_enum: + dt = ENUM( + "one", + "two", + "three", + name="mytype", + create_type=bool(createtype), + ) + elif datatype.domain: + dt = DOMAIN( name="mytype", data_type=Text, check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", - create_type=False, - ), - ), - argnames="datatype", - ) - def test_disable_create(self, metadata, connection, datatype): - metadata = self.metadata + create_type=bool(createtype), + ) - t1 = Table("e1", metadata, Column("c1", datatype)) - # table can be created separately - # without conflict - datatype.create(bind=connection) - t1.create(connection) - t1.drop(connection) - datatype.drop(bind=connection) + else: + assert False + + expected_create = [ + RegexSQL( + r"CREATE TABLE e1 \(c1 mytype\)", + dialect="postgresql", + ) + ] + + expected_drop = [RegexSQL("DROP TABLE e1", dialect="postgresql")] + + if datatype.domain: + type_exists = functools.partial( + self._domain_exists, "mytype", connection + ) + if createtype: + expected_create.insert( + 0, + RegexSQL( + r"CREATE DOMAIN mytype AS TEXT CHECK \(VALUE .*\)", + dialect="postgresql", + ), + ) + expected_drop.append( + RegexSQL("DROP DOMAIN mytype", dialect="postgresql") + ) + else: + type_exists = functools.partial( + self._enum_exists, "mytype", connection + ) + + if createtype: + expected_create.insert( + 0, + RegexSQL( + r"CREATE TYPE mytype AS ENUM " + r"\('one', 'two', 'three'\)", + dialect="postgresql", + ), + ) + expected_drop.append( + RegexSQL("DROP TYPE mytype", dialect="postgresql") + ) + + t1 = Table("e1", metadata, Column("c1", dt)) + + assert not type_exists() + + if createtype: + with self.sql_execution_asserter(connection) as create_asserter: + t1.create(connection, checkfirst=False) + + assert type_exists() + + with self.sql_execution_asserter(connection) as drop_asserter: + t1.drop(connection, checkfirst=False) + else: + dt.create(bind=connection, checkfirst=False) + assert type_exists() + + with self.sql_execution_asserter(connection) as create_asserter: + t1.create(connection, checkfirst=False) + with self.sql_execution_asserter(connection) as drop_asserter: + t1.drop(connection, checkfirst=False) + + assert type_exists() + dt.drop(bind=connection, checkfirst=False) + + assert not type_exists() + + create_asserter.assert_(*expected_create) + drop_asserter.assert_(*expected_drop) def test_enum_dont_keep_checking(self, metadata, connection): metadata = self.metadata @@ -1396,53 +1483,6 @@ class DomainTest( ], ) - @testing.combinations( - tuple( - [ - DOMAIN( - name="mytype", - data_type=Text, - check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", - create_type=True, - ), - ] - ), - tuple( - [ - DOMAIN( - name="mytype", - data_type=Text, - check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", - create_type=False, - ), - ] - ), - argnames="domain", - ) - def test_create_drop_domain_with_table(self, connection, metadata, domain): - table = Table("e1", metadata, Column("e1", domain)) - - def _domain_names(): - return {d["name"] for d in inspect(connection).get_domains()} - - assert "mytype" not in _domain_names() - - if domain.create_type: - table.create(connection) - assert "mytype" in _domain_names() - else: - with expect_raises(exc.ProgrammingError): - table.create(connection) - connection.rollback() - - domain.create(connection) - assert "mytype" in _domain_names() - table.create(connection) - - table.drop(connection) - if domain.create_type: - assert "mytype" not in _domain_names() - @testing.combinations( (Integer, "value > 0", 4), (String, "value != ''", "hello world"), -- 2.47.3