]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
overhaul to schema, addition of ForeignKeyConstraint/
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 Jul 2006 20:06:09 +0000 (20:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 Jul 2006 20:06:09 +0000 (20:06 +0000)
PrimaryKeyConstraint objects (also UniqueConstraint not
completed yet).  table creation and reflection modified
to be more oriented towards these new table-level objects.
reflection for sqlite/postgres/mysql supports composite
foreign keys; oracle/mssql/firebird not converted yet.

14 files changed:
CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/information_schema.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/engine/reflection.py
test/orm/objectstore.py
test/sql/indexes.py

diff --git a/CHANGES b/CHANGES
index 28b3316fc939f70213a1dc54fead385ead7bd1f2..810f2ec9a953e7def0439de71b79bb776a839f38 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,4 +1,11 @@
 0.2.6
+- big overhaul to schema to allow truly composite primary and foreign
+key constraints, via new ForeignKeyConstraint and PrimaryKeyConstraint
+objects.
+Existing methods of primary/foreign key creation have not been changed
+but use these new objects behind the scenes.  table creation
+and reflection is now more table oriented rather than column oriented.
+[ticket:76]
 - tweaks to ActiveMapper, supports self-referential relationships
 - slight rearrangement to objectstore (in activemapper/threadlocal)
 so that the SessionContext is referenced by '.context' instead
index 78017bc9191042d261b4d08c4c5bca75ca18bd62..5d01e275cbf5b344887ea29d794b7065010472c3 100644 (file)
@@ -602,7 +602,7 @@ class ANSICompiler(sql.Compiled):
 
 
 class ANSISchemaGenerator(engine.SchemaIterator):
-    def get_column_specification(self, column, override_pk=False, first_pk=False):
+    def get_column_specification(self, column, first_pk=False):
         raise NotImplementedError()
         
     def visit_table(self, table):
@@ -614,19 +614,17 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         separator = "\n"
         
         # if only one primary key, specify it along with the column
-        pks = table.primary_key
         first_pk = False
         for column in table.columns:
             self.append(separator)
             separator = ", \n"
-            self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1, first_pk=column.primary_key and not first_pk))
+            self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk))
             if column.primary_key:
                 first_pk = True
-        # if multiple primary keys, specify it at the bottom
-        if len(pks) > 1:
-            self.append(", \n")
-            self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', '))
-                    
+
+        for constraint in table.constraints:
+            constraint.accept_schema_visitor(self)            
+
         self.append("\n)%s\n\n" % self.post_create_table(table))
         self.execute()        
         if hasattr(table, 'indexes'):
@@ -650,6 +648,26 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         compiler.compile()
         return compiler
 
+    def visit_primary_key_constraint(self, constraint):
+        if len(constraint) == 0:
+            return
+        self.append(", \n")
+        self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', '))
+            
+    def visit_foreign_key_constraint(self, constraint):
+        self.append(", \n\t ")
+        if constraint.name is not None:
+            self.append("CONSTRAINT %s " % constraint.name)
+        self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
+            string.join([f.parent.name for f in constraint.elements], ', '),
+            list(constraint.elements)[0].column.table.name,
+            string.join([f.column.name for f in constraint.elements], ', ')
+        ))
+        if constraint.ondelete is not None:
+            self.append(" ON DELETE %s" % constraint.ondelete)
+        if constraint.onupdate is not None:
+            self.append(" ON UPDATE %s" % constraint.onupdate)
+
     def visit_column(self, column):
         pass
 
index 0039333d51205d0103dd8afd006c24c5f545d138..085d8cf444c91e150cfb34d975d4b33074a39718 100644 (file)
@@ -293,7 +293,7 @@ class FBCompiler(ansisql.ANSICompiler):
         return ""
 
 class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column, override_pk=False, **kwargs):
+    def get_column_specification(self, column, **kwargs):
         colspec = column.name 
         colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
@@ -302,10 +302,6 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key and not override_pk:
-            colspec += " PRIMARY KEY"
-        if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
 
     def visit_sequence(self, sequence):
