]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add ``omit_aliases`` in ``Enum``.
authorFederico Caselli <cfederico87@gmail.com>
Mon, 29 Mar 2021 20:30:49 +0000 (22:30 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Apr 2021 22:55:37 +0000 (18:55 -0400)
Introduce a new parameter :paramref:`_types.Enum.omit_aliases` in
:class:`_types.Enum` type allow filtering aliases when using a pep435 Enum.
Previous versions of SQLAlchemy kept aliases in all cases, creating
database enum type with additional states, meaning that they were treated
as different values in the db. For backward compatibility this flag
defaults to ``False`` in the 1.4 series, but will be switched to ``True``
in a future version. A deprecation warning is raise if this flag is not
specified and the passed enum contains aliases.

Fixes: #6146
Change-Id: I547322ffa90d0273d91bb3bf8bfea6ec934d48b9

doc/build/changelog/unreleased_14/6146.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/enumerated.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_14/6146.rst b/doc/build/changelog/unreleased_14/6146.rst
new file mode 100644 (file)
index 0000000..cca7247
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql, schema
+    :tickets: 6146
+
+    Introduce a new parameter :paramref:`_types.Enum.omit_aliases` in
+    :class:`_types.Enum` type allow filtering aliases when using a pep435 Enum.
+    Previous versions of SQLAlchemy kept aliases in all cases, creating
+    database enum type with additional states, meaning that they were treated
+    as different values in the db. For backward compatibility this flag
+    defaults to ``False`` in the 1.4 series, but will be switched to ``True``
+    in a future version. A deprecation warning is raise if this flag is not
+    specified and the passed enum contains aliases.
index c44b602260395e7fef28ded0fc2eb34e3f471764..3b61e4e904dfa5b62610aa0821be4a963e45f16c 100644 (file)
@@ -80,6 +80,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
         """
         kw.setdefault("validate_strings", impl.validate_strings)
         kw.setdefault("values_callable", impl.values_callable)
+        kw.setdefault("omit_aliases", impl._omit_aliases)
         return cls(**kw)
 
     def _object_value_for_elem(self, elem):
index 0854214d029ac639c375a6290007179455f275d1..7d9205c018447b0d8e84cea8922bd9d2630cc188 100644 (file)
@@ -1880,6 +1880,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         kw.setdefault("metadata", impl.metadata)
         kw.setdefault("_create_events", False)
         kw.setdefault("values_callable", impl.values_callable)
+        kw.setdefault("omit_aliases", impl._omit_aliases)
         return cls(**kw)
 
     def create(self, bind=None, checkfirst=True):
index 367b2e203788516fb7e2fc7d1eac4771cb0bb64b..59ba7c39110a61cbbb3558302612b579101a98d2 100644 (file)
@@ -41,6 +41,7 @@ from .. import processors
 from .. import util
 from ..util import compat
 from ..util import langhelpers
+from ..util import OrderedDict
 from ..util import pickle
 
 
@@ -1425,7 +1426,15 @@ class Enum(Emulated, String, SchemaType):
 
            .. versionadded:: 1.3.8
 
+        :param omit_aliases: A boolean that when true will remove aliases from
+           pep 435 enums. For backward compatibility it defaults to ``False``.
+           A deprecation warning is raised if the enum has aliases and this
+           flag was not set.
 
+           .. versionadded:: 1.4.4
+
+           .. deprecated:: 1.4  The default will be changed to ``True`` in
+              SQLAlchemy 2.0.
 
         """
         self._enum_init(enums, kw)
@@ -1450,6 +1459,7 @@ class Enum(Emulated, String, SchemaType):
         self.values_callable = kw.pop("values_callable", None)
         self._sort_key_function = kw.pop("sort_key_function", NO_ARG)
         length_arg = kw.pop("length", NO_ARG)
+        self._omit_aliases = kw.pop("omit_aliases", NO_ARG)
 
         values, objects = self._parse_into_values(enums, kw)
         self._setup_for_values(values, objects, kw)
@@ -1506,7 +1516,24 @@ class Enum(Emulated, String, SchemaType):
 
         if len(enums) == 1 and hasattr(enums[0], "__members__"):
             self.enum_class = enums[0]
-            members = self.enum_class.__members__
+
+            _members = self.enum_class.__members__
+
+            aliases = [n for n, v in _members.items() if v.name != n]
+            if self._omit_aliases is NO_ARG and aliases:
+                util.warn_deprecated_20(
+                    "The provided enum %s contains the aliases %s. The "
+                    "``omit_aliases`` will default to ``True`` in SQLAlchemy "
+                    "2.0. Specify a value to silence this warning."
+                    % (self.enum_class.__name__, aliases)
+                )
+            if self._omit_aliases is True:
+                # remove aliases
+                members = OrderedDict(
+                    (n, v) for n, v in _members.items() if v.name == n
+                )
+            else:
+                members = _members
             if self.values_callable:
                 values = self.values_callable(self.enum_class)
             else:
@@ -1633,6 +1660,7 @@ class Enum(Emulated, String, SchemaType):
         kw.setdefault("values_callable", self.values_callable)
         kw.setdefault("create_constraint", self.create_constraint)
         kw.setdefault("length", self.length)
+        kw.setdefault("omit_aliases", self._omit_aliases)
         assert "_enums" in kw
         return impltype(**kw)
 
index 08966b38e6dc8ff7083c97b8283eb021b9635e8e..e63197ae2f98c6a41f8c64329da2c264b2cbbafd 100644 (file)
@@ -79,6 +79,7 @@ from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_deprecated_20
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -1659,7 +1660,24 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             "stdlib_enum_table",
             metadata,
             Column("id", Integer, primary_key=True),
-            Column("someenum", Enum(cls.SomeEnum, create_constraint=True)),
+            Column(
+                "someenum",
+                Enum(cls.SomeEnum, create_constraint=True, omit_aliases=False),
+            ),
+        )
+        Table(
+            "stdlib_enum_table_no_alias",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "someenum",
+                Enum(
+                    cls.SomeEnum,
+                    create_constraint=True,
+                    omit_aliases=True,
+                    name="someenum_no_alias",
+                ),
+            ),
         )
 
         Table(
@@ -1677,7 +1695,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         )
 
     def test_python_type(self):
