]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add values_callable feature to Enum
authorJon Snyder <snyder.jon@gmail.com>
Wed, 17 Jan 2018 21:37:59 +0000 (16:37 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Feb 2018 02:22:27 +0000 (21:22 -0500)
Added support for :class:`.Enum` to persist the values of the enumeration,
rather than the keys, when using a Python pep-435 style enumerated object.
The user supplies a callable function that will return the string values to
be persisted.  This allows enumerations against non-string values to be
value-persistable as well.  Pull request courtesy Jon Snyder.

Pull-request: https://github.com/zzzeek/sqlalchemy/pull/410
Fixes: #3906
Change-Id: Id385465d215d1e5baaad68368b168afdd846b82c

doc/build/changelog/unreleased_12/3906.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/enumerated.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/mysql/test_types.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_12/3906.rst b/doc/build/changelog/unreleased_12/3906.rst
new file mode 100644 (file)
index 0000000..a565370
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: feature, sql
+    :tickets: 3906
+
+    Added support for :class:`.Enum` to persist the values of the enumeration,
+    rather than the keys, when using a Python pep-435 style enumerated object.
+    The user supplies a callable function that will return the string values to
+    be persisted.  This allows enumerations against non-string values to be
+    value-persistable as well.  Pull request courtesy Jon Snyder.
index dfbe96b4a8b137b1e8d5e3436a9bf9940d6e0718..f63d64e8f92e74c93832775e2997eb519256558e 100644 (file)
@@ -126,6 +126,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues):
 
         """
         kw.setdefault("validate_strings", impl.validate_strings)
+        kw.setdefault("values_callable", impl.values_callable)
         return cls(**kw)
 
     def _setup_for_values(self, values, objects, kw):
index 340d3d2be501c3bdc52d29723773e65f36af991d..0cc7c307fa704ea111fc19fd7e8ae5555211c1ff 100644 (file)
@@ -1257,6 +1257,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         kw.setdefault('inherit_schema', impl.inherit_schema)
         kw.setdefault('metadata', impl.metadata)
         kw.setdefault('_create_events', False)
+        kw.setdefault('values_callable', impl.values_callable)
         return cls(**kw)
 
     def create(self, bind=None, checkfirst=True):
index ac915c73a90ced01eaa73202bcb2641aaca40cf8..c02ece98aab839c0dce72c3d1daceedbbbef6a60 100644 (file)
@@ -1189,6 +1189,13 @@ class Enum(Emulated, String, SchemaType):
     indicated as integers, are **not** used; the value of each enum can
     therefore be any kind of Python object whether or not it is persistable.
 
+    In order to persist the values and not the names, the
+    :paramref:`.Enum.values_callable` parameter may be used.   The value of
+    this parameter is a user-supplied callable, which  is intended to be used
+    with a PEP-435-compliant enumerated class and  returns a list of string
+    values to be persisted.   For a simple enumeration that uses string values,
+    a callable such as  ``lambda x: [e.value for e in x]`` is sufficient.
+
     .. versionadded:: 1.1 - support for PEP-435-style enumerated
        classes.
 
@@ -1277,6 +1284,14 @@ class Enum(Emulated, String, SchemaType):
 
            .. versionadded:: 1.1.0b2
 
+        :param values_callable: A callable which will be passed the PEP-435
+           compliant enumerated type, which should then return a list of string
+           values to be persisted. This allows for alternate usages such as
+           using the string value of an enum to be persisted to the database
+           instead of its name.
+
+           .. versionadded:: 1.2.3
+
         """
         self._enum_init(enums, kw)
 