index 08236f7991d1771687c9d274582e7161f3fc3f6f..296db2de5779cfe268c12d5822c42aba8fda7417 100644 (file)
@@ -54,6 +54,7 @@ pg_key_constraints = schema.Table("key_column_usage", ischema,
     Column("table_name", String),
     Column("column_name", String),
     Column("constraint_name", String),
+    Column("ordinal_position", Integer),
     schema="information_schema")
 
 #mysql_key_constraints = schema.Table("key_column_usage", ischema,
@@ -100,13 +101,9 @@ class ISchema(object):
         return self.cache[name]
 
 
-def reflecttable(connection, table, ischema_names, use_mysql=False):
+def reflecttable(connection, table, ischema_names):
     
-    if use_mysql:
-        # no idea which INFORMATION_SCHEMA spec is correct, mysql or postgres
-        key_constraints = mysql_key_constraints
-    else:
-        key_constraints = pg_key_constraints
+    key_constraints = pg_key_constraints
         
     if table.schema is not None:
         current_schema = table.schema
@@ -152,39 +149,50 @@ def reflecttable(connection, table, ischema_names, use_mysql=False):
     if not found_table:
         raise exceptions.NoSuchTableError(table.name)
 
-    s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)])
-    if not use_mysql:
-        s.append_column(column_constraints)
-        s.append_whereclause(constraints.c.table_name==table.name)
-        s.append_whereclause(constraints.c.table_schema==current_schema)
-        colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name]
-    else:
-        # this doesnt seem to pick up any foreign keys with mysql
-        s.append_whereclause(key_constraints.c.table_name==constraints.c.table_name)
-        s.append_whereclause(key_constraints.c.table_schema==constraints.c.table_schema)
-        s.append_whereclause(constraints.c.table_name==table.name)
-        s.append_whereclause(constraints.c.table_schema==current_schema)
-        colmap = [constraints.c.constraint_type, key_constraints.c.column_name, key_constraints.c.referenced_table_schema, key_constraints.c.referenced_table_name, key_constraints.c.referenced_column_name]
+    # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns
+    # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys
+    # wont reflect properly.  dont see a way around this based on whats available from information_schema
+    s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position])
+    s.append_column(column_constraints)
+    s.append_whereclause(constraints.c.table_name==table.name)
+    s.append_whereclause(constraints.c.table_schema==current_schema)
+    colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position]
     c = connection.execute(s)
 
+    fks = {}
     while True:
         row = c.fetchone()
         if row is None:
             break
-#        continue
-        (type, constrained_column, referred_schema, referred_table, referred_column) = (
+        (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = (
             row[colmap[0]],
             row[colmap[1]],
             row[colmap[2]],
             row[colmap[3]],
-            row[colmap[4]]
+            row[colmap[4]],
+            row[colmap[5]],
+            row[colmap[6]]
         )
         #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) 
         if type=='PRIMARY KEY':
             table.c[constrained_column]._set_primary_key()
         elif type=='FOREIGN KEY':
+            try:
+                fk = fks[constraint_name]
+            except KeyError:
+                fk = ([],[])
+                fks[constraint_name] = fk
             if current_schema == referred_schema:
                 referred_schema = table.schema
-            remotetable = Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema)
-            table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column]))
+            if referred_schema is not None:
+                refspec = ".".join([referred_schema, referred_table, referred_column])
+            else:
+                refspec = ".".join([referred_table, referred_column])
+            if constrained_column not in fk[0]:
+                fk[0].append(constrained_column)
+            if refspec not in fk[1]:
+                fk[1].append(refspec)
+    
+    for name, value in fks.iteritems():
+        table.append_item(ForeignKeyConstraint(value[0], value[1], name=name))    
             
index c297195caff6fd3fc8e6935a366d40f756788fb8..9d51d535dad38ce507251321457c24b62dbf3702 100644 (file)
@@ -511,7 +511,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
 
         
 class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column, override_pk=False, first_pk=False):
+    def get_column_specification(self, column, **kwargs):
         colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
@@ -528,12 +528,6 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
-
-        if column.primary_key:
-            if not override_pk:
-                colspec += " PRIMARY KEY"
-        if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name)
         
         return colspec
 
