]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- lots of paring down and cleanup of schema / DDL. reworked
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 31 May 2009 01:27:46 +0000 (01:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 31 May 2009 01:27:46 +0000 (01:27 +0000)
all _CreateDropBase classes to extend from the same event
framework as DDL().   semi-support for dialect-conditional
Constraint objects, needs work.

14 files changed:
06CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/ddl.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/engine/ddlevents.py
test/engine/metadata.py
test/sql/constraints.py
test/testlib/engines.py
test/testlib/testing.py

index 4594bb519bb9ade803aa2c28bea46f5a720d104a..0eab37b4401715e6bf4ca355606772d4fb7dfbb5 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
       create_engine(... isolation_level="..."); available on
       postgresql and sqlite. [ticket:443]
 
+
 - schema
-    - metadata.connect() and threadlocalmetadata.connect() have been removed.
-    - new CreateTable,DropTable,CreateSequence,DropSequence, etc.
-     
+    - deprecated metadata.connect() and threadlocalmetadata.connect() have been 
+      removed - send the "bind" attribute to bind a metadata.
+    - deprecated metadata.table_iterator() method removed (use sorted_tables)
+    - the "metadata" argument is removed from DefaultGenerator and subclasses.
+    - Removed public mutability from Index and Constraint objects:
+        - ForeignKeyConstraint.append_element()
+        - Index.append_column()
+        - UniqueConstraint.append_column()
+        - PrimaryKeyConstraint.add()
+        - PrimaryKeyConstraint.remove()
+      These should be constructed declaratively (i.e. in one construction).
+    - UniqueConstraint, Index, PrimaryKeyConstraint all accept lists
+      of column names or column objects as arguments.
+    - Other removed things:
+        - Table.key (no idea what this was for)
+        - Table.primary_key is not assignable - use table.append_constraint(PrimaryKeyConstraint(...))
+        - Column.bind       (get via column.table.bind)
+        - Column.metadata   (get via column.table.metadata)
+        
+- DDL
+    - the DDL() system has been greatly expanded:
+        - CreateTable()
+        - DropTable()
+        - AddConstraint()
+        - DropConstraint()
+        - CreateIndex()
+        - DropIndex()
+        - CreateSequence()
+        - DropSequence()
+        - these support "on" and "execute-at()" just like
+          plain DDL() does.
+    
 - dialect refactor
     - the "owner" keyword argument is removed from Table.  Use "schema" to 
       represent any namespaces to be prepended to the table name.
index e68a51c6c16decfe700004ba5d0771407bad3715..6122c5a0cf394c55850bc7f8ea7b13af658ceb7d 100644 (file)
@@ -733,19 +733,6 @@ class FBCompiler(sql.compiler.SQLCompiler):
 class FBSchemaGenerator(sql.compiler.SchemaGenerator):
     """Firebird syntactic idiosincrasies"""
 
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable or column.primary_key:
-            colspec += " NOT NULL"
-
-        return colspec
-
     def visit_sequence(self, sequence):
         """Generate a ``CREATE GENERATOR`` statement for the sequence."""
 
index bb3c00478a9c24cb3fcf28394ce7c00b61d7a4b3..5de31d510b7ca45ff26f2af1a40e8a4a8afeb14b 100644 (file)
@@ -411,16 +411,6 @@ class OracleCompiler(compiler.SQLCompiler):
             return super(OracleCompiler, self).for_update_clause(select)
 
 class OracleDDLCompiler(compiler.DDLCompiler):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column)
-        colspec += " " + self.dialect.type_compiler.process(column.type)
-        default = self.get_column_default_string(column)
-        if default is not None:
-            colspec += " DEFAULT " + default
-
-        if not column.nullable:
-            colspec += " NOT NULL"
-        return colspec
 
     def visit_create_sequence(self, create):
         return "CREATE SEQUENCE %s" % self.preparer.format_sequence(create.element)
index e2be54c9ff316b6496adc08e0e3033ceaef61a07..bbf07e7f1bd4cf908f52777d41e2b48a46353926 100644 (file)
@@ -586,6 +586,7 @@ class Compiled(object):
 
         raise NotImplementedError()
 
+    @property
     def params(self):
         """Return the bind params for this compiled object."""
         return self.construct_params()
index 2e3c23c970e49cc8c99558cf22635191af2f4627..669e71b48571b963143624a829e9900805e12560 100644 (file)
@@ -51,7 +51,7 @@ class SchemaGenerator(DDLBase):
             self.traverse_single(table)
         if self.dialect.supports_alter:
             for alterable in self.find_alterables(collection):
-                self.connection.execute(schema.AddForeignKey(alterable))
+                self.connection.execute(schema.AddConstraint(alterable))
 
     def visit_table(self, table):
         for listener in table.ddl_listeners['before-create']:
@@ -98,7 +98,7 @@ class SchemaDropper(DDLBase):
         collection = [t for t in reversed(sql_util.sort_tables(tables)) if self._can_drop(t)]
         if self.dialect.supports_alter:
             for alterable in self.find_alterables(collection):
-                self.connection.execute(schema.DropForeignKey(alterable))
+                self.connection.execute(schema.DropConstraint(alterable))
         for table in collection:
             self.traverse_single(table)
 
index e7c4f800246d5a5b2dabb6a88e093c35d2f68ca8..bb1f5019928d4484c884ec66b4c3e4e413774c43 100644 (file)
@@ -302,10 +302,13 @@ class Inspector(object):
             raise exc.NoSuchTableError(table.name)
 
         # Primary keys
-        for pk in self.get_primary_keys(table_name, schema, **tblkw):
-            if pk in table.c:
-                col = table.c[pk]
-                table.primary_key.add(col)
+        primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
+            table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
+            if pk in table.c
+        ])
+
+        table.append_constraint(primary_key_constraint)
+
         # Foreign keys
         fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
         for fkey_d in fkeys:
index 5d46a9928e4cbaad6451acab73df33092ababe73..3a8a7e09b7929c855e77da222131b2c5bdcafd80 100644 (file)
@@ -65,13 +65,6 @@ class SchemaItem(visitors.Visitable):
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    @property
-    def bind(self):
-        """Return the connectable associated with this SchemaItem."""
-
-        m = self.metadata
-        return m and m.bind or None
-
     @property
     def info(self):
         try:
@@ -207,7 +200,7 @@ class Table(SchemaItem, expression.TableClause):
         self.indexes = set()
         self.constraints = set()
         self._columns = expression.ColumnCollection()
-        self.primary_key = PrimaryKeyConstraint()
+        self._set_primary_key(PrimaryKeyConstraint())
         self._foreign_keys = util.OrderedSet()
         self.ddl_listeners = util.defaultdict(list)
         self.kwargs = {}
@@ -229,7 +222,7 @@ class Table(SchemaItem, expression.TableClause):
 
         self._prefixes = kwargs.pop('prefixes', [])
 
