]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- generalized Enum to issue a CHECK constraint + VARCHAR on default platform
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Oct 2009 21:27:08 +0000 (21:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Oct 2009 21:27:08 +0000 (21:27 +0000)
- added native_enum=False flag to do the same on MySQL, PG, if desired

12 files changed:
CHANGES
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/dialect/test_mysql.py
test/dialect/test_postgresql.py
test/sql/test_types.py

diff --git a/CHANGES b/CHANGES
index 7f217ef1db10c4f49b74dfdee8bb69de6c3a8b35..3c9c52a3a29b6ffadb63dac7ae842193c89ff255 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -583,13 +583,12 @@ CHANGES
       type.  This means reflection now returns more accurate
       information about reflected types.  
     
-    - Added a new Enum generic type, currently supported on
-      Postgresql and MySQL.  Enum is a schema-aware object
-      to support databases which require specific DDL in 
-      order to use enum or equivalent; in the case of PG
-      it handles the details of `CREATE TYPE`, and on 
-      other databases without native enum support can 
-      support generation of CHECK constraints.
+    - Added a new Enum generic type. Enum is a schema-aware object
+      to support databases which require specific DDL in order to
+      use enum or equivalent; in the case of PG it handles the
+      details of `CREATE TYPE`, and on other databases without
+      native enum support will by generate VARCHAR + an inline CHECK
+      constraint to enforce the enum.
       [ticket:1109] [ticket:1511]
       
     - PickleType now uses == for comparison of values when
index d7ea358b549f05e51d6c91757796a2dd63350b19..e54b7687da4bd4bc34eb0d07f85877894f1000d8 100644 (file)
@@ -1351,6 +1351,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return ' '.join(colspec)
 
+    def visit_enum_constraint(self, constraint):
+        if not constraint.type.native_enum:
+            return super(MySQLDDLCompiler, self).visit_enum_constraint(constraint)
+        else:
+            return None
+
     def post_create_table(self, table):
         """Build table-level CREATE options like ENGINE and COLLATE."""
 
@@ -1576,7 +1582,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler):
         return self.visit_BLOB(type_)
     
     def visit_enum(self, type_):
-        return self.visit_ENUM(type_)
+        if not type_.native_enum:
+            return super(MySQLTypeCompiler, self).visit_enum(type_)
+        else:
+            return self.visit_ENUM(type_)
     
     def visit_BINARY(self, type_):
         if type_.length:
index 6108d3d660ff473e5e2578c723d7e9cd31e17a5b..e4d3b312b3971ea43bb4f90dd2cb3d0b056b4d74 100644 (file)
@@ -251,13 +251,18 @@ class Oracle_cx_oracleExecutionContext(OracleExecutionContext):
                 for bind, name in self.compiled.bind_names.iteritems():
                     if name in self.out_parameters:
                         type = bind.type
-                        result_processor = type.dialect_impl(self.dialect).result_processor(self.dialect)
+                        result_processor = type.dialect_impl(self.dialect).\
+                                                    result_processor(self.dialect)
                         if result_processor is not None:
-                            out_parameters[name] = result_processor(self.out_parameters[name].getvalue())
+                            out_parameters[name] = \
+                                    result_processor(self.out_parameters[name].getvalue())
                         else:
                             out_parameters[name] = self.out_parameters[name].getvalue()
             else:
-                result.out_parameters = dict((k, v.getvalue()) for k, v in self.out_parameters.items())
+                result.out_parameters = dict(
+                                            (k, v.getvalue()) 
+                                            for k, v in self.out_parameters.items()
+                                        )
 
         return result
 
index 1f4858cdd26b4d9250b73459ff32ec9e477c4149..26c4a8a971bba96dfd8a56fd33f72313a5065151 100644 (file)
@@ -330,7 +330,10 @@ class PGDDLCompiler(compiler.DDLCompiler):
     def visit_drop_sequence(self, drop):
         return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
         