index 997010f1c233bcdce0275b3e3d6a79b770578836..1d587ff7c511b97b1e16ecc82e4afbe26f845163 100644 (file)
@@ -309,8 +309,6 @@ class MySQLDialect(ansisql.ANSIDialect):
                 break
             #print "row! " + repr(row)
             if not found_table:
-                tabletype, foreignkeyD = self.moretableinfo(connection, table=table)
-                table.kwargs['mysql_engine'] = tabletype
                 found_table = True
 
             (name, type, nullable, primary_key, default) = (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4])
@@ -338,16 +336,15 @@ class MySQLDialect(ansisql.ANSIDialect):
                     argslist = re.findall(r'(\d+)', args)
                     coltype = coltype(*[int(a) for a in argslist], **kw)
             
-            arglist = []
-            fkey = foreignkeyD.get(name)
-            if fkey is not None:
-                arglist.append(schema.ForeignKey(fkey))
-    
-            table.append_item(schema.Column(name, coltype, *arglist,
+            table.append_item(schema.Column(name, coltype, 
                                             **dict(primary_key=primary_key,
                                                    nullable=nullable,
                                                    default=default
                                                    )))
+
+        tabletype = self.moretableinfo(connection, table=table)
+        table.kwargs['mysql_engine'] = tabletype
+
         if not found_table:
             raise exceptions.NoSuchTableError(table.name)
     
@@ -368,15 +365,15 @@ class MySQLDialect(ansisql.ANSIDialect):
             match = re.search(r'\b(?:TYPE|ENGINE)=(?P<ttype>.+)\b', desc[lastparen.start():], re.I)
             if match:
                 tabletype = match.group('ttype')
-        foreignkeyD = {}
-        fkpat = (r'FOREIGN KEY\s*\(`?(?P<name>.+?)`?\)'
-                 r'\s*REFERENCES\s*`?(?P<reftable>.+?)`?'
-                 r'\s*\(`?(?P<refcol>.+?)`?\)'
-                )
+
+        fkpat = r'CONSTRAINT `(?P<name>.+?)` FOREIGN KEY \((?P<columns>.+?)\) REFERENCES `(?P<reftable>.+?)` \((?P<refcols>.+?)\)'
         for match in re.finditer(fkpat, desc):
-            foreignkeyD[match.group('name')] = match.group('reftable') + '.' + match.group('refcol')
+            columns = re.findall(r'`(.+?)`', match.group('columns'))
+            refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))]
+            constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name'))
+            table.append_item(constraint)
 
-        return (tabletype, foreignkeyD)
+        return tabletype
         
 
 class MySQLCompiler(ansisql.ANSICompiler):
@@ -411,12 +408,8 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         if not column.nullable:
             colspec += " NOT NULL"
         if column.primary_key:
-            if not override_pk:
-                colspec += " PRIMARY KEY"
             if not column.foreign_key and first_pk and isinstance(column.type, sqltypes.Integer):
                 colspec += " AUTO_INCREMENT"
-        if column.foreign_key:
-            colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
 
     def post_create_table(self, table):
index bf6c1fd8d4b3c304c20af4a477ab5d5b94fdd37e..d184291fd532d6ff371c9e6d34bf7116e56ad544 100644 (file)
@@ -320,7 +320,7 @@ class OracleCompiler(ansisql.ANSICompiler):
         return ""
 
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column, override_pk=False, **kwargs):
+    def get_column_specification(self, column, **kwargs):
         colspec = column.name
         colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
@@ -329,10 +329,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key and not override_pk:
-            colspec += " PRIMARY KEY"
-        if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
 
     def visit_sequence(self, sequence):
index de21bd570dd6cb6a6d4b3caa33699d38db81a57a..decccba58d1bd4230c25ac3efe495288e8bc4cd4 100644 (file)
@@ -329,7 +329,7 @@ class PGCompiler(ansisql.ANSICompiler):
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         
-    def get_column_specification(self, column, override_pk=False, **kwargs):
+    def get_column_specification(self, column, **kwargs):
         colspec = column.name
         if column.primary_key and not column.foreign_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
             colspec += " SERIAL"
