]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- reworked the DDL generation of ENUM and similar to be more platform agnostic.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 6 Dec 2009 19:51:10 +0000 (19:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 6 Dec 2009 19:51:10 +0000 (19:51 +0000)
Uses a straight CheckConstraint with a generic expression.  Preparing for boolean
constraint in [ticket:1589]
- CheckConstraint now accepts SQL expressions, though support for quoting of values
will be very limited.  we don't want to get into formatting dates and such.

13 files changed:
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/types.py
test/dialect/test_mysql.py
test/dialect/test_postgresql.py
test/dialect/test_sqlite.py
test/engine/test_metadata.py
test/sql/test_constraints.py

index d5ee4f5bb4c329dbc95f768e6c0e7060abfbbe82..3a277389250902e9ebe52062a821be2feb83ce41 100644 (file)
@@ -1354,12 +1354,6 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
 
         return ' '.join(colspec)
 
-    def visit_enum_constraint(self, constraint):
-        if not constraint.type.native_enum:
-            return super(MySQLDDLCompiler, self).visit_enum_constraint(constraint)
-        else:
-            return None
-
     def post_create_table(self, table):
         """Build table-level CREATE options like ENGINE and COLLATE."""
 
@@ -1661,6 +1655,9 @@ class MySQLDialect(default.DefaultDialect):
     # identifiers are 64, however aliases can be 255...
     max_identifier_length = 255
     
+    supports_native_enum = True
+    supports_native_boolean = True
+    
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
     
index e31cc7d91b8707dfb98e69eb98f8a6bd96f70e01..d0a87d28232a3b7f05935bf8b8187743899e8657 100644 (file)
@@ -74,7 +74,7 @@ import re
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.engine import base, default, reflection
-from sqlalchemy.sql import compiler, expression
+from sqlalchemy.sql import compiler, expression, util as sql_util
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy import types as sqltypes
 
@@ -348,10 +348,6 @@ class PGDDLCompiler(compiler.DDLCompiler):
             colspec += " NOT NULL"
         return colspec
 
-    def visit_enum_constraint(self, constraint):
-        if not constraint.type.native_enum:
-            return super(PGDDLCompiler, self).visit_enum_constraint(constraint)
-            
     def visit_create_enum_type(self, create):
         type_ = create.element
         
@@ -387,11 +383,9 @@ class PGDDLCompiler(compiler.DDLCompiler):
             whereclause = None
             
         if whereclause is not None:
-            compiler = self._compile(whereclause, None)
-            # this might belong to the compiler class
-            inlined_clause = str(compiler) % dict(
-                [(key,bind.value) for key,bind in compiler.binds.iteritems()])
-            text += " WHERE " + inlined_clause
+            whereclause = sql_util.expression_as_ddl(whereclause)
+            where_compiled = self.sql_compiler.process(whereclause)
+            text += " WHERE " + where_compiled
         return text
 
 
@@ -530,6 +524,8 @@ class PGDialect(default.DefaultDialect):
     max_identifier_length = 63
     supports_sane_rowcount = True
     
+    supports_native_enum = True
+    
     supports_sequences = True
     sequences_optional = True
     preexecute_autoincrement_sequences = True
index 3ea52cd725176cee4b57fbf4d3e972c1aed16747..ddf2602c2d13c39a4724f3e637262ed62402b73f 100644 (file)
@@ -135,6 +135,27 @@ class Dialect(object):
     supports_default_values
       Indicates if the construct ``INSERT INTO tablename DEFAULT
       VALUES`` is supported
+    
+    supports_sequences
+      Indicates if the dialect supports CREATE SEQUENCE or similar.
+    
+    sequences_optional
+      If True, indicates if the "optional" flag on the Sequence() construct
+      should signal to not generate a CREATE SEQUENCE. Applies only to
+      dialects that support sequences. Currently used only to allow Postgresql
+      SERIAL to be used on a column that specifies Sequence() for usage on
+      other backends.
+        
+    supports_native_enum
+      Indicates if the dialect supports a native ENUM construct.
+      This will prevent types.Enum from generating a CHECK
+      constraint when that type is used.
+
+    supports_native_boolean
+      Indicates if the dialect supports a native boolean construct.
+      This will prevent types.Boolean from generating a CHECK
+      constraint when that type is used.
+      
     """
 
     def create_connect_args(self, url):
index ca5106c3411c7e5ae039a9a765cfff7118c9768f..41470f35998a9b5a32b579b1be815622b2de7662 100644 (file)
@@ -36,6 +36,9 @@ class DefaultDialect(base.Dialect):
     postfetch_lastrowid = True
     implicit_returning = False
     
+    supports_native_enum = False
+    supports_native_boolean = False
+    
     # Py3K
     #supports_unicode_statements = True
     #supports_unicode_binds = True
index 2e5a1a6371e94fa34d24108ab401eb796bde7b25..70087ee739206feb63c083a332efbeed1aefa13f 100644 (file)
@@ -1261,7 +1261,8 @@ class Constraint(SchemaItem):
 
     __visit_name__ = 'constraint'
 
-    def __init__(self, name=None, deferrable=None, initially=None, inline_ddl=True):
+    def __init__(self, name=None, deferrable=None, initially=None, 
+                            _create_rule=None):
         """Create a SQL constraint.
 
         name
@@ -1275,20 +1276,29 @@ class Constraint(SchemaItem):
           Optional string.  If set, emit INITIALLY <value> when issuing DDL
           for this constraint.
           
-        inline_ddl
-          if True, DDL for this Constraint will be generated within the span of a
-          CREATE TABLE or DROP TABLE statement, when the associated table's
-          DDL is generated.  if False, no DDL is issued within that process.
-          Instead, it is expected that an AddConstraint or DropConstraint 
-          construct will be used to issue DDL for this Contraint.
-          The AddConstraint/DropConstraint constructs set this flag automatically
-          as well.
+        _create_rule
+          a callable which is passed the DDLCompiler object during
+          compilation. Returns True or False to signal inline generation of
+          this Constraint.
+
+          The AddConstraint and DropConstraint DDL constructs provide
+          DDLElement's more comprehensive "conditional DDL" approach that is
+          passed a database connection when DDL is being issued. _create_rule
+          is instead called during any CREATE TABLE compilation, where there
+          may not be any transaction/connection in progress. However, it
+          allows conditional compilation of the constraint even for backends
+          which do not support addition of constraints through ALTER TABLE,
+          which currently includes SQLite.
+
+          _create_rule is used by some types to create constraints.
+          Currently, its call signature is subject to change at any time.
+          
         """
 
         self.name = name
         self.deferrable = deferrable
         self.initially = initially
-        self.inline_ddl = inline_ddl
+        self._create_rule = _create_rule
 
     @property
     def table(self):
@@ -1364,13 +1374,13 @@ class CheckConstraint(Constraint):
     Can be included in the definition of a Table or Column.
     """
 
-    def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None):
+    def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None, _create_rule=None):
         """Construct a CHECK constraint.
 
         sqltext
-          A string containing the constraint definition.  Will be used
-          verbatim.
-
+          A string containing the constraint definition, which will be used
+          verbatim, or a SQL expression construct.
+          
         name
           Optional, the in-database name of the constraint.
 
@@ -1384,11 +1394,8 @@ class CheckConstraint(Constraint):
           
         """
 
