]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add create_type to Enum
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Sep 2025 16:13:21 +0000 (12:13 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 21 Sep 2025 16:15:49 +0000 (12:15 -0400)
Added new parameter :paramref:`.Enum.create_type` to the Core
:class:`.Enum` class. This parameter is automatically passed to the
corresponding :class:`_postgresql.ENUM` native type during DDL operations,
allowing control over whether the PostgreSQL ENUM type is implicitly
created or dropped within DDL operations that are otherwise targeting
tables only. This provides control over the
:paramref:`_postgresql.ENUM.create_type` behavior without requiring
explicit creation of a :class:`_postgresql.ENUM` object.

Fixes: #10604
Change-Id: I450003ec2a2a65c119fe7ca8ff201392ce6b91e1

doc/build/changelog/unreleased_21/10604.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/named_types.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_21/10604.rst b/doc/build/changelog/unreleased_21/10604.rst
new file mode 100644 (file)
index 0000000..863affd
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 10604
+
+    Added new parameter :paramref:`.Enum.create_type` to the Core
+    :class:`.Enum` class. This parameter is automatically passed to the
+    corresponding :class:`_postgresql.ENUM` native type during DDL operations,
+    allowing control over whether the PostgreSQL ENUM type is implicitly
+    created or dropped within DDL operations that are otherwise targeting
+    tables only. This provides control over the
+    :paramref:`_postgresql.ENUM.create_type` behavior without requiring
+    explicit creation of a :class:`_postgresql.ENUM` object.
index 5807041ead3e0f165b7875f8d7484833b1490113..c47b38185653d7920f85889c6a6237c582daaef4 100644 (file)
@@ -308,9 +308,9 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
                 "always refers to ENUM.   Use sqlalchemy.types.Enum for "
                 "non-native enum."
             )
-        self.create_type = create_type
         if name is not _NoArg.NO_ARG:
             kw["name"] = name
+        kw["create_type"] = create_type
         super().__init__(*enums, **kw)
 
     def coerce_compared_value(self, op, value):
@@ -335,6 +335,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
         """
         kw.setdefault("validate_strings", impl.validate_strings)
         kw.setdefault("name", impl.name)
+        kw.setdefault("create_type", impl.create_type)
         kw.setdefault("schema", impl.schema)
         kw.setdefault("inherit_schema", impl.inherit_schema)
         kw.setdefault("metadata", impl.metadata)
@@ -342,8 +343,6 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
         kw.setdefault("values_callable", impl.values_callable)
         kw.setdefault("omit_aliases", impl._omit_aliases)
         kw.setdefault("_adapted_from", impl)
-        if type_api._is_native_for_emulated(impl.__class__):
-            kw.setdefault("create_type", impl.create_type)
 
         return cls(**kw)
 
@@ -496,8 +495,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
         if check is not None:
             check = coercions.expect(roles.DDLExpressionRole, check)
         self.check = check
-        self.create_type = create_type
-        super().__init__(name=name, **kw)
+        super().__init__(name=name, create_type=create_type, **kw)
 
     @classmethod
     def __test_init__(cls):
index 449eadda456c0f452d15d0b264086fb167746ce8..7154dba97c709b4ef468aea4f685721b9c73842f 100644 (file)
@@ -1073,6 +1073,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
         metadata: Optional[MetaData] = None,
         inherit_schema: Union[bool, _NoArg] = NO_ARG,
         quote: Optional[bool] = None,
+        create_type: bool = True,
         _create_events: bool = True,
         _adapted_from: Optional[SchemaType] = None,
     ):
@@ -1082,7 +1083,7 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin):
             self.name = None
         self.schema = schema
         self.metadata = metadata
-
+        self.create_type = create_type
         if inherit_schema is True and schema is not None:
             raise exc.ArgumentError(
                 "Ambiguously setting inherit_schema=True while "
@@ -1442,6 +1443,21 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
            class was used, its name (converted to lower case) is used by
            default.
 
+        :param create_type: Defaults to True.  This parameter only applies
+         to backends such as PostgreSQL which use explicitly named types
+         that are created and dropped separately from the table(s) they
+         are used by.   Indicates that ``CREATE TYPE`` should be emitted,
+         after optionally checking for the presence of the type, when the
+         parent table is being created; and additionally that ``DROP TYPE`` is
+         called when the table is dropped.  This parameter is equivalent to the
+         parameter of the same name on the PostgreSQL-specific
+         :class:`_postgresql.ENUM` datatype.
+
+          .. versionadded:: 2.1 - The dialect agnostic :class:`.Enum` class
+             now includes the same :paramref:`.Enum.create_type` parameter that
+             was already available on the PostgreSQL native
+             :class:`_postgresql.ENUM` implementation.
+
         :param native_enum: Use the database's native ENUM type when
            available. Defaults to True. When False, uses VARCHAR + check
            constraint for all backends. When False, the VARCHAR length can be
@@ -1586,6 +1602,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
         SchemaType.__init__(
             self,
             name=kw.pop("name", None),
+            create_type=kw.pop("create_type", True),
             inherit_schema=kw.pop("inherit_schema", NO_ARG),
             schema=kw.pop("schema", None),
             metadata=kw.pop("metadata", None),
index 734e909e099779f0d764e540a91a7cfad15d23c3..e57991a30aafe83f5290bc6bd5356985f983aef5 100644 (file)
@@ -255,6 +255,12 @@ class NamedTypeTest(
 
     __only_on__ = "postgresql > 8.3"
 
+    def _enum_exists(self, name, connection):
+        return name in {d["name"] for d in inspect(connection).get_enums()}
+
+    def _domain_exists(self, name, connection):
+        return name in {d["name"] for d in inspect(connection).get_domains()}
+
     def test_native_enum_warnings(self):
         """test #6106"""
 