@@ -341,10 +341,6 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key and not override_pk:
-            colspec += " PRIMARY KEY"
-        if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) 
         return colspec
 
     def visit_sequence(self, sequence):
index c07952ff215c0e7d5e99f92ab73f3e9578aed95d..c703cd81eb97ca92f01f774c83493c6221c2be19 100644 (file)
@@ -257,7 +257,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
             return ansisql.ANSICompiler.binary_operator_string(self, binary)
 
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column, override_pk=False, **kwargs):
+    def get_column_specification(self, column, **kwargs):
         colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
@@ -265,34 +265,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key and not override_pk:
-            colspec += " PRIMARY KEY"
-        if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
-    def visit_table(self, table):
-        """sqlite is going to create multi-primary keys with just a UNIQUE index."""
-        self.append("\nCREATE TABLE " + table.fullname + "(")
-
-        separator = "\n"
-
-        have_pk = False
-        use_pks = len(table.primary_key) == 1
-        for column in table.columns:
-            self.append(separator)
-            separator = ", \n"
-            self.append("\t" + self.get_column_specification(column, override_pk=not use_pks))
-                
-        if len(table.primary_key) > 1:
-            self.append(", \n")
-            # put all PRIMARY KEYS in a UNIQUE index
-            self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_key],', '))
-
-        self.append("\n)\n\n")
-        self.execute()        
-        if hasattr(table, 'indexes'):
-            for index in table.indexes:
-                self.visit_index(index)
 
+    # this doesnt seem to be needed, although i suspect older versions of sqlite might still
+    # not directly support composite primary keys
+    #def visit_primary_key_constraint(self, constraint):
+    #    if len(constraint) > 1:
+    #        self.append(", \n")
+    #        # put all PRIMARY KEYS in a UNIQUE index
+    #        self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', '))
+    #    else:
+    #        super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
+            
 dialect = SQLiteDialect
 poolclass = pool.SingletonThreadPool       
index 1df2d30053f6da0dd3f661aab07fc2bb8814af4f..dcd023fe9510561d4b8b1e206b260c3dc84fe75b 100644 (file)
@@ -18,23 +18,24 @@ from sqlalchemy import sql, types, exceptions,util
 import sqlalchemy
 import copy, re, string
 
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint',
+            'PrimaryKeyConstraint', 
            'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
 
 class SchemaItem(object):
     """base class for items that define a database schema."""
     def _init_items(self, *args):
+        """initialize the list of child items for this SchemaItem"""
         for item in args:
             if item is not None:
                 item._set_parent(self)
     def _set_parent(self, parent):
-        """a child item attaches itself to its parent via this method."""
+        """associate with this SchemaItem's parent object."""
         raise NotImplementedError()
     def __repr__(self):
         return "%s()" % self.__class__.__name__
     def _derived_metadata(self):
-        """subclasses override this method to return a the MetaData
-        to which this item is bound"""
+        """return the the MetaData to which this item is bound"""
         return None
     def _get_engine(self):
         return self._derived_metadata().engine
@@ -77,7 +78,7 @@ class TableSingleton(type):
             table = metadata.tables[key]
             if len(args):
                 if redefine:
-                    table.reload_values(*args)
+                    table._reload_values(*args)
                 elif not useexisting:
                     raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name))
             return table
@@ -109,7 +110,9 @@ class Table(SchemaItem, sql.TableClause):
     __metaclass__ = TableSingleton
     
     def __init__(self, name, metadata, **kwargs):
-        """Table objects can be constructed directly.  The init method is actually called via 
+        """Construct a Table.
+        
+        Table objects can be constructed directly.  The init method is actually called via 
         the TableSingleton metaclass.  Arguments are:
         
         name : the name of this table, exactly as it appears, or will appear, in the database.
@@ -141,11 +144,23 @@ class Table(SchemaItem, sql.TableClause):
         super(Table, self).__init__(name)
         self._metadata = metadata
         self.schema = kwargs.pop('schema', None)
+        self.indexes = util.OrderedProperties()
+        self.constraints = []
+        self.primary_key = PrimaryKeyConstraint()
+
         if self.schema is not None:
             self.fullname = "%s.%s" % (self.schema, self.name)
         else:
             self.fullname = self.name
         self.kwargs = kwargs
