From 9fd4d188708ff4d5a11dde806acda6792b810703 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Wed, 7 Apr 2021 22:02:10 +0200 Subject: [PATCH] Unify native and non-native valid values for ``Enum`` Unify behaviour :class:`_schema.Enum` in native and non-native implementations regarding the accepted values for an enum with aliased elements. When :paramref:`_schema.Enum.omit_aliases` is ``False`` all values, alias included, are accepted as valid values. When :paramref:`_schema.Enum.omit_aliases` is ``True`` only non aliased values are accepted as valid values. Fixes: #6146 Change-Id: I6f40789c1ca56e533990882deadcc6a377d4fc40 --- doc/build/changelog/unreleased_14/6146.rst | 11 ++++ lib/sqlalchemy/sql/sqltypes.py | 5 +- test/sql/test_types.py | 65 +++++++++++++++++++++- 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6146.rst diff --git a/doc/build/changelog/unreleased_14/6146.rst b/doc/build/changelog/unreleased_14/6146.rst new file mode 100644 index 0000000000..418fbc4e2b --- /dev/null +++ b/doc/build/changelog/unreleased_14/6146.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: enum, schema + :tickets: 6146 + + Unify behaviour :class:`_schema.Enum` in native and non-native + implementations regarding the accepted values for an enum with + aliased elements. + When :paramref:`_schema.Enum.omit_aliases` is ``False`` all values, + alias included, are accepted as valid values. + When :paramref:`_schema.Enum.omit_aliases` is ``True`` only non aliased values + are accepted as valid values. diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e57a14681d..fd3118e300 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1692,7 +1692,7 @@ class Enum(Emulated, String, SchemaType): variant_mapping = self._variant_mapping_for_set_table(column) e = schema.CheckConstraint( - type_coerce(column, self).in_(self.enums), + type_coerce(column, String()).in_(self.enums), name=_NONE_NAME if self.name is None else self.name, _create_rule=util.portable_instancemethod( self._should_create_constraint, @@ -1714,13 +1714,14 @@ class Enum(Emulated, String, SchemaType): return process def bind_processor(self, dialect): + parent_processor = super(Enum, self).bind_processor(dialect) + def process(value): value = self._db_value_for_elem(value) if parent_processor: value = parent_processor(value) return value - parent_processor = super(Enum, self).bind_processor(dialect) return process def result_processor(self, dialect, coltype): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index e63197ae2f..2cfd148cbd 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -80,6 +80,7 @@ from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_deprecated_20 +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -2048,7 +2049,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): assert_raises( (exc.DBAPIError,), connection.exec_driver_sql, - "insert into my_table " "(data) values('two')", + "insert into my_table (data) values('two')", ) trans.rollback() @@ -2418,6 +2419,68 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): ): Enum(self.SomeEnum) + @testing.combinations( + (True, "native"), (False, "non_native"), id_="ai", argnames="native" + ) + @testing.combinations( + (True, "omit_alias"), (False, "with_alias"), id_="ai", argnames="omit" + ) + @testing.provide_metadata + @testing.skip_if('mysql < 8') + def test_duplicate_values_accepted(self, native, omit): + foo_enum = pep435_enum("foo_enum") + foo_enum("one", 1, "two") + foo_enum("three", 3, "four") + tbl = sa.Table( + "foo_table", + self.metadata, + sa.Column("id", sa.Integer), + sa.Column( + "data", + sa.Enum( + foo_enum, + native_enum=native, + omit_aliases=omit, + create_constraint=True, + ), + ), + ) + t = sa.table("foo_table", sa.column("id"), sa.column("data")) + + self.metadata.create_all(testing.db) + if omit: + with expect_raises( + ( + exc.IntegrityError, + exc.DataError, + exc.OperationalError, + exc.DBAPIError, + ) + ): + with testing.db.begin() as conn: + conn.execute( + t.insert(), + [ + {"id": 1, "data": "four"}, + {"id": 2, "data": "three"}, + ], + ) + else: + with testing.db.begin() as conn: + conn.execute( + t.insert(), + [{"id": 1, "data": "four"}, {"id": 2, "data": "three"}], + ) + + eq_( + conn.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "four"), (2, "three")], + ) + eq_( + conn.execute(tbl.select().order_by(tbl.c.id)).fetchall(), + [(1, foo_enum.three), (2, foo_enum.three)], + ) + MyPickleType = None -- 2.47.2