-        super(CheckConstraint, self).__init__(name, deferrable, initially)
-        if not isinstance(sqltext, basestring):
-            raise exc.ArgumentError(
-                "sqltext must be a string and will be used verbatim.")
-        self.sqltext = sqltext
+        super(CheckConstraint, self).__init__(name, deferrable, initially, _create_rule)
+        self.sqltext = expression._literal_as_text(sqltext)
         if table is not None:
             self._set_parent(table)
             
@@ -2224,7 +2231,6 @@ class _CreateDropBase(DDLElement):
         self._check_ddl_on(on)
         self.on = on
         self.bind = bind
-        element.inline_ddl = False
 
 class CreateTable(_CreateDropBase):
     """Represent a CREATE TABLE statement."""
@@ -2265,6 +2271,10 @@ class AddConstraint(_CreateDropBase):
     
     __visit_name__ = "add_constraint"
 
+    def __init__(self, element, *args, **kw):
+        super(AddConstraint, self).__init__(element, *args, **kw)
+        element._create_rule = lambda compiler: False
+        
 class DropConstraint(_CreateDropBase):
     """Represent an ALTER TABLE DROP CONSTRAINT statement."""
 
@@ -2273,6 +2283,7 @@ class DropConstraint(_CreateDropBase):
     def __init__(self, element, cascade=False, **kw):
         self.cascade = cascade
         super(DropConstraint, self).__init__(element, **kw)