+
+    def _set_primary_key(self, pk):
+        if getattr(self, '_primary_key', None) in self.constraints:
+            self.constraints.remove(self._primary_key)
+        self._primary_key = pk
+        self.constraints.append(pk)
+    primary_key = property(lambda s:s._primary_key, _set_primary_key)
+    
     def _derived_metadata(self):
         return self._metadata
     def __repr__(self):
@@ -158,8 +173,8 @@ class Table(SchemaItem, sql.TableClause):
     def __str__(self):
         return _get_table_key(self.name, self.schema)
         
-    def reload_values(self, *args):
-        """clears out the columns and other properties of this Table, and reloads them from the 
+    def _reload_values(self, *args):
+        """clear out the columns and other properties of this Table, and reload them from the 
         given argument list.  This is used with the "redefine" keyword argument sent to the
         metaclass constructor."""
         self._clear()
@@ -216,8 +231,9 @@ class Table(SchemaItem, sql.TableClause):
         return index
     
     def deregister(self):
-        """removes this table from it's metadata.  this does not
-        issue a SQL DROP statement."""
+        """remove this table from it's owning metadata.  
+        
+        this does not issue a SQL DROP statement."""
         key = _get_table_key(self.name, self.schema)
         del self.metadata.tables[key]
     def create(self, connectable=None):
@@ -232,7 +248,7 @@ class Table(SchemaItem, sql.TableClause):
         else:
             self.engine.drop(self)
     def tometadata(self, metadata, schema=None):
-        """returns a singleton instance of this Table with a different Schema"""
+        """return a copy of this Table associated with a different MetaData."""
         try:
             if schema is None:
                 schema = self.schema
@@ -242,6 +258,8 @@ class Table(SchemaItem, sql.TableClause):
             args = []
             for c in self.columns:
                 args.append(c.copy())
+            for c in self.constraints:
+                args.append(c.copy())
             return Table(self.name, metadata, schema=schema, *args)
 
 class Column(SchemaItem, sql.ColumnClause):
@@ -362,13 +380,9 @@ class Column(SchemaItem, sql.ColumnClause):
         self._init_items(*self.args)
         self.args = None
 
-    def copy(self):
+    def copy(self): 
         """creates a copy of this Column, unitialized"""
-        if self.foreign_key is None:
-            fk = None
-        else:
-            fk = self.foreign_key.copy()
-        return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden)
+        return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden)
         
     def _make_proxy(self, selectable, name = None):
         """creates a copy of this Column, initialized the way this Column is"""
@@ -401,23 +415,33 @@ class Column(SchemaItem, sql.ColumnClause):
 
 
 class ForeignKey(SchemaItem):
-    """defines a ForeignKey constraint between two columns.  ForeignKey is 
-    specified as an argument to a Column object."""
-    def __init__(self, column):
-        """Constructs a new ForeignKey object.  "column" can be a schema.Column
-        object representing the relationship, or just its string name given as 
-        "tablename.columnname".  schema can be specified as 
-        "schema.tablename.columnname" """
+    """defines a column-level ForeignKey constraint between two columns.  
+    
+    ForeignKey is specified as an argument to a Column object.
+    
+    One or more ForeignKey objects are used within a ForeignKeyConstraint
+    object which represents the table-level constraint definition."""
+    def __init__(self, column, constraint=None):
+        """Construct a new ForeignKey object.  
+        
+        "column" can be a schema.Column object representing the relationship, 
+        or just its string name given as "tablename.columnname".  schema can be 
+        specified as "schema.tablename.columnname" 
+        
+        "constraint" is the owning ForeignKeyConstraint object, if any.  if not given,
+        then a ForeignKeyConstraint will be automatically created and added to the parent table.
+        """
         if isinstance(column, unicode):
             column = str(column)
         self._colspec = column
         self._column = None
-
+        self.constraint = constraint
+        
     def __repr__(self):
         return "ForeignKey(%s)" % repr(self._get_colspec())
         
     def copy(self):
-        """produces a copy of this ForeignKey object."""
+        """produce a copy of this ForeignKey object."""
         return ForeignKey(self._get_colspec())
     
     def _get_colspec(self):
@@ -462,6 +486,7 @@ class ForeignKey(SchemaItem):
                     self._column = table.c[colname]
             else:
                 self._column = self._colspec
+
         return self._column
             
     column = property(lambda s: s._init_column())
@@ -472,8 +497,14 @@ class ForeignKey(SchemaItem):
         
     def _set_parent(self, column):
         self.parent = column
-        # if a foreign key was already set up for this, replace it with 
-        # this one, including removing from the parent
+
+        if self.constraint is None and isinstance(self.parent.table, Table):
+            self.constraint = ForeignKeyConstraint([],[])
+            self.parent.table.append_item(self.constraint)
+            self.constraint._append_fk(self)
+
+        # if a foreign key was already set up for the parent column, replace it with 
+        # this one
         if self.parent.foreign_key is not None:
             self.parent.table.foreign_keys.remove(self.parent.foreign_key)
         self.parent.foreign_key = self
@@ -551,7 +582,81 @@ class Sequence(DefaultGenerator):
         """calls the visit_seauence method on the given visitor."""
         return visitor.visit_sequence(self)
 
-
+class Constraint(SchemaItem):
+    """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint.
+    
+    Also follows list behavior with regards to the underlying set of columns."""
+    def __init__(self, name=None):
+        self.name = name
+        self.columns = []
+    def __contains__(self, x):
+        return x in self.columns
+    def __add__(self, other):
+        return self.columns + other
+    def __iter__(self):
+        return iter(self.columns)
+    def __len__(self):
+        return len(self.columns)
+    def __getitem__(self, index):
+        return self.columns[index]
+    def copy(self):
+        raise NotImplementedError()
+        
+class ForeignKeyConstraint(Constraint):
+    """table-level foreign key constraint, represents a colleciton of ForeignKey objects."""
+    def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None):
+        super(ForeignKeyConstraint, self).__init__(name)
+        self.__colnames = columns
+        self.__refcolnames = refcolumns
+        self.elements = []
+        self.onupdate = onupdate
+        self.ondelete = ondelete
+    def _set_parent(self, table):
+        self.table = table
+        table.constraints.append(self)
+        for (c, r) in zip(self.__colnames, self.__refcolnames):
+            self.append(c,r)
+    def accept_schema_visitor(self, visitor):
+        visitor.visit_foreign_key_constraint(self)
+    def append(self, col, refcol):
+        fk = ForeignKey(refcol, constraint=self)
+        fk._set_parent(self.table.c[col])
+        self._append_fk(fk)
+    def _append_fk(self, fk):
+        self.columns.append(self.table.c[fk.parent.name])
+        self.elements.append(fk)
+    def copy(self):
+        return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete)
+                        
+class PrimaryKeyConstraint(Constraint):
+    def __init__(self, *columns, **kwargs):
+        super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None))
+        self.__colnames = list(columns)
+    def _set_parent(self, table):
+        table.primary_key = self
+        for c in self.__colnames:
+            self.append(table.c[c])
+    def accept_schema_visitor(self, visitor):
+        visitor.visit_primary_key_constraint(self)
+    def append(self, col):
+        self.columns.append(col)
+        col.primary_key=True
+    def copy(self):
+        return PrimaryKeyConstraint(name=self.name, *[c.name for c in self])
+            
+class UniqueConstraint(Constraint):
+    def __init__(self, name=None, *columns):
+        super(Constraint, self).__init__(name)
+        self.__colnames = list(columns)
+    def _set_parent(self, table):
+        table.constraints.append(self)
+        for c in self.__colnames:
+            self.append(table.c[c])
+    def append(self, col):
+        self.columns.append(col)
+    def accept_schema_visitor(self, visitor):
+        visitor.visit_unique_constraint(self)
+        
 class Index(SchemaItem):
     """Represents an index of columns from a database table
     """
@@ -746,7 +851,13 @@ class SchemaVisitor(sql.ClauseVisitor):
     def visit_sequence(self, sequence):
         """visit a Sequence."""
         pass
-
+    def visit_primary_key_constraint(self, constraint):
+        pass
+    def visit_foreign_key_constraint(self, constraint):
+        pass
+    def visit_unique_constraint(self, constraint):
+        pass
+        
 default_metadata = DynamicMetaData('default')
 
             
index 7b17927f0896b57583b1dd8d1cebc3701a062542..8109d8cd52bbc60ca433db9ced8be4528f6867a4 100644 (file)
@@ -1225,15 +1225,12 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
         self._columns = util.OrderedProperties()
-        self._indexes = util.OrderedProperties()
         self._foreign_keys = []
         self._primary_key = []
         for c in columns:
             self.append_column(c)
         self._oid_column = ColumnClause('oid', self, hidden=True)
 
-    indexes = property(lambda s:s._indexes)
-
     def named_with_column(self):
         return True
     def append_column(self, c):
@@ -1250,16 +1247,11 @@ class TableClause(FromClause):
                 for ci in c.orig_set:
                     self._orig_cols[ci] = c
             return self._orig_cols
-    columns = property(lambda s:s._columns)
-    c = property(lambda s:s._columns)
-    primary_key = property(lambda s:s._primary_key)
-    foreign_keys = property(lambda s:s._foreign_keys)
     original_columns = property(_orig_columns)
 
     def _clear(self):
         """clears all attributes on this TableClause so that new items can be added again"""
         self.columns.clear()
-        self.indexes.clear()
         self.foreign_keys[:] = []
         self.primary_key[:] = []
         try:
index f9fa4e40c544c2cf9d7e6477e295bc2a38f91a30..ec59d652cdd947e74e2973a11496e6bf932089d7 100644 (file)
@@ -29,8 +29,10 @@ class ReflectionTest(PersistTest):
         else:
             deftype2 = Integer
             defval2 = "15"
-            
-        users = Table('engine_users', testbase.db,
+        
+        meta = BoundMetaData(testbase.db)
+        
+        users = Table('engine_users', meta,
             Column('user_id', INT, primary_key = True),
             Column('user_name', VARCHAR(20), nullable = False),
             Column('test1', CHAR(5), nullable = False),
@@ -49,14 +51,13 @@ class ReflectionTest(PersistTest):
             mysql_engine='InnoDB'
         )
 
-        addresses = Table('engine_email_addresses', testbase.db,
+        addresses = Table('engine_email_addresses', meta,
             Column('address_id', Integer, primary_key = True),
             Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(20)),
             mysql_engine='InnoDB'
         )
 
-        
 #        users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id))
 
         users.create()
@@ -119,28 +120,69 @@ class ReflectionTest(PersistTest):
         table.insert().execute({'multi_id':3,'multi_rev':3,'name':'row3', 'val':'value3'})
         table.select().execute().fetchall()
         table.drop()
-    
+
+    def testcompositefk(self):
+        meta = BoundMetaData(testbase.db)
+        table = Table(
+            'multi', meta, 
+            Column('multi_id', Integer, primary_key=True),
+            Column('multi_rev', Integer, primary_key=True),
+            Column('name', String(50), nullable=False),
+            Column('val', String(100)),
+            mysql_engine='InnoDB'
+        )
+        table2 = Table('multi2', meta, 
+            Column('id', Integer, primary_key=True),
+            Column('foo', Integer),
+            Column('bar', Integer),
+            Column('data', String(50)),
+            ForeignKeyConstraint(['foo', 'bar'], ['multi.multi_id', 'multi.multi_rev']),
+            mysql_engine='InnoDB'
+        )
+        meta.create_all()
+        meta.clear()
+        
+        try:
+            table = Table('multi', meta, autoload=True)
+            table2 = Table('multi2', meta, autoload=True)
+            
+            print table
+            print table2
+            j = join(table, table2)
+            print str(j.onclause)
+            self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar).compare(j.onclause))
+
+        finally:
+            meta.drop_all()
+
     def testtoengine(self):
         meta = MetaData('md1')
         meta2 = MetaData('md2')
         
         table = Table('mytable', meta,
-            Column('myid', Integer, key = 'id'),
-            Column('name', String, key = 'name', nullable=False),
-            Column('description', String, key = 'description'),
+            Column('myid', Integer, primary_key=True),
+            Column('name', String, nullable=False),
+            Column('description', String(30)),
         )
         
-        print repr(table)
-        
-        table2 = table.tometadata(meta2)
+        table2 = Table('othertable', meta,
+            Column('id', Integer, primary_key=True),
+            Column('myid', Integer, ForeignKey('mytable.myid'))
+            )
+            
         
-        print repr(table2)
+        table_c = table.tometadata(meta2)
+        table2_c = table2.tometadata(meta2)
+
+        assert table is not table_c
+        assert table_c.c.myid.primary_key
+        assert not table_c.c.name.nullable 
+        assert table_c.c.description.nullable 
+        assert table.primary_key is not table_c.primary_key
+        assert [x.name for x in table.primary_key] == [x.name for x in table_c.primary_key]
+        assert table2_c.c.myid.foreign_key.column is table_c.c.myid
+        assert table2_c.c.myid.foreign_key.column is not table.c.myid
         
-        assert table is not table2
-        assert table2.c.id.nullable 
-        assert not table2.c.name.nullable 
-        assert table2.c.description.nullable 
-    
     # mysql throws its own exception for no such table, resulting in 
     # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError.
     # this could probably be fixed at some point.
index 6d3a712762711d7fc2eed6663792a3cab359fed6..237c4b554587f63effd213635f712f823bd85c55 100644 (file)
@@ -230,7 +230,7 @@ class PKTest(SessionTest):
     @testbase.unsupported('mssql')
     def setUpAll(self):
         SessionTest.setUpAll(self)
-        db.echo = False
+        #db.echo = False
         global table
         global table2
         global table3
@@ -266,6 +266,8 @@ class PKTest(SessionTest):
         db.echo = testbase.echo
         SessionTest.tearDownAll(self)
         
+    # not support on sqlite since sqlite's auto-pk generation only works with
+    # single column primary keys    
     @testbase.unsupported('sqlite', 'mssql')
     def testprimarykey(self):
         class Entry(object):
@@ -279,6 +281,8 @@ class PKTest(SessionTest):
         ctx.current.clear()
         e2 = Entry.mapper.get((e.multi_id, 2))
         self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+        
+    # this one works with sqlite since we are manually setting up pk values
     @testbase.unsupported('mssql')
     def testmanualpk(self):
         class Entry(object):
@@ -289,6 +293,7 @@ class PKTest(SessionTest):
         e.pk_col_2 = 'pk1_related'
         e.data = 'im the data'
         ctx.current.flush()
+        
     @testbase.unsupported('mssql')
     def testkeypks(self):
         import datetime
index ec72beda39dd8c9f83564f469a543331873b9ebb..e9af301de65d941d83aeea3a5663367128016b6b 100644 (file)
@@ -5,16 +5,33 @@ import sys
 class IndexTest(testbase.AssertMixin):
     
     def setUp(self):
-       global metadata
-       metadata = BoundMetaData(testbase.db)
+        global metadata
+        metadata = BoundMetaData(testbase.db)
         self.echo = testbase.db.echo
         self.logger = testbase.db.logger
         
     def tearDown(self):
         testbase.db.echo = self.echo
         testbase.db.logger = testbase.db.engine.logger = self.logger
-       metadata.drop_all()
+        metadata.drop_all()
     
+    def test_constraint(self):
+        employees = Table('employees', metadata, 
+            Column('id', Integer),
+            Column('soc', String(40)),
+            Column('name', String(30)),
+            PrimaryKeyConstraint('id', 'soc')
+            )
+        elements = Table('elements', metadata,
+            Column('id', Integer),
+            Column('stuff', String(30)),
+            Column('emp_id', Integer),
+            Column('emp_soc', String(40)),
+            PrimaryKeyConstraint('id'),
+            ForeignKeyConstraint(['emp_id', 'emp_soc'], ['employees.id', 'employees.soc'])
+            )
+        metadata.create_all()
+        
     def test_index_create(self):
         employees = Table('employees', metadata,
                           Column('id', Integer, primary_key=True),