]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Unify native and non-native valid values for ``Enum``
authorFederico Caselli <cfederico87@gmail.com>
Wed, 7 Apr 2021 20:02:10 +0000 (22:02 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 10 Apr 2021 13:36:33 +0000 (15:36 +0200)
Unify behaviour :class:`_schema.Enum` in native and non-native
implementations regarding the accepted values for an enum with
aliased elements.
When :paramref:`_schema.Enum.omit_aliases` is ``False`` all values,
alias included, are accepted as valid values.
When :paramref:`_schema.Enum.omit_aliases` is ``True`` only non aliased values
are accepted as valid values.

Fixes: #6146
Change-Id: I6f40789c1ca56e533990882deadcc6a377d4fc40

doc/build/changelog/unreleased_14/6146.rst [new file with mode: 0644]
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/6146.rst b/doc/build/changelog/unreleased_14/6146.rst
new file mode 100644 (file)
index 0000000..418fbc4
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: enum, schema
+    :tickets: 6146
+
+    Unify behaviour :class:`_schema.Enum` in native and non-native
+    implementations regarding the accepted values for an enum with
+    aliased elements.
+    When :paramref:`_schema.Enum.omit_aliases` is ``False`` all values,
+    alias included, are accepted as valid values.
+    When :paramref:`_schema.Enum.omit_aliases` is ``True`` only non aliased values
+    are accepted as valid values.
index e57a14681d77d4915ae4577fd3fb40df4cf0a0b4..fd3118e30074bf6593ce80012b8cfa5d6718028e 100644 (file)
@@ -1692,7 +1692,7 @@ class Enum(Emulated, String, SchemaType):
         variant_mapping = self._variant_mapping_for_set_table(column)
 
         e = schema.CheckConstraint(
-            type_coerce(column, self).in_(self.enums),
+            type_coerce(column, String()).in_(self.enums),
             name=_NONE_NAME if self.name is None else self.name,
             _create_rule=util.portable_instancemethod(
                 self._should_create_constraint,
@@ -1714,13 +1714,14 @@ class Enum(Emulated, String, SchemaType):
         return process
 
     def bind_processor(self, dialect):
+        parent_processor = super(Enum, self).bind_processor(dialect)
+
         def process(value):
             value = self._db_value_for_elem(value)
             if parent_processor:
                 value = parent_processor(value)
             return value
 
-        parent_processor = super(Enum, self).bind_processor(dialect)
         return process
 
     def result_processor(self, dialect, coltype):
index e63197ae2f98c6a41f8c64329da2c264b2cbbafd..2cfd148cbd422cc0c14bf0e9c75a6a0606677942 100644 (file)
@@ -80,6 +80,7 @@ from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_deprecated_20
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -2048,7 +2049,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             assert_raises(
                 (exc.DBAPIError,),
                 connection.exec_driver_sql,
-                "insert into my_table " "(data) values('two')",
+                "insert into my_table (data) values('two')",
             )
             trans.rollback()
 
@@ -2418,6 +2419,68 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         ):
             Enum(self.SomeEnum)
 
+    @testing.combinations(
+        (True, "native"), (False, "non_native"), id_="ai", argnames="native"
+    )
+    @testing.combinations(
+        (True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit"
+    )
+    @testing.provide_metadata
+    @testing.skip_if('mysql < 8')
+    def test_duplicate_values_accepted(self, native, omit):
+        foo_enum = pep435_enum("foo_enum")
+        foo_enum("one", 1, "two")
+        foo_enum("three", 3, "four")
+        tbl = sa.Table(
+            "foo_table",
+            self.metadata,
+            sa.Column("id", sa.Integer),
+            sa.Column(
+                "data",
+                sa.Enum(
+                    foo_enum,
+                    native_enum=native,
+                    omit_aliases=omit,
+                    create_constraint=True,
+                ),
+            ),
+        )
+        t = sa.table("foo_table", sa.column("id"), sa.column("data"))
+
+        self.metadata.create_all(testing.db)
+        if omit:
+            with expect_raises(
+                (
+                    exc.IntegrityError,
+                    exc.DataError,
+                    exc.OperationalError,
+                    exc.DBAPIError,
+                )
+            ):
+                with testing.db.begin() as conn:
+                    conn.execute(
+                        t.insert(),
+                        [
+                            {"id": 1, "data": "four"},
+                            {"id": 2, "data": "three"},
+                        ],
+                    )
+        else:
+            with testing.db.begin() as conn:
+                conn.execute(
+                    t.insert(),
+                    [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}],
+                )
+
+                eq_(
+                    conn.execute(t.select().order_by(t.c.id)).fetchall(),
+                    [(1, "four"), (2, "three")],
+                )
+                eq_(
+                    conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(),
+                    [(1, foo_enum.three), (2, foo_enum.three)],
+                )
+
 
 MyPickleType = None