]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Oct 2006 00:07:06 +0000 (00:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Oct 2006 00:07:06 +0000 (00:07 +0000)
via ALTER.  this allows circular foreign key relationships to be set up.

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql_util.py
test/orm/cycles.py

diff --git a/CHANGES b/CHANGES
index b9ac78ee9a01509654766d092f0537ad0eacf1ef..5d3c1c34df5bf682fc4187a22e7e78181ccc7632 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -40,6 +40,8 @@
     indexed.  a comparison clause between two pks that are derived from the 
     same underlying tables (i.e. such as two Alias objects) can be generated 
     via table1.primary_key==table2.primary_key
+    - ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key
+    via ALTER.  this allows circular foreign key relationships to be set up.
     - append_item() methods removed from Table and Column; preferably
     construct Table/Column/related objects inline, but if needed use 
     append_column(), append_foreign_key(), append_constraint(), etc.
index 208b2f603a1dbbfb29a34b6dac4236fcb6f3e4b9..b6923c7da0bfe0abfdd6a125b4b1b576d2b66b4a 100644 (file)
@@ -606,8 +606,20 @@ class ANSICompiler(sql.Compiled):
     def __str__(self):
         return self.get_str(self.statement)
 
-
-class ANSISchemaGenerator(engine.SchemaIterator):
+class ANSISchemaBase(engine.SchemaIterator):
+    def find_alterables(self, tables):
+        alterables = []
+        class FindAlterables(schema.SchemaVisitor):
+            def visit_foreign_key_constraint(self, constraint):
+                if constraint.use_alter and constraint.table in tables:
+                    alterables.append(constraint)
+        findalterables = FindAlterables()
+        for table in tables:
+            for c in table.constraints:
+                c.accept_schema_visitor(findalterables)
+        return alterables
+        
+class ANSISchemaGenerator(ANSISchemaBase):
     def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
         super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs)
         self.checkfirst = checkfirst
@@ -620,11 +632,13 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         raise NotImplementedError()
     
     def visit_metadata(self, metadata):
-        for table in metadata.table_iterator(reverse=False, tables=self.tables):
-            if self.checkfirst and self.dialect.has_table(self.connection, table.name):
-                continue
+        collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name))]
+        for table in collection:
             table.accept_schema_visitor(self, traverse=False)
-            
+        if self.supports_alter():
+            for alterable in self.find_alterables(collection):
+                self.add_foreignkey(alterable)
+                
     def visit_table(self, table):
         for column in table.columns:
             if column.default is not None:
@@ -687,9 +701,22 @@ class ANSISchemaGenerator(engine.SchemaIterator):
         if constraint.name is not None:
             self.append("%s " % constraint.name)
         self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', ')))
-                    
+    
+    def supports_alter(self):
+        return True
+                        
     def visit_foreign_key_constraint(self, constraint):
+        if constraint.use_alter and self.supports_alter():
+            return
         self.append(", \n\t ")
+        self.define_foreign_key(constraint)
+    
+    def add_foreignkey(self, constraint):
+        self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table))
+        self.define_foreign_key(constraint)
+        self.execute()
+        
+    def define_foreign_key(self, constraint):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " % constraint.name)
         self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
@@ -721,7 +748,7 @@ class ANSISchemaGenerator(engine.SchemaIterator):
                        string.join([self.preparer.format_column(c) for c in index.columns], ', ')))
         self.execute()
         
-class ANSISchemaDropper(engine.SchemaIterator):
+class ANSISchemaDropper(ANSISchemaBase):
     def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs):
         super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs)
         self.checkfirst = checkfirst
@@ -731,14 +758,23 @@ class ANSISchemaDropper(engine.SchemaIterator):
         self.dialect = self.engine.dialect
 
     def visit_metadata(self, metadata):
-        for table in metadata.table_iterator(reverse=True, tables=self.tables):
-            if self.checkfirst and not self.dialect.has_table(self.connection, table.name):
-                continue
+        collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or  self.dialect.has_table(self.connection, t.name))]
+        if self.supports_alter():
+            for alterable in self.find_alterables(collection):
+                self.drop_foreignkey(alterable)
+        for table in collection:
             table.accept_schema_visitor(self, traverse=False)
 
+    def supports_alter(self):
+        return True
+
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.name)
         self.execute()
+
+    def drop_foreignkey(self, constraint):
+        self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.execute()
         
     def visit_table(self, table):
         for column in table.columns:
index 2fa7e9227f42e41c45cb89219d7ece2807b22ca6..86b74c364459efb84b12fe405e8382704c40f41f 100644 (file)
@@ -456,6 +456,9 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
     def visit_index(self, index):
         self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
         self.execute()
+    def drop_foreignkey(self, constraint):
+        self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name))
+        self.execute()
 
 class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
     def __init__(self, dialect):