-        self.__extra_kwargs(**kwargs)
+        self._extra_kwargs(**kwargs)
 
         # load column definitions from the database if 'autoload' is defined
         # we do it after the table is in the singleton dictionary to support
@@ -242,7 +235,7 @@ class Table(SchemaItem, expression.TableClause):
 
         # initialize all the column, etc. objects.  done after reflection to
         # allow user-overrides
-        self.__post_init(*args, **kwargs)
+        self._init_items(*args)
 
     def _init_existing(self, *args, **kwargs):
         autoload = kwargs.pop('autoload', False)
@@ -266,10 +259,10 @@ class Table(SchemaItem, expression.TableClause):
         if 'info' in kwargs:
             self._info = kwargs.pop('info')
 
-        self.__extra_kwargs(**kwargs)
-        self.__post_init(*args, **kwargs)
+        self._extra_kwargs(**kwargs)
+        self._init_items(*args)
 
-    def __extra_kwargs(self, **kwargs):
+    def _extra_kwargs(self, **kwargs):
         # validate remaining kwargs that they all specify DB prefixes
         if len([k for k in kwargs
                 if not re.match(r'^(?:%s)_' % '|'.join(dialects.__all__), k)]):
@@ -277,22 +270,18 @@ class Table(SchemaItem, expression.TableClause):
                 "Invalid argument(s) for Table: %s" % repr(kwargs.keys()))
         self.kwargs.update(kwargs)
 
-    def __post_init(self, *args, **kwargs):
-        self._init_items(*args)
-
-    @property
-    def key(self):
-        return _get_table_key(self.name, self.schema)
-
     def _set_primary_key(self, pk):
         if getattr(self, '_primary_key', None) in self.constraints:
             self.constraints.remove(self._primary_key)
         self._primary_key = pk
         self.constraints.add(pk)
 
+        for c in pk.columns:
+            c.primary_key = True
+
+    @property
     def primary_key(self):
         return self._primary_key
-    primary_key = property(primary_key, _set_primary_key)
 
     def __repr__(self):
         return "Table(%s)" % ', '.join(
@@ -303,6 +292,13 @@ class Table(SchemaItem, expression.TableClause):
     def __str__(self):
         return _get_table_key(self.description, self.schema)
 
+    @property
+    def bind(self):
+        """Return the connectable associated with this Table."""
+
+        m = self.metadata
+        return m and m.bind or None
+
     def append_column(self, column):
         """Append a ``Column`` to this ``Table``."""
 
@@ -576,7 +572,7 @@ class Column(SchemaItem, expression.ColumnClause):
                 type_ = args.pop(0)
 
         super(Column, self).__init__(name, None, type_)
-        self.args = args
+        self._pending_args = args
         self.key = kwargs.pop('key', name)
         self.primary_key = kwargs.pop('primary_key', False)
         self.nullable = kwargs.pop('nullable', not self.primary_key)
@@ -609,10 +605,6 @@ class Column(SchemaItem, expression.ColumnClause):
         else:
             return self.description
 
-    @property
-    def bind(self):
-        return self.table.bind
-
     def references(self, column):
         """Return True if this Column references the given column via foreign key."""
         for fk in self.foreign_keys:
@@ -652,24 +644,23 @@ class Column(SchemaItem, expression.ColumnClause):
                 "before adding to a Table.")
         if self.key is None:
             self.key = self.name
-        self.metadata = table.metadata
+
         if getattr(self, 'table', None) is not None:
             raise exc.ArgumentError("this Column already has a table!")
 
         if self.key in table._columns:
             # note the column being replaced, if any
             self._pre_existing_column = table._columns.get(self.key)
+            
         table._columns.replace(self)
 
         if self.primary_key:
-            table.primary_key.replace(self)
+            table.primary_key._replace(self)
         elif self.key in table.primary_key:
             raise exc.ArgumentError(
                 "Trying to redefine primary-key column '%s' as a "
                 "non-primary-key column on table '%s'" % (
                 self.key, table.fullname))
-            # if we think this should not raise an error, we'd instead do this:
-            #table.primary_key.remove(self)
         self.table = table
 
         if self.index:
@@ -689,7 +680,7 @@ class Column(SchemaItem, expression.ColumnClause):
                     "external to the Table.")
             table.append_constraint(UniqueConstraint(self.key))
 
-        toinit = list(self.args)
+        toinit = list(self._pending_args)
         if self.default is not None:
             if isinstance(self.default, ColumnDefault):
                 toinit.append(self.default)
@@ -709,7 +700,7 @@ class Column(SchemaItem, expression.ColumnClause):
                 toinit.append(DefaultClause(self.server_onupdate,
                                             for_update=True))
         self._init_items(*toinit)
-        self.args = None
+        del self._pending_args
 
     def copy(self, **kw):
         """Create a copy of this ``Column``, unitialized.
@@ -717,7 +708,17 @@ class Column(SchemaItem, expression.ColumnClause):
         This is used in ``Table.tometadata``.
 
         """
-        return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy(**kw) for c in self.constraints])
+        return Column(
+                self.name, 
+                self.type, 
+                self.default, 
+                key = self.key, 
+                primary_key = self.primary_key, 
+                nullable = self.nullable, 
+                quote=self.quote, 
+                index=self.index, 
+                autoincrement=self.autoincrement, 
+                *[c.copy(**kw) for c in self.constraints])
 
     def _make_proxy(self, selectable, name=None):
         """Create a *proxy* for this column.
@@ -769,7 +770,8 @@ class ForeignKey(SchemaItem):
 
     __visit_name__ = 'foreign_key'
 
-    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None, link_to_name=False):
+    def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, 
+                    ondelete=None, deferrable=None, initially=None, link_to_name=False):
         """
         Construct a column-level FOREIGN KEY.
 
@@ -940,6 +942,8 @@ class ForeignKey(SchemaItem):
 
     def _set_parent(self, column):
         if hasattr(self, 'parent'):
+            if self.parent is column:
+                return
             raise exc.InvalidRequestError("This ForeignKey already has a parent !")
         self.parent = column
 
@@ -954,9 +958,10 @@ class ForeignKey(SchemaItem):
             self.constraint = ForeignKeyConstraint(
                 [], [], use_alter=self.use_alter, name=self.name,
                 onupdate=self.onupdate, ondelete=self.ondelete,
-                deferrable=self.deferrable, initially=self.initially)
-            self.parent.table.append_constraint(self.constraint)
-            self.constraint._append_fk(self)
+                deferrable=self.deferrable, initially=self.initially,
+                )
+            self.constraint._elements[column] = self
+            self.constraint._set_parent(self.parent.table)
 
         self.parent.foreign_keys.add(self)
         self.parent.table.foreign_keys.add(self)
@@ -968,11 +973,9 @@ class DefaultGenerator(SchemaItem):
 
     def __init__(self, for_update=False, metadata=None):
         self.for_update = for_update