@@ -800,28 +806,109 @@ class NamedTypeTest(
         connection.execute(t1.insert(), {"bar": "Ü"})
         eq_(connection.scalar(select(t1.c.bar)), "Ü")
 
-    @testing.combinations(
-        (ENUM("one", "two", "three", name="mytype", create_type=False),),
-        (
-            DOMAIN(
+    @testing.variation("datatype", ["enum", "native_enum", "domain"])
+    @testing.variation("createtype", [True, False])
+    def test_create_type_parameter(
+        self, metadata, connection, datatype, createtype
+    ):
+
+        if datatype.enum:
+            dt = Enum(
+                "one",
+                "two",
+                "three",
+                name="mytype",
+                create_type=bool(createtype),
+            )
+        elif datatype.native_enum:
+            dt = ENUM(
+                "one",
+                "two",
+                "three",
+                name="mytype",
+                create_type=bool(createtype),
+            )
+        elif datatype.domain:
+            dt = DOMAIN(
                 name="mytype",
                 data_type=Text,
                 check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
-                create_type=False,
-            ),
-        ),
-        argnames="datatype",
-    )
-    def test_disable_create(self, metadata, connection, datatype):
-        metadata = self.metadata
+                create_type=bool(createtype),
+            )
 
-        t1 = Table("e1", metadata, Column("c1", datatype))
-        # table can be created separately
-        # without conflict
-        datatype.create(bind=connection)
-        t1.create(connection)
-        t1.drop(connection)
-        datatype.drop(bind=connection)
+        else:
+            assert False
+
+        expected_create = [
+            RegexSQL(
+                r"CREATE TABLE e1 \(c1 mytype\)",
+                dialect="postgresql",
+            )
+        ]
+
+        expected_drop = [RegexSQL("DROP TABLE e1", dialect="postgresql")]
+
+        if datatype.domain:
+            type_exists = functools.partial(
+                self._domain_exists, "mytype", connection
+            )
+            if createtype:
+                expected_create.insert(
+                    0,
+                    RegexSQL(
+                        r"CREATE DOMAIN mytype AS TEXT CHECK \(VALUE .*\)",
+                        dialect="postgresql",
+                    ),
+                )
+                expected_drop.append(
+                    RegexSQL("DROP DOMAIN mytype", dialect="postgresql")
+                )
+        else:
+            type_exists = functools.partial(
+                self._enum_exists, "mytype", connection
+            )
+
+            if createtype:
+                expected_create.insert(
+                    0,
+                    RegexSQL(
+                        r"CREATE TYPE mytype AS ENUM "
+                        r"\('one', 'two', 'three'\)",
+                        dialect="postgresql",
+                    ),
+                )
+                expected_drop.append(
+                    RegexSQL("DROP TYPE mytype", dialect="postgresql")
+                )
+
+        t1 = Table("e1", metadata, Column("c1", dt))
+
+        assert not type_exists()
+
+        if createtype:
+            with self.sql_execution_asserter(connection) as create_asserter:
+                t1.create(connection, checkfirst=False)
+
+            assert type_exists()
+
+            with self.sql_execution_asserter(connection) as drop_asserter:
+                t1.drop(connection, checkfirst=False)
+        else:
+            dt.create(bind=connection, checkfirst=False)
+            assert type_exists()
+
+            with self.sql_execution_asserter(connection) as create_asserter:
+                t1.create(connection, checkfirst=False)
+            with self.sql_execution_asserter(connection) as drop_asserter:
+                t1.drop(connection, checkfirst=False)
+
+            assert type_exists()
+            dt.drop(bind=connection, checkfirst=False)
+
+        assert not type_exists()
+
+        create_asserter.assert_(*expected_create)
+        drop_asserter.assert_(*expected_drop)
 
     def test_enum_dont_keep_checking(self, metadata, connection):
         metadata = self.metadata
@@ -1396,53 +1483,6 @@ class DomainTest(
             ],
         )
 
-    @testing.combinations(
-        tuple(
-            [
-                DOMAIN(
-                    name="mytype",
-                    data_type=Text,
-                    check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
-                    create_type=True,
-                ),
-            ]
-        ),
-        tuple(
-            [
-                DOMAIN(
-                    name="mytype",
-                    data_type=Text,
-                    check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
-                    create_type=False,
-                ),
-            ]
-        ),
-        argnames="domain",
-    )
-    def test_create_drop_domain_with_table(self, connection, metadata, domain):
-        table = Table("e1", metadata, Column("e1", domain))
-
-        def _domain_names():
-            return {d["name"] for d in inspect(connection).get_domains()}
-
-        assert "mytype" not in _domain_names()
-
-        if domain.create_type:
-            table.create(connection)
-            assert "mytype" in _domain_names()
-        else:
-            with expect_raises(exc.ProgrammingError):
-                table.create(connection)
-            connection.rollback()
-
-            domain.create(connection)
-            assert "mytype" in _domain_names()
-            table.create(connection)
-
-        table.drop(connection)
-        if domain.create_type:
-            assert "mytype" not in _domain_names()
-
     @testing.combinations(
         (Integer, "value > 0", 4),
         (String, "value != ''", "hello world"),