+        element._create_rule = lambda compiler: False
 
 def _bind_or_error(schemaitem):
     bind = schemaitem.bind
index 6802bfbefc8d0d6b8e5718bc1117e9e601650749..a41a149d1dcc7c6cd027d57f33349e6c7713141f 100644 (file)
@@ -922,6 +922,11 @@ class SQLCompiler(engine.Compiled):
 
 
 class DDLCompiler(engine.Compiled):
+    
+    @util.memoized_property
+    def sql_compiler(self):
+        return self.dialect.statement_compiler(self.dialect, self.statement)
+        
     @property
     def preparer(self):
         return self.dialect.identifier_preparer
@@ -982,7 +987,9 @@ class DDLCompiler(engine.Compiled):
         const = ", \n\t".join(p for p in 
                         (self.process(constraint) for constraint in table.constraints 
                         if constraint is not table.primary_key
-                        and constraint.inline_ddl
+                        and (
+                            constraint._create_rule is None or
+                            constraint._create_rule(self))
                         and (
                             not self.dialect.supports_alter or 
                             not getattr(constraint, 'use_alter', False)
@@ -1058,13 +1065,6 @@ class DDLCompiler(engine.Compiled):
     def post_create_table(self, table):
         return ''
 
-    def _compile(self, tocompile, parameters):
-        """compile the given string/parameters using this SchemaGenerator's dialect."""
-        
-        compiler = self.dialect.statement_compiler(self.dialect, tocompile, parameters)
-        compiler.compile()
-        return compiler
-
     def _validate_identifier(self, ident, truncate):
         if truncate:
             if len(ident) > self.dialect.max_identifier_length:
@@ -1082,7 +1082,7 @@ class DDLCompiler(engine.Compiled):
             if isinstance(column.server_default.arg, basestring):
                 return "'%s'" % column.server_default.arg
             else:
-                return unicode(self._compile(column.server_default.arg, None))
+                return self.sql_compiler.process(column.server_default.arg)
         else:
             return None
 
@@ -1091,7 +1091,8 @@ class DDLCompiler(engine.Compiled):
         if constraint.name is not None:
             text += "CONSTRAINT %s " % \
                         self.preparer.format_constraint(constraint)
-        text += " CHECK (%s)" % constraint.sqltext
+        sqltext = sql_util.expression_as_ddl(constraint.sqltext)
+        text += "CHECK (%s)" % self.sql_compiler.process(sqltext)
         text += self.define_constraint_deferrability(constraint)
         return text
 
@@ -1138,17 +1139,6 @@ class DDLCompiler(engine.Compiled):
         text += self.define_constraint_deferrability(constraint)
         return text
 
-    def visit_enum_constraint(self, constraint):
-        text = ""
-        if constraint.name is not None:
-            text += "CONSTRAINT %s " % \
-                        self.preparer.format_constraint(constraint)
-        text += " CHECK (%s IN (%s))" % (
-                    self.preparer.format_column(constraint.column),
-                    ",".join("'%s'" % x for x in constraint.type.enums)
-                )
-        return text
-
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
index a84a3eb7477a2dbf43e6e5cf770ffba4077610e5..78160ad1e2668a3f65d67273192c921dd4ff8bde 100644 (file)
@@ -79,6 +79,26 @@ def find_columns(clause):
     visitors.traverse(clause, {}, {'column':cols.add})
     return cols
 
+def expression_as_ddl(clause):
+    """Given a SQL expression, convert for usage in DDL, such as 
+     CREATE INDEX and CHECK CONSTRAINT.
+     
+     Converts bind params into quoted literals, column identifiers
+     into detached column constructs so that the parent table
+     identifier is not included.
+    
+    """
+    def repl(element):
+        if isinstance(element, expression._BindParamClause):
+            return expression.literal_column(repr(element.value))
+        elif isinstance(element, expression.ColumnClause) and \
+                element.table is not None:
+            return expression.column(element.name)
+        else:
+            return None
+        
+    return visitors.replacement_traverse(clause, {}, repl)
+    
 def adapt_criterion_to_null(crit, nulls):
     """given criterion containing bind params, convert selected elements to IS NULL."""
 
index 25aa8b7e73ed349a192b5f37057bc5e8ba00e893..d7dda85e26112896dfe12d46692f2ecfd93b13b9 100644 (file)
@@ -836,7 +836,13 @@ class Binary(TypeEngine):
         return dbapi.BINARY
 
 class SchemaType(object):
-    """Mark a type as possibly requiring schema-level DDL for usage."""
+    """Mark a type as possibly requiring schema-level DDL for usage.
+    
+    Supports types that must be explicitly created/dropped (i.e. PG ENUM type)
+    as well as types that are complimented by table or schema level
+    constraints, triggers, and other rules.
+    
+    """
     
     def __init__(self, **kw):
         self.name = kw.pop('name', None)
@@ -867,6 +873,8 @@ class SchemaType(object):
         return self.metadata and self.metadata.bind or None
         
     def create(self, bind=None, checkfirst=False):
+        """Issue CREATE ddl for this type, if applicable."""
+        
         from sqlalchemy.schema import _bind_or_error
         if bind is None:
             bind = _bind_or_error(self)
@@ -875,6 +883,8 @@ class SchemaType(object):
             t.create(bind=bind, checkfirst=checkfirst)
 
     def drop(self, bind=None, checkfirst=False):
+        """Issue DROP ddl for this type, if applicable."""
+
         from sqlalchemy.schema import _bind_or_error
         if bind is None:
             bind = _bind_or_error(self)
@@ -983,12 +993,16 @@ class Enum(String, SchemaType):
         if self.native_enum:
             SchemaType._set_table(self, table, column)
             
-        # this constraint DDL object is conditionally
-        # compiled by MySQL, Postgresql based on
-        # the native_enum flag.
-        table.append_constraint(
-            EnumConstraint(self, column)
-        )
+        def should_create_constraint(compiler):
+            return not self.native_enum or \
+                        not compiler.dialect.supports_native_enum
+
+        e = schema.CheckConstraint(
+                        column.in_(self.enums),
+                        name=self.name,
+                        _create_rule=should_create_constraint
+                    )
+        table.append_constraint(e)
         
     def adapt(self, impltype):
         return impltype(name=self.name, 
@@ -1000,14 +1014,6 @@ class Enum(String, SchemaType):
                         *self.enums
                         )
 
-class EnumConstraint(schema.CheckConstraint):
-    __visit_name__ = 'enum_constraint'
-    
-    def __init__(self, type_, column, **kw):
-        super(EnumConstraint, self).__init__('', name=type_.name, **kw)
-        self.type = type_
-        self.column = column
-    
 class PickleType(MutableType, TypeDecorator):
     """Holds Python objects.
 
index f40fa89bd2816fd80eb916fc3843b864d30b5918..accc84c2c9a70ab7427534df37cc2b360bef89ff 100644 (file)
@@ -660,7 +660,7 @@ class TypesTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             schema.CreateTable(t1),
             "CREATE TABLE sometable ("
             "somecolumn VARCHAR(1), "
-            " CHECK (somecolumn IN ('x','y','z'))"
+            "CHECK (somecolumn IN ('x', 'y', 'z'))"
             ")"
         )
         
index 39771bbe959cd947cae25b7a5a0c86ced82bf99c..c929d38b3927ddaa093118e2f64691cd23fe7153 100644 (file)
@@ -84,19 +84,20 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(i, "INSERT INTO mytable (name) VALUES (%(name)s) RETURNING mytable.myid, mytable.name", dialect=dialect)
         
     def test_create_partial_index(self):
-        tbl = Table('testtbl', MetaData(), Column('data',Integer))
+        m = MetaData()
+        tbl = Table('testtbl', m, Column('data',Integer))
         idx = Index('test_idx1', tbl.c.data, postgresql_where=and_(tbl.c.data > 5, tbl.c.data < 10))
 
         self.assert_compile(schema.CreateIndex(idx), 
-            "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
-
+            "CREATE INDEX test_idx1 ON testtbl (data) WHERE data > 5 AND data < 10", dialect=postgresql.dialect())
+            
     @testing.uses_deprecated(r".*'postgres_where' argument has been renamed.*")
     def test_old_create_partial_index(self):
         tbl = Table('testtbl', MetaData(), Column('data',Integer))
         idx = Index('test_idx1', tbl.c.data, postgres_where=and_(tbl.c.data > 5, tbl.c.data < 10))
 
         self.assert_compile(schema.CreateIndex(idx), 
-            "CREATE INDEX test_idx1 ON testtbl (data) WHERE testtbl.data > 5 AND testtbl.data < 10", dialect=postgresql.dialect())
+            "CREATE INDEX test_idx1 ON testtbl (data) WHERE data > 5 AND data < 10", dialect=postgresql.dialect())
 
     def test_extract(self):
         t = table('t', column('col1'))
@@ -214,7 +215,7 @@ class EnumTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
             schema.CreateTable(t1),
             "CREATE TABLE sometable ("
             "somecolumn VARCHAR(1), "
-            " CHECK (somecolumn IN ('x','y','z'))"
+            "CHECK (somecolumn IN ('x', 'y', 'z'))"
             ")"
         )
 
index 6c6ad65e0b286cbcefa40dc8efb4e036d71d7d36..e817d257b51fad23032befd00aa4a633f1e9eb4a 100644 (file)
@@ -39,18 +39,13 @@ class TestTypes(TestBase, AssertsExecutionResults):
     def test_time_microseconds(self):
         dt = datetime.datetime(2008, 6, 27, 12, 0, 0, 125)  # 125 usec
         eq_(str(dt), '2008-06-27 12:00:00.000125')
-        sldt = sqlite._SLDateTime()
+        sldt = sqlite.DATETIME()
         bp = sldt.bind_processor(None)
         eq_(bp(dt), '2008-06-27 12:00:00.000125')
         
         rp = sldt.result_processor(None, None)
         eq_(rp(bp(dt)), dt)
         
-        sldt.__legacy_microseconds__ = True
-        bp = sldt.bind_processor(None)
-        eq_(bp(dt), '2008-06-27 12:00:00.125')
-        eq_(rp(bp(dt)), dt)
-
     def test_no_convert_unicode(self):
         """test no utf-8 encoding occurs"""
         
index 56b6d6a102b32c2e2347a50b537940df856b3814..e2179da09eb30191fe3ce2d666e6b197e250b03e 100644 (file)
@@ -148,7 +148,7 @@ class MetaDataTest(TestBase, ComparesTables):
                             break
                     else:
                         assert False
-                    assert c.sqltext=="description='hi'"
+                    assert str(c.sqltext)=="description='hi'"
 
                     for c in table_c.constraints:
                         if isinstance(c, UniqueConstraint):
index 4ad52604d32e508ebec642b3412395cd8e39e61e..55dcd3484ea41170560df6d1b18fc17a2ef7c83f 100644 (file)
@@ -319,7 +319,7 @@ class ConstraintCompilationTest(TestBase, AssertsCompiledSQL):
         constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t)
         self.assert_compile(
             schema.AddConstraint(constraint),
-            "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint  CHECK (a < b) DEFERRABLE INITIALLY DEFERRED"
+            "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint CHECK (a < b) DEFERRABLE INITIALLY DEFERRED"
         )
 
         self.assert_compile(