From: Alex Grönholm Date: Tue, 2 Feb 2016 19:20:17 +0000 (-0500) Subject: - Initial implementation of support for PEP-435 enumerated types X-Git-Tag: rel_1_1_0b1~98^2~56 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5401c4d8514aa42e8ac4b5579454e68151e78a93;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Initial implementation of support for PEP-435 enumerated types within the Enum type. --- diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index d198d3912c..21e57d519b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1103,7 +1103,8 @@ class Enum(String, SchemaType): Keyword arguments which don't apply to a specific backend are ignored by that backend. - :param \*enums: string or unicode enumeration labels. If unicode + :param \*enums: either exactly one PEP 435 compliant enumerated type + or one or more string or unicode enumeration labels. If unicode labels are present, the `convert_unicode` flag is auto-enabled. :param convert_unicode: Enable unicode-aware bind parameter and @@ -1124,7 +1125,9 @@ class Enum(String, SchemaType): :param name: The name of this type. This is required for Postgresql and any future supported database which requires an explicitly named type, or an explicitly named constraint in order to generate - the type and/or a table that uses it. + the type and/or a table that uses it. If an :class:`~enum.Enum` + class was used, its name (converted to lower case) is used by + default. :param native_enum: Use the database's native ENUM type when available. Defaults to True. When False, uses VARCHAR + check @@ -1153,11 +1156,20 @@ class Enum(String, SchemaType): .. versionadded:: 0.8 """ - self.enums = enums + if len(enums) == 1 and hasattr(enums[0], '__members__'): + self.enums = list(enums[0].__members__) + self.enum_class = enums[0] + kw.setdefault('name', enums[0].__name__.lower()) + self.key_lookup = dict((value, key) for key, value in enums[0].__members__.items()) + self.value_lookup = enums[0].__members__.copy() + else: + self.enums = enums + self.enum_class = self.key_lookup = self.value_lookup = None + self.native_enum = kw.pop('native_enum', True) convert_unicode = kw.pop('convert_unicode', None) if convert_unicode is None: - for e in enums: + for e in self.enums: if isinstance(e, util.text_type): convert_unicode = True break @@ -1203,6 +1215,7 @@ class Enum(String, SchemaType): metadata = kw.pop('metadata', self.metadata) _create_events = kw.pop('_create_events', False) if issubclass(impltype, Enum): + args = [self.enum_class] if self.enum_class is not None else self.enums return impltype(name=self.name, schema=schema, metadata=metadata, @@ -1210,12 +1223,66 @@ class Enum(String, SchemaType): native_enum=self.native_enum, inherit_schema=self.inherit_schema, _create_events=_create_events, - *self.enums, + *args, **kw) else: # TODO: why would we be here? return super(Enum, self).adapt(impltype, **kw) + def literal_processor(self, dialect): + parent_processor = super(Enum, self).literal_processor(dialect) + if self.key_lookup: + def process(value): + value = self.key_lookup.get(value, value) + if parent_processor: + return parent_processor(value) + + return process + else: + return parent_processor + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, util.string_types): + if value not in self.enums: + raise LookupError( + '"%s" is not among the defined enum values' % + value) + elif self.key_lookup and value in self.key_lookup: + value = self.key_lookup[value] + + if parent_processor: + value = parent_processor(value) + return value + + parent_processor = super(Enum, self).bind_processor(dialect) + return process + + def result_processor(self, dialect, coltype): + parent_processor = super(Enum, self).result_processor(dialect, + coltype) + if self.value_lookup: + def process(value): + if parent_processor: + value = parent_processor(value) + + try: + return self.value_lookup[value] + except KeyError: + raise LookupError('No such member in enum class %s: %s' % + (self.enum_class.__name__, value)) + + return process + else: + return parent_processor + + @property + def python_type(self): + if self.enum_class: + return self.enum_class + else: + return super(Enum, self).python_type + class PickleType(TypeDecorator): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index b08556926b..bd62c4fd3b 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -15,6 +15,7 @@ from sqlalchemy.sql import ddl from sqlalchemy.sql import visitors from sqlalchemy import inspection from sqlalchemy import exc, types, util, dialects +from sqlalchemy.util import OrderedDict for name in dialects.__all__: __import__("sqlalchemy.dialects.%s" % name) from sqlalchemy.sql import operators, column, table, null @@ -30,6 +31,21 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +class SomeEnum(object): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value): + self.name = name + self.value = value + self.__members__[name] = self + setattr(SomeEnum, name, self) + +SomeEnum('one', 1) +SomeEnum('two', 2) +SomeEnum('three', 3) + + class AdaptTest(fixtures.TestBase): def _all_dialect_modules(self): @@ -181,6 +197,8 @@ class AdaptTest(fixtures.TestBase): eq_(types.String().python_type, str) eq_(types.Unicode().python_type, util.text_type) eq_(types.String(convert_unicode=True).python_type, util.text_type) + eq_(types.Enum('one', 'two', 'three').python_type, str) + eq_(types.Enum(SomeEnum).python_type, SomeEnum) assert_raises( NotImplementedError, @@ -278,6 +296,7 @@ class PickleTypesTest(fixtures.TestBase): Column('Pic', PickleType()), Column('Int', Interval()), Column('Enu', Enum('x', 'y', 'z', name="somename")), + Column('En2', Enum(SomeEnum)), ] for column_type in column_types: meta = MetaData() @@ -1087,41 +1106,35 @@ class UnicodeTest(fixtures.TestBase): unicodedata.encode('ascii', 'ignore').decode() ) -enum_table = non_native_enum_table = metadata = None - -class EnumTest(AssertsCompiledSQL, fixtures.TestBase): +class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): @classmethod - def setup_class(cls): - global enum_table, non_native_enum_table, metadata - metadata = MetaData(testing.db) - enum_table = Table( + def define_tables(cls, metadata): + Table( 'enum_table', metadata, Column("id", Integer, primary_key=True), Column('someenum', Enum('one', 'two', 'three', name='myenum')) ) - non_native_enum_table = Table( + Table( 'non_native_enum_table', metadata, Column("id", Integer, primary_key=True), Column('someenum', Enum('one', 'two', 'three', native_enum=False)), ) - metadata.create_all() - - def teardown(self): - enum_table.delete().execute() - non_native_enum_table.delete().execute() - - @classmethod - def teardown_class(cls): - metadata.drop_all() + Table( + 'stdlib_enum_table', metadata, + Column("id", Integer, primary_key=True), + Column('someenum', Enum(SomeEnum)) + ) @testing.fails_on( 'postgresql+zxjdbc', 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' 'but expression is of type character varying') def test_round_trip(self): + enum_table = self.tables['enum_table'] + enum_table.insert().execute([ {'id': 1, 'someenum': 'two'}, {'id': 2, 'someenum': 'two'}, @@ -1138,6 +1151,8 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase): ) def test_non_native_round_trip(self): + non_native_enum_table = self.tables['non_native_enum_table'] + non_native_enum_table.insert().execute([ {'id': 1, 'someenum': 'two'}, {'id': 2, 'someenum': 'two'}, @@ -1154,6 +1169,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase): ] ) + def test_stdlib_enum_round_trip(self): + stdlib_enum_table = self.tables['stdlib_enum_table'] + + stdlib_enum_table.insert().execute([ + {'id': 1, 'someenum': SomeEnum.two}, + {'id': 2, 'someenum': SomeEnum.two}, + {'id': 3, 'someenum': SomeEnum.one}, + ]) + + eq_( + stdlib_enum_table.select(). + order_by(stdlib_enum_table.c.id).execute().fetchall(), + [ + (1, SomeEnum.two), + (2, SomeEnum.two), + (3, SomeEnum.one), + ] + ) + def test_adapt(self): from sqlalchemy.dialects.postgresql import ENUM e1 = Enum('one', 'two', 'three', native_enum=False) @@ -1163,6 +1197,9 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase): e1 = Enum('one', 'two', 'three', name='foo', schema='bar') eq_(e1.adapt(ENUM).name, 'foo') eq_(e1.adapt(ENUM).schema, 'bar') + e1 = Enum(SomeEnum) + eq_(e1.adapt(ENUM).name, 'someenum') + eq_(e1.adapt(ENUM).enums, ['one', 'two', 'three']) @testing.provide_metadata def test_create_metadata_bound_no_crash(self): @@ -1171,13 +1208,6 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase): m1.create_all(testing.db) - @testing.crashes( - 'mysql', 'Inconsistent behavior across various OS/drivers') - def test_constraint(self): - assert_raises( - exc.DBAPIError, enum_table.insert().execute, - {'id': 4, 'someenum': 'four'}) - def test_non_native_constraint_custom_type(self): class Foob(object): @@ -1209,12 +1239,9 @@ class EnumTest(AssertsCompiledSQL, fixtures.TestBase): dialect="default" ) - @testing.fails_on( - 'mysql', - "the CHECK constraint doesn't raise an exception for unknown reason") - def test_non_native_constraint(self): + def test_lookup_failure(self): assert_raises( - exc.DBAPIError, non_native_enum_table.insert().execute, + exc.StatementError, self.tables['non_native_enum_table'].insert().execute, {'id': 4, 'someenum': 'four'} )