-        eq_(types.Enum(self.SomeEnum).python_type, self.SomeEnum)
+        eq_(types.Enum(self.SomeOtherEnum).python_type, self.SomeOtherEnum)
 
     def test_pickle_types(self):
         global SomeEnum
@@ -1685,7 +1703,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         for loads, dumps in picklers():
             column_types = [
                 Column("Enu", Enum("x", "y", "z", name="somename")),
-                Column("En2", Enum(self.SomeEnum)),
+                Column("En2", Enum(self.SomeEnum, omit_aliases=False)),
             ]
             for column_type in column_types:
                 meta = MetaData()
@@ -1694,8 +1712,10 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
                 loads(dumps(meta))
 
     def test_validators_pep435(self):
-        type_ = Enum(self.SomeEnum)
-        validate_type = Enum(self.SomeEnum, validate_strings=True)
+        type_ = Enum(self.SomeEnum, omit_aliases=False)
+        validate_type = Enum(
+            self.SomeEnum, validate_strings=True, omit_aliases=False
+        )
 
         bind_processor = type_.bind_processor(testing.db.dialect)
         bind_processor_validates = validate_type.bind_processor(
@@ -2086,7 +2106,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             self.a_member,
             self.b_member,
         )
-        typ = Enum(self.SomeEnum)
+        typ = Enum(self.SomeEnum, omit_aliases=False)
 
         is_(typ.sort_key_function.__func__, typ._db_value_for_elem.__func__)
 
@@ -2106,7 +2126,11 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         def sort_enum_key_value(value):
             return str(value.value)
 
-        typ = Enum(self.SomeEnum, sort_key_function=sort_enum_key_value)
+        typ = Enum(
+            self.SomeEnum,
+            sort_key_function=sort_enum_key_value,
+            omit_aliases=False,
+        )
         is_(typ.sort_key_function, sort_enum_key_value)
 
         eq_(
@@ -2115,7 +2139,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         )
 
     def test_pep435_no_sort_key(self):
-        typ = Enum(self.SomeEnum, sort_key_function=None)
+        typ = Enum(self.SomeEnum, sort_key_function=None, omit_aliases=False)
         is_(typ.sort_key_function, None)
 
     def test_pep435_enum_round_trip(self, connection):
@@ -2231,7 +2255,7 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         eq_(e1.adapt(Enum).name, "foo")
         eq_(e1.adapt(Enum).schema, "bar")
         is_(e1.adapt(Enum).metadata, e1.metadata)
-        e1 = Enum(self.SomeEnum)
+        e1 = Enum(self.SomeEnum, omit_aliases=False)
         eq_(e1.adapt(ENUM).name, "someenum")
         eq_(
             e1.adapt(ENUM).enums,
@@ -2366,6 +2390,34 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         e = Enum("x", "y", "long", native_enum=False, length=42)
         eq_(e.length, 42)
 
+    def test_omit_aliases(self, connection):
+        table0 = self.tables["stdlib_enum_table"]
+        type0 = table0.c.someenum.type
+        eq_(type0.enums, ["one", "two", "three", "four", "AMember", "BMember"])
+
+        table = self.tables["stdlib_enum_table_no_alias"]
+
+        type_ = table.c.someenum.type
+        eq_(type_.enums, ["one", "two", "three", "AMember", "BMember"])
+
+        connection.execute(
+            table.insert(),
+            [
+                {"id": 1, "someenum": self.SomeEnum.three},
+                {"id": 2, "someenum": self.SomeEnum.four},
+            ],
+        )
+        eq_(
+            connection.execute(table.select().order_by(table.c.id)).fetchall(),
+            [(1, self.SomeEnum.three), (2, self.SomeEnum.three)],
+        )
+
+    def test_omit_warn(self):
+        with expect_deprecated_20(
+            r"The provided enum someenum contains the aliases \['four'\]"
+        ):
+            Enum(self.SomeEnum)
+
 
 MyPickleType = None