]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Initial implementation of support for PEP-435 enumerated types
authorAlex Grönholm <alex.gronholm@nextday.fi>
Tue, 2 Feb 2016 19:20:17 +0000 (14:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Feb 2016 19:21:54 +0000 (14:21 -0500)
within the Enum type.

lib/sqlalchemy/sql/sqltypes.py
test/sql/test_types.py

index d198d3912c98f408322d1a25c53329e079c5f954..21e57d519b50b81939edce7a828a201a1f0ea939 100644 (file)
@@ -1103,7 +1103,8 @@ class Enum(String, SchemaType):
         Keyword arguments which don't apply to a specific backend are ignored
         by that backend.
 
-        :param \*enums: string or unicode enumeration labels. If unicode
+        :param \*enums: either exactly one PEP 435 compliant enumerated type
+           or one or more string or unicode enumeration labels. If unicode
            labels are present, the `convert_unicode` flag is auto-enabled.
 
         :param convert_unicode: Enable unicode-aware bind parameter and
@@ -1124,7 +1125,9 @@ class Enum(String, SchemaType):
         :param name: The name of this type. This is required for Postgresql
            and any future supported database which requires an explicitly
            named type, or an explicitly named constraint in order to generate
-           the type and/or a table that uses it.
+           the type and/or a table that uses it. If an :class:`~enum.Enum`
+           class was used, its name (converted to lower case) is used by
+           default.
 
         :param native_enum: Use the database's native ENUM type when
            available. Defaults to True. When False, uses VARCHAR + check
@@ -1153,11 +1156,20 @@ class Enum(String, SchemaType):
            .. versionadded:: 0.8
 
         """
-        self.enums = enums
+        if len(enums) == 1 and hasattr(enums[0], '__members__'):
+            self.enums = list(enums[0].__members__)
+            self.enum_class = enums[0]
+            kw.setdefault('name', enums[0].__name__.lower())
+            self.key_lookup = dict((value, key) for key, value in enums[0].__members__.items())
+            self.value_lookup = enums[0].__members__.copy()
+        else:
+            self.enums = enums
+            self.enum_class = self.key_lookup = self.value_lookup = None
+
         self.native_enum = kw.pop('native_enum', True)
         convert_unicode = kw.pop('convert_unicode', None)
         if convert_unicode is None:
-            for e in enums:
+            for e in self.enums:
                 if isinstance(e, util.text_type):
                     convert_unicode = True
                     break
@@ -1203,6 +1215,7 @@ class Enum(String, SchemaType):
         metadata = kw.pop('metadata', self.metadata)
         _create_events = kw.pop('_create_events', False)
         if issubclass(impltype, Enum):
+            args = [self.enum_class] if self.enum_class is not None else self.enums
             return impltype(name=self.name,
                             schema=schema,
                             metadata=metadata,
@@ -1210,12 +1223,66 @@ class Enum(String, SchemaType):
                             native_enum=self.native_enum,
                             inherit_schema=self.inherit_schema,
                             _create_events=_create_events,
-                            *self.enums,
+                            *args,
                             **kw)
         else:
             # TODO: why would we be here?
             return super(Enum, self).adapt(impltype, **kw)
 
+    def literal_processor(self, dialect):
+        parent_processor = super(Enum, self).literal_processor(dialect)
+        if self.key_lookup:
+            def process(value):
+                value = self.key_lookup.get(value, value)
+                if parent_processor:
+                    return parent_processor(value)
+
+            return process
+        else:
+            return parent_processor
+
+    def bind_processor(self, dialect):
+        def process(value):
+            if isinstance(value, util.string_types):
+                if value not in self.enums:
+                    raise LookupError(
+                        '"%s" is not among the defined enum values' %
+                        value)
+            elif self.key_lookup and value in self.key_lookup:
+                value = self.key_lookup[value]
+
+            if parent_processor:
+                value = parent_processor(value)
+            return value
+
+        parent_processor = super(Enum, self).bind_processor(dialect)
+        return process
+
+    def result_processor(self, dialect, coltype):
+        parent_processor = super(Enum, self).result_processor(dialect,
+                                                              coltype)
+        if self.value_lookup:
+            def process(value):
+                if parent_processor:
+                    value = parent_processor(value)
+
+                try:
+                    return self.value_lookup[value]
+                except KeyError:
+                    raise LookupError('No such member in enum class %s: %s' %
+                                      (self.enum_class.__name__, value))
+
+            return process
+        else:
+            return parent_processor
+
+    @property
+    def python_type(self):
+        if self.enum_class:
+            return self.enum_class
+        else:
+            return super(Enum, self).python_type
+
 
 class PickleType(TypeDecorator):
 
index b08556926bfe646ead0326637873ab88851eee72..bd62c4fd3b294843626779bbadccedbfee7316bf 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.sql import ddl
 from sqlalchemy.sql import visitors
 from sqlalchemy import inspection
 from sqlalchemy import exc, types, util, dialects
+from sqlalchemy.util import OrderedDict
 for name in dialects.__all__:
     __import__("sqlalchemy.dialects.%s" % name)
 from sqlalchemy.sql import operators, column, table, null
@@ -30,6 +31,21 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 
 
+class SomeEnum(object):
+    # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+    __members__ = OrderedDict()
+
+    def __init__(self, name, value):
+        self.name = name
+        self.value = value
+        self.__members__[name] = self
+        setattr(SomeEnum, name, self)
+
+SomeEnum('one', 1)
+SomeEnum('two', 2)
+SomeEnum('three', 3)
+
+
 class AdaptTest(fixtures.TestBase):
 
     def _all_dialect_modules(self):
@@ -181,6 +197,8 @@ class AdaptTest(fixtures.TestBase):
         eq_(types.String().python_type, str)
         eq_(types.Unicode().python_type, util.text_type)
         eq_(types.String(convert_unicode=True).python_type, util.text_type)
+        eq_(types.Enum('one', 'two', 'three').python_type, str)
+        eq_(types.Enum(SomeEnum).python_type, SomeEnum)
 
         assert_raises(
             NotImplementedError,
@@ -278,6 +296,7 @@ class PickleTypesTest(fixtures.TestBase):
                 Column('Pic', PickleType()),
                 Column('Int', Interval()),
                 Column('Enu', Enum('x', 'y', 'z', name="somename")),
+                Column('En2', Enum(SomeEnum)),
             ]
             for column_type in column_types:
                 meta = MetaData()
@@ -1087,41 +1106,35 @@ class UnicodeTest(fixtures.TestBase):
             unicodedata.encode('ascii', 'ignore').decode()
         )
 
-enum_table = non_native_enum_table = metadata = None
-
 
-class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
+class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
 
     @classmethod
-    def setup_class(cls):
-        global enum_table, non_native_enum_table, metadata
-        metadata = MetaData(testing.db)
-        enum_table = Table(
+    def define_tables(cls, metadata):
+        Table(
             'enum_table', metadata, Column("id", Integer, primary_key=True),
             Column('someenum', Enum('one', 'two', 'three', name='myenum'))
         )
 
-        non_native_enum_table = Table(
+        Table(
             'non_native_enum_table', metadata,
             Column("id", Integer, primary_key=True),
             Column('someenum', Enum('one', 'two', 'three', native_enum=False)),
         )
 
-        metadata.create_all()
-
-    def teardown(self):
-        enum_table.delete().execute()
-        non_native_enum_table.delete().execute()
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
+        Table(
+            'stdlib_enum_table', metadata,
+            Column("id", Integer, primary_key=True),
+            Column('someenum', Enum(SomeEnum))
+        )
 
     @testing.fails_on(
         'postgresql+zxjdbc',
         'zxjdbc fails on ENUM: column "XXX" is of type XXX '
         'but expression is of type character varying')
     def test_round_trip(self):
+        enum_table = self.tables['enum_table']
+
         enum_table.insert().execute([
             {'id': 1, 'someenum': 'two'},
             {'id': 2, 'someenum': 'two'},
@@ -1138,6 +1151,8 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
     def test_non_native_round_trip(self):
+        non_native_enum_table = self.tables['non_native_enum_table']
+
         non_native_enum_table.insert().execute([
             {'id': 1, 'someenum': 'two'},
             {'id': 2, 'someenum': 'two'},
@@ -1154,6 +1169,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
             ]
         )
 
+    def test_stdlib_enum_round_trip(self):
+        stdlib_enum_table = self.tables['stdlib_enum_table']
+
+        stdlib_enum_table.insert().execute([
+            {'id': 1, 'someenum': SomeEnum.two},
+            {'id': 2, 'someenum': SomeEnum.two},
+            {'id': 3, 'someenum': SomeEnum.one},
+        ])
+
+        eq_(
+            stdlib_enum_table.select().
+            order_by(stdlib_enum_table.c.id).execute().fetchall(),
+            [
+                (1, SomeEnum.two),
+                (2, SomeEnum.two),
+                (3, SomeEnum.one),
+            ]
+        )
+
     def test_adapt(self):
         from sqlalchemy.dialects.postgresql import ENUM
         e1 = Enum('one', 'two', 'three', native_enum=False)
@@ -1163,6 +1197,9 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
         e1 = Enum('one', 'two', 'three', name='foo', schema='bar')
         eq_(e1.adapt(ENUM).name, 'foo')
         eq_(e1.adapt(ENUM).schema, 'bar')
+        e1 = Enum(SomeEnum)
+        eq_(e1.adapt(ENUM).name, 'someenum')
+        eq_(e1.adapt(ENUM).enums, ['one', 'two', 'three'])
 
     @testing.provide_metadata
     def test_create_metadata_bound_no_crash(self):
@@ -1171,13 +1208,6 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
 
         m1.create_all(testing.db)
 
-    @testing.crashes(
-        'mysql', 'Inconsistent behavior across various OS/drivers')
-    def test_constraint(self):
-        assert_raises(
-            exc.DBAPIError, enum_table.insert().execute,
-            {'id': 4, 'someenum': 'four'})
-
     def test_non_native_constraint_custom_type(self):
         class Foob(object):
 
@@ -1209,12 +1239,9 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase):
             dialect="default"
         )
 
-    @testing.fails_on(
-        'mysql',
-        "the CHECK constraint doesn't raise an exception for unknown reason")
-    def test_non_native_constraint(self):
+    def test_lookup_failure(self):
         assert_raises(
-            exc.DBAPIError, non_native_enum_table.insert().execute,
+            exc.StatementError, self.tables['non_native_enum_table'].insert().execute,
             {'id': 4, 'someenum': 'four'}
         )