From eb9763febe58655ca0f61fa758925c56b94ece9b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 25 Oct 2009 21:27:08 +0000 Subject: [PATCH] - generalized Enum to issue a CHECK constraint + VARCHAR on default platform - added native_enum=False flag to do the same on MySQL, PG, if desired --- CHANGES | 13 ++-- lib/sqlalchemy/dialects/mysql/base.py | 11 ++- lib/sqlalchemy/dialects/oracle/cx_oracle.py | 11 ++- lib/sqlalchemy/dialects/postgresql/base.py | 10 ++- lib/sqlalchemy/dialects/sqlite/base.py | 2 +- lib/sqlalchemy/schema.py | 14 ++-- lib/sqlalchemy/sql/compiler.py | 29 ++++++-- lib/sqlalchemy/sql/expression.py | 6 +- lib/sqlalchemy/types.py | 46 ++++++++++-- test/dialect/test_mysql.py | 24 +++++- test/dialect/test_postgresql.py | 19 +++++ test/sql/test_types.py | 81 +++++++++++++++++++++ 12 files changed, 226 insertions(+), 40 deletions(-) diff --git a/CHANGES b/CHANGES index 7f217ef1db..3c9c52a3a2 100644 --- a/CHANGES +++ b/CHANGES @@ -583,13 +583,12 @@ CHANGES type. This means reflection now returns more accurate information about reflected types. - - Added a new Enum generic type, currently supported on - Postgresql and MySQL. Enum is a schema-aware object - to support databases which require specific DDL in - order to use enum or equivalent; in the case of PG - it handles the details of `CREATE TYPE`, and on - other databases without native enum support can - support generation of CHECK constraints. + - Added a new Enum generic type. Enum is a schema-aware object + to support databases which require specific DDL in order to + use enum or equivalent; in the case of PG it handles the + details of `CREATE TYPE`, and on other databases without + native enum support will by generate VARCHAR + an inline CHECK + constraint to enforce the enum. [ticket:1109] [ticket:1511] - PickleType now uses == for comparison of values when diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index d7ea358b54..e54b7687da 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1351,6 +1351,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return ' '.join(colspec) + def visit_enum_constraint(self, constraint): + if not constraint.type.native_enum: + return super(MySQLDDLCompiler, self).visit_enum_constraint(constraint) + else: + return None + def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" @@ -1576,7 +1582,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return self.visit_BLOB(type_) def visit_enum(self, type_): - return self.visit_ENUM(type_) + if not type_.native_enum: + return super(MySQLTypeCompiler, self).visit_enum(type_) + else: + return self.visit_ENUM(type_) def visit_BINARY(self, type_): if type_.length: diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 6108d3d660..e4d3b312b3 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -251,13 +251,18 @@ class Oracle_cx_oracleExecutionContext(OracleExecutionContext): for bind, name in self.compiled.bind_names.iteritems(): if name in self.out_parameters: type = bind.type - result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect) + result_processor = type.dialect_impl(self.dialect).\ + result_processor(self.dialect) if result_processor is not None: - out_parameters[name] = result_processor(self.out_parameters[name].getvalue()) + out_parameters[name] = \ + result_processor(self.out_parameters[name].getvalue()) else: out_parameters[name] = self.out_parameters[name].getvalue() else: - result.out_parameters = dict((k, v.getvalue()) for k, v in self.out_parameters.items()) + result.out_parameters = dict( + (k, v.getvalue()) + for k, v in self.out_parameters.items() + ) return result diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 1f4858cdd2..26c4a8a971 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -330,7 +330,10 @@ class PGDDLCompiler(compiler.DDLCompiler): def visit_drop_sequence(self, drop): return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) - + def visit_enum_constraint(self, constraint): + if not constraint.type.native_enum: + return super(PGDDLCompiler, self).visit_enum_constraint(constraint) + def visit_create_enum_type(self, create): type_ = create.element @@ -400,7 +403,10 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): return self.visit_TIMESTAMP(type_) def visit_enum(self, type_): - return self.visit_ENUM(type_) + if not type_.native_enum: + return super(PGTypeCompiler, self).visit_enum(type_) + else: + return self.visit_ENUM(type_) def visit_ENUM(self, type_): return self.dialect.identifier_preparer.format_type(type_) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c25e75f2c9..86b2eacd35 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -236,7 +236,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec += " NOT NULL" return colspec - + class SQLiteTypeCompiler(compiler.GenericTypeCompiler): def visit_binary(self, type_): return self.visit_BLOB(type_) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index b99f79a8ed..44f53f2356 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -28,7 +28,7 @@ expressions. """ import re, inspect -from sqlalchemy import types, exc, util, dialects +from sqlalchemy import exc, util, dialects from sqlalchemy.sql import expression, visitors URL = None @@ -765,12 +765,12 @@ class Column(SchemaItem, expression.ColumnClause): table.append_constraint(UniqueConstraint(self.key)) for fn in self._table_events: - fn(table) + fn(table, self) del self._table_events def _on_table_attach(self, fn): if self.table is not None: - fn(self.table) + fn(self.table, self) else: self._table_events.add(fn) @@ -819,7 +819,7 @@ class Column(SchemaItem, expression.ColumnClause): if self.primary_key: selectable.primary_key.add(c) for fn in c._table_events: - fn(selectable) + fn(selectable, c) del c._table_events return c @@ -1032,7 +1032,7 @@ class ForeignKey(SchemaItem): self.parent.foreign_keys.add(self) self.parent._on_table_attach(self._set_table) - def _set_table(self, table): + def _set_table(self, table, column): if self.constraint is None and isinstance(table, Table): self.constraint = ForeignKeyConstraint( [], [], use_alter=self.use_alter, name=self.name, @@ -1181,11 +1181,9 @@ class Sequence(DefaultGenerator): def _set_parent(self, column): super(Sequence, self)._set_parent(column) -# column.sequence = self - column._on_table_attach(self._set_table) - def _set_table(self, table): + def _set_table(self, table, column): self.metadata = table.metadata @property diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c1b421843a..088ca19695 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -964,7 +964,10 @@ class DDLCompiler(engine.Compiled): for column in table.columns: text += separator separator = ", \n" - text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk) + text += "\t" + self.get_column_specification( + column, + first_pk=column.primary_key and not first_pk + ) if column.primary_key: first_pk = True const = " ".join(self.process(constraint) for constraint in column.constraints) @@ -976,15 +979,18 @@ class DDLCompiler(engine.Compiled): if table.primary_key: text += ", \n\t" + self.process(table.primary_key) - const = ", \n\t".join( - self.process(constraint) for constraint in table.constraints + const = ", \n\t".join(p for p in + (self.process(constraint) for constraint in table.constraints if constraint is not table.primary_key and constraint.inline_ddl - and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False)) + and ( + not self.dialect.supports_alter or + not getattr(constraint, 'use_alter', False) + )) if p is not None ) if const: text += ", \n\t" + const - + text += "\n)%s\n\n" % self.post_create_table(table) return text @@ -1121,6 +1127,17 @@ class DDLCompiler(engine.Compiled): text += self.define_constraint_deferrability(constraint) return text + def visit_enum_constraint(self, constraint): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % \ + self.preparer.format_constraint(constraint) + text += " CHECK (%s IN (%s))" % ( + self.preparer.format_column(constraint.column), + ",".join("'%s'" % x for x in constraint.type.enums) + ) + return text + def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -1247,7 +1264,7 @@ class GenericTypeCompiler(engine.TypeCompiler): return self.visit_TEXT(type_) def visit_enum(self, type_): - raise NotImplementedError("Enum not supported generically") + return self.visit_VARCHAR(type_) def visit_null(self, type_): raise NotImplementedError("Can't generate DDL for the null type") diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index b71c1892b6..9324ed6a08 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -29,12 +29,12 @@ to stay the same in future releases. import itertools, re from operator import attrgetter -from sqlalchemy import util, exc, types as sqltypes +from sqlalchemy import util, exc #, types as sqltypes from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import Visitable, cloned_traverse import operator -functions, schema, sql_util = None, None, None +functions, schema, sql_util, sqltypes = None, None, None, None DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ @@ -3071,7 +3071,7 @@ class TableClause(_Immutable, FromClause): __visit_name__ = 'table' named_with_column = True - + def __init__(self, name, *columns): super(TableClause, self).__init__() self.name = self.fullname = name diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index ba1a3f9076..27918e15c7 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -24,7 +24,10 @@ import inspect import datetime as dt from decimal import Decimal as _python_Decimal -from sqlalchemy import exc +from sqlalchemy import exc, schema +from sqlalchemy.sql import expression +import sys +schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] from sqlalchemy.util import pickle from sqlalchemy.sql.visitors import Visitable import sqlalchemy.util as util @@ -809,8 +812,8 @@ class SchemaType(object): def _set_parent(self, column): column._on_table_attach(self._set_table) - - def _set_table(self, table): + + def _set_table(self, table, column): table.append_ddl_listener('before-create', self._on_table_create) table.append_ddl_listener('after-drop', self._on_table_drop) if self.metadata is None: @@ -863,9 +866,11 @@ class SchemaType(object): class Enum(String, SchemaType): """Generic Enum Type. - Currently supported on MySQL and Postgresql, the Enum type - provides a set of possible string values which the column is constrained - towards. + The Enum type provides a set of possible string values which the + column is constrained towards. + + By default, uses the backend's native ENUM type if available, + else uses VARCHAR + a CHECK constraint. Keyword arguments which don't apply to a specific backend are ignored by that backend. @@ -895,6 +900,10 @@ class Enum(String, SchemaType): or an explicitly named constraint in order to generate the type and/or a table that uses it. + :param native_enum: Use the database's native ENUM type when available. + Defaults to True. When False, uses VARCHAR + check constraint + for all backends. + :param schema: Schemaname of this type. For types that exist on the target database as an independent schema construct (Postgresql), this parameter specifies the named schema in which the type is present. @@ -909,6 +918,7 @@ class Enum(String, SchemaType): def __init__(self, *enums, **kw): self.enums = enums + self.native_enum = kw.pop('native_enum', True) convert_unicode= kw.pop('convert_unicode', None) assert_unicode = kw.pop('assert_unicode', None) if convert_unicode is None: @@ -919,11 +929,27 @@ class Enum(String, SchemaType): else: convert_unicode = False + if self.enums: + length =max(len(x) for x in self.enums) + else: + length = 0 String.__init__(self, + length =length, convert_unicode=convert_unicode, assert_unicode=assert_unicode ) SchemaType.__init__(self, **kw) + + def _set_table(self, table, column): + if self.native_enum: + SchemaType._set_table(self, table, column) + + # this constraint DDL object is conditionally + # compiled by MySQL, Postgresql based on + # the native_enum flag. + table.append_constraint( + EnumConstraint(self, column) + ) def adapt(self, impltype): return impltype(name=self.name, @@ -935,6 +961,14 @@ class Enum(String, SchemaType): *self.enums ) +class EnumConstraint(schema.CheckConstraint): + __visit_name__ = 'enum_constraint' + + def __init__(self, type_, column, **kw): + super(EnumConstraint, self).__init__('', name=type_.name, **kw) + self.type = type_ + self.column = column + class PickleType(MutableType, TypeDecorator): """Holds Python objects. diff --git a/test/dialect/test_mysql.py b/test/dialect/test_mysql.py index 64f65d8f6f..49dde1520f 100644 --- a/test/dialect/test_mysql.py +++ b/test/dialect/test_mysql.py @@ -7,18 +7,19 @@ import sets # end Py2K from sqlalchemy import * -from sqlalchemy import sql, exc +from sqlalchemy import sql, exc, schema from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.test.testing import eq_ from sqlalchemy.test import * from sqlalchemy.test.engines import utf8_engine -class TypesTest(TestBase, AssertsExecutionResults): +class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): "Test MySQL column types" __only_on__ = 'mysql' - + __dialect__ = mysql.dialect() + @testing.uses_deprecated('Manually quoting ENUM value literals') def test_basic(self): meta1 = MetaData(testing.db) @@ -643,6 +644,23 @@ class TypesTest(TestBase, AssertsExecutionResults): finally: metadata.drop_all() + def test_enum_compile(self): + e1 = Enum('x', 'y', 'z', name="somename") + t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (somecolumn ENUM('x','y','z'))" + ) + t1 = Table('sometable', MetaData(), + Column('somecolumn', Enum('x', 'y', 'z', native_enum=False)) + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (" + "somecolumn VARCHAR(1), " + " CHECK (somecolumn IN ('x','y','z'))" + ")" + ) @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''") @testing.uses_deprecated('Manually quoting ENUM value literals') diff --git a/test/dialect/test_postgresql.py b/test/dialect/test_postgresql.py index 4e9a324d44..626d546770 100644 --- a/test/dialect/test_postgresql.py +++ b/test/dialect/test_postgresql.py @@ -132,6 +132,25 @@ class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL): postgresql.DropEnumType(e2), "DROP TYPE someschema.somename" ) + + t1 = Table('sometable', MetaData(), Column('somecolumn', e1)) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (" + "somecolumn somename" + ")" + ) + t1 = Table('sometable', MetaData(), + Column('somecolumn', Enum('x', 'y', 'z', native_enum=False)) + ) + self.assert_compile( + schema.CreateTable(t1), + "CREATE TABLE sometable (" + "somecolumn VARCHAR(1), " + " CHECK (somecolumn IN ('x','y','z'))" + ")" + ) + @testing.fails_on('postgresql+zxjdbc', 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' diff --git a/test/sql/test_types.py b/test/sql/test_types.py index c844cf696e..51dd4c12b2 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -329,6 +329,87 @@ class UnicodeTest(TestBase, AssertsExecutionResults): assert uni(unicodedata) == unicodedata.encode('utf-8') +class EnumTest(TestBase): + @classmethod + def setup_class(cls): + global enum_table, non_native_enum_table, metadata + metadata = MetaData(testing.db) + enum_table = Table('enum_table', metadata, + Column("id", Integer, primary_key=True), + Column('someenum', Enum('one','two','three', name='myenum')) + ) + + non_native_enum_table = Table('non_native_enum_table', metadata, + Column("id", Integer, primary_key=True), + Column('someenum', Enum('one','two','three', native_enum=False)), + ) + + metadata.create_all() + + def teardown(self): + enum_table.delete().execute() + non_native_enum_table.delete().execute() + + @classmethod + def teardown_class(cls): + metadata.drop_all() + + @testing.fails_on('postgresql+zxjdbc', + 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' + 'but expression is of type character varying') + @testing.fails_on('postgresql+pg8000', + 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' + 'but expression is of type text') + def test_round_trip(self): + enum_table.insert().execute([ + {'id':1, 'someenum':'two'}, + {'id':2, 'someenum':'two'}, + {'id':3, 'someenum':'one'}, + ]) + + eq_( + enum_table.select().order_by(enum_table.c.id).execute().fetchall(), + [ + (1, 'two'), + (2, 'two'), + (3, 'one'), + ] + ) + + def test_non_native_round_trip(self): + non_native_enum_table.insert().execute([ + {'id':1, 'someenum':'two'}, + {'id':2, 'someenum':'two'}, + {'id':3, 'someenum':'one'}, + ]) + + eq_( + non_native_enum_table.select(). + order_by(non_native_enum_table.c.id).execute().fetchall(), + [ + (1, 'two'), + (2, 'two'), + (3, 'one'), + ] + ) + + @testing.fails_on('postgresql+zxjdbc', + 'zxjdbc fails on ENUM: column "XXX" is of type XXX ' + 'but expression is of type character varying') + @testing.fails_on('mysql', "MySQL seems to issue a 'data truncated' warning.") + def test_constraint(self): + assert_raises(exc.DBAPIError, + enum_table.insert().execute, + {'id':4, 'someenum':'four'} + ) + + @testing.fails_on('mysql', "the CHECK constraint doesn't raise an exception for unknown reason") + def test_non_native_constraint(self): + assert_raises(exc.DBAPIError, + non_native_enum_table.insert().execute, + {'id':4, 'someenum':'four'} + ) + class BinaryTest(TestBase, AssertsExecutionResults): __excluded_on__ = ( ('mysql', '<', (4, 1, 1)), # screwy varbinary types -- 2.47.2