-    
+    def visit_enum_constraint(self, constraint):
+        if not constraint.type.native_enum:
+            return super(PGDDLCompiler, self).visit_enum_constraint(constraint)
+            
     def visit_create_enum_type(self, create):
         type_ = create.element
         
@@ -400,7 +403,10 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
         return self.visit_TIMESTAMP(type_)
     
     def visit_enum(self, type_):
-        return self.visit_ENUM(type_)
+        if not type_.native_enum:
+            return super(PGTypeCompiler, self).visit_enum(type_)
+        else:
+            return self.visit_ENUM(type_)
         
     def visit_ENUM(self, type_):
         return self.dialect.identifier_preparer.format_type(type_)
index c25e75f2c9e9b5d5a6bf363ebb63b59f6dc1ab8b..86b2eacd35bef92293b76ed643b702b20ad20d06 100644 (file)
@@ -236,7 +236,7 @@ class SQLiteDDLCompiler(compiler.DDLCompiler):
         if not column.nullable:
             colspec += " NOT NULL"
         return colspec
-
+    
 class SQLiteTypeCompiler(compiler.GenericTypeCompiler):
     def visit_binary(self, type_):
         return self.visit_BLOB(type_)
index b99f79a8ed46ed886089a0fc5a600372332b97c0..44f53f2356c5146f3b3684400dff99cd1306a2ac 100644 (file)
@@ -28,7 +28,7 @@ expressions.
 
 """
 import re, inspect
-from sqlalchemy import types, exc, util, dialects
+from sqlalchemy import exc, util, dialects
 from sqlalchemy.sql import expression, visitors
 
 URL = None
@@ -765,12 +765,12 @@ class Column(SchemaItem, expression.ColumnClause):
             table.append_constraint(UniqueConstraint(self.key))
 
         for fn in self._table_events:
-            fn(table)
+            fn(table, self)
         del self._table_events
     
     def _on_table_attach(self, fn):
         if self.table is not None:
-            fn(self.table)
+            fn(self.table, self)
         else:
             self._table_events.add(fn)
             
@@ -819,7 +819,7 @@ class Column(SchemaItem, expression.ColumnClause):
         if self.primary_key:
             selectable.primary_key.add(c)
         for fn in c._table_events:
-            fn(selectable)
+            fn(selectable, c)
         del c._table_events
         return c
 
@@ -1032,7 +1032,7 @@ class ForeignKey(SchemaItem):
         self.parent.foreign_keys.add(self)
         self.parent._on_table_attach(self._set_table)
     
-    def _set_table(self, table):
+    def _set_table(self, table, column):
         if self.constraint is None and isinstance(table, Table):
             self.constraint = ForeignKeyConstraint(
                 [], [], use_alter=self.use_alter, name=self.name,
@@ -1181,11 +1181,9 @@ class Sequence(DefaultGenerator):
 
     def _set_parent(self, column):
         super(Sequence, self)._set_parent(column)
-#        column.sequence = self
-        
         column._on_table_attach(self._set_table)
     
-    def _set_table(self, table):
+    def _set_table(self, table, column):
         self.metadata = table.metadata
         
     @property
index c1b421843aa4817bbab4864f8c783ef7f343f30d..088ca19695b8bb49a353ad07b607dd43ff31d8f3 100644 (file)
@@ -964,7 +964,10 @@ class DDLCompiler(engine.Compiled):
         for column in table.columns:
             text += separator
             separator = ", \n"
-            text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
+            text += "\t" + self.get_column_specification(
+                                            column, 
+                                            first_pk=column.primary_key and not first_pk
+                                        )
             if column.primary_key:
                 first_pk = True
             const = " ".join(self.process(constraint) for constraint in column.constraints)
@@ -976,15 +979,18 @@ class DDLCompiler(engine.Compiled):
         if table.primary_key:
             text += ", \n\t" + self.process(table.primary_key)
         
-        const = ", \n\t".join(
-                        self.process(constraint) for constraint in table.constraints 
+        const = ", \n\t".join(p for p in 
+                        (self.process(constraint) for constraint in table.constraints 
                         if constraint is not table.primary_key
                         and constraint.inline_ddl
-                        and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+                        and (
+                            not self.dialect.supports_alter or 
+                            not getattr(constraint, 'use_alter', False)
+                        )) if p is not None
                 )
         if const:
             text += ", \n\t" + const
-
+        
         text += "\n)%s\n\n" % self.post_create_table(table)
         return text
         
@@ -1121,6 +1127,17 @@ class DDLCompiler(engine.Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
+    def visit_enum_constraint(self, constraint):
+        text = ""
+        if constraint.name is not None:
+            text += "CONSTRAINT %s " % \
+                        self.preparer.format_constraint(constraint)
+        text += " CHECK (%s IN (%s))" % (
+                    self.preparer.format_column(constraint.column),
+                    ",".join("'%s'" % x for x in constraint.type.enums)
+                )
+        return text
+
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
@@ -1247,7 +1264,7 @@ class GenericTypeCompiler(engine.TypeCompiler):
         return self.visit_TEXT(type_)
     
     def visit_enum(self, type_):
-        raise NotImplementedError("Enum not supported generically")
+        return self.visit_VARCHAR(type_)
         
     def visit_null(self, type_):
         raise NotImplementedError("Can't generate DDL for the null type")
index b71c1892b697b43adb907ed202ebee8ca789baae..9324ed6a089b28c1e699e4b9b55c4a82b0dd284a 100644 (file)
@@ -29,12 +29,12 @@ to stay the same in future releases.
 import itertools, re
 from operator import attrgetter
 
-from sqlalchemy import util, exc, types as sqltypes
+from sqlalchemy import util, exc #, types as sqltypes
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import Visitable, cloned_traverse
 import operator
 
-functions, schema, sql_util = None, None, None
+functions, schema, sql_util, sqltypes = None, None, None, None
 DefaultDialect, ClauseAdapter, Annotated = None, None, None
 
 __all__ = [
@@ -3071,7 +3071,7 @@ class TableClause(_Immutable, FromClause):
     __visit_name__ = 'table'
 
     named_with_column = True
-
+    
     def __init__(self, name, *columns):
         super(TableClause, self).__init__()
         self.name = self.fullname = name
index ba1a3f9076dd825a2daf86c8d783202783b517c9..27918e15c7049c7789d40383bbded55f23a2cdc0 100644 (file)
@@ -24,7 +24,10 @@ import inspect
 import datetime as dt
 from decimal import Decimal as _python_Decimal
 
-from sqlalchemy import exc
+from sqlalchemy import exc, schema
+from sqlalchemy.sql import expression
+import sys
+schema.types = expression.sqltypes =sys.modules['sqlalchemy.types']
 from sqlalchemy.util import pickle
 from sqlalchemy.sql.visitors import Visitable
 import sqlalchemy.util as util
@@ -809,8 +812,8 @@ class SchemaType(object):
             
     def _set_parent(self, column):
         column._on_table_attach(self._set_table)
-
-    def _set_table(self, table):
+        
+    def _set_table(self, table, column):
         table.append_ddl_listener('before-create', self._on_table_create)
         table.append_ddl_listener('after-drop', self._on_table_drop)
         if self.metadata is None:
@@ -863,9 +866,11 @@ class SchemaType(object):
 class Enum(String, SchemaType):
     """Generic Enum Type.
     