index 90cd66dd3c8bd88f9ecf2d0bb527b05f8e1dfc2a..a4445b1a83100c8d43d60810d481377d98644d2d 100644 (file)
@@ -147,6 +147,8 @@ class SQLiteDialect(ansisql.ANSIDialect):
         return SQLiteCompiler(self, statement, bindparams, **kwargs)
     def schemagenerator(self, *args, **kwargs):
         return SQLiteSchemaGenerator(*args, **kwargs)
+    def schemadropper(self, *args, **kwargs):
+        return SQLiteSchemaDropper(*args, **kwargs)
     def preparer(self):
         return SQLiteIdentifierPreparer(self)
     def create_connect_args(self, url):
@@ -283,6 +285,9 @@ class SQLiteCompiler(ansisql.ANSICompiler):
             return ansisql.ANSICompiler.binary_operator_string(self, binary)
 
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
+    def supports_alter(self):
+        return False
+        
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec()
         default = self.get_column_default_string(column)
@@ -303,6 +308,10 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     #    else:
     #        super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint)
 
+class SQLiteSchemaDropper(ansisql.ANSISchemaDropper):
+    def supports_alter(self):
+        return False
+
 class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
     def __init__(self, dialect):
         super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
index 4ba5e111582e14c47e8c11bd5eb47ec494171455..83db06090d7f175aadf83452d2c1d69ad0155918 100644 (file)
@@ -225,7 +225,7 @@ class Connection(Connectable):
         """when no Transaction is present, this is called after executions to provide "autocommit" behavior."""
         # TODO: have the dialect determine if autocommit can be set on the connection directly without this 
         # extra step
-        if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()):
+        if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()):
             self._commit_impl()
     def _autorollback(self):
         if not self.in_transaction():
index 5728d7c375ce38991689bd78c60d912352de9adf..88d52f0753b4eebfc464de1bf661c554c891f8e2 100644 (file)
@@ -491,7 +491,7 @@ class ForeignKey(SchemaItem):
     
     One or more ForeignKey objects are used within a ForeignKeyConstraint
     object which represents the table-level constraint definition."""
-    def __init__(self, column, constraint=None, use_alter=False):
+    def __init__(self, column, constraint=None, use_alter=False, name=None):
         """Construct a new ForeignKey object.  
         
         "column" can be a schema.Column object representing the relationship, 
@@ -507,6 +507,7 @@ class ForeignKey(SchemaItem):
         self._column = None
         self.constraint = constraint
         self.use_alter = use_alter
+        self.name = name
         
     def __repr__(self):
         return "ForeignKey(%s)" % repr(self._get_colspec())
@@ -575,7 +576,7 @@ class ForeignKey(SchemaItem):
         self.parent = column
 
         if self.constraint is None and isinstance(self.parent.table, Table):
-            self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter)
+            self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name)
             self.parent.table.append_constraint(self.constraint)
             self.constraint._append_fk(self)
 
@@ -699,6 +700,8 @@ class ForeignKeyConstraint(Constraint):
         self.elements = util.Set()
         self.onupdate = onupdate
         self.ondelete = ondelete
+        if self.name is None and use_alter:
+            raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
         self.use_alter = use_alter
     def _set_parent(self, table):
         self.table = table
index 94caade68bd74666f7d308d48340f02060a177f1..4935b1adda0812aecf5f9e673b9174cf6a356b16 100644 (file)
@@ -40,6 +40,8 @@ class TableCollection(object):
         tuples = []
         class TVisitor(schema.SchemaVisitor):
             def visit_foreign_key(_self, fkey):
+                if fkey.use_alter:
+                    return
                 parent_table = fkey.column.table
                 if parent_table in self:
                     child_table = fkey.parent.table
index eebe7af7550f966db2f8fe3882c1d11540d50d37..0ff3abb7b76673dd677fba463ffb167026db86c0 100644 (file)
@@ -213,28 +213,19 @@ class OneToManyManyToOneTest(AssertMixin):
         global ball
         ball = Table('ball', metadata,
          Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True),
-         Column('person_id', Integer),
+         Column('person_id', Integer, ForeignKey('person.id', use_alter=True, name='fk_person_id')),
          Column('data', String(30))
          )
         person = Table('person', metadata,
          Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True),
          Column('favorite_ball_id', Integer, ForeignKey('ball.id')),
          Column('data', String(30))
-#         Column('favorite_ball_id', Integer),
          )
 
-        ball.create()
-        person.create()
-        ball.c.person_id.append_foreign_key(ForeignKey('person.id'))
+        metadata.create_all()
         
-        # make the test more complete for postgres
-        if db.engine.__module__.endswith('postgres'):
-            db.execute("alter table ball add constraint fk_ball_person foreign key (person_id) references person(id)", {})
     def tearDownAll(self):
-        if db.engine.__module__.endswith('postgres'):
-            db.execute("alter table ball drop constraint fk_ball_person", {})
-        person.drop()
-        ball.drop()
+        metadata.drop_all()
         
     def tearDown(self):
         clear_mappers()