@@ -1297,6 +1312,7 @@ class Enum(Emulated, String, SchemaType):
         """
         self.native_enum = kw.pop('native_enum', True)
         self.create_constraint = kw.pop('create_constraint', True)
+        self.values_callable = kw.pop('values_callable', None)
 
         values, objects = self._parse_into_values(enums, kw)
         self._setup_for_values(values, objects, kw)
@@ -1341,8 +1357,11 @@ class Enum(Emulated, String, SchemaType):
 
         if len(enums) == 1 and hasattr(enums[0], '__members__'):
             self.enum_class = enums[0]
-            values = list(self.enum_class.__members__)
-            objects = [self.enum_class.__members__[k] for k in values]
+            if self.values_callable:
+                values = self.values_callable(self.enum_class)
+            else:
+                values = list(self.enum_class.__members__)
+            objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__]
             return values, objects
         else:
             self.enum_class = None
@@ -1423,6 +1442,7 @@ class Enum(Emulated, String, SchemaType):
         kw.setdefault('metadata', self.metadata)
         kw.setdefault('_create_events', False)
         kw.setdefault('native_enum', self.native_enum)
+        kw.setdefault('values_callable', self.values_callable)
         assert '_enums' in kw
         return impltype(**kw)
 
index 0bc9de505d1295389349c61d35cefd00e0199642..e32b92043d510118999dda12b114e0de1155c0bf 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import testing
 import datetime
 import decimal
 from sqlalchemy import types as sqltypes
+from collections import OrderedDict
 
 
 class TypesTest(fixtures.TestBase,
@@ -652,6 +653,26 @@ class EnumSetTest(
     __dialect__ = mysql.dialect()
     __backend__ = True
 
+    class SomeEnum(object):
+        # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+        __members__ = OrderedDict()
+
+        def __init__(self, name, value):
+            self.name = name
+            self.value = value
+            self.__members__[name] = self
+            setattr(self.__class__, name, self)
+
+    one = SomeEnum('one', 1)
+    two = SomeEnum('two', 2)
+    three = SomeEnum('three', 3)
+    a_member = SomeEnum('AMember', 'a')
+    b_member = SomeEnum('BMember', 'b')
+
+    @staticmethod
+    def get_enum_string_values(some_enum):
+        return [str(v.value) for v in some_enum.__members__.values()]
+
     @testing.provide_metadata
     def test_enum(self):
         """Exercise the ENUM type."""
@@ -673,6 +694,10 @@ class EnumSetTest(
             Column('e5', mysql.ENUM("a", "b")),
             Column('e5generic', Enum("a", "b")),
             Column('e6', mysql.ENUM("'a'", "b")),
+            Column('e7', mysql.ENUM(EnumSetTest.SomeEnum,
+                                    values_callable=EnumSetTest.
+                                    get_enum_string_values)),
+            Column('e8', mysql.ENUM(EnumSetTest.SomeEnum))
         )
 
         eq_(
@@ -699,6 +724,14 @@ class EnumSetTest(
         eq_(
             colspec(enum_table.c.e6),
             "e6 ENUM('''a''','b')")
+        eq_(
+            colspec(enum_table.c.e7),
+            "e7 ENUM('1','2','3','a','b')"
+        )
+        eq_(
+            colspec(enum_table.c.e8),
+            "e8 ENUM('one','two','three','AMember','BMember')"
+        )
         enum_table.create()
 
         assert_raises(
@@ -710,19 +743,27 @@ class EnumSetTest(
             exc.StatementError,
             enum_table.insert().execute,
             e1='c', e2='c', e2generic='c', e3='c',
-            e4='c', e5='c', e5generic='c', e6='c')
+            e4='c', e5='c', e5generic='c', e6='c',
+            e7='c', e8='c')
 
         enum_table.insert().execute()
         enum_table.insert().execute(e1='a', e2='a', e2generic='a', e3='a',
-                                    e4='a', e5='a', e5generic='a', e6="'a'")
+                                    e4='a', e5='a', e5generic='a', e6="'a'",
+                                    e7='a', e8='AMember')
         enum_table.insert().execute(e1='b', e2='b', e2generic='b', e3='b',
-                                    e4='b', e5='b', e5generic='b', e6='b')
+                                    e4='b', e5='b', e5generic='b', e6='b',
+                                    e7='b', e8='BMember')
 
         res = enum_table.select().execute().fetchall()
 
-        expected = [(None, 'a', 'a', None, 'a', None, None, None),
-                    ('a', 'a', 'a', 'a', 'a', 'a', 'a', "'a'"),
-                    ('b', 'b', 'b', 'b', 'b', 'b', 'b', 'b')]
+        expected = [(None, 'a', 'a', None, 'a', None, None, None,
+                     None, None),
+                    ('a', 'a', 'a', 'a', 'a', 'a', 'a', "'a'",
+                     EnumSetTest.SomeEnum.AMember,
+                     EnumSetTest.SomeEnum.AMember),
+                    ('b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',
+                     EnumSetTest.SomeEnum.BMember,
+                     EnumSetTest.SomeEnum.BMember)]
 
         eq_(res, expected)
 
index fa917c466bc22ae62c0b7a3554f6097a18c96a5c..002094f7bb416d6cbcd867bbcc26552b56fcceb6 100644 (file)
@@ -1170,9 +1170,24 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
                 self.__members__[alias] = self
                 setattr(self.__class__, alias, self)
 
+    class SomeOtherEnum(SomeEnum):
+        __members__ = OrderedDict()
+
     one = SomeEnum('one', 1)
     two = SomeEnum('two', 2)
     three = SomeEnum('three', 3, 'four')
+    a_member = SomeEnum('AMember', 'a')
+    b_member = SomeEnum('BMember', 'b')
+
+    other_one = SomeOtherEnum('one', 1)
+    other_two = SomeOtherEnum('two', 2)
+    other_three = SomeOtherEnum('three', 3)
+    other_a_member = SomeOtherEnum('AMember', 'a')
+    other_b_member = SomeOtherEnum('BMember', 'b')
+
+    @staticmethod
+    def get_enum_string_values(some_enum):
+        return [str(v.value) for v in some_enum.__members__.values()]
 
     @classmethod
     def define_tables(cls, metadata):
@@ -1197,6 +1212,14 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             Column('someenum', Enum(cls.SomeEnum))
         )
 
+        Table(
+            'stdlib_enum_table2', metadata,
+            Column("id", Integer, primary_key=True),
+            Column('someotherenum',
+                   Enum(cls.SomeOtherEnum,
+                        values_callable=EnumTest.get_enum_string_values))
+        )
+
     def test_python_type(self):
         eq_(types.Enum(self.SomeEnum).python_type, self.SomeEnum)
 
@@ -1521,6 +1544,27 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             ]
         )
 
+    def test_pep435_enum_values_callable_round_trip(self):
+        stdlib_enum_table_custom_values =\
+            self.tables['stdlib_enum_table2']
+
+        stdlib_enum_table_custom_values.insert().execute([
+            {'id': 1, 'someotherenum': self.SomeOtherEnum.AMember},
+            {'id': 2, 'someotherenum': self.SomeOtherEnum.BMember},
+            {'id': 3, 'someotherenum': self.SomeOtherEnum.AMember}
+        ])
+
+        eq_(
+            stdlib_enum_table_custom_values.select().
+            order_by(stdlib_enum_table_custom_values.c.id).execute().
+            fetchall(),
+            [
+                (1, self.SomeOtherEnum.AMember),
+                (2, self.SomeOtherEnum.BMember),
+                (3, self.SomeOtherEnum.AMember)
+            ]
+        )
+
     def test_adapt(self):
         from sqlalchemy.dialects.postgresql import ENUM
         e1 = Enum('one', 'two', 'three', native_enum=False)
@@ -1544,7 +1588,13 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         is_(e1.adapt(Enum).metadata, e1.metadata)
         e1 = Enum(self.SomeEnum)
         eq_(e1.adapt(ENUM).name, 'someenum')
-        eq_(e1.adapt(ENUM).enums, ['one', 'two', 'three', 'four'])
+        eq_(e1.adapt(ENUM).enums,
+            ['one', 'two', 'three', 'four', 'AMember', 'BMember'])
+
+        e1_vc = Enum(self.SomeOtherEnum,
+                     values_callable=EnumTest.get_enum_string_values)
+        eq_(e1_vc.adapt(ENUM).name, 'someotherenum')
+        eq_(e1_vc.adapt(ENUM).enums, ['1', '2', '3', 'a', 'b'])
 
     @testing.provide_metadata
     def test_create_metadata_bound_no_crash(self):