]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- a fair amount of cleanup to the schema package, removal of ambiguous
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Oct 2006 21:58:04 +0000 (21:58 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Oct 2006 21:58:04 +0000 (21:58 +0000)
methods, methods that are no longer needed.  slightly more constrained
useage, greater emphasis on explicitness.
- table_iterator signature fixup, includes fix for [ticket:288]
- the "primary_key" attribute of Table and other selectables becomes
a setlike ColumnCollection object; is no longer ordered or numerically
indexed.  a comparison clause between two pks that are derived from the
same underlying tables (i.e. such as two Alias objects) can be generated
via table1.primary_key==table2.primary_key
- append_item() methods removed from Table and Column; preferably
construct Table/Column/related objects inline, but if needed use
append_column(), append_foreign_key(), append_constraint(), etc.
- table.create() no longer returns the Table object, instead has no
return value.  the usual case is that tables are created via metadata,
which is preferable since it will handle table dependencies.
- added UniqueConstraint (goes at Table level), CheckConstraint
(goes at Table or Column level) fixes [ticket:217]
- index=False/unique=True on Column now creates a UniqueConstraint,
index=True/unique=False creates a plain Index,
index=True/unique=True on Column creates a unique Index.  'index'
and 'unique' keyword arguments to column are now boolean only; for
explcit names and groupings of indexes or unique constraints, use the
UniqueConstraint/Index constructs explicitly.
- relationship of Metadata/Table/SchemaGenerator/Dropper has been
improved so that the schemavisitor receives the metadata object
for greater control over groupings of creates/drops.
- added "use_alter" argument to ForeignKey, ForeignKeyConstraint,
but it doesnt do anything yet.  will utilize new generator/dropper
behavior to implement.

24 files changed:
CHANGES
doc/build/content/metadata.txt
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/information_schema.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/engine/reflection.py
test/orm/cycles.py
test/orm/inheritance.py
test/orm/manytomany.py
test/orm/unitofwork.py
test/sql/alltests.py
test/sql/constraints.py [moved from test/sql/indexes.py with 60% similarity]
test/sql/testtypes.py
test/zblog/tests.py

diff --git a/CHANGES b/CHANGES
index 721a5ef19cf40edcb3efe42774107c544615b7d0..b9ac78ee9a01509654766d092f0537ad0eacf1ef 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     - aliases do not use "AS"
     - correctly raises NoSuchTableError when reflecting non-existent table
 - Schema:
+    - a fair amount of cleanup to the schema package, removal of ambiguous
+    methods, methods that are no longer needed.  slightly more constrained
+    useage, greater emphasis on explicitness
+    - the "primary_key" attribute of Table and other selectables becomes
+    a setlike ColumnCollection object; is no longer ordered or numerically
+    indexed.  a comparison clause between two pks that are derived from the 
+    same underlying tables (i.e. such as two Alias objects) can be generated 
+    via table1.primary_key==table2.primary_key
+    - append_item() methods removed from Table and Column; preferably
+    construct Table/Column/related objects inline, but if needed use 
+    append_column(), append_foreign_key(), append_constraint(), etc.
+    - table.create() no longer returns the Table object, instead has no
+    return value.  the usual case is that tables are created via metadata,
+    which is preferable since it will handle table dependencies.
+    - added UniqueConstraint (goes at Table level), CheckConstraint
+    (goes at Table or Column level).
+    - index=False/unique=True on Column now creates a UniqueConstraint,
+    index=True/unique=False creates a plain Index, 
+    index=True/unique=True on Column creates a unique Index.  'index'
+    and 'unique' keyword arguments to column are now boolean only; for
+    explcit names and groupings of indexes or unique constraints, use the
+    UniqueConstraint/Index constructs explicitly.
     - added autoincrement=True to Column; will disable schema generation
     of SERIAL/AUTO_INCREMENT/identity seq for postgres/mysql/mssql if
     explicitly set to False
index bf5d78dce09882d0c2c4e5cb214331a8eb790262..4cae58cb4b21bf4396eacb82478b7d592e333b6a 100644 (file)
@@ -470,41 +470,86 @@ A Sequence object can be defined on a Table that is then used for a non-sequence
     
 A sequence can also be specified with `optional=True` which indicates the Sequence should only be used on a database that requires an explicit sequence, and not those that supply some other method of providing integer values.  At the moment, it essentially means "use this sequence only with Oracle and not Postgres".
     
-### Defining Indexes {@name=indexes}
+### Defining Constraints and Indexes {@name=constraints}
 
-Indexes can be defined on table columns, including named indexes, non-unique or unique, multiple column.  Indexes are included along with table create and drop statements.  They are not used for any kind of run-time constraint checking; SQLAlchemy leaves that job to the expert on constraint checking, the database itself.
+#### UNIQUE Constraint
+
+Unique constraints can be created anonymously on a single column using the `unique` keyword on `Column`.  Explicitly named unique constraints and/or those with multiple columns are created via the `UniqueConstraint` table-level construct.
 
     {python}
-    boundmeta = BoundMetaData('postgres:///scott:tiger@localhost/test')
-    mytable = Table('mytable', boundmeta, 
-        # define a unique index 
+    meta = MetaData()
+    mytable = Table('mytable', meta,
+    
+        # per-column anonymous unique constraint
         Column('col1', Integer, unique=True),
         
-        # define a unique index with a specific name
-        Column('col2', Integer, unique='mytab_idx_1'),
-        
-        # define a non-unique index
-        Column('col3', Integer, index=True),
+        Column('col2', Integer),
+        Column('col3', Integer),
         
-        # define a non-unique index with a specific name
-        Column('col4', Integer, index='mytab_idx_2'),
+        # explicit/composite unique constraint.  'name' is optional.
+        UniqueConstraint('col2', 'col3', name='uix_1')
+        )
+
+#### CHECK Constraint
+
+Check constraints can be named or unnamed and can be created at the Column or Table level, using the `CheckConstraint` construct.  The text of the check constraint is passed directly through to the database, so there is limited "database independent" behavior.  Column level check constraints generally should only refer to the column to which they are placed, while table level constraints can refer to any columns in the table.
+
+Note that some databases do not actively support check constraints such as MySQL and sqlite.
+
+    {python}
+    meta = MetaData()
+    mytable = Table('mytable', meta,
+    
+        # per-column CHECK constraint
+        Column('col1', Integer, CheckConstraint('col1&gt;5')),
         
-        # pass the same name to multiple columns to add them to the same index
-        Column('col5', Integer, index='mytab_idx_2'),
+        Column('col2', Integer),
+        Column('col3', Integer),
         
-        Column('col6', Integer),
-        Column('col7', Integer)
-    )
-    
-    # create the table.  all the indexes will be created along with it.
-    mytable.create()
-    
-    # indexes can also be specified standalone
-    i = Index('mytab_idx_3', mytable.c.col6, mytable.c.col7, unique=False)
+        # table level CHECK constraint.  'name' is optional.
+        CheckConstraint('col2 &gt; col3 + 5', name='check1')
+        )
     
-    # which can then be created separately (will also get created with table creates)
+#### Indexes
+
+Indexes can be created anonymously (using an auto-generated name "ix_&lt;column label&gt;") for a single column using the inline `index` keyword on `Column`, which also modifies the usage of `unique` to apply the uniqueness to the index itself, instead of adding a separate UNIQUE constraint.  For indexes with specific names or which encompass more than one column, use the `Index` construct, which requires a name.  
+
+Note that the `Index` construct is created **externally** to the table which it corresponds, using `Column` objects and not strings.
+
+    {python}
+    meta = MetaData()
+    mytable = Table('mytable', meta,
+        # an indexed column, with index "ix_mytable_col1"
+        Column('col1', Integer, index=True),
+
+        # a uniquely indexed column with index "ix_mytable_col2"
+        Column('col2', Integer, index=True, unique=True),
+
+        Column('col3', Integer),
+        Column('col4', Integer),
+
+        Column('col5', Integer),
+        Column('col6', Integer),
+        )
+
+    # place an index on col3, col4
+    Index('idx_col34', mytable.c.col3, mytable.c.col4)
+
+    # place a unique index on col5, col6
+    Index('myindex', mytable.c.col5, mytable.c.col6, unique=True)
+
+The `Index` objects will be created along with the CREATE statements for the table itself.  An index can also be created on its own independently of the table:
+
+    {python}
+    # create a table
+    sometable.create()
+
+    # define an index
+    i = Index('someindex', sometable.c.col5)
+
+    # create the index, will use the table's connectable, or specify the connectable keyword argument
     i.create()
-    
+
 ### Adapting Tables to Alternate Metadata {@name=adapting}
 
 A `Table` object created against a specific `MetaData` object can be re-created against a new MetaData using the `tometadata` method:
index 2b0d7d17e5e3be4d438ad0134ff736b72cb6950a..208b2f603a1dbbfb29a34b6dac4236fcb6f3e4b9 100644 (file)
@@ -7,7 +7,7 @@
 """defines ANSI SQL operations.  Contains default implementations for the abstract objects 
 in the sql module."""
 
-from sqlalchemy import schema, sql, engine, util
+from sqlalchemy import schema, sql, engine, util, sql_util
 import sqlalchemy.engine.default as default
 import string, re, sets, weakref
 
@@ -28,9 +28,6 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', '
 LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$')
 ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$')
 
-def create_engine():
-    return engine.ComposedSQLEngine(None, ANSIDialect())
-    
 class ANSIDialect(default.DefaultDialect):
     def __init__(self, cache_identifiers=True, **kwargs):
         super(ANSIDialect,self).__init__(**kwargs)
@@ -174,7 +171,7 @@ class ANSICompiler(sql.Compiled):
                 if n is not None:
                     self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
                 elif len(column.table.primary_key) != 0:
-                    self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0])
+                    self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0])
                 else:
                     self.strings[column] = None
             else:
@@ -611,22 +608,30 @@ class ANSICompiler(sql.Compiled):
 
 
 class ANSISchemaGenerator(engine.SchemaIterator):
-    def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
-        super(ANSISchemaGenerator, self).__init__(engine, proxy, **params)
+    def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
+        super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
         self.checkfirst = checkfirst
+        self.tables = tables and util.Set(tables) or None
         self.connection = connection
         self.preparer = self.engine.dialect.preparer()
-    
+        self.dialect = self.engine.dialect
+        
     def get_column_specification(self, column, first_pk=False):
         raise NotImplementedError()
-        
-    def visit_table(self, table):
-        # the single whitespace before the "(" is significant
-        # as its MySQL's method of indicating a table name and not a reserved word.
-        # feel free to localize this logic to the mysql module
-        if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name):
-            return
+    
+    def visit_metadata(self, metadata):
+        for table in metadata.table_iterator(reverse=False, tables=self.tables):
+            if self.checkfirst and self.dialect.has_table(self.connection, table.name):
+                continue
+            table.accept_schema_visitor(self, traverse=False)
             
+    def visit_table(self, table):
+        for column in table.columns:
+            if column.default is not None:
+                column.default.accept_schema_visitor(self, traverse=False)
+            #if column.onupdate is not None:
+            #    column.onupdate.accept_schema_visitor(visitor, traverse=False)
+        
         self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (")
         
         separator = "\n"
@@ -639,15 +644,17 @@ class ANSISchemaGenerator(engine.SchemaIterator):
             self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
             if column.primary_key:
                 first_pk = True
-
+            for constraint in column.constraints:
+                constraint.accept_schema_visitor(self, traverse=False)
+                
         for constraint in table.constraints:
-            constraint.accept_schema_visitor(self)            
+            constraint.accept_schema_visitor(self, traverse=False)
 
         self.append("\n)%s\n\n" % self.post_create_table(table))
-        self.execute()        
+        self.execute()
         if hasattr(table, 'indexes'):
             for index in table.indexes:
-                self.visit_index(index)
+                index.accept_schema_visitor(self, traverse=False)
         
     def post_create_table(self, table):
         return ''
@@ -662,10 +669,17 @@ class ANSISchemaGenerator(engine.SchemaIterator):
             return None
 
     def _compile(self, tocompile, parameters):
+        """compile the given string/parameters using this SchemaGenerator's dialect."""
         compiler = self.engine.dialect.compiler(tocompile, parameters)
         compiler.compile()
         return compiler
 
+    def visit_check_constraint(self, constraint):
+        self.append(", \n\t")
+        if constraint.name is not None:
+            self.append("CONSTRAINT %s " % constraint.name)
+        self.append(" CHECK (%s)" % constraint.sqltext)
+        
     def visit_primary_key_constraint(self, constraint):
         if len(constraint) == 0:
             return
@@ -688,6 +702,13 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         if constraint.onupdate is not None:
             self.append(" ON UPDATE %s" % constraint.onupdate)
 
+    def visit_unique_constraint(self, constraint):
+        self.append(", \n\t")
+        if constraint.name is not None:
+            self.append("CONSTRAINT %s " % constraint.name)
+        self.append(" UNIQUE ")
+        self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
+
     def visit_column(self, column):
         pass
 
@@ -701,21 +722,29 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         self.execute()
         
 class ANSISchemaDropper(engine.SchemaIterator):
-    def __init__(self, engine, proxy, connection=None, checkfirst=False, **params):
-        super(ANSISchemaDropper, self).__init__(engine, proxy, **params)
+    def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
+        super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
         self.checkfirst = checkfirst
+        self.tables = tables
         self.connection = connection
         self.preparer = self.engine.dialect.preparer()
+        self.dialect = self.engine.dialect
+
+    def visit_metadata(self, metadata):
+        for table in metadata.table_iterator(reverse=True, tables=self.tables):
+            if self.checkfirst and not self.dialect.has_table(self.connection, table.name):
+                continue
+            table.accept_schema_visitor(self, traverse=False)
 
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.name)
         self.execute()
         
     def visit_table(self, table):
-        # NOTE: indexes on the table will be automatically dropped, so
-        # no need to drop them individually
-        if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name):
-            return
+        for column in table.columns:
+            if column.default is not None:
+                column.default.accept_schema_visitor(self, traverse=False)
+
         self.append("\nDROP TABLE " + self.preparer.format_table(table))
         self.execute()
 
index fa090a89e5cf354c276f510508eb5604bdd255ca..f38a24b1f877c56f5f5fa33e8b999c7a8f876e36 100644 (file)
@@ -253,7 +253,7 @@ class FireBirdDialect(ansisql.ANSIDialect):
             # is it a primary key?
             kw['primary_key'] = name in pkfields
 
-            table.append_item(schema.Column(*args, **kw))
+            table.append_column(schema.Column(*args, **kw))
             row = c.fetchone()
 
         # get the foreign keys
@@ -276,7 +276,7 @@ class FireBirdDialect(ansisql.ANSIDialect):
             fk[1].append(refspec)
 
         for name,value in fks.iteritems():
-            table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name))
+            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
                               
 
     def last_inserted_ids(self):
index 291637e9e5590da9da51246f61d6ecc2c2a09d49..5a7369ccda1701ac3bc2396c692b38f451692513 100644 (file)
@@ -144,7 +144,7 @@ def reflecttable(connection, table, ischema_names):
         colargs= []
         if default is not None:
             colargs.append(PassiveDefault(sql.text(default)))
-        table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
+        table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
     
     if not found_table:
         raise exceptions.NoSuchTableError(table.name)
@@ -175,7 +175,7 @@ def reflecttable(connection, table, ischema_names):
         )
         #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) 
         if type=='PRIMARY KEY':
-            table.c[constrained_column]._set_primary_key()
+            table.primary_key.add(table.c[constrained_column])
         elif type=='FOREIGN KEY':
             try:
                 fk = fks[constraint_name]
@@ -196,5 +196,5 @@ def reflecttable(connection, table, ischema_names):
                 fk[1].append(refspec)
     
     for name, value in fks.iteritems():
-        table.append_item(ForeignKeyConstraint(value[0], value[1], name=name))    
+        table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name))    
             
index 3d65abf0c52b64f3f9b3c9c77362b9f1b154406c..d23c417306df519c8034613c703de0d3b1f26523 100644 (file)
@@ -446,7 +446,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
                 
-            table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
+            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
         
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
@@ -478,7 +478,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
         c = connection.execute(s)
         for row in c:
             if 'PRIMARY' in row[TC.c.constraint_type.name]:
-                table.c[row[0]]._set_primary_key()
+                table.primary_key.add(table.c[row[0]])
 
 
         # Foreign key constraints
@@ -498,13 +498,13 @@ class MSSQLDialect(ansisql.ANSIDialect):
             scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
             if rfknm != fknm:
                 if fknm:
-                    table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
+                    table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
                 fknm, scols, rcols = (rfknm, [], [])
             if (not scol in scols): scols.append(scol)
             if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol))
 
         if fknm and scols:
-            table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
+            table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
                                 
 
 
index 4443814c56e0f980c7310bfbb0fcda8fbe639ae9..2fa7e9227f42e41c45cb89219d7ece2807b22ca6 100644 (file)
@@ -353,7 +353,7 @@ class MySQLDialect(ansisql.ANSIDialect):
             colargs= []
             if default:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
-            table.append_item(schema.Column(name, coltype, *colargs, 
+            table.append_column(schema.Column(name, coltype, *colargs, 
                                             **dict(primary_key=primary_key,
                                                    nullable=nullable,
                                                    )))
@@ -397,7 +397,7 @@ class MySQLDialect(ansisql.ANSIDialect):
             refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))]
             schema.Table(match.group('reftable'), table.metadata, autoload=True, autoload_with=connection)
             constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name'))
-            table.append_item(constraint)
+            table.append_constraint(constraint)
 
         return tabletype
         
index db82e3dea8ef6b46616d937cb922caa50ffe2969..b9aa096952b072a7ae6cc6f70da6cedaf9d354af 100644 (file)
@@ -256,7 +256,7 @@ class OracleDialect(ansisql.ANSIDialect):
             if (name.upper() == name): 
                 name = name.lower()
             
-            table.append_item (schema.Column(name, coltype, nullable=nullable, *colargs))
+            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
 
        
         c = connection.execute(constraintSQL, {'table_name' : table.name.upper(), 'owner' : owner})
@@ -268,7 +268,7 @@ class OracleDialect(ansisql.ANSIDialect):
             #print "ROW:" , row                
             (cons_name, cons_type, local_column, remote_table, remote_column) = row
             if cons_type == 'P':
-                table.c[local_column]._set_primary_key()
+                table.primary_key.add(table.c[local_column])
             elif cons_type == 'R':
                 try:
                     fk = fks[cons_name]
@@ -283,7 +283,7 @@ class OracleDialect(ansisql.ANSIDialect):
                     fk[1].append(refspec)
 
         for name, value in fks.iteritems():
-            table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name))
+            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
 
     def do_executemany(self, c, statement, parameters, context=None):
         rowcount = 0
index a28a22cd62da88f53e4c16753f3f5e366fd0f425..dad2d3bff2d9b1a710fc3e161571b770b2b20d62 100644 (file)
@@ -370,7 +370,7 @@ class PGDialect(ansisql.ANSIDialect):
                 colargs= []
                 if default is not None:
                     colargs.append(PassiveDefault(sql.text(default)))
-                table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs))
+                table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
     
     
             if not found_table:
@@ -392,7 +392,7 @@ class PGDialect(ansisql.ANSIDialect):
                 if row is None:
                     break
                 pk = row[0]
-                table.c[pk]._set_primary_key()
+                table.primary_key.add(table.c[pk])
     
             # Foreign keys
             FK_SQL = """
@@ -443,7 +443,7 @@ class PGDialect(ansisql.ANSIDialect):
                     for column in referred_columns:
                         refspec.append(".".join([referred_table, column]))
                 
-                table.append_item(ForeignKeyConstraint(constrained_columns, refspec, row['conname']))
+                table.append_constraint(ForeignKeyConstraint(constrained_columns, refspec, row['conname']))
 
 class PGCompiler(ansisql.ANSICompiler):
         
@@ -502,13 +502,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         return colspec
 
     def visit_sequence(self, sequence):
-        if not sequence.optional and not self.engine.dialect.has_sequence(self.connection, sequence.name):
+        if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)):
             self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
             self.execute()
             
 class PGSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_sequence(self, sequence):
-        if not sequence.optional and self.engine.dialect.has_sequence(self.connection, sequence.name):
+        if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)):
             self.append("DROP SEQUENCE %s" % sequence.name)
             self.execute()
 
index 80d5a7d2af5fb699ad402e1984336e67005ae674..90cd66dd3c8bd88f9ecf2d0bb527b05f8e1dfc2a 100644 (file)
@@ -199,7 +199,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
             colargs= []
             if has_default:
                 colargs.append(PassiveDefault('?'))
-            table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
+            table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
         
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
@@ -228,7 +228,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
             if refspec not in fk[1]:
                 fk[1].append(refspec)
         for name, value in fks.iteritems():
-            table.append_item(schema.ForeignKeyConstraint(value[0], value[1]))    
+            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1]))    
         # check for UNIQUE indexes
         c = connection.execute("PRAGMA index_list(" + table.name + ")", {})
         unique_indexes = []
@@ -250,8 +250,7 @@ class SQLiteDialect(ansisql.ANSIDialect):
                 col = table.columns[row[2]]
             # unique index that includes the pk is considered a multiple primary key
             for col in cols:
-                column = table.columns[col]
-                table.columns[col]._set_primary_key()
+                table.primary_key.add(table.columns[col])
                     
 class SQLiteCompiler(ansisql.ANSICompiler):
     def visit_cast(self, cast):
index 6d0cf2eb36b93cd0fa81b79c997a7f1258a943ae..4ba5e111582e14c47e8c11bd5eb47ec494171455 100644 (file)
@@ -421,7 +421,7 @@ class ComposedSQLEngine(sql.Engine, Connectable):
         else:
             conn = connection
         try:
-            element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs))
+            element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs), traverse=False)
         finally:
             if connection is None:
                 conn.close()
index 5afd3e1b672464195bdee0ede339d278035a10b7..462c5e799108f112c5d91c74b000b60f01b7873a 100644 (file)
@@ -366,10 +366,8 @@ class Query(object):
             if not distinct and order_by:
                 s2.order_by(*util.to_list(order_by))
             s3 = s2.alias('tbl_row_count')
-            crit = []
-            for i in range(0, len(self.table.primary_key)):
-                crit.append(s3.primary_key[i] == self.table.primary_key[i])
-            statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update)
+            crit = s3.primary_key==self.table.primary_key
+            statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update)
             # now for the order by, convert the columns to their corresponding columns
             # in the "rowcount" query, and tack that new order by onto the "rowcount" query
             if order_by:
index 05753e424fbc40f0a80c1e4857dce600bc9b0fbf..5728d7c375ce38991689bd78c60d912352de9adf 100644 (file)
@@ -19,7 +19,7 @@ import sqlalchemy
 import copy, re, string
 
 __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
-            'PrimaryKeyConstraint', 
+            'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint',
            'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
 class SchemaItem(object):
@@ -99,36 +99,33 @@ def _get_table_key(name, schema):
 class TableSingleton(type):
     """a metaclass used by the Table object to provide singleton behavior."""
     def __call__(self, name, metadata, *args, **kwargs):
+        if isinstance(metadata, sql.Engine):
+            # backwards compatibility - get a BoundSchema associated with the engine
+            engine = metadata
+            if not hasattr(engine, '_legacy_metadata'):
+                engine._legacy_metadata = BoundMetaData(engine)
+            metadata = engine._legacy_metadata
+        elif metadata is not None and not isinstance(metadata, MetaData):
+            # they left MetaData out, so assume its another SchemaItem, add it to *args
+            args = list(args)
+            args.insert(0, metadata)
+            metadata = None
+            
+        if metadata is None:
+            metadata = default_metadata
+            
+        name = str(name)    # in case of incoming unicode
+        schema = kwargs.get('schema', None)
+        autoload = kwargs.pop('autoload', False)
+        autoload_with = kwargs.pop('autoload_with', False)
+        mustexist = kwargs.pop('mustexist', False)
+        useexisting = kwargs.pop('useexisting', False)
+        key = _get_table_key(name, schema)
         try:
-            if isinstance(metadata, sql.Engine):
-                # backwards compatibility - get a BoundSchema associated with the engine
-                engine = metadata
-                if not hasattr(engine, '_legacy_metadata'):
-                    engine._legacy_metadata = BoundMetaData(engine)
-                metadata = engine._legacy_metadata
-            elif metadata is not None and not isinstance(metadata, MetaData):
-                # they left MetaData out, so assume its another SchemaItem, add it to *args
-                args = list(args)
-                args.insert(0, metadata)
-                metadata = None
-                
-            if metadata is None:
-                metadata = default_metadata
-                
-            name = str(name)    # in case of incoming unicode
-            schema = kwargs.get('schema', None)
-            autoload = kwargs.pop('autoload', False)
-            autoload_with = kwargs.pop('autoload_with', False)
-            redefine = kwargs.pop('redefine', False)
-            mustexist = kwargs.pop('mustexist', False)
-            useexisting = kwargs.pop('useexisting', False)
-            key = _get_table_key(name, schema)
             table = metadata.tables[key]
             if len(args):
-                if redefine:
-                    table._reload_values(*args)
-                elif not useexisting:
-                    raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name))
+                if not useexisting:
+                    raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name))
             return table
         except KeyError:
             if mustexist:
@@ -145,7 +142,7 @@ class TableSingleton(type):
                     else:
                         metadata.get_engine().reflecttable(table)
                 except exceptions.NoSuchTableError:
-                    table.deregister()
+                    del metadata.tables[key]
                     raise
             # initialize all the column, etc. objects.  done after
             # reflection to allow user-overrides
@@ -210,8 +207,8 @@ class Table(SchemaItem, sql.TableClause):
         super(Table, self).__init__(name)
         self._metadata = metadata
         self.schema = kwargs.pop('schema', None)
-        self.indexes = util.OrderedProperties()
-        self.constraints = []
+        self.indexes = util.Set()
+        self.constraints = util.Set()
         self.primary_key = PrimaryKeyConstraint()
         self.quote = kwargs.get('quote', False)
         self.quote_schema = kwargs.get('quote_schema', False)
@@ -237,7 +234,7 @@ class Table(SchemaItem, sql.TableClause):
         if getattr(self, '_primary_key', None) in self.constraints:
             self.constraints.remove(self._primary_key)
         self._primary_key = pk
-        self.constraints.append(pk)
+        self.constraints.add(pk)
     primary_key = property(lambda s:s._primary_key, _set_primary_key)
     
     def _derived_metadata(self):
@@ -251,93 +248,45 @@ class Table(SchemaItem, sql.TableClause):
     
     def __str__(self):
         return _get_table_key(self.name, self.schema)
-        
-    def _reload_values(self, *args):
-        """clear out the columns and other properties of this Table, and reload them from the 
-        given argument list.  This is used with the "redefine" keyword argument sent to the
-        metaclass constructor."""
-        self._clear()
-        
-        self._init_items(*args)
 
-    def append_item(self, item):
-        """appends a Column item or other schema item to this Table."""
-        self._init_items(item)
-    
     def append_column(self, column):
-        if not column.hidden:
-            self._columns[column.key] = column
-        if column.primary_key:
-            self.primary_key.append(column)
-        column.table = self
+        """append a Column to this Table."""
+        column._set_parent(self)
+    def append_constraint(self, constraint):
+        """append a Constraint to this Table."""
+        constraint._set_parent(self)
 
-    def append_index(self, index):
-        self.indexes[index.name] = index
-    
     def _get_parent(self):
         return self._metadata    
     def _set_parent(self, metadata):
         metadata.tables[_get_table_key(self.name, self.schema)] = self
         self._metadata = metadata
-    def accept_schema_visitor(self, visitor): 
-        """traverses the given visitor across the Column objects inside this Table,
-        then calls the visit_table method on the visitor."""
-        for c in self.columns:
-            c.accept_schema_visitor(visitor)
+    def accept_schema_visitor(self, visitor, traverse=True): 
+        if traverse:
+            for c in self.columns:
+                c.accept_schema_visitor(visitor, True)
         return visitor.visit_table(self)
 
-    def append_index_column(self, column, index=None, unique=None):
-        """Add an index or a column to an existing index of the same name.
-        """
-        if index is not None and unique is not None:
-            raise ValueError("index and unique may not both be specified")
-        if index:
-            if index is True:
-                name = 'ix_%s' % column._label
-            else:
-                name = index
-        elif unique:
-            if unique is True:
-                name = 'ux_%s' % column._label
-            else:
-                name = unique
-        # find this index in self.indexes
-        # add this column to it if found
-        # otherwise create new
-        try:
-            index = self.indexes[name]
-            index.append_column(column)
-        except KeyError:
-            index = Index(name, column, unique=unique)
-        return index
-    
-    def deregister(self):
-        """remove this table from it's owning metadata.  
-        
-        this does not issue a SQL DROP statement."""
-        key = _get_table_key(self.name, self.schema)
-        del self.metadata.tables[key]
-        
-    def exists(self, engine=None):
-        if engine is None:
-            engine = self.get_engine()
+    def exists(self, connectable=None):
+        """return True if this table exists."""
+        if connectable is None:
+            connectable = self.get_engine()
 
         def do(conn):
             e = conn.engine
             return e.dialect.has_table(conn, self.name)
-        return engine.run_callable(do)
+        return connectable.run_callable(do)
 
     def create(self, connectable=None, checkfirst=False):
-        if connectable is not None:
-            connectable.create(self, checkfirst=checkfirst)
-        else:
-            self.get_engine().create(self, checkfirst=checkfirst)
-        return self
+        """issue a CREATE statement for this table.
+        
+        see also metadata.create_all()."""
+        self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
     def drop(self, connectable=None, checkfirst=False):
-        if connectable is not None:
-            connectable.drop(self, checkfirst=checkfirst)
-        else:
-            self.get_engine().drop(self, checkfirst=checkfirst)
+        """issue a DROP statement for this table.
+        
+        see also metadata.drop_all()."""
+        self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
     def tometadata(self, metadata, schema=None):
         """return a copy of this Table associated with a different MetaData."""
         try:
@@ -389,17 +338,16 @@ class Column(SchemaItem, sql.ColumnClause):
         table's list of columns.  Used for the "oid" column, which generally
         isnt in column lists.
 
-        index=None : True or index name. Indicates that this column is
-        indexed. Pass true to autogenerate the index name. Pass a string to
-        specify the index name. Multiple columns that specify the same index
-        name will all be included in the index, in the order of their
-        creation.
+        index=False : Indicates that this column is
+        indexed. The name of the index is autogenerated.
+        to specify indexes with explicit names or indexes that contain multiple 
+        columns, use the Index construct instead.
 
-        unique=None : True or index name. Indicates that this column is
-        indexed in a unique index . Pass true to autogenerate the index
-        name. Pass a string to specify the index name. Multiple columns that
-        specify the same index name will all be included in the index, in the
-        order of their creation.
+        unique=False : Indicates that this column 
+        contains a unique constraint, or if index=True as well, indicates
+        that the Index should be created with the unique flag.
+        To specify multiple columns in the constraint/index or to specify an 
+        explicit name, use the UniqueConstraint or Index constructs instead.
 
         autoincrement=True : Indicates that integer-based primary key columns should have autoincrementing behavior,
         if supported by the underlying database.  This will affect CREATE TABLE statements such that they will
@@ -430,9 +378,8 @@ class Column(SchemaItem, sql.ColumnClause):
         self._set_casing_strategy(name, kwargs)
         self.onupdate = kwargs.pop('onupdate', None)
         self.autoincrement = kwargs.pop('autoincrement', True)
+        self.constraints = util.Set()
         self.__originating_column = self
-        if self.index is not None and self.unique is not None:
-            raise exceptions.ArgumentError("Column may not define both index and unique")
         self._foreign_keys = util.Set()
         if len(kwargs):
             raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
@@ -455,7 +402,10 @@ class Column(SchemaItem, sql.ColumnClause):
         return self.table.metadata
     def _get_engine(self):
         return self.table.engine
-        
+    
+    def append_foreign_key(self, fk):
+        fk._set_parent(self)
+            
     def __repr__(self):
        return "Column(%s)" % string.join(
         [repr(self.name)] + [repr(self.type)] +
@@ -463,33 +413,33 @@ class Column(SchemaItem, sql.ColumnClause):
         ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']]
        , ',')
         
-    def append_item(self, item):
-        self._init_items(item)
-        
-    def _set_primary_key(self):
-        if self.primary_key:
-            return
-        self.primary_key = True
-        self.nullable = False
-        self.table.primary_key.append(self)
-    
     def _get_parent(self):
         return self.table        
+
     def _set_parent(self, table):
         if getattr(self, 'table', None) is not None:
             raise exceptions.ArgumentError("this Column already has a table!")
-        table.append_column(self)
-        if self.index or self.unique:
-            table.append_index_column(self, index=self.index,
-                                      unique=self.unique)
-        
+        if not self.hidden:
+            table._columns.add(self)
+        if self.primary_key:
+            table.primary_key.add(self)
+        self.table = table
+
+        if self.index:
+            if isinstance(self.index, str):
+                raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only.  To create indexes with a specific name, append an explicit Index object to the Table's list of elements.")
+            Index('ix_%s' % self._label, self, unique=self.unique)
+        elif self.unique:
+            if isinstance(self.unique, str):
+                raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only.  To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.")
+            table.append_constraint(UniqueConstraint(self.key))
+            
+        toinit = list(self.args)
         if self.default is not None:
-            self.default = ColumnDefault(self.default)
-            self._init_items(self.default)
+            toinit.append(ColumnDefault(self.default))
         if self.onupdate is not None:
-            self.onupdate = ColumnDefault(self.onupdate, for_update=True)
-            self._init_items(self.onupdate)
-        self._init_items(*self.args)
+            toinit.append(ColumnDefault(self.onupdate, for_update=True))
+        self._init_items(*toinit)
         self.args = None
 
     def copy(self): 
@@ -507,9 +457,9 @@ class Column(SchemaItem, sql.ColumnClause):
         c.orig_set = self.orig_set
         c.__originating_column = self.__originating_column
         if not c.hidden:
-            selectable.columns[c.key] = c
+            selectable.columns.add(c)
             if self.primary_key:
-                selectable.primary_key.append(c)
+                selectable.primary_key.add(c)
         [c._init_items(f) for f in fk]
         return c
 
@@ -519,15 +469,18 @@ class Column(SchemaItem, sql.ColumnClause):
         return self.__originating_column._get_case_sensitive()
     case_sensitive = property(_case_sens)
     
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         """traverses the given visitor to this Column's default and foreign key object,
         then calls visit_column on the visitor."""
-        if self.default is not None:
-            self.default.accept_schema_visitor(visitor)
-        if self.onupdate is not None:
-            self.onupdate.accept_schema_visitor(visitor)
-        for f in self.foreign_keys:
-            f.accept_schema_visitor(visitor)
+        if traverse:
+            if self.default is not None:
+                self.default.accept_schema_visitor(visitor, traverse=True)
+            if self.onupdate is not None:
+                self.onupdate.accept_schema_visitor(visitor, traverse=True)
+            for f in self.foreign_keys:
+                f.accept_schema_visitor(visitor, traverse=True)
+            for constraint in self.constraints:
+                constraint.accept_schema_visitor(visitor, traverse=True)
         visitor.visit_column(self)
 
 
@@ -538,7 +491,7 @@ class ForeignKey(SchemaItem):
     
     One or more ForeignKey objects are used within a ForeignKeyConstraint
     object which represents the table-level constraint definition."""
-    def __init__(self, column, constraint=None):
+    def __init__(self, column, constraint=None, use_alter=False):
         """Construct a new ForeignKey object.  
         
         "column" can be a schema.Column object representing the relationship, 
@@ -553,6 +506,7 @@ class ForeignKey(SchemaItem):
         self._colspec = column
         self._column = None
         self.constraint = constraint
+        self.use_alter = use_alter
         
     def __repr__(self):
         return "ForeignKey(%s)" % repr(self._get_colspec())
@@ -611,7 +565,7 @@ class ForeignKey(SchemaItem):
             
     column = property(lambda s: s._init_column())
 
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         """calls the visit_foreign_key method on the given visitor."""
         visitor.visit_foreign_key(self)
   
@@ -621,17 +575,13 @@ class ForeignKey(SchemaItem):
         self.parent = column
 
         if self.constraint is None and isinstance(self.parent.table, Table):
-            self.constraint = ForeignKeyConstraint([],[])
-            self.parent.table.append_item(self.constraint)
+            self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter)
+            self.parent.table.append_constraint(self.constraint)
             self.constraint._append_fk(self)
 
-        # if a foreign key was already set up for the parent column, replace it with 
-        # this one
-        #if self.parent.foreign_key is not None:
-        #    self.parent.table.foreign_keys.remove(self.parent.foreign_key)
-        #self.parent.foreign_key = self
         self.parent.foreign_keys.add(self)
         self.parent.table.foreign_keys.add(self)
+
 class DefaultGenerator(SchemaItem):
     """Base class for column "default" values."""
     def __init__(self, for_update=False, metadata=None):
@@ -661,7 +611,7 @@ class PassiveDefault(DefaultGenerator):
     def __init__(self, arg, **kwargs):
         super(PassiveDefault, self).__init__(**kwargs)
         self.arg = arg
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         return visitor.visit_passive_default(self)
     def __repr__(self):
         return "PassiveDefault(%s)" % repr(self.arg)
@@ -672,7 +622,7 @@ class ColumnDefault(DefaultGenerator):
     def __init__(self, arg, **kwargs):
         super(ColumnDefault, self).__init__(**kwargs)
         self.arg = arg
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         """calls the visit_column_default method on the given visitor."""
         if self.for_update:
             return visitor.visit_column_onupdate(self)
@@ -704,57 +654,66 @@ class Sequence(DefaultGenerator):
        return self
     def drop(self):
        self.get_engine().drop(self)
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         """calls the visit_seauence method on the given visitor."""
         return visitor.visit_sequence(self)
 
 class Constraint(SchemaItem):
     """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint.
     
-    Also follows list behavior with regards to the underlying set of columns."""
+    Implements a hybrid of dict/setlike behavior with regards to the list of underying columns"""
     def __init__(self, name=None):
         self.name = name
-        self.columns = []
+        self.columns = sql.ColumnCollection()
     def __contains__(self, x):
         return x in self.columns
+    def keys(self):
+        return self.columns.keys()
     def __add__(self, other):
         return self.columns + other
     def __iter__(self):
         return iter(self.columns)
     def __len__(self):
         return len(self.columns)
-    def __getitem__(self, index):
-        return self.columns[index]
-    def __setitem__(self, index, item):
-        self.columns[index] = item
     def copy(self):
         raise NotImplementedError()
     def _get_parent(self):
         return getattr(self, 'table', None)
-        
+
+class CheckConstraint(Constraint):
+    def __init__(self, sqltext, name=None):
+        super(CheckConstraint, self).__init__(name)
+        self.sqltext = sqltext
+    def accept_schema_visitor(self, visitor, traverse=True):
+        visitor.visit_check_constraint(self)
+    def _set_parent(self, parent):
+        self.parent = parent
+        parent.constraints.add(self)
+                
 class ForeignKeyConstraint(Constraint):
     """table-level foreign key constraint, represents a colleciton of ForeignKey objects."""
-    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None):
+    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False):
         super(ForeignKeyConstraint, self).__init__(name)
         self.__colnames = columns
         self.__refcolnames = refcolumns
-        self.elements = []
+        self.elements = util.Set()
         self.onupdate = onupdate
         self.ondelete = ondelete
+        self.use_alter = use_alter
     def _set_parent(self, table):
         self.table = table
-        table.constraints.append(self)
+        table.constraints.add(self)
         for (c, r) in zip(self.__colnames, self.__refcolnames):
-            self.append(c,r)
-    def accept_schema_visitor(self, visitor):
+            self.append_element(c,r)
+    def accept_schema_visitor(self, visitor, traverse=True):
         visitor.visit_foreign_key_constraint(self)
-    def append(self, col, refcol):
+    def append_element(self, col, refcol):
         fk = ForeignKey(refcol, constraint=self)
         fk._set_parent(self.table.c[col])
         self._append_fk(fk)
     def _append_fk(self, fk):
-        self.columns.append(self.table.c[fk.parent.key])
-        self.elements.append(fk)
+        self.columns.add(self.table.c[fk.parent.key])
+        self.elements.add(fk)
     def copy(self):
         return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete)
                         
@@ -766,37 +725,37 @@ class PrimaryKeyConstraint(Constraint):
         self.table = table
         table.primary_key = self
         for c in self.__colnames:
-            self.append(table.c[c])
-    def accept_schema_visitor(self, visitor):
+            self.append_column(table.c[c])
+    def accept_schema_visitor(self, visitor, traverse=True):
         visitor.visit_primary_key_constraint(self)
-    def append(self, col):
-        # TODO: change "columns" to a key-sensitive set ?
-        for c in self.columns:
-            if c.key == col.key:
-                self.columns.remove(c)
-        self.columns.append(col)
+    def add(self, col):
+        self.append_column(col)
+    def append_column(self, col):
+        self.columns.add(col)
         col.primary_key=True
     def copy(self):
         return PrimaryKeyConstraint(name=self.name, *[c.key for c in self])
-            
+    def __eq__(self, other):
+        return self.columns == other
+                
 class UniqueConstraint(Constraint):
-    def __init__(self, name=None, *columns):
-        super(Constraint, self).__init__(name)
+    def __init__(self, *columns, **kwargs):
+        super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None))
         self.__colnames = list(columns)
     def _set_parent(self, table):
         self.table = table
-        table.constraints.append(self)
+        table.constraints.add(self)
         for c in self.__colnames:
-            self.append(table.c[c])
-    def append(self, col):
-        self.columns.append(col)
-    def accept_schema_visitor(self, visitor):
+            self.append_column(table.c[c])
+    def append_column(self, col):
+        self.columns.add(col)
+    def accept_schema_visitor(self, visitor, traverse=True):
         visitor.visit_unique_constraint(self)
         
 class Index(SchemaItem):
     """Represents an index of columns from a database table
     """
-    def __init__(self, name, *columns, **kw):
+    def __init__(self, name, *columns, **kwargs):
         """Constructs an index object. Arguments are:
 
         name : the name of the index
@@ -811,7 +770,7 @@ class Index(SchemaItem):
         self.name = name
         self.columns = []
         self.table = None
-        self.unique = kw.pop('unique', False)
+        self.unique = kwargs.pop('unique', False)
         self._init_items(*columns)
 
     def _derived_metadata(self):
@@ -821,12 +780,15 @@ class Index(SchemaItem):
             self.append_column(column)
     def _get_parent(self):
         return self.table    
+    def _set_parent(self, table):
+        self.table = table
+        table.indexes.add(self)
+
     def append_column(self, column):
         # make sure all columns are from the same table
         # and no column is repeated
         if self.table is None:
-            self.table = column.table
-            self.table.append_index(self)
+            self._set_parent(column.table)
         elif column.table != self.table:
             # all columns muse be from same table
             raise exceptions.ArgumentError("All index columns must be from same table. "
@@ -850,7 +812,7 @@ class Index(SchemaItem):
             connectable.drop(self)
         else:
             self.get_engine().drop(self)
-    def accept_schema_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor, traverse=True):
         visitor.visit_index(self)
     def __str__(self):
         return repr(self)
@@ -863,7 +825,6 @@ class Index(SchemaItem):
 class MetaData(SchemaItem):
     """represents a collection of Tables and their associated schema constructs."""
     def __init__(self, name=None, **kwargs):
-        # a dictionary that stores Table objects keyed off their name (and possibly schema name)
         self.tables = {}
         self.name = name
         self._set_casing_strategy(name, kwargs)
@@ -871,11 +832,18 @@ class MetaData(SchemaItem):
         return False
     def clear(self):
         self.tables.clear()
-    def table_iterator(self, reverse=True):
-        return self._sort_tables(self.tables.values(), reverse=reverse)
+
+    def table_iterator(self, reverse=True, tables=None):
+        import sqlalchemy.sql_util
+        if tables is None:
+            tables = self.tables.values()
+        else:
+            tables = util.Set(tables).intersection(self.tables.values())
+        sorter = sqlalchemy.sql_util.TableCollection(list(tables))
+        return iter(sorter.sort(reverse=reverse))
     def _get_parent(self):
         return None    
-    def create_all(self, connectable=None, tables=None, engine=None):
+    def create_all(self, connectable=None, tables=None, checkfirst=True):
         """create all tables stored in this metadata.
         
         This will conditionally create tables depending on if they do not yet
@@ -884,28 +852,13 @@ class MetaData(SchemaItem):
         connectable - a Connectable used to access the database; or use the engine
         bound to this MetaData.
         
-        tables - optional list of tables to create
-        
-        engine - deprecated argument."""
-        if not tables:
-            tables = self.tables.values()
-
-        if connectable is None:
-            connectable = engine
-            
+        tables - optional list of tables, which is a subset of the total
+        tables in the MetaData (others are ignored)"""
         if connectable is None:
             connectable = self.get_engine()
-
-        def do(conn):
-            e = conn.engine
-            ts = self._sort_tables( tables )
-            for table in ts:
-                if e.dialect.has_table(conn, table.name):
-                    continue
-                conn.create(table)
-        connectable.run_callable(do)
+        connectable.create(self, checkfirst=checkfirst, tables=tables)
         
-    def drop_all(self, connectable=None, tables=None, engine=None):
+    def drop_all(self, connectable=None, tables=None, checkfirst=True):
         """drop all tables stored in this metadata.
         
         This will conditionally drop tables depending on if they currently 
@@ -914,33 +867,17 @@ class MetaData(SchemaItem):
         connectable - a Connectable used to access the database; or use the engine
         bound to this MetaData.
         
-        tables - optional list of tables to drop
-        
-        engine - deprecated argument."""
-        if not tables:
-            tables = self.tables.values()
-
-        if connectable is None:
-            connectable = engine
-
+        tables - optional list of tables, which is a subset of the total
+        tables in the MetaData (others are ignored)
+        """
         if connectable is None:
             connectable = self.get_engine()
-        
-        def do(conn):
-            e = conn.engine
-            ts = self._sort_tables( tables, reverse=True )
-            for table in ts:
-                if e.dialect.has_table(conn, table.name):
-                    conn.drop(table)
-        connectable.run_callable(do)
+        connectable.drop(self, checkfirst=checkfirst, tables=tables)
                 
-    def _sort_tables(self, tables, reverse=False):
-        import sqlalchemy.sql_util
-        sorter = sqlalchemy.sql_util.TableCollection()
-        for t in tables:
-            sorter.add(t)
-        return sorter.sort(reverse=reverse)
-        
+    
+    def accept_schema_visitor(self, visitor, traverse=True):
+        visitor.visit_metadata(self)
+            
     def _derived_metadata(self):
         return self
     def _get_engine(self):
@@ -1029,6 +966,8 @@ class SchemaVisitor(sql.ClauseVisitor):
         pass
     def visit_unique_constraint(self, constraint):
         pass
+    def visit_check_constraint(self, constraint):
+        pass
         
 default_metadata = DynamicMetaData('default')
 
index c113edaa32bded4ded14b66c91e3348065ad2f34..6f51ccbe9940d7b92177411ef5a2797c2b943fb4 100644 (file)
@@ -658,6 +658,17 @@ class ColumnElement(Selectable, CompareMixin):
         else:
             return self
 
+class ColumnCollection(util.OrderedProperties):
+    def add(self, column):
+        self[column.key] = column
+    def __eq__(self, other):
+        l = []
+        for c in other:
+            for local in self:
+                if c.shares_lineage(local):
+                    l.append(c==local)
+        return and_(*l)
+             
 class FromClause(Selectable):
     """represents an element that can be used within the FROM clause of a SELECT statement."""
     def __init__(self, name=None):
@@ -671,7 +682,7 @@ class FromClause(Selectable):
         visitor.visit_fromclause(self)
     def count(self, whereclause=None, **params):
         if len(self.primary_key):
-            col = self.primary_key[0]
+            col = list(self.primary_key)[0]
         else:
             col = list(self.columns)[0]
         return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
@@ -735,8 +746,8 @@ class FromClause(Selectable):
         if hasattr(self, '_columns'):
             # TODO: put a mutex here ?  this is a key place for threading probs
             return
-        self._columns = util.OrderedProperties()
-        self._primary_key = []
+        self._columns = ColumnCollection()
+        self._primary_key = ColumnCollection()
         self._foreign_keys = util.Set()
         self._orig_cols = {}
         export = self._exportable_columns()
@@ -1082,7 +1093,7 @@ class Join(FromClause):
     def _proxy_column(self, column):
         self._columns[column._label] = column
         if column.primary_key:
-            self._primary_key.append(column)
+            self._primary_key.add(column)
         for f in column.foreign_keys:
             self._foreign_keys.add(f)
         return column
@@ -1257,9 +1268,9 @@ class TableClause(FromClause):
     def __init__(self, name, *columns):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
-        self._columns = util.OrderedProperties()
+        self._columns = ColumnCollection()
         self._foreign_keys = util.Set()
-        self._primary_key = []
+        self._primary_key = util.Set()
         for c in columns:
             self.append_column(c)
         self._oid_column = ColumnClause('oid', self, hidden=True)
@@ -1282,16 +1293,6 @@ class TableClause(FromClause):
             return self._orig_cols
     original_columns = property(_orig_columns)
 
-    def _clear(self):
-        """clears all attributes on this TableClause so that new items can be added again"""
-        self.columns.clear()
-        self.foreign_keys[:] = []
-        self.primary_key[:] = []
-        try:
-            delattr(self, '_orig_cols')
-        except AttributeError:
-            pass
-
     def accept_visitor(self, visitor):
         visitor.visit_table(self)
     def _exportable_columns(self):
@@ -1305,7 +1306,7 @@ class TableClause(FromClause):
             data[self] = self
     def count(self, whereclause=None, **params):
         if len(self.primary_key):
-            col = self.primary_key[0]
+            col = list(self.primary_key)[0]
         else:
             col = list(self.columns)[0]
         return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params)
index 5f243ae04872d15e933fd7b5a4af5ac9069b5380..d5c6a3b92bca6eb6f6f05af49b019da5aa58ca69 100644 (file)
@@ -87,6 +87,8 @@ class OrderedProperties(object):
         return len(self.__data)
     def __iter__(self):
         return self.__data.itervalues()
+    def __add__(self, other):
+        return list(self) + list(other)
     def __setitem__(self, key, object):
         self.__data[key] = object
     def __getitem__(self, key):
index 45010111e4ad0ac2f207d117ef69071e902e817d..469aab20ebe3a09758d1bce7f26ebbe9470ad333 100644 (file)
@@ -59,8 +59,6 @@ class ReflectionTest(PersistTest):
             mysql_engine='InnoDB'
         )
 
-#        users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id))
-
         users.create()
         addresses.create()
 
@@ -154,6 +152,7 @@ class ReflectionTest(PersistTest):
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
 
+            print "ITS", list(a2.primary_key)
             assert list(a2.primary_key) == [a2.c.id]
             assert list(u2.primary_key) == [u2.c.id]
             assert u2.join(a2).onclause == u2.c.id==a2.c.id
@@ -226,19 +225,19 @@ class ReflectionTest(PersistTest):
             
     def testmultipk(self):
         """test that creating a table checks for a sequence before creating it"""
+        meta = BoundMetaData(testbase.db)
         table = Table(
-            'engine_multi', testbase.db
+            'engine_multi', meta
             Column('multi_id', Integer, Sequence('multi_id_seq'), primary_key=True),
             Column('multi_rev', Integer, Sequence('multi_rev_seq'), primary_key=True),
             Column('name', String(50), nullable=False),
             Column('val', String(100))
         )
         table.create()
-        # clear out table registry
-        table.deregister()
 
+        meta2 = BoundMetaData(testbase.db)
         try:
-            table = Table('engine_multi', testbase.db, autoload=True)
+            table = Table('engine_multi', meta2, autoload=True)
         finally:
             table.drop()
         
@@ -348,19 +347,20 @@ class ReflectionTest(PersistTest):
                           testbase.db, autoload=True)
         
     def testoverride(self):
+        meta = BoundMetaData(testbase.db)
         table = Table(
-            'override_test', testbase.db
+            'override_test', meta
             Column('col1', Integer, primary_key=True),
             Column('col2', String(20)),
             Column('col3', Numeric)
         )
         table.create()
         # clear out table registry
-        table.deregister()
 
+        meta2 = BoundMetaData(testbase.db)
         try:
             table = Table(
-                'override_test', testbase.db,
+                'override_test', meta2,
                 Column('col2', Unicode()),
                 Column('col4', String(30)), autoload=True)
         
@@ -403,22 +403,22 @@ class CreateDropTest(PersistTest):
         )
 
     def test_sorter( self ):
-        tables = metadata._sort_tables(metadata.tables.values())
+        tables = metadata.table_iterator(reverse=False)
         table_names = [t.name for t in tables]
         self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names ==  ['users', 'email_addresses', 'orders', 'items'])
 
 
     def test_createdrop(self):
-        metadata.create_all(engine=testbase.db)
+        metadata.create_all(connectable=testbase.db)
         self.assertEqual( testbase.db.has_table('items'), True )
         self.assertEqual( testbase.db.has_table('email_addresses'), True )        
-        metadata.create_all(engine=testbase.db)
+        metadata.create_all(connectable=testbase.db)
         self.assertEqual( testbase.db.has_table('items'), True )        
 
-        metadata.drop_all(engine=testbase.db)
+        metadata.drop_all(connectable=testbase.db)
         self.assertEqual( testbase.db.has_table('items'), False )
         self.assertEqual( testbase.db.has_table('email_addresses'), False )                
-        metadata.drop_all(engine=testbase.db)
+        metadata.drop_all(connectable=testbase.db)
         self.assertEqual( testbase.db.has_table('items'), False )                
 
 class SchemaTest(PersistTest):
@@ -438,7 +438,7 @@ class SchemaTest(PersistTest):
         buf = StringIO.StringIO()
         def foo(s, p):
             buf.write(s)
-        gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo)
+        gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None)
         table1.accept_schema_visitor(gen)
         table2.accept_schema_visitor(gen)
         buf = buf.getvalue()
index 63eb5b0f6fbb412d6411d4276f799a0b13c8cfc3..eebe7af7550f966db2f8fe3882c1d11540d50d37 100644 (file)
@@ -109,7 +109,7 @@ class BiDirectionalOneToManyTest(AssertMixin):
             Column('c2', Integer)
         )
         metadata.create_all()
-        t2.c.c2.append_item(ForeignKey('t1.c1'))
+        t2.c.c2.append_foreign_key(ForeignKey('t1.c1'))
     def tearDownAll(self):
         t1.drop()
         t2.drop()
@@ -153,7 +153,7 @@ class BiDirectionalOneToManyTest2(AssertMixin):
         )
         t2.create()
         t1.create()
-        t2.c.c2.append_item(ForeignKey('t1.c1'))
+        t2.c.c2.append_foreign_key(ForeignKey('t1.c1'))
         t3 = Table('t1_data', metadata, 
             Column('c1', Integer, primary_key=True),
             Column('t1id', Integer, ForeignKey('t1.c1')),
@@ -225,8 +225,7 @@ class OneToManyManyToOneTest(AssertMixin):
 
         ball.create()
         person.create()
-#        person.c.favorite_ball_id.append_item(ForeignKey('ball.id'))
-        ball.c.person_id.append_item(ForeignKey('person.id'))
+        ball.c.person_id.append_foreign_key(ForeignKey('person.id'))
         
         # make the test more complete for postgres
         if db.engine.__module__.endswith('postgres'):
index ce9a35479ca9d6297d9ec92d52e324939382bede..392e54407f5ecc98ea509d765d25f368afe4fdf0 100644 (file)
@@ -96,16 +96,16 @@ class InheritTest2(testbase.AssertMixin):
         foo = Table('foo', metadata,
             Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
             Column('data', String(20)),
-            ).create()
+            )
 
         bar = Table('bar', metadata,
             Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
             #Column('fid', Integer, ForeignKey('foo.id'), )
-            ).create()
+            )
 
         foo_bar = Table('foo_bar', metadata,
             Column('foo_id', Integer, ForeignKey('foo.id')),
-            Column('bar_id', Integer, ForeignKey('bar.bid'))).create()
+            Column('bar_id', Integer, ForeignKey('bar.bid')))
         metadata.create_all()
     def tearDownAll(self):
         metadata.drop_all()
index 3966041b5c145e13f34fc9d2b945c20348d1240f..dc343cb95d4238d02990948857b9756cc2301f58 100644 (file)
@@ -28,7 +28,7 @@ class Transition(object):
         
 class M2MTest(testbase.AssertMixin):
     def setUpAll(self):
-        self.install_threadlocal()
+        global metadata
         metadata = testbase.metadata
         global place
         place = Table('place', metadata,
@@ -68,28 +68,14 @@ class M2MTest(testbase.AssertMixin):
             Column('pl1_id', Integer, ForeignKey('place.place_id')),
             Column('pl2_id', Integer, ForeignKey('place.place_id')),
             )
-
-        place.create()
-        transition.create()
-        place_input.create()
-        place_output.create()
-        place_thingy.create()
-        place_place.create()
+        metadata.create_all()
 
     def tearDownAll(self):
-        place_place.drop()
-        place_input.drop()
-        place_output.drop()
-        place_thingy.drop()
-        place.drop()
-        transition.drop()
-        objectstore.clear()
+        metadata.drop_all()
         clear_mappers()
         #testbase.db.tables.clear()
-        self.uninstall_threadlocal()
         
     def setUp(self):
-        objectstore.clear()
         clear_mappers()
 
     def tearDown(self):
@@ -111,6 +97,7 @@ class M2MTest(testbase.AssertMixin):
             lazy=True,
             ))
 
+        sess = create_session()
         p1 = Place('place1')
         p2 = Place('place2')
         p3 = Place('place3')
@@ -118,7 +105,7 @@ class M2MTest(testbase.AssertMixin):
         p5 = Place('place5')
         p6 = Place('place6')
         p7 = Place('place7')
-
+        [sess.save(x) for x in [p1,p2,p3,p4,p5,p6,p7]]
         p1.places.append(p2)
         p1.places.append(p3)
         p5.places.append(p6)
@@ -127,10 +114,10 @@ class M2MTest(testbase.AssertMixin):
         p1.places.append(p5)
         p4.places.append(p3)
         p3.places.append(p4)
-        objectstore.flush()
+        sess.flush()
 
-        objectstore.clear()
-        l = Place.mapper.select(order_by=place.c.place_id)
+        sess.clear()
+        l = sess.query(Place).select(order_by=place.c.place_id)
         (p1, p2, p3, p4, p5, p6, p7) = l
         assert p1.places == [p2,p3,p5]
         assert p5.places == [p6]
@@ -144,8 +131,8 @@ class M2MTest(testbase.AssertMixin):
             pp = p.places
             self.echo("Place " + str(p) +" places " + repr(pp))
 
-        [objectstore.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
-        objectstore.flush()
+        [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7]
+        sess.flush()
 
     def testdouble(self):
         """tests that a mapper can have two eager relations to the same table, via
@@ -165,10 +152,12 @@ class M2MTest(testbase.AssertMixin):
         tran.inputs.append(Place('place1'))
         tran.outputs.append(Place('place2'))
         tran.outputs.append(Place('place3'))
-        objectstore.flush()
+        sess = create_session()
+        sess.save(tran)
+        sess.flush()
 
-        objectstore.clear()
-        r = Transition.mapper.select()
+        sess.clear()
+        r = sess.query(Transition).select()
         self.assert_result(r, Transition, 
             {'name':'transition1', 
             'inputs' : (Place, [{'name':'place1'}]),
@@ -199,15 +188,15 @@ class M2MTest(testbase.AssertMixin):
         p2.inputs.append(t2)
         p3.inputs.append(t2)
         p1.outputs.append(t1)
-        
-        objectstore.flush()
+        sess = create_session()
+        [sess.save(x) for x in [t1,t2,t3,p1,p2,p3]]
+        sess.flush()
         
         self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
         self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
 
 class M2MTest2(testbase.AssertMixin):        
     def setUpAll(self):
-        self.install_threadlocal()
         metadata = testbase.metadata
         global studentTbl
         studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True))
@@ -217,22 +206,13 @@ class M2MTest2(testbase.AssertMixin):
         enrolTbl = Table('enrol', metadata,
             Column('student_id', String(20), ForeignKey('student.name'),primary_key=True),
             Column('course_id', String(20), ForeignKey('course.name'), primary_key=True))
-
-        studentTbl.create()
-        courseTbl.create()
-        enrolTbl.create()
+        metadata.create_all()
 
     def tearDownAll(self):
-        enrolTbl.drop()
-        studentTbl.drop()
-        courseTbl.drop()
-        objectstore.clear()
+        metadata.drop_all()
         clear_mappers()
-        #testbase.db.tables.clear()
-        self.uninstall_threadlocal()
         
     def setUp(self):
-        objectstore.clear()
         clear_mappers()
 
     def tearDown(self):
@@ -251,6 +231,7 @@ class M2MTest2(testbase.AssertMixin):
         Course.mapper = mapper(Course, courseTbl, properties = {
             'students': relation(Student.mapper, enrolTbl, lazy=True, backref='courses')
         })
+        sess = create_session()
         s1 = Student('Student1')
         c1 = Course('Course1')
         c2 = Course('Course2')
@@ -260,55 +241,53 @@ class M2MTest2(testbase.AssertMixin):
         c3.students.append(s1)
         self.assert_(len(s1.courses) == 3)
         self.assert_(len(c1.students) == 1)
-        objectstore.flush()
-        objectstore.clear()
-        s = Student.mapper.get_by(name='Student1')
-        c = Course.mapper.get_by(name='Course3')
+        sess.save(s1)
+        sess.flush()
+        sess.clear()
+        s = sess.query(Student).get_by(name='Student1')
+        c = sess.query(Course).get_by(name='Course3')
         self.assert_(len(s.courses) == 3)
         del s.courses[1]
         self.assert_(len(s.courses) == 2)
         
 class M2MTest3(testbase.AssertMixin):    
     def setUpAll(self):
-        self.install_threadlocal()
         metadata = testbase.metadata
         global c, c2a1, c2a2, b, a
         c = Table('c', metadata, 
             Column('c1', Integer, primary_key = True),
             Column('c2', String(20)),
-        ).create()
+        )
 
         a = Table('a', metadata, 
             Column('a1', Integer, primary_key=True),
             Column('a2', String(20)),
             Column('c1', Integer, ForeignKey('c.c1'))
-            ).create()
+            )
 
         c2a1 = Table('ctoaone', metadata, 
             Column('c1', Integer, ForeignKey('c.c1')),
             Column('a1', Integer, ForeignKey('a.a1'))
-        ).create()
+        )
         c2a2 = Table('ctoatwo', metadata, 
             Column('c1', Integer, ForeignKey('c.c1')),
             Column('a1', Integer, ForeignKey('a.a1'))
-        ).create()
+        )
 
         b = Table('b', metadata, 
             Column('b1', Integer, primary_key=True),
             Column('a1', Integer, ForeignKey('a.a1')),
             Column('b2', Boolean)
-        ).create()
-
+        )
+        metadata.create_all()
+        
     def tearDownAll(self):
         b.drop()
         c2a2.drop()
         c2a1.drop()
         a.drop()
         c.drop()
-        objectstore.clear()
         clear_mappers()
-        #testbase.db.tables.clear()
-        self.uninstall_threadlocal()
         
     def testbasic(self):
         class C(object):pass
index 63d0904287130acf18a6ab453aea95074c61e646..6cf2b4b49431aa8d29b1ee4d4d2e640cdccb1e2f 100644 (file)
@@ -91,7 +91,8 @@ class VersioningTest(UnitOfWorkTest):
         Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
         Column('version_id', Integer, nullable=False),
         Column('value', String(40), nullable=False)
-        ).create()
+        )
+        version_table.create()
     def tearDownAll(self):
         version_table.drop()
         UnitOfWorkTest.tearDownAll(self)
@@ -408,12 +409,14 @@ class PrivateAttrTest(UnitOfWorkTest):
         a_table = Table('a',testbase.db,
             Column('a_id', Integer, Sequence('next_a_id'), primary_key=True),
             Column('data', String(10)),
-            ).create()
+            )
     
         b_table = Table('b',testbase.db,
             Column('b_id',Integer,Sequence('next_b_id'),primary_key=True),
             Column('a_id',Integer,ForeignKey('a.a_id')),
-            Column('data',String(10))).create()
+            Column('data',String(10)))
+        a_table.create()
+        b_table.create()
     def tearDownAll(self):
         b_table.drop()
         a_table.drop()
index 29b638bb8e44ab9ec72773d2f96841b456933947..c79d7b67e881172c39acb9cefbca4852485f44a8 100644 (file)
@@ -5,7 +5,7 @@ import unittest
 def suite():
     modules_to_test = (
         'sql.testtypes',
-        'sql.indexes',
+        'sql.constraints',
 
         # SQL syntax
         'sql.select',
similarity index 60%
rename from test/sql/indexes.py
rename to test/sql/constraints.py
index 5c46b63f2cc63d82d79fb46e76b04aa1e4ba14c9..045d44968793b0d0cfeaf42ead38af6c3de3088d 100644 (file)
@@ -2,7 +2,7 @@ import testbase
 from sqlalchemy import *
 import sys
 
-class IndexTest(testbase.AssertMixin):
+class ConstraintTest(testbase.AssertMixin):
     
     def setUp(self):
         global metadata
@@ -27,6 +27,59 @@ class IndexTest(testbase.AssertMixin):
             ForeignKeyConstraint(['emp_id', 'emp_soc'], ['employees.id', 'employees.soc'])
             )
         metadata.create_all()
+
+    @testbase.unsupported('sqlite', 'mysql')
+    def test_check_constraint(self):
+        foo = Table('foo', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('x', Integer),
+            Column('y', Integer),
+            CheckConstraint('x>y'))
+        bar = Table('bar', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('x', Integer, CheckConstraint('x>7')),
+            )
+
+        metadata.create_all()
+        foo.insert().execute(id=1,x=9,y=5)
+        try:
+            foo.insert().execute(id=2,x=5,y=9)
+            assert False
+        except exceptions.SQLError:
+            assert True
+
+        bar.insert().execute(id=1,x=10)
+        try:
+            bar.insert().execute(id=2,x=5)
+            assert False
+        except exceptions.SQLError:
+            assert True
+    
+    def test_unique_constraint(self):
+        foo = Table('foo', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('value', String(30), unique=True))
+        bar = Table('bar', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('value', String(30)),
+            Column('value2', String(30)),
+            UniqueConstraint('value', 'value2', name='uix1')
+            )
+        metadata.create_all()
+        foo.insert().execute(id=1, value='value1')
+        foo.insert().execute(id=2, value='value2')
+        bar.insert().execute(id=1, value='a', value2='a')
+        bar.insert().execute(id=2, value='a', value2='b')
+        try:
+            foo.insert().execute(id=3, value='value1')
+            assert False
+        except exceptions.SQLError:
+            assert True
+        try:
+            bar.insert().execute(id=3, value='a', value2='b')
+            assert False
+        except exceptions.SQLError:
+            assert True
         
     def test_index_create(self):
         employees = Table('employees', metadata,
@@ -39,12 +92,12 @@ class IndexTest(testbase.AssertMixin):
         i = Index('employee_name_index',
                   employees.c.last_name, employees.c.first_name)
         i.create()
-        assert employees.indexes['employee_name_index'] is i
+        assert i in employees.indexes
         
         i2 = Index('employee_email_index',
                    employees.c.email_address, unique=True)        
         i2.create()
-        assert employees.indexes['employee_email_index'] is i2
+        assert i2 in employees.indexes
 
     def test_index_create_camelcase(self):
         """test that mixed-case index identifiers are legal"""
@@ -76,16 +129,17 @@ class IndexTest(testbase.AssertMixin):
 
         events = Table('events', metadata,
                        Column('id', Integer, primary_key=True),
-                       Column('name', String(30), unique=True),
+                       Column('name', String(30), index=True, unique=True),
                        Column('location', String(30), index=True),
-                       Column('sport', String(30),
-                              unique='sport_announcer'),
-                       Column('announcer', String(30),
-                              unique='sport_announcer'),
-                       Column('winner', String(30), index='idx_winners'))
+                       Column('sport', String(30)),
+                       Column('announcer', String(30)),
+                       Column('winner', String(30)))
+
+        Index('sport_announcer', events.c.sport, events.c.announcer, unique=True)
+        Index('idx_winners', events.c.winner)
         
         index_names = [ ix.name for ix in events.indexes ]
-        assert 'ux_events_name' in index_names
+        assert 'ix_events_name' in index_names
         assert 'ix_events_location' in index_names
         assert 'sport_announcer' in index_names
         assert 'idx_winners' in index_names
@@ -97,19 +151,20 @@ class IndexTest(testbase.AssertMixin):
             capt.append(statement)
             capt.append(repr(parameters))
             connection.proxy(statement, parameters)
-        schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy)
+        schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection)
         events.accept_schema_visitor(schemagen)
         
         assert capt[0].strip().startswith('CREATE TABLE events')
-        assert capt[2].strip() == \
-            'CREATE UNIQUE INDEX ux_events_name ON events (name)'
-        assert capt[4].strip() == \
-            'CREATE INDEX ix_events_location ON events (location)'
-        assert capt[6].strip() == \
-            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'
-        assert capt[8].strip() == \
+        
+        s = set([capt[x].strip() for x in [2,4,6,8]])
+        
+        assert s == set([
+            'CREATE UNIQUE INDEX ix_events_name ON events (name)',
+            'CREATE INDEX ix_events_location ON events (location)',
+            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)',
             'CREATE INDEX idx_winners ON events (winner)'
-
+            ])
+            
         # verify that the table is functional
         events.insert().execute(id=1, name='hockey finals', location='rink',
                                 sport='hockey', announcer='some canadian',
index e08bdb89f1cb6e09e9481c2a1b76855a1b76fed0..ef851cf63005f1cfba2e9a98259eb1dc28dde8f0 100644 (file)
@@ -121,7 +121,7 @@ class ColumnsTest(AssertMixin):
         )
 
         for aCol in testTable.c:
-            self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None).get_column_specification(aCol))
+            self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol))
         
 class UnicodeTest(AssertMixin):
     """tests the Unicode type.  also tests the TypeDecorator with instances in the types package."""
index 7cec19590125af5249a0093ce72ed1dfbaf47fdd..e538cff9d832b7340c42f35f0367fac9330280fc 100644 (file)
@@ -12,9 +12,9 @@ from zblog.blog import *
 class ZBlogTest(AssertMixin):
 
     def create_tables(self):
-        tables.metadata.create_all(engine=db)
+        tables.metadata.create_all(connectable=db)
     def drop_tables(self):
-        tables.metadata.drop_all(engine=db)
+        tables.metadata.drop_all(connectable=db)
         
     def setUpAll(self):
         self.create_tables()