-        self.metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
     def _set_parent(self, column):
         self.column = column
-        self.metadata = self.column.table.metadata
         if self.for_update:
             self.column.onupdate = self
         else:
@@ -983,6 +986,12 @@ class DefaultGenerator(SchemaItem):
             bind = _bind_or_error(self)
         return bind._execute_default(self, **kwargs)
 
+    @property
+    def bind(self):
+        """Return the connectable associated with this default."""
+
+        return self.column.table.bind
+
     def __repr__(self):
         return "DefaultGenerator()"
 
@@ -1119,13 +1128,8 @@ class DefaultClause(FetchedValue):
 # alias; deprecated starting 0.5.0
 PassiveDefault = DefaultClause
 
-
 class Constraint(SchemaItem):
-    """A table-level SQL constraint, such as a KEY.
-
-    Implements a hybrid of dict/setlike behavior with regards to the list of
-    underying columns.
-    """
+    """A table-level SQL constraint."""
 
     __visit_name__ = 'constraint'
 
@@ -1145,30 +1149,74 @@ class Constraint(SchemaItem):
         """
 
         self.name = name
-        self.columns = expression.ColumnCollection()
         self.deferrable = deferrable
         self.initially = initially
 
+    @property
+    def table(self):
+        if isinstance(self.parent, Table):
+            return self.parent
+        else:
+            raise exc.InvalidRequestError("This constraint is not bound to a table.")
+
+    def _set_parent(self, parent):
+        self.parent = parent
+        parent.constraints.add(self)
+
+    def copy(self, **kw):
+        raise NotImplementedError()
+
+class ColumnCollectionConstraint(Constraint):
+    """A constraint that proxies a ColumnCollection."""
+    
+    def __init__(self, *columns, **kw):
+        """
+        \*columns
+          A sequence of column names or Column objects.
+
+        name
+          Optional, the in-database name of this constraint.
+
+        deferrable
+          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
+          issuing DDL for this constraint.
+
+        initially
+          Optional string.  If set, emit INITIALLY <value> when issuing DDL
+          for this constraint.
+        
+        """
+        super(ColumnCollectionConstraint, self).__init__(**kw)
+        self.columns = expression.ColumnCollection()
+        self._pending_colargs = [_to_schema_column_or_string(c) for c in columns]
+        if self._pending_colargs and \
+                isinstance(self._pending_colargs[0], Column) and \
+                self._pending_colargs[0].table is not None:
+            self._set_parent(self._pending_colargs[0].table)
+        
+    def _set_parent(self, table):
+        super(ColumnCollectionConstraint, self)._set_parent(table)
+        for col in self._pending_colargs:
+            if isinstance(col, basestring):
+                col = table.c[col]
+            self.columns.add(col)
+
     def __contains__(self, x):
         return x in self.columns
 
+    def copy(self, **kw):
+        return self.__class__(*self.columns.keys(), 
+                    name=self.name, deferrable=self.deferrable, initially=self.initially)
+
     def contains_column(self, col):
         return self.columns.contains_column(col)
 
-    def keys(self):
-        return self.columns.keys()
-
-    def __add__(self, other):
-        return self.columns + other
-
     def __iter__(self):
         return iter(self.columns)
 
     def __len__(self):
         return len(self.columns)
 
-    def copy(self, **kw):
-        raise NotImplementedError()
 
 class CheckConstraint(Constraint):
     """A table- or column-level CHECK constraint.
@@ -1176,7 +1224,7 @@ class CheckConstraint(Constraint):
     Can be included in the definition of a Table or Column.
     """
 
-    def __init__(self, sqltext, name=None, deferrable=None, initially=None):
+    def __init__(self, sqltext, name=None, deferrable=None, initially=None, table=None):
         """Construct a CHECK constraint.
 
         sqltext
@@ -1193,6 +1241,7 @@ class CheckConstraint(Constraint):
         initially
           Optional string.  If set, emit INITIALLY <value> when issuing DDL
           for this constraint.
+          
         """
 
         super(CheckConstraint, self).__init__(name, deferrable, initially)
@@ -1200,7 +1249,9 @@ class CheckConstraint(Constraint):
             raise exc.ArgumentError(
                 "sqltext must be a string and will be used verbatim.")
         self.sqltext = sqltext
-
+        if table:
+            self._set_parent(table)
+            
     def __visit_name__(self):
         if isinstance(self.parent, Table):
             return "check_constraint"
@@ -1208,10 +1259,6 @@ class CheckConstraint(Constraint):
             return "column_check_constraint"
     __visit_name__ = property(__visit_name__)
 
-    def _set_parent(self, parent):
-        self.parent = parent
-        parent.constraints.add(self)
-
     def copy(self, **kw):
         return CheckConstraint(self.sqltext, name=self.name)
 
@@ -1228,7 +1275,7 @@ class ForeignKeyConstraint(Constraint):
     """
     __visit_name__ = 'foreign_key_constraint'
 
-    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False, deferrable=None, initially=None, link_to_name=False):
+    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None, use_alter=False, link_to_name=False, table=None):
         """Construct a composite-capable FOREIGN KEY.
 
         :param columns: A sequence of local column names.  The named columns must be defined
@@ -1257,42 +1304,53 @@ class ForeignKeyConstraint(Constraint):
         :param link_to_name: if True, the string name given in ``column`` is the rendered
           name of the referenced column, not its locally assigned ``key``.
 
-        :param use_alter: If True, do not emit this key as part of the CREATE TABLE
+        :param use_alter: If True, do not emit this constraint as part of the CREATE TABLE
           definition.  Instead, use ALTER TABLE after table creation to add
-          the key.  Useful for circular dependencies.
+          the key.  Useful for circular dependencies and conditional constraint generation.
           
         """
         super(ForeignKeyConstraint, self).__init__(name, deferrable, initially)
-        self.__colnames = columns
-        self.__refcolnames = refcolumns
-        self.elements = util.OrderedSet()
+
         self.onupdate = onupdate
         self.ondelete = ondelete
         self.link_to_name = link_to_name
         if self.name is None and use_alter:
-            raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
+            raise exc.ArgumentError("Alterable Constraint requires a name")
         self.use_alter = use_alter
 
-    def _set_parent(self, table):
-        self.table = table
-        if self not in table.constraints:
-            table.constraints.add(self)
-            for (c, r) in zip(self.__colnames, self.__refcolnames):
-                self.append_element(c, r)
+        self._elements = util.OrderedDict()
+        for col, refcol in zip(columns, refcolumns):
+            self._elements[col] = ForeignKey(
+                    refcol, 
+                    constraint=self, 
+                    name=self.name, 
+                    onupdate=self.onupdate, 
+                    ondelete=self.ondelete, 
+                    use_alter=self.use_alter, 
+                    link_to_name=self.link_to_name
+                )
+
+        if table:
+            self._set_parent(table)
 
