From 7e5e985c0e17a2d300f9aa8633c3610db600f2e2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 15 Oct 2006 00:07:06 +0000 Subject: [PATCH] - ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key via ALTER. this allows circular foreign key relationships to be set up. --- CHANGES | 2 ++ lib/sqlalchemy/ansisql.py | 58 ++++++++++++++++++++++++------ lib/sqlalchemy/databases/mysql.py | 3 ++ lib/sqlalchemy/databases/sqlite.py | 9 +++++ lib/sqlalchemy/engine/base.py | 2 +- lib/sqlalchemy/schema.py | 7 ++-- lib/sqlalchemy/sql_util.py | 2 ++ test/orm/cycles.py | 15 ++------ 8 files changed, 72 insertions(+), 26 deletions(-) diff --git a/CHANGES b/CHANGES index b9ac78ee9a..5d3c1c34df 100644 --- a/CHANGES +++ b/CHANGES @@ -40,6 +40,8 @@ indexed. a comparison clause between two pks that are derived from the same underlying tables (i.e. such as two Alias objects) can be generated via table1.primary_key==table2.primary_key + - ForeignKey(Constraint) supports "use_alter=True", to create/drop a foreign key + via ALTER. this allows circular foreign key relationships to be set up. - append_item() methods removed from Table and Column; preferably construct Table/Column/related objects inline, but if needed use append_column(), append_foreign_key(), append_constraint(), etc. diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 208b2f603a..b6923c7da0 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -606,8 +606,20 @@ class ANSICompiler(sql.Compiled): def __str__(self): return self.get_str(self.statement) - -class ANSISchemaGenerator(engine.SchemaIterator): +class ANSISchemaBase(engine.SchemaIterator): + def find_alterables(self, tables): + alterables = [] + class FindAlterables(schema.SchemaVisitor): + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and constraint.table in tables: + alterables.append(constraint) + findalterables = FindAlterables() + for table in tables: + for c in table.constraints: + c.accept_schema_visitor(findalterables) + return alterables + +class ANSISchemaGenerator(ANSISchemaBase): def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst @@ -620,11 +632,13 @@ class ANSISchemaGenerator(engine.SchemaIterator): raise NotImplementedError() def visit_metadata(self, metadata): - for table in metadata.table_iterator(reverse=False, tables=self.tables): - if self.checkfirst and self.dialect.has_table(self.connection, table.name): - continue + collection = [t for t in metadata.table_iterator(reverse=False, tables=self.tables) if (not self.checkfirst or not self.dialect.has_table(self.connection, t.name))] + for table in collection: table.accept_schema_visitor(self, traverse=False) - + if self.supports_alter(): + for alterable in self.find_alterables(collection): + self.add_foreignkey(alterable) + def visit_table(self, table): for column in table.columns: if column.default is not None: @@ -687,9 +701,22 @@ class ANSISchemaGenerator(engine.SchemaIterator): if constraint.name is not None: self.append("%s " % constraint.name) self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) - + + def supports_alter(self): + return True + def visit_foreign_key_constraint(self, constraint): + if constraint.use_alter and self.supports_alter(): + return self.append(", \n\t ") + self.define_foreign_key(constraint) + + def add_foreignkey(self, constraint): + self.append("ALTER TABLE %s ADD " % self.preparer.format_table(constraint.table)) + self.define_foreign_key(constraint) + self.execute() + + def define_foreign_key(self, constraint): if constraint.name is not None: self.append("CONSTRAINT %s " % constraint.name) self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( @@ -721,7 +748,7 @@ class ANSISchemaGenerator(engine.SchemaIterator): string.join([self.preparer.format_column(c) for c in index.columns], ', '))) self.execute() -class ANSISchemaDropper(engine.SchemaIterator): +class ANSISchemaDropper(ANSISchemaBase): def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst @@ -731,14 +758,23 @@ class ANSISchemaDropper(engine.SchemaIterator): self.dialect = self.engine.dialect def visit_metadata(self, metadata): - for table in metadata.table_iterator(reverse=True, tables=self.tables): - if self.checkfirst and not self.dialect.has_table(self.connection, table.name): - continue + collection = [t for t in metadata.table_iterator(reverse=True, tables=self.tables) if (not self.checkfirst or self.dialect.has_table(self.connection, t.name))] + if self.supports_alter(): + for alterable in self.find_alterables(collection): + self.drop_foreignkey(alterable) + for table in collection: table.accept_schema_visitor(self, traverse=False) + def supports_alter(self): + return True + def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() + + def drop_foreignkey(self, constraint): + self.append("ALTER TABLE %s DROP CONSTRAINT %s" % (self.preparer.format_table(constraint.table), constraint.name)) + self.execute() def visit_table(self, table): for column in table.columns: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 2fa7e9227f..86b74c3644 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -456,6 +456,9 @@ class MySQLSchemaDropper(ansisql.ANSISchemaDropper): def visit_index(self, index): self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) self.execute() + def drop_foreignkey(self, constraint): + self.append("ALTER TABLE %s DROP FOREIGN KEY %s" % (self.preparer.format_table(constraint.table), constraint.name)) + self.execute() class MySQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 90cd66dd3c..a4445b1a83 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -147,6 +147,8 @@ class SQLiteDialect(ansisql.ANSIDialect): return SQLiteCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): return SQLiteSchemaGenerator(*args, **kwargs) + def schemadropper(self, *args, **kwargs): + return SQLiteSchemaDropper(*args, **kwargs) def preparer(self): return SQLiteIdentifierPreparer(self) def create_connect_args(self, url): @@ -283,6 +285,9 @@ class SQLiteCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.binary_operator_string(self, binary) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): + def supports_alter(self): + return False + def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -303,6 +308,10 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): # else: # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) +class SQLiteSchemaDropper(ansisql.ANSISchemaDropper): + def supports_alter(self): + return False + class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def __init__(self, dialect): super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4ba5e11158..83db06090d 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -225,7 +225,7 @@ class Connection(Connectable): """when no Transaction is present, this is called after executions to provide "autocommit" behavior.""" # TODO: have the dialect determine if autocommit can be set on the connection directly without this # extra step - if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP', statement.lstrip().upper()): + if not self.in_transaction() and re.match(r'UPDATE|INSERT|CREATE|DELETE|DROP|ALTER', statement.lstrip().upper()): self._commit_impl() def _autorollback(self): if not self.in_transaction(): diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 5728d7c375..88d52f0753 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -491,7 +491,7 @@ class ForeignKey(SchemaItem): One or more ForeignKey objects are used within a ForeignKeyConstraint object which represents the table-level constraint definition.""" - def __init__(self, column, constraint=None, use_alter=False): + def __init__(self, column, constraint=None, use_alter=False, name=None): """Construct a new ForeignKey object. "column" can be a schema.Column object representing the relationship, @@ -507,6 +507,7 @@ class ForeignKey(SchemaItem): self._column = None self.constraint = constraint self.use_alter = use_alter + self.name = name def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -575,7 +576,7 @@ class ForeignKey(SchemaItem): self.parent = column if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter) + self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter, name=self.name) self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) @@ -699,6 +700,8 @@ class ForeignKeyConstraint(Constraint): self.elements = util.Set() self.onupdate = onupdate self.ondelete = ondelete + if self.name is None and use_alter: + raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name") self.use_alter = use_alter def _set_parent(self, table): self.table = table diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index 94caade68b..4935b1adda 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -40,6 +40,8 @@ class TableCollection(object): tuples = [] class TVisitor(schema.SchemaVisitor): def visit_foreign_key(_self, fkey): + if fkey.use_alter: + return parent_table = fkey.column.table if parent_table in self: child_table = fkey.parent.table diff --git a/test/orm/cycles.py b/test/orm/cycles.py index eebe7af755..0ff3abb7b7 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -213,28 +213,19 @@ class OneToManyManyToOneTest(AssertMixin): global ball ball = Table('ball', metadata, Column('id', Integer, Sequence('ball_id_seq', optional=True), primary_key=True), - Column('person_id', Integer), + Column('person_id', Integer, ForeignKey('person.id', use_alter=True, name='fk_person_id')), Column('data', String(30)) ) person = Table('person', metadata, Column('id', Integer, Sequence('person_id_seq', optional=True), primary_key=True), Column('favorite_ball_id', Integer, ForeignKey('ball.id')), Column('data', String(30)) -# Column('favorite_ball_id', Integer), ) - ball.create() - person.create() - ball.c.person_id.append_foreign_key(ForeignKey('person.id')) + metadata.create_all() - # make the test more complete for postgres - if db.engine.__module__.endswith('postgres'): - db.execute("alter table ball add constraint fk_ball_person foreign key (person_id) references person(id)", {}) def tearDownAll(self): - if db.engine.__module__.endswith('postgres'): - db.execute("alter table ball drop constraint fk_ball_person", {}) - person.drop() - ball.drop() + metadata.drop_all() def tearDown(self): clear_mappers() -- 2.47.2