-    Currently supported on MySQL and Postgresql, the Enum type
-    provides a set of possible string values which the column is constrained
-    towards.
+    The Enum type provides a set of possible string values which the 
+    column is constrained towards.
+    
+    By default, uses the backend's native ENUM type if available, 
+    else uses VARCHAR + a CHECK constraint.
     
     Keyword arguments which don't apply to a specific backend are ignored
     by that backend.
@@ -895,6 +900,10 @@ class Enum(String, SchemaType):
         or an explicitly named constraint in order to generate the type and/or
         a table that uses it.
     
+    :param native_enum: Use the database's native ENUM type when available.
+        Defaults to True.  When False, uses VARCHAR + check constraint
+        for all backends.
+    
     :param schema: Schemaname of this type. For types that exist on the target
         database as an independent schema construct (Postgresql), this
         parameter specifies the named schema in which the type is present.
@@ -909,6 +918,7 @@ class Enum(String, SchemaType):
     
     def __init__(self, *enums, **kw):
         self.enums = enums
+        self.native_enum = kw.pop('native_enum', True)
         convert_unicode= kw.pop('convert_unicode', None)
         assert_unicode = kw.pop('assert_unicode', None)
         if convert_unicode is None:
@@ -919,11 +929,27 @@ class Enum(String, SchemaType):
             else:
                 convert_unicode = False
         
+        if self.enums:
+            length =max(len(x) for x in self.enums)
+        else:
+            length = 0
         String.__init__(self, 
+                        length =length,
                         convert_unicode=convert_unicode, 
                         assert_unicode=assert_unicode
                         )
         SchemaType.__init__(self, **kw)
+    
+    def _set_table(self, table, column):
+        if self.native_enum:
+            SchemaType._set_table(self, table, column)
+            
+        # this constraint DDL object is conditionally
+        # compiled by MySQL, Postgresql based on
+        # the native_enum flag.
+        table.append_constraint(
+            EnumConstraint(self, column)
+        )
         
     def adapt(self, impltype):
         return impltype(name=self.name, 
@@ -935,6 +961,14 @@ class Enum(String, SchemaType):
                         *self.enums
                         )
 
+class EnumConstraint(schema.CheckConstraint):
+    __visit_name__ = 'enum_constraint'
+    
+    def __init__(self, type_, column, **kw):
+        super(EnumConstraint, self).__init__('', name=type_.name, **kw)
+        self.type = type_
+        self.column = column
+    
 class PickleType(MutableType, TypeDecorator):
     """Holds Python objects.
 
index 64f65d8f6f06006befae796e8a9bd9f4e3394548..49dde1520f03fab8120b5fa9eea1e84392bf9b66 100644 (file)
@@ -7,18 +7,19 @@ import sets
 # end Py2K
 
 from sqlalchemy import *
-from sqlalchemy import sql, exc
+from sqlalchemy import sql, exc, schema
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.test.testing import eq_
 from sqlalchemy.test import *
 from sqlalchemy.test.engines import utf8_engine
 
 
-class TypesTest(TestBase, AssertsExecutionResults):
+class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
     "Test MySQL column types"
 
     __only_on__ = 'mysql'
-
+    __dialect__ = mysql.dialect()
+    
     @testing.uses_deprecated('Manually quoting ENUM value literals')
     def test_basic(self):
         meta1 = MetaData(testing.db)
@@ -643,6 +644,23 @@ class TypesTest(TestBase, AssertsExecutionResults):
         finally:
             metadata.drop_all()
         
+    def test_enum_compile(self):
+        e1 = Enum('x', 'y', 'z', name="somename")
+        t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
+        self.assert_compile(
+            schema.CreateTable(t1),
+            "CREATE TABLE sometable (somecolumn ENUM('x','y','z'))"
+        )
+        t1 = Table('sometable', MetaData(), 
+                    Column('somecolumn', Enum('x', 'y', 'z', native_enum=False))
+                )
+        self.assert_compile(
+            schema.CreateTable(t1),
+            "CREATE TABLE sometable ("
+            "somecolumn VARCHAR(1), "
+            " CHECK (somecolumn IN ('x','y','z'))"
+            ")"
+        )
         
     @testing.exclude('mysql', '<', (4,), "3.23 can't handle an ENUM of ''")
     @testing.uses_deprecated('Manually quoting ENUM value literals')
index 4e9a324d449f879d8d894091186a00c92022ca6c..626d5467707f0d93d01bb86ec738e026be1b25f4 100644 (file)
@@ -132,6 +132,25 @@ class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             postgresql.DropEnumType(e2), 
             "DROP TYPE someschema.somename"
         )
+        
+        t1 = Table('sometable', MetaData(), Column('somecolumn', e1))
+        self.assert_compile(
+            schema.CreateTable(t1),
+            "CREATE TABLE sometable ("
+            "somecolumn somename"
+            ")"
+        )
+        t1 = Table('sometable', MetaData(), 
+                    Column('somecolumn', Enum('x', 'y', 'z', native_enum=False))
+                )
+        self.assert_compile(
+            schema.CreateTable(t1),
+            "CREATE TABLE sometable ("
+            "somecolumn VARCHAR(1), "
+            " CHECK (somecolumn IN ('x','y','z'))"
+            ")"
+        )
+
     
     @testing.fails_on('postgresql+zxjdbc', 
                         'zxjdbc fails on ENUM: column "XXX" is of type XXX '
index c844cf696e4f8cec9625c32ee1d5ad20ee227b53..51dd4c12b2fb217129d375fcdb0ee5c04827bf05 100644 (file)
@@ -329,6 +329,87 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         
         assert uni(unicodedata) == unicodedata.encode('utf-8')
 
+class EnumTest(TestBase):
+    @classmethod
+    def setup_class(cls):
+        global enum_table, non_native_enum_table, metadata
+        metadata = MetaData(testing.db)
+        enum_table = Table('enum_table', metadata,
+            Column("id", Integer, primary_key=True),
+            Column('someenum', Enum('one','two','three', name='myenum'))
+        )
+
+        non_native_enum_table = Table('non_native_enum_table', metadata,
+            Column("id", Integer, primary_key=True),
+            Column('someenum', Enum('one','two','three', native_enum=False)),
+        )
+
+        metadata.create_all()
+    
+    def teardown(self):
+        enum_table.delete().execute()
+        non_native_enum_table.delete().execute()
+        
+    @classmethod
+    def teardown_class(cls):
+        metadata.drop_all()
+
+    @testing.fails_on('postgresql+zxjdbc', 
+                        'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+                        'but expression is of type character varying')
+    @testing.fails_on('postgresql+pg8000', 
+                        'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+                        'but expression is of type text')
+    def test_round_trip(self):
+        enum_table.insert().execute([
+            {'id':1, 'someenum':'two'},
+            {'id':2, 'someenum':'two'},
+            {'id':3, 'someenum':'one'},
+        ])
+        
+        eq_(
+            enum_table.select().order_by(enum_table.c.id).execute().fetchall(), 
+            [
+                (1, 'two'),
+                (2, 'two'),
+                (3, 'one'),
+            ]
+        )
+
+    def test_non_native_round_trip(self):
+        non_native_enum_table.insert().execute([
+            {'id':1, 'someenum':'two'},
+            {'id':2, 'someenum':'two'},
+            {'id':3, 'someenum':'one'},
+        ])
+
+        eq_(
+            non_native_enum_table.select().
+                    order_by(non_native_enum_table.c.id).execute().fetchall(), 
+            [
+                (1, 'two'),
+                (2, 'two'),
+                (3, 'one'),
+            ]
+        )
+
+    @testing.fails_on('postgresql+zxjdbc', 
+                        'zxjdbc fails on ENUM: column "XXX" is of type XXX '
+                        'but expression is of type character varying')
+    @testing.fails_on('mysql', "MySQL seems to issue a 'data truncated' warning.")
+    def test_constraint(self):
+        assert_raises(exc.DBAPIError, 
+            enum_table.insert().execute,
+            {'id':4, 'someenum':'four'}
+        )
+
+    @testing.fails_on('mysql', "the CHECK constraint doesn't raise an exception for unknown reason")
+    def test_non_native_constraint(self):
+        assert_raises(exc.DBAPIError, 
+            non_native_enum_table.insert().execute,
+            {'id':4, 'someenum':'four'}
+        )
+        
 class BinaryTest(TestBase, AssertsExecutionResults):
     __excluded_on__ = (
         ('mysql', '<', (4, 1, 1)),  # screwy varbinary types