-    def append_element(self, col, refcol):
-        fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter, link_to_name=self.link_to_name)
-        fk._set_parent(self.table.c[col])
-        self._append_fk(fk)
-
-    def _append_fk(self, fk):
-        self.columns.add(self.table.c[fk.parent.key])
-        self.elements.add(fk)
+    def _set_parent(self, table):
+        super(ForeignKeyConstraint, self)._set_parent(table)
+        for col, fk in self._elements.iteritems():
+            if isinstance(col, basestring):
+                col = table.c[col]
+            fk._set_parent(col)
 
     def copy(self, **kw):
-        return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec(**kw) for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
-
-class PrimaryKeyConstraint(Constraint):
+        return ForeignKeyConstraint(
+                    [x.parent.name for x in self._elements.values()], 
+                    [x._get_colspec(**kw) for x in self._elements.values()], 
+                    name=self.name, 
+                    onupdate=self.onupdate, 
+                    ondelete=self.ondelete, 
+                    use_alter=self.use_alter
+                )
+
+class PrimaryKeyConstraint(ColumnCollectionConstraint):
     """A table-level PRIMARY KEY constraint.
 
     Defines a single column or composite PRIMARY KEY constraint. For a
@@ -1303,63 +1361,14 @@ class PrimaryKeyConstraint(Constraint):
 
     __visit_name__ = 'primary_key_constraint'
 
-    def __init__(self, *columns, **kwargs):
-        """Construct a composite-capable PRIMARY KEY.
-
-        \*columns
-          A sequence of column names.  All columns named must be defined and
-          present within the parent Table.
-
-        name
-          Optional, the in-database name of the key.
-
-        deferrable
-          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
-          issuing DDL for this constraint.
-
-        initially
-          Optional string.  If set, emit INITIALLY <value> when issuing DDL
-          for this constraint.
-        """
-
-        constraint_args = dict(name=kwargs.pop('name', None),
-                               deferrable=kwargs.pop('deferrable', None),
-                               initially=kwargs.pop('initially', None))
-        if kwargs:
-            raise exc.ArgumentError(
-                'Unknown PrimaryKeyConstraint argument(s): %s' %
-                ', '.join(repr(x) for x in kwargs.iterkeys()))
-
-        super(PrimaryKeyConstraint, self).__init__(**constraint_args)
-        self.__colnames = list(columns)
-
     def _set_parent(self, table):
-        self.table = table
-        table.primary_key = self
-        for name in self.__colnames:
-            self.add(table.c[name])
-
-    def add(self, col):
-        self.columns.add(col)
-        col.primary_key = True
-    append_column = add
+        super(PrimaryKeyConstraint, self)._set_parent(table)
+        table._set_primary_key(self)
 
-    def replace(self, col):
+    def _replace(self, col):
         self.columns.replace(col)
 
-    def remove(self, col):
-        col.primary_key = False
-        del self.columns[col.key]
-
-    def copy(self, **kw):
-        return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
-
-    __hash__ = Constraint.__hash__
-    
-    def __eq__(self, other):
-        return self.columns == other
-
-class UniqueConstraint(Constraint):
+class UniqueConstraint(ColumnCollectionConstraint):
     """A table-level UNIQUE constraint.
 
     Defines a single column or composite UNIQUE constraint. For a no-frills,
@@ -1370,48 +1379,6 @@ class UniqueConstraint(Constraint):
 
     __visit_name__ = 'unique_constraint'
 
-    def __init__(self, *columns, **kwargs):
-        """Construct a UNIQUE constraint.
-
-        \*columns
-          A sequence of column names.  All columns named must be defined and
-          present within the parent Table.
-
-        name
-          Optional, the in-database name of the key.
-
-        deferrable
-          Optional bool.  If set, emit DEFERRABLE or NOT DEFERRABLE when
-          issuing DDL for this constraint.
-
-        initially
-          Optional string.  If set, emit INITIALLY <value> when issuing DDL
-          for this constraint.
-        """
-
-        constraint_args = dict(name=kwargs.pop('name', None),
-                               deferrable=kwargs.pop('deferrable', None),
-                               initially=kwargs.pop('initially', None))
-        if kwargs:
-            raise exc.ArgumentError(
-                'Unknown UniqueConstraint argument(s): %s' %
-                ', '.join(repr(x) for x in kwargs.iterkeys()))
-
-        super(UniqueConstraint, self).__init__(**constraint_args)
-        self.__colnames = list(columns)
-
-    def _set_parent(self, table):
-        self.table = table
-        table.constraints.add(self)
-        for c in self.__colnames:
-            self.append_column(table.c[c])
-
-    def append_column(self, col):
-        self.columns.add(col)
-
-    def copy(self, **kw):
-        return UniqueConstraint(name=self.name, *self.__colnames)
-
 class Index(SchemaItem):
     """A table-level INDEX.
 
@@ -1432,7 +1399,7 @@ class Index(SchemaItem):
 
         \*columns
           Columns to include in the index. All columns must belong to the same
-          table, and no column may appear more than once.
+          table.
 
         \**kwargs
           Keyword arguments include:
@@ -1445,7 +1412,7 @@ class Index(SchemaItem):
         """
 
         self.name = name
-        self.columns = []
+        self.columns = expression.ColumnCollection()
         self.table = None
         self.unique = kwargs.pop('unique', False)
         self.kwargs = kwargs
@@ -1454,28 +1421,25 @@ class Index(SchemaItem):
 
     def _init_items(self, *args):
         for column in args:
-            self.append_column(_to_schema_column(column))
+            column = _to_schema_column(column)
+            if self.table is None:
+                self._set_parent(column.table)
+            elif column.table != self.table:
+                # all columns muse be from same table
+                raise exc.ArgumentError(
+                    "All index columns must be from same table. "
+                    "%s is from %s not %s" % (column, column.table, self.table))
+            self.columns.add(column)
 
     def _set_parent(self, table):
         self.table = table
-        self.metadata = table.metadata
         table.indexes.add(self)
 
-    def append_column(self, column):
-        # make sure all columns are from the same table
-        # and no column is repeated
-        if self.table is None:
-            self._set_parent(column.table)
-        elif column.table != self.table:
-            # all columns muse be from same table
-            raise exc.ArgumentError(
-                "All index columns must be from same table. "
-                "%s is from %s not %s" % (column, column.table, self.table))
-        elif column.name in [ c.name for c in self.columns ]:
-            raise exc.ArgumentError(
-                "A column may not appear twice in the "
-                "same index (%s already has column %s)" % (self.name, column))
-        self.columns.append(column)
+    @property
+    def bind(self):
+        """Return the connectable associated with this Index."""
+        
+        return self.table.bind
 
     def create(self, bind=None):
         if bind is None:
@@ -1488,9 +1452,6 @@ class Index(SchemaItem):
             bind = _bind_or_error(self)
         bind.drop(self)
 
-    def __str__(self):
-        return repr(self)
-
     def __repr__(self):
         return 'Index("%s", %s%s)' % (self.name,
                                       ', '.join(repr(c) for c in self.columns),
@@ -1608,20 +1569,6 @@ class MetaData(SchemaItem):
         # TODO: scan all other tables and remove FK _column
         del self.tables[table.key]
 
-    @util.deprecated('Deprecated. Use ``metadata.sorted_tables``')
-    def table_iterator(self, reverse=True, tables=None):
-        """Deprecated - use metadata.sorted_tables()."""
-        
-        from sqlalchemy.sql.util import sort_tables
-        if tables is None:
-            tables = self.tables.itervalues()
-        else:
-            tables = set(tables).intersection(self.tables.itervalues())
-        ret = sort_tables(tables)
-        if reverse:
-            ret = reversed(ret)
-        return iter(ret)
-    
     @property
     def sorted_tables(self):
         """Returns a list of ``Table`` objects sorted in order of
@@ -1752,6 +1699,7 @@ class MetaData(SchemaItem):
         """
         if bind is None:
             bind = _bind_or_error(self)
+        # TODO!!! the listener stuff here needs to move to engine/ddl.py
         for listener in self.ddl_listeners['before-create']:
             listener('before-create', self, bind)
         bind.create(self, checkfirst=checkfirst, tables=tables)
@@ -1779,6 +1727,7 @@ class MetaData(SchemaItem):
         """
         if bind is None:
             bind = _bind_or_error(self)
+        # TODO!!! the listener stuff here needs to move to engine/ddl.py
         for listener in self.ddl_listeners['before-drop']:
             listener('before-drop', self, bind)
         bind.drop(self, checkfirst=checkfirst, tables=tables)
@@ -1865,6 +1814,111 @@ class DDLElement(expression.ClauseElement):
     supports_execution = True
     _autocommit = True
 
+    schema_item = None
+    on = None
+    
+    def execute(self, bind=None, schema_item=None):
+        """Execute this DDL immediately.
+
+        Executes the DDL statement in isolation using the supplied
+        :class:`~sqlalchemy.engine.base.Connectable` or :class:`~sqlalchemy.engine.base.Connectable` assigned to the ``.bind`` property,
+        if not supplied.  If the DDL has a conditional ``on`` criteria, it
+        will be invoked with None as the event.
+
+        bind
+          Optional, an ``Engine`` or ``Connection``.  If not supplied, a
+          valid :class:`~sqlalchemy.engine.base.Connectable` must be present in the ``.bind`` property.
+
+        schema_item
+          Optional, defaults to None.  Will be passed to the ``on`` callable
+          criteria, if any, and may provide string expansion data for the
+          statement. See ``execute_at`` for more information.
+        """
+
+        if bind is None:
+            bind = _bind_or_error(self)
+
+        if self._should_execute(None, schema_item, bind):
+            return bind.execute(self.against(schema_item))
+        else:
+            bind.engine.logger.info("DDL execution skipped, criteria not met.")
+
+    def execute_at(self, event, schema_item):
+        """Link execution of this DDL to the DDL lifecycle of a SchemaItem.
+
+        Links this ``DDL`` to a ``Table`` or ``MetaData`` instance, executing
+        it when that schema item is created or dropped.  The DDL statement
+        will be executed using the same Connection and transactional context
+        as the Table create/drop itself.  The ``.bind`` property of this
+        statement is ignored.
+
+        event
+          One of the events defined in the schema item's ``.ddl_events``;
+          e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop'
+
+        schema_item
+          A Table or MetaData instance
+
+        When operating on Table events, the following additional ``statement``
+        string substitions are available::
+
+            %(table)s  - the Table name, with any required quoting applied
+            %(schema)s - the schema name, with any required quoting applied
+            %(fullname)s - the Table name including schema, quoted if needed
+
+        The DDL's ``context``, if any, will be combined with the standard
+        substutions noted above.  Keys present in the context will override
+        the standard substitutions.
+
+        A DDL instance can be linked to any number of schema items. The
+        statement subsitution support allows for DDL instances to be used in a
+        template fashion.
+
+        ``execute_at`` builds on the ``append_ddl_listener`` interface of
+        MetaDta and Table objects.
+
+        Caveat: Creating or dropping a Table in isolation will also trigger
+        any DDL set to ``execute_at`` that Table's MetaData.  This may change
+        in a future release.
+        """
+
+        if not hasattr(schema_item, 'ddl_listeners'):
+            raise exc.ArgumentError(
+                "%s does not support DDL events" % type(schema_item).__name__)
+        if event not in schema_item.ddl_events:
+            raise exc.ArgumentError(
+                "Unknown event, expected one of (%s), got '%r'" %
+                (', '.join(schema_item.ddl_events), event))
+        schema_item.ddl_listeners[event].append(self)
+        return self
+
+    @expression._generative
+    def against(self, schema_item):
+        """Return a copy of this DDL against a specific schema item."""
+
+        self.schema_item = schema_item
+
+    def __call__(self, event, schema_item, bind):
+        """Execute the DDL as a ddl_listener."""
+
+        if self._should_execute(event, schema_item, bind):
+            return bind.execute(self.against(schema_item))
+
+    def _check_ddl_on(self, on):
+        if (on is not None and
+            (not isinstance(on, basestring) and not util.callable(on))):
+            raise exc.ArgumentError(
+                "Expected the name of a database dialect or a callable for "
+                "'on' criteria, got type '%s'." % type(on).__name__)
+
+    def _should_execute(self, event, schema_item, bind):
+        if self.on is None:
+            return True
+        elif isinstance(self.on, basestring):
+            return self.on == bind.engine.name
+        else:
+            return self.on(event, schema_item, bind)
+
     def bind(self):
         if self._bind:
             return self._bind
@@ -1954,112 +2008,14 @@ class DDL(DDLElement):
             raise exc.ArgumentError(
                 "Expected a string or unicode SQL statement, got '%r'" %
                 statement)
-        if (on is not None and
-            (not isinstance(on, basestring) and not util.callable(on))):
-            raise exc.ArgumentError(
-                "Expected the name of a database dialect or a callable for "
-                "'on' criteria, got type '%s'." % type(on).__name__)
 
         self.statement = statement
-        self.on = on
         self.context = context or {}
-        self._bind = bind
-        self.schema_item = None
-
-    def execute(self, bind=None, schema_item=None):
-        """Execute this DDL immediately.
-
-        Executes the DDL statement in isolation using the supplied
-        :class:`~sqlalchemy.engine.base.Connectable` or :class:`~sqlalchemy.engine.base.Connectable` assigned to the ``.bind`` property,
-        if not supplied.  If the DDL has a conditional ``on`` criteria, it
-        will be invoked with None as the event.
-
-        bind
-          Optional, an ``Engine`` or ``Connection``.  If not supplied, a
-          valid :class:`~sqlalchemy.engine.base.Connectable` must be present in the ``.bind`` property.
-
-        schema_item
-          Optional, defaults to None.  Will be passed to the ``on`` callable
-          criteria, if any, and may provide string expansion data for the
-          statement. See ``execute_at`` for more information.
-        """
-
-        if bind is None:
-            bind = _bind_or_error(self)
 
-        if self._should_execute(None, schema_item, bind):
-            return bind.execute(self.against(schema_item))
-        else:
-            bind.engine.logger.info("DDL execution skipped, criteria not met.")
-
-    def execute_at(self, event, schema_item):
-        """Link execution of this DDL to the DDL lifecycle of a SchemaItem.
-
-        Links this ``DDL`` to a ``Table`` or ``MetaData`` instance, executing
-        it when that schema item is created or dropped.  The DDL statement
-        will be executed using the same Connection and transactional context
-        as the Table create/drop itself.  The ``.bind`` property of this
-        statement is ignored.
-
-        event
-          One of the events defined in the schema item's ``.ddl_events``;
-          e.g. 'before-create', 'after-create', 'before-drop' or 'after-drop'
-
-        schema_item
-          A Table or MetaData instance
-
-        When operating on Table events, the following additional ``statement``
-        string substitions are available::
-
-            %(table)s  - the Table name, with any required quoting applied
-            %(schema)s - the schema name, with any required quoting applied
-            %(fullname)s - the Table name including schema, quoted if needed
-
-        The DDL's ``context``, if any, will be combined with the standard
-        substutions noted above.  Keys present in the context will override
-        the standard substitutions.
-
-        A DDL instance can be linked to any number of schema items. The
-        statement subsitution support allows for DDL instances to be used in a
-        template fashion.
-
-        ``execute_at`` builds on the ``append_ddl_listener`` interface of
-        MetaDta and Table objects.
-
-        Caveat: Creating or dropping a Table in isolation will also trigger
-        any DDL set to ``execute_at`` that Table's MetaData.  This may change
-        in a future release.
-        """
-
-        if not hasattr(schema_item, 'ddl_listeners'):
-            raise exc.ArgumentError(
-                "%s does not support DDL events" % type(schema_item).__name__)
-        if event not in schema_item.ddl_events:
-            raise exc.ArgumentError(
-                "Unknown event, expected one of (%s), got '%r'" %
-                (', '.join(schema_item.ddl_events), event))
-        schema_item.ddl_listeners[event].append(self)
-        return self
-    
-    @expression._generative
-    def against(self, schema_item):
-        """Return a copy of this DDL against a specific schema item."""
-        
-        self.schema_item = schema_item
-        
-    def __call__(self, event, schema_item, bind):
-        """Execute the DDL as a ddl_listener."""
-
-        if self._should_execute(event, schema_item, bind):
-            return bind.execute(self.against(schema_item))
+        self._check_ddl_on(on)
+        self.on = on
+        self._bind = bind
 
-    def _should_execute(self, event, schema_item, bind):
-        if self.on is None:
-            return True
-        elif isinstance(self.on, basestring):
-            return self.on == bind.engine.name
-        else:
-            return self.on(event, schema_item, bind)
 
     def __repr__(self):
         return '<%s@%s; %s>' % (
@@ -2076,6 +2032,11 @@ def _to_schema_column(element):
        raise exc.ArgumentError("schema.Column object expected")
    return element
 
+def _to_schema_column_or_string(element):
+  if hasattr(element, '__clause_element__'):
+      element = element.__clause_element__()
+  return element
+
 class _CreateDropBase(DDLElement):
     """Base class for DDL constucts that represent CREATE and DROP or equivalents.
 
@@ -2085,21 +2046,12 @@ class _CreateDropBase(DDLElement):
     
     """
     
-    def __init__(self, element):
+    def __init__(self, element, on=None, bind=None):
         self.element = element
+        self._check_ddl_on(on)
+        self.on = on
+        self.bind = bind
         
-    def bind(self):
-        if self._bind:
-            return self._bind
-        if self.element:
-            e = self.element.bind
-            if e:
-                return e
-        return None
-
-    def _set_bind(self, bind):
-        self._bind = bind
-    bind = property(bind, _set_bind)
 
 class CreateTable(_CreateDropBase):
     """Represent a CREATE TABLE statement."""
@@ -2111,16 +2063,6 @@ class DropTable(_CreateDropBase):
 
     __visit_name__ = "drop_table"
 
-class AddForeignKey(_CreateDropBase):
-    """Represent an ALTER TABLE ADD FOREIGN KEY statement."""
-    
-    __visit_name__ = "add_foreignkey"
-    
-class DropForeignKey(_CreateDropBase):
-    """Represent an ALTER TABLE DROP FOREIGN KEY statement."""
-    
-    __visit_name__ = "drop_foreignkey"
-
 class CreateSequence(_CreateDropBase):
     """Represent a CREATE SEQUENCE statement."""
     
@@ -2140,7 +2082,17 @@ class DropIndex(_CreateDropBase):
     """Represent a DROP INDEX statement."""
 
     __visit_name__ = "drop_index"
+
+class AddConstraint(_CreateDropBase):
+    """Represent an ALTER TABLE ADD CONSTRAINT statement."""
     
+    __visit_name__ = "add_constraint"
+
+class DropConstraint(_CreateDropBase):
+    """Represent an ALTER TABLE DROP CONSTRAINT statement."""
+
+    __visit_name__ = "drop_constraint"
+
 def _bind_or_error(schemaitem):
     bind = schemaitem.bind
     if not bind:
index 5ea22b2fdb511ed5517277b6adbc57e167d6f408..1258dde90db88721f342043d9782596407c5dd7e 100644 (file)
@@ -780,6 +780,9 @@ class DDLCompiler(engine.Compiled):
     @property
     def preparer(self):
         return self.dialect.identifier_preparer
+
+    def construct_params(self, params=None):
+        return None
         
     def visit_ddl(self, ddl, **kwargs):
         # table events can substitute table and schema name
@@ -819,16 +822,21 @@ class DDLCompiler(engine.Compiled):
             text += "\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)
             if column.primary_key:
                 first_pk = True
-            for constraint in column.constraints:
-                text += self.process(constraint)
+            const = " ".join(self.process(constraint) for constraint in column.constraints)
+            if const:
+                text += " " + const
 
         # On some DB order is significant: visit PK first, then the
         # other constraints (engine.ReflectionTest.testbasic failed on FB2)
         if table.primary_key:
-            text += self.process(table.primary_key)
-
-        for constraint in [c for c in table.constraints if c is not table.primary_key]:
-            text += self.process(constraint)
+            text += ", \n\t" + self.process(table.primary_key)
+        
+        const = ", \n\t".join(
+                        self.process(constraint) for constraint in table.constraints if constraint is not table.primary_key
+                        and (not self.dialect.supports_alter or not getattr(constraint, 'use_alter', False))
+                )
+        if const:
+            text += ", \n\t" + const
 
         text += "\n)%s\n\n" % self.post_create_table(table)
         return text
@@ -836,15 +844,6 @@ class DDLCompiler(engine.Compiled):
     def visit_drop_table(self, drop):
         return "\nDROP TABLE " + self.preparer.format_table(drop.element)
         
-    def visit_add_foreignkey(self, add):
-        return "ALTER TABLE %s ADD " % self.preparer.format_table(add.element.table) + \
-            self.define_foreign_key(add.element)
-
-    def visit_drop_foreignkey(self, drop):
-        return "ALTER TABLE %s DROP CONSTRAINT %s" % (
-            self.preparer.format_table(drop.element.table),
-            self.preparer.format_constraint(drop.element))
-
     def visit_create_index(self, create):
         index = create.element
         preparer = self.preparer
@@ -862,8 +861,29 @@ class DDLCompiler(engine.Compiled):
         index = drop.element
         return "\nDROP INDEX " + self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
 
-    def get_column_specification(self, column, first_pk=False):
-        raise NotImplementedError()
+    def visit_add_constraint(self, create):
+        preparer = self.preparer
+        return "ALTER TABLE %s ADD %s" % (
+            self.preparer.format_table(create.element.table),
+            self.process(create.element)
+        )
+        
+    def visit_drop_constraint(self, drop):
+        preparer = self.preparer
+        return "ALTER TABLE %s DROP CONSTRAINT %s" % (
+            self.preparer.format_table(drop.element.table),
+            self.preparer.format_constraint(drop.element)
+        )
+        
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+        default = self.get_column_default_string(column)
+        if default is not None:
+            colspec += " DEFAULT " + default
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        return colspec
 
     def post_create_table(self, table):
         return ''
@@ -896,7 +916,7 @@ class DDLCompiler(engine.Compiled):
             return None
 
     def visit_check_constraint(self, constraint):
-        text = ", \n\t"
+        text = ""
         if constraint.name is not None:
             text += "CONSTRAINT %s " % \
                         self.preparer.format_constraint(constraint)
@@ -912,7 +932,7 @@ class DDLCompiler(engine.Compiled):
     def visit_primary_key_constraint(self, constraint):
         if len(constraint) == 0:
             return ''
-        text = ", \n\t"
+        text = ""
         if constraint.name is not None:
             text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
         text += "PRIMARY KEY "
@@ -922,24 +942,18 @@ class DDLCompiler(engine.Compiled):
         return text
 
     def visit_foreign_key_constraint(self, constraint):
-        if constraint.use_alter and self.dialect.supports_alter:
-            return ''
-        
-        return ", \n\t " + self.define_foreign_key(constraint)
-
-    def define_foreign_key(self, constraint):
         preparer = self.dialect.identifier_preparer
         text = ""
         if constraint.name is not None:
             text += "CONSTRAINT %s " % \
                         preparer.format_constraint(constraint)
-        table = list(constraint.elements)[0].column.table
+        remote_table = list(constraint._elements.values())[0].column.table
         text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
             ', '.join(preparer.quote(f.parent.name, f.parent.quote)
-                      for f in constraint.elements),
-            preparer.format_table(table),
+                      for f in constraint._elements.values()),
+            preparer.format_table(remote_table),
             ', '.join(preparer.quote(f.column.name, f.column.quote)
-                      for f in constraint.elements)
+                      for f in constraint._elements.values())
         )
         if constraint.ondelete is not None:
             text += " ON DELETE %s" % constraint.ondelete
@@ -949,7 +963,7 @@ class DDLCompiler(engine.Compiled):
         return text
 
     def visit_unique_constraint(self, constraint):
-        text = ", \n\t"
+        text = ""
         if constraint.name is not None:
             text += "CONSTRAINT %s " % self.preparer.format_constraint(constraint)
         text += " UNIQUE (%s)" % (', '.join(self.preparer.quote(c.name, c.quote) for c in constraint))
index 68696ae8e41ea6ca10d91ac0184d5104b59f6ae4..58366f81feee20eb3bf32e8a3b11310f87c60477 100644 (file)
@@ -1088,6 +1088,11 @@ class ClauseElement(Visitable):
     def self_group(self, against=None):
         return self
 
+    # TODO: remove .bind as a method from the root ClauseElement.
+    # we should only be deriving binds from FromClause elements
+    # and certain SchemaItem subclasses.
+    # the "search_for_bind" functionality can still be used by
+    # execute(), however.
     @property
     def bind(self):
         """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
index 4c929b766c6d677c67c05f1981706b8139dde70c..61bb4c85d373181dd68d1c0f205a645dd9740838 100644 (file)
@@ -1,5 +1,5 @@
 import testenv; testenv.configure_for_tests()
-from sqlalchemy.schema import DDL
+from sqlalchemy.schema import DDL, CheckConstraint, AddConstraint, DropConstraint
 from sqlalchemy import create_engine
 from testlib.sa import MetaData, Table, Column, Integer, String
 import testlib.sa as tsa
@@ -230,7 +230,40 @@ class DDLExecutionTest(TestBase):
         assert 'klptzyxm' not in strings
         assert 'xyzzy' in strings
         assert 'fnord' in strings
-
+    
+    def test_conditional_constraint(self):
+        metadata, users, engine = self.metadata, self.users, self.engine
+        nonpg_mock = engines.mock_engine(dialect_name='sqlite')
+        pg_mock = engines.mock_engine(dialect_name='postgres')
+        
+        constraint = CheckConstraint('a < b',name="my_test_constraint", table=users)
+        
+        AddConstraint(constraint, on='postgres').execute_at("after-create", users)
+        DropConstraint(constraint, on='postgres').execute_at("before-drop", users)
+        
+        # TODO: need to figure out how to achieve
+        # finer grained control of the DDL process in a
+        # consistent way.
+        # Constraint should get a new flag that is not part of the constructor:
+        # "manual_ddl" or similar.  The flag is public but is normally 
+        # set automatically by DDLElement.execute_at(), so that the
+        # remove() step here is not needed.
+        users.constraints.remove(constraint)
+        
+        metadata.create_all(bind=nonpg_mock)
+        strings = " ".join(str(x) for x in nonpg_mock.mock)
+        assert "my_test_constraint" not in strings
+        metadata.drop_all(bind=nonpg_mock)
+        strings = " ".join(str(x) for x in nonpg_mock.mock)
+        assert "my_test_constraint" not in strings
+
+        metadata.create_all(bind=pg_mock)
+        strings = " ".join(str(x) for x in pg_mock.mock)
+        assert "my_test_constraint" in strings
+        metadata.drop_all(bind=pg_mock)
+        strings = " ".join(str(x) for x in pg_mock.mock)
+        assert "my_test_constraint" in strings
+        
     def test_metadata(self):
         metadata, engine = self.metadata, self.engine
         DDL('mxyzptlk').execute_at('before-create', metadata)
@@ -295,7 +328,6 @@ class DDLTest(TestBase, AssertsCompiledSQL):
 
     def test_tokens(self):
         m = MetaData()
-        bind = self.mock_engine()
         sane_alone = Table('t', m, Column('id', Integer))
         sane_schema = Table('t', m, Column('id', Integer), schema='s')
         insane_alone = Table('t t', m, Column('id', Integer))
@@ -303,7 +335,7 @@ class DDLTest(TestBase, AssertsCompiledSQL):
 
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
 
-        dialect = bind.dialect
+        dialect = testing.db.dialect
         self.assert_compile(ddl.against(sane_alone), '-t-t', dialect=dialect)
         self.assert_compile(ddl.against(sane_schema), 's-t-s.t', dialect=dialect)
         self.assert_compile(ddl.against(insane_alone), '-"t t"-"t t"', dialect=dialect)
index c8fc6f7e0fdf4b925f052a0cc4ffde5ec0f2ba08..5d537bc3089eed35e98de13b3bf6a42ce79dcab3 100644 (file)
@@ -82,7 +82,7 @@ class MetaDataTest(TestBase, ComparesTables):
 
         meta.create_all(testing.db)
         try:
-            for test, has_constraints in ((test_to_metadata, True), (test_pickle, True), (test_pickle_via_reflect, False)):
+            for test, has_constraints in ((test_to_metadata, True), (test_pickle, True),(test_pickle_via_reflect, False)):
                 table_c, table2_c = test()
                 self.assert_tables_equal(table, table_c)
                 self.assert_tables_equal(table2, table2_c)
index b03005c00efb39d4649f481777cc2fe4f4fa4566..e2a1adb93af1a0a3a226eae849999d9e7ba59b4a 100644 (file)
@@ -6,6 +6,7 @@ from testlib import config, engines
 from sqlalchemy.engine import ddl
 from testlib.testing import eq_
 from testlib.assertsql import AllOf, RegexSQL, ExactSQL, CompiledSQL
+from sqlalchemy.dialects.postgres import base as postgres
 
 class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
@@ -200,7 +201,7 @@ class ConstraintTest(TestBase, AssertsExecutionResults, AssertsCompiledSQL):
         )
 
     
-class ConstraintCompilationTest(TestBase):
+class ConstraintCompilationTest(TestBase, AssertsCompiledSQL):
 
     def _test_deferrable(self, constraint_factory):
         t = Table('tbl', MetaData(),
@@ -253,9 +254,10 @@ class ConstraintCompilationTest(TestBase):
                          ForeignKey('tbl.a', deferrable=True,
                                     initially='DEFERRED')))
 
-        sql = str(schema.CreateTable(t).compile(bind=testing.db))
-        assert 'DEFERRABLE' in sql
-        assert 'INITIALLY DEFERRED' in sql
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE tbl (a INTEGER, b INTEGER, FOREIGN KEY(b) REFERENCES tbl (a) DEFERRABLE INITIALLY DEFERRED)",
+        )
 
     def test_deferrable_unique(self):
         factory = lambda **kw: UniqueConstraint('b', **kw)
@@ -272,10 +274,71 @@ class ConstraintCompilationTest(TestBase):
                          CheckConstraint('a < b',
                                          deferrable=True,
                                          initially='DEFERRED')))
-        sql = str(schema.CreateTable(t).compile(bind=testing.db))
-        assert 'DEFERRABLE' in sql
-        assert 'INITIALLY DEFERRED' in sql
+        
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE tbl (a INTEGER, b INTEGER  CHECK (a < b) DEFERRABLE INITIALLY DEFERRED)"
+        )
+    
+    def test_add_drop_constraint(self):
+        m = MetaData()
+        
+        t = Table('tbl', m,
+                  Column('a', Integer),
+                  Column('b', Integer)
+        )
+        
+        t2 = Table('t2', m,
+                Column('a', Integer),
+                Column('b', Integer)
+        )
+        
+        constraint = CheckConstraint('a < b',name="my_test_constraint", deferrable=True,initially='DEFERRED', table=t)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD CONSTRAINT my_test_constraint  CHECK (a < b) DEFERRABLE INITIALLY DEFERRED"
+        )
 
+        self.assert_compile(
+            schema.DropConstraint(constraint),
+            "ALTER TABLE tbl DROP CONSTRAINT my_test_constraint"
+        )
+
+        constraint = ForeignKeyConstraint(["b"], ["t2.a"])
+        t.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD FOREIGN KEY(b) REFERENCES t2 (a)"
+        )
 
+        constraint = ForeignKeyConstraint([t.c.a], [t2.c.b])
+        t.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD FOREIGN KEY(a) REFERENCES t2 (b)"
+        )
+
+        constraint = UniqueConstraint("a", "b", name="uq_cst")
+        t2.append_constraint(constraint)
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT uq_cst  UNIQUE (a, b)"
+        )
+        
+        constraint = UniqueConstraint(t2.c.a, t2.c.b, name="uq_cs2")
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE t2 ADD CONSTRAINT uq_cs2  UNIQUE (a, b)"
+        )
+        
+        assert t.c.a.primary_key is False
+        constraint = PrimaryKeyConstraint(t.c.a)
+        assert t.c.a.primary_key is True
+        self.assert_compile(
+            schema.AddConstraint(constraint),
+            "ALTER TABLE tbl ADD PRIMARY KEY (a)"
+        )
+    
+        
 if __name__ == "__main__":
     testenv.main()
index b923b2092ea07b6cd9c762a38b17f286feed1733..5b6eaa828d9344158a625a34fa93bd946213f758 100644 (file)
@@ -154,16 +154,18 @@ def utf8_engine(url=None, options=None):
 
     return testing_engine(url, options)
 
-def mock_engine(db=None):
+def mock_engine(db=None, dialect_name=None):
     """Provides a mocking engine based on the current testing.db."""
     
     from sqlalchemy import create_engine
     
-    dbi = db or config.db
+    if not dialect_name:
+        dbi = db or config.db
+        dialect_name = dbi.name
     buffer = []
     def executor(sql, *a, **kw):
         buffer.append(sql)
-    engine = create_engine(dbi.name + '://',
+    engine = create_engine(dialect_name + '://',
                            strategy='mock', executor=executor)
     assert not hasattr(engine, 'mock')
     engine.mock = buffer
index 139661e5e8b92a81dbe24ad413e827aa95153dbb..4348c61a7b3f9c38dfd876afc28d041f7ddcf19e 100644 (file)
@@ -651,7 +651,7 @@ class AssertsCompiledSQL(object):
 
         print "\nSQL String:\n" + str(c) + repr(getattr(c, 'params', {}))
 
-        cc = re.sub(r'\n', '', str(c))
+        cc = re.sub(r'[\n\t]', '', str(c))
 
         self.assertEquals(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))