From bc6fbfa84ab6e1e9639e00cc23b3c41ab1d30dc1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 14 Jul 2006 20:06:09 +0000 Subject: [PATCH] overhaul to schema, addition of ForeignKeyConstraint/ PrimaryKeyConstraint objects (also UniqueConstraint not completed yet). table creation and reflection modified to be more oriented towards these new table-level objects. reflection for sqlite/postgres/mysql supports composite foreign keys; oracle/mssql/firebird not converted yet. --- CHANGES | 7 + lib/sqlalchemy/ansisql.py | 34 +++- lib/sqlalchemy/databases/firebird.py | 6 +- .../databases/information_schema.py | 56 +++--- lib/sqlalchemy/databases/mssql.py | 8 +- lib/sqlalchemy/databases/mysql.py | 31 ++-- lib/sqlalchemy/databases/oracle.py | 6 +- lib/sqlalchemy/databases/postgres.py | 6 +- lib/sqlalchemy/databases/sqlite.py | 39 ++-- lib/sqlalchemy/schema.py | 171 +++++++++++++++--- lib/sqlalchemy/sql.py | 8 - test/engine/reflection.py | 76 ++++++-- test/orm/objectstore.py | 7 +- test/sql/indexes.py | 23 ++- 14 files changed, 318 insertions(+), 160 deletions(-) diff --git a/CHANGES b/CHANGES index 28b3316fc9..810f2ec9a9 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,11 @@ 0.2.6 +- big overhaul to schema to allow truly composite primary and foreign +key constraints, via new ForeignKeyConstraint and PrimaryKeyConstraint +objects. +Existing methods of primary/foreign key creation have not been changed +but use these new objects behind the scenes. table creation +and reflection is now more table oriented rather than column oriented. +[ticket:76] - tweaks to ActiveMapper, supports self-referential relationships - slight rearrangement to objectstore (in activemapper/threadlocal) so that the SessionContext is referenced by '.context' instead diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 78017bc919..5d01e275cb 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -602,7 +602,7 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def get_column_specification(self, column, override_pk=False, first_pk=False): + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() def visit_table(self, table): @@ -614,19 +614,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): separator = "\n" # if only one primary key, specify it along with the column - pks = table.primary_key first_pk = False for column in table.columns: self.append(separator) separator = ", \n" - self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1, first_pk=column.primary_key and not first_pk)) + self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - # if multiple primary keys, specify it at the bottom - if len(pks) > 1: - self.append(", \n") - self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', ')) - + + for constraint in table.constraints: + constraint.accept_schema_visitor(self) + self.append("\n)%s\n\n" % self.post_create_table(table)) self.execute() if hasattr(table, 'indexes'): @@ -650,6 +648,26 @@ class ANSISchemaGenerator(engine.SchemaIterator): compiler.compile() return compiler + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return + self.append(", \n") + self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in constraint],', ')) + + def visit_foreign_key_constraint(self, constraint): + self.append(", \n\t ") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + string.join([f.parent.name for f in constraint.elements], ', '), + list(constraint.elements)[0].column.table.name, + string.join([f.column.name for f in constraint.elements], ', ') + )) + if constraint.ondelete is not None: + self.append(" ON DELETE %s" % constraint.ondelete) + if constraint.onupdate is not None: + self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_column(self, column): pass diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 0039333d51..085d8cf444 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -293,7 +293,7 @@ class FBCompiler(ansisql.ANSICompiler): return "" class FBSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -302,10 +302,6 @@ class FBSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 08236f7991..296db2de57 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -54,6 +54,7 @@ pg_key_constraints = schema.Table("key_column_usage", ischema, Column("table_name", String), Column("column_name", String), Column("constraint_name", String), + Column("ordinal_position", Integer), schema="information_schema") #mysql_key_constraints = schema.Table("key_column_usage", ischema, @@ -100,13 +101,9 @@ class ISchema(object): return self.cache[name] -def reflecttable(connection, table, ischema_names, use_mysql=False): +def reflecttable(connection, table, ischema_names): - if use_mysql: - # no idea which INFORMATION_SCHEMA spec is correct, mysql or postgres - key_constraints = mysql_key_constraints - else: - key_constraints = pg_key_constraints + key_constraints = pg_key_constraints if table.schema is not None: current_schema = table.schema @@ -152,39 +149,50 @@ def reflecttable(connection, table, ischema_names, use_mysql=False): if not found_table: raise exceptions.NoSuchTableError(table.name) - s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)]) - if not use_mysql: - s.append_column(column_constraints) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name] - else: - # this doesnt seem to pick up any foreign keys with mysql - s.append_whereclause(key_constraints.c.table_name==constraints.c.table_name) - s.append_whereclause(key_constraints.c.table_schema==constraints.c.table_schema) - s.append_whereclause(constraints.c.table_name==table.name) - s.append_whereclause(constraints.c.table_schema==current_schema) - colmap = [constraints.c.constraint_type, key_constraints.c.column_name, key_constraints.c.referenced_table_schema, key_constraints.c.referenced_table_name, key_constraints.c.referenced_column_name] + # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns + # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys + # wont reflect properly. dont see a way around this based on whats available from information_schema + s = select([constraints.c.constraint_name, constraints.c.constraint_type, constraints.c.table_name, key_constraints], use_labels=True, from_obj=[constraints.join(column_constraints, column_constraints.c.constraint_name==constraints.c.constraint_name).join(key_constraints, key_constraints.c.constraint_name==column_constraints.c.constraint_name)], order_by=[key_constraints.c.ordinal_position]) + s.append_column(column_constraints) + s.append_whereclause(constraints.c.table_name==table.name) + s.append_whereclause(constraints.c.table_schema==current_schema) + colmap = [constraints.c.constraint_type, key_constraints.c.column_name, column_constraints.c.table_schema, column_constraints.c.table_name, column_constraints.c.column_name, constraints.c.constraint_name, key_constraints.c.ordinal_position] c = connection.execute(s) + fks = {} while True: row = c.fetchone() if row is None: break -# continue - (type, constrained_column, referred_schema, referred_table, referred_column) = ( + (type, constrained_column, referred_schema, referred_table, referred_column, constraint_name, ordinal_position) = ( row[colmap[0]], row[colmap[1]], row[colmap[2]], row[colmap[3]], - row[colmap[4]] + row[colmap[4]], + row[colmap[5]], + row[colmap[6]] ) #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) if type=='PRIMARY KEY': table.c[constrained_column]._set_primary_key() elif type=='FOREIGN KEY': + try: + fk = fks[constraint_name] + except KeyError: + fk = ([],[]) + fks[constraint_name] = fk if current_schema == referred_schema: referred_schema = table.schema - remotetable = Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema) - table.c[constrained_column].append_item(schema.ForeignKey(remotetable.c[referred_column])) + if referred_schema is not None: + refspec = ".".join([referred_schema, referred_table, referred_column]) + else: + refspec = ".".join([referred_table, referred_column]) + if constrained_column not in fk[0]: + fk[0].append(constrained_column) + if refspec not in fk[1]: + fk[1].append(refspec) + + for name, value in fks.iteritems(): + table.append_item(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index c297195caf..9d51d535da 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -511,7 +511,7 @@ class MSSQLCompiler(ansisql.ANSICompiler): class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, first_pk=False): + def get_column_specification(self, column, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column @@ -528,12 +528,6 @@ class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator): default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default - - if column.primary_key: - if not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) return colspec diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 997010f1c2..1d587ff7c5 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -309,8 +309,6 @@ class MySQLDialect(ansisql.ANSIDialect): break #print "row! " + repr(row) if not found_table: - tabletype, foreignkeyD = self.moretableinfo(connection, table=table) - table.kwargs['mysql_engine'] = tabletype found_table = True (name, type, nullable, primary_key, default) = (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4]) @@ -338,16 +336,15 @@ class MySQLDialect(ansisql.ANSIDialect): argslist = re.findall(r'(\d+)', args) coltype = coltype(*[int(a) for a in argslist], **kw) - arglist = [] - fkey = foreignkeyD.get(name) - if fkey is not None: - arglist.append(schema.ForeignKey(fkey)) - - table.append_item(schema.Column(name, coltype, *arglist, + table.append_item(schema.Column(name, coltype, **dict(primary_key=primary_key, nullable=nullable, default=default ))) + + tabletype = self.moretableinfo(connection, table=table) + table.kwargs['mysql_engine'] = tabletype + if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -368,15 +365,15 @@ class MySQLDialect(ansisql.ANSIDialect): match = re.search(r'\b(?:TYPE|ENGINE)=(?P.+)\b', desc[lastparen.start():], re.I) if match: tabletype = match.group('ttype') - foreignkeyD = {} - fkpat = (r'FOREIGN KEY\s*\(`?(?P.+?)`?\)' - r'\s*REFERENCES\s*`?(?P.+?)`?' - r'\s*\(`?(?P.+?)`?\)' - ) + + fkpat = r'CONSTRAINT `(?P.+?)` FOREIGN KEY \((?P.+?)\) REFERENCES `(?P.+?)` \((?P.+?)\)' for match in re.finditer(fkpat, desc): - foreignkeyD[match.group('name')] = match.group('reftable') + '.' + match.group('refcol') + columns = re.findall(r'`(.+?)`', match.group('columns')) + refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))] + constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) + table.append_item(constraint) - return (tabletype, foreignkeyD) + return tabletype class MySQLCompiler(ansisql.ANSICompiler): @@ -411,12 +408,8 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" if column.primary_key: - if not override_pk: - colspec += " PRIMARY KEY" if not column.foreign_key and first_pk and isinstance(column.type, sqltypes.Integer): colspec += " AUTO_INCREMENT" - if column.foreign_key: - colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def post_create_table(self, table): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index bf6c1fd8d4..d184291fd5 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -320,7 +320,7 @@ class OracleCompiler(ansisql.ANSICompiler): return "" class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) @@ -329,10 +329,6 @@ class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index de21bd570d..decccba58d 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -329,7 +329,7 @@ class PGCompiler(ansisql.ANSICompiler): class PGSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name if column.primary_key and not column.foreign_key and isinstance(column.type, sqltypes.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" @@ -341,10 +341,6 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) return colspec def visit_sequence(self, sequence): diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index c07952ff21..c703cd81eb 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -257,7 +257,7 @@ class SQLiteCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.binary_operator_string(self, binary) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column, override_pk=False, **kwargs): + def get_column_specification(self, column, **kwargs): colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -265,34 +265,17 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key and not override_pk: - colspec += " PRIMARY KEY" - if column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec - def visit_table(self, table): - """sqlite is going to create multi-primary keys with just a UNIQUE index.""" - self.append("\nCREATE TABLE " + table.fullname + "(") - - separator = "\n" - - have_pk = False - use_pks = len(table.primary_key) == 1 - for column in table.columns: - self.append(separator) - separator = ", \n" - self.append("\t" + self.get_column_specification(column, override_pk=not use_pks)) - - if len(table.primary_key) > 1: - self.append(", \n") - # put all PRIMARY KEYS in a UNIQUE index - self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_key],', ')) - - self.append("\n)\n\n") - self.execute() - if hasattr(table, 'indexes'): - for index in table.indexes: - self.visit_index(index) + # this doesnt seem to be needed, although i suspect older versions of sqlite might still + # not directly support composite primary keys + #def visit_primary_key_constraint(self, constraint): + # if len(constraint) > 1: + # self.append(", \n") + # # put all PRIMARY KEYS in a UNIQUE index + # self.append("\tUNIQUE (%s)" % string.join([c.name for c in constraint],', ')) + # else: + # super(SQLiteSchemaGenerator, self).visit_primary_key_constraint(constraint) + dialect = SQLiteDialect poolclass = pool.SingletonThreadPool diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 1df2d30053..dcd023fe95 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -18,23 +18,24 @@ from sqlalchemy import sql, types, exceptions,util import sqlalchemy import copy, re, string -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', + 'PrimaryKeyConstraint', 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): """base class for items that define a database schema.""" def _init_items(self, *args): + """initialize the list of child items for this SchemaItem""" for item in args: if item is not None: item._set_parent(self) def _set_parent(self, parent): - """a child item attaches itself to its parent via this method.""" + """associate with this SchemaItem's parent object.""" raise NotImplementedError() def __repr__(self): return "%s()" % self.__class__.__name__ def _derived_metadata(self): - """subclasses override this method to return a the MetaData - to which this item is bound""" + """return the the MetaData to which this item is bound""" return None def _get_engine(self): return self._derived_metadata().engine @@ -77,7 +78,7 @@ class TableSingleton(type): table = metadata.tables[key] if len(args): if redefine: - table.reload_values(*args) + table._reload_values(*args) elif not useexisting: raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) return table @@ -109,7 +110,9 @@ class Table(SchemaItem, sql.TableClause): __metaclass__ = TableSingleton def __init__(self, name, metadata, **kwargs): - """Table objects can be constructed directly. The init method is actually called via + """Construct a Table. + + Table objects can be constructed directly. The init method is actually called via the TableSingleton metaclass. Arguments are: name : the name of this table, exactly as it appears, or will appear, in the database. @@ -141,11 +144,23 @@ class Table(SchemaItem, sql.TableClause): super(Table, self).__init__(name) self._metadata = metadata self.schema = kwargs.pop('schema', None) + self.indexes = util.OrderedProperties() + self.constraints = [] + self.primary_key = PrimaryKeyConstraint() + if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name self.kwargs = kwargs + + def _set_primary_key(self, pk): + if getattr(self, '_primary_key', None) in self.constraints: + self.constraints.remove(self._primary_key) + self._primary_key = pk + self.constraints.append(pk) + primary_key = property(lambda s:s._primary_key, _set_primary_key) + def _derived_metadata(self): return self._metadata def __repr__(self): @@ -158,8 +173,8 @@ class Table(SchemaItem, sql.TableClause): def __str__(self): return _get_table_key(self.name, self.schema) - def reload_values(self, *args): - """clears out the columns and other properties of this Table, and reloads them from the + def _reload_values(self, *args): + """clear out the columns and other properties of this Table, and reload them from the given argument list. This is used with the "redefine" keyword argument sent to the metaclass constructor.""" self._clear() @@ -216,8 +231,9 @@ class Table(SchemaItem, sql.TableClause): return index def deregister(self): - """removes this table from it's metadata. this does not - issue a SQL DROP statement.""" + """remove this table from it's owning metadata. + + this does not issue a SQL DROP statement.""" key = _get_table_key(self.name, self.schema) del self.metadata.tables[key] def create(self, connectable=None): @@ -232,7 +248,7 @@ class Table(SchemaItem, sql.TableClause): else: self.engine.drop(self) def tometadata(self, metadata, schema=None): - """returns a singleton instance of this Table with a different Schema""" + """return a copy of this Table associated with a different MetaData.""" try: if schema is None: schema = self.schema @@ -242,6 +258,8 @@ class Table(SchemaItem, sql.TableClause): args = [] for c in self.columns: args.append(c.copy()) + for c in self.constraints: + args.append(c.copy()) return Table(self.name, metadata, schema=schema, *args) class Column(SchemaItem, sql.ColumnClause): @@ -362,13 +380,9 @@ class Column(SchemaItem, sql.ColumnClause): self._init_items(*self.args) self.args = None - def copy(self): + def copy(self): """creates a copy of this Column, unitialized""" - if self.foreign_key is None: - fk = None - else: - fk = self.foreign_key.copy() - return Column(self.name, self.type, fk, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden) + return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, hidden = self.hidden) def _make_proxy(self, selectable, name = None): """creates a copy of this Column, initialized the way this Column is""" @@ -401,23 +415,33 @@ class Column(SchemaItem, sql.ColumnClause): class ForeignKey(SchemaItem): - """defines a ForeignKey constraint between two columns. ForeignKey is - specified as an argument to a Column object.""" - def __init__(self, column): - """Constructs a new ForeignKey object. "column" can be a schema.Column - object representing the relationship, or just its string name given as - "tablename.columnname". schema can be specified as - "schema.tablename.columnname" """ + """defines a column-level ForeignKey constraint between two columns. + + ForeignKey is specified as an argument to a Column object. + + One or more ForeignKey objects are used within a ForeignKeyConstraint + object which represents the table-level constraint definition.""" + def __init__(self, column, constraint=None): + """Construct a new ForeignKey object. + + "column" can be a schema.Column object representing the relationship, + or just its string name given as "tablename.columnname". schema can be + specified as "schema.tablename.columnname" + + "constraint" is the owning ForeignKeyConstraint object, if any. if not given, + then a ForeignKeyConstraint will be automatically created and added to the parent table. + """ if isinstance(column, unicode): column = str(column) self._colspec = column self._column = None - + self.constraint = constraint + def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) def copy(self): - """produces a copy of this ForeignKey object.""" + """produce a copy of this ForeignKey object.""" return ForeignKey(self._get_colspec()) def _get_colspec(self): @@ -462,6 +486,7 @@ class ForeignKey(SchemaItem): self._column = table.c[colname] else: self._column = self._colspec + return self._column column = property(lambda s: s._init_column()) @@ -472,8 +497,14 @@ class ForeignKey(SchemaItem): def _set_parent(self, column): self.parent = column - # if a foreign key was already set up for this, replace it with - # this one, including removing from the parent + + if self.constraint is None and isinstance(self.parent.table, Table): + self.constraint = ForeignKeyConstraint([],[]) + self.parent.table.append_item(self.constraint) + self.constraint._append_fk(self) + + # if a foreign key was already set up for the parent column, replace it with + # this one if self.parent.foreign_key is not None: self.parent.table.foreign_keys.remove(self.parent.foreign_key) self.parent.foreign_key = self @@ -551,7 +582,81 @@ class Sequence(DefaultGenerator): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) - +class Constraint(SchemaItem): + """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint. + + Also follows list behavior with regards to the underlying set of columns.""" + def __init__(self, name=None): + self.name = name + self.columns = [] + def __contains__(self, x): + return x in self.columns + def __add__(self, other): + return self.columns + other + def __iter__(self): + return iter(self.columns) + def __len__(self): + return len(self.columns) + def __getitem__(self, index): + return self.columns[index] + def copy(self): + raise NotImplementedError() + +class ForeignKeyConstraint(Constraint): + """table-level foreign key constraint, represents a colleciton of ForeignKey objects.""" + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None): + super(ForeignKeyConstraint, self).__init__(name) + self.__colnames = columns + self.__refcolnames = refcolumns + self.elements = [] + self.onupdate = onupdate + self.ondelete = ondelete + def _set_parent(self, table): + self.table = table + table.constraints.append(self) + for (c, r) in zip(self.__colnames, self.__refcolnames): + self.append(c,r) + def accept_schema_visitor(self, visitor): + visitor.visit_foreign_key_constraint(self) + def append(self, col, refcol): + fk = ForeignKey(refcol, constraint=self) + fk._set_parent(self.table.c[col]) + self._append_fk(fk) + def _append_fk(self, fk): + self.columns.append(self.table.c[fk.parent.name]) + self.elements.append(fk) + def copy(self): + return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) + +class PrimaryKeyConstraint(Constraint): + def __init__(self, *columns, **kwargs): + super(PrimaryKeyConstraint, self).__init__(name=kwargs.pop('name', None)) + self.__colnames = list(columns) + def _set_parent(self, table): + table.primary_key = self + for c in self.__colnames: + self.append(table.c[c]) + def accept_schema_visitor(self, visitor): + visitor.visit_primary_key_constraint(self) + def append(self, col): + self.columns.append(col) + col.primary_key=True + def copy(self): + return PrimaryKeyConstraint(name=self.name, *[c.name for c in self]) + +class UniqueConstraint(Constraint): + def __init__(self, name=None, *columns): + super(Constraint, self).__init__(name) + self.__colnames = list(columns) + def _set_parent(self, table): + table.constraints.append(self) + for c in self.__colnames: + self.append(table.c[c]) + def append(self, col): + self.columns.append(col) + def accept_schema_visitor(self, visitor): + visitor.visit_unique_constraint(self) + class Index(SchemaItem): """Represents an index of columns from a database table """ @@ -746,7 +851,13 @@ class SchemaVisitor(sql.ClauseVisitor): def visit_sequence(self, sequence): """visit a Sequence.""" pass - + def visit_primary_key_constraint(self, constraint): + pass + def visit_foreign_key_constraint(self, constraint): + pass + def visit_unique_constraint(self, constraint): + pass + default_metadata = DynamicMetaData('default') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 7b17927f08..8109d8cd52 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1225,15 +1225,12 @@ class TableClause(FromClause): super(TableClause, self).__init__(name) self.name = self.fullname = name self._columns = util.OrderedProperties() - self._indexes = util.OrderedProperties() self._foreign_keys = [] self._primary_key = [] for c in columns: self.append_column(c) self._oid_column = ColumnClause('oid', self, hidden=True) - indexes = property(lambda s:s._indexes) - def named_with_column(self): return True def append_column(self, c): @@ -1250,16 +1247,11 @@ class TableClause(FromClause): for ci in c.orig_set: self._orig_cols[ci] = c return self._orig_cols - columns = property(lambda s:s._columns) - c = property(lambda s:s._columns) - primary_key = property(lambda s:s._primary_key) - foreign_keys = property(lambda s:s._foreign_keys) original_columns = property(_orig_columns) def _clear(self): """clears all attributes on this TableClause so that new items can be added again""" self.columns.clear() - self.indexes.clear() self.foreign_keys[:] = [] self.primary_key[:] = [] try: diff --git a/test/engine/reflection.py b/test/engine/reflection.py index f9fa4e40c5..ec59d652cd 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -29,8 +29,10 @@ class ReflectionTest(PersistTest): else: deftype2 = Integer defval2 = "15" - - users = Table('engine_users', testbase.db, + + meta = BoundMetaData(testbase.db) + + users = Table('engine_users', meta, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20), nullable = False), Column('test1', CHAR(5), nullable = False), @@ -49,14 +51,13 @@ class ReflectionTest(PersistTest): mysql_engine='InnoDB' ) - addresses = Table('engine_email_addresses', testbase.db, + addresses = Table('engine_email_addresses', meta, Column('address_id', Integer, primary_key = True), Column('remote_user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(20)), mysql_engine='InnoDB' ) - # users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id)) users.create() @@ -119,28 +120,69 @@ class ReflectionTest(PersistTest): table.insert().execute({'multi_id':3,'multi_rev':3,'name':'row3', 'val':'value3'}) table.select().execute().fetchall() table.drop() - + + def testcompositefk(self): + meta = BoundMetaData(testbase.db) + table = Table( + 'multi', meta, + Column('multi_id', Integer, primary_key=True), + Column('multi_rev', Integer, primary_key=True), + Column('name', String(50), nullable=False), + Column('val', String(100)), + mysql_engine='InnoDB' + ) + table2 = Table('multi2', meta, + Column('id', Integer, primary_key=True), + Column('foo', Integer), + Column('bar', Integer), + Column('data', String(50)), + ForeignKeyConstraint(['foo', 'bar'], ['multi.multi_id', 'multi.multi_rev']), + mysql_engine='InnoDB' + ) + meta.create_all() + meta.clear() + + try: + table = Table('multi', meta, autoload=True) + table2 = Table('multi2', meta, autoload=True) + + print table + print table2 + j = join(table, table2) + print str(j.onclause) + self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar).compare(j.onclause)) + + finally: + meta.drop_all() + def testtoengine(self): meta = MetaData('md1') meta2 = MetaData('md2') table = Table('mytable', meta, - Column('myid', Integer, key = 'id'), - Column('name', String, key = 'name', nullable=False), - Column('description', String, key = 'description'), + Column('myid', Integer, primary_key=True), + Column('name', String, nullable=False), + Column('description', String(30)), ) - print repr(table) - - table2 = table.tometadata(meta2) + table2 = Table('othertable', meta, + Column('id', Integer, primary_key=True), + Column('myid', Integer, ForeignKey('mytable.myid')) + ) + - print repr(table2) + table_c = table.tometadata(meta2) + table2_c = table2.tometadata(meta2) + + assert table is not table_c + assert table_c.c.myid.primary_key + assert not table_c.c.name.nullable + assert table_c.c.description.nullable + assert table.primary_key is not table_c.primary_key + assert [x.name for x in table.primary_key] == [x.name for x in table_c.primary_key] + assert table2_c.c.myid.foreign_key.column is table_c.c.myid + assert table2_c.c.myid.foreign_key.column is not table.c.myid - assert table is not table2 - assert table2.c.id.nullable - assert not table2.c.name.nullable - assert table2.c.description.nullable - # mysql throws its own exception for no such table, resulting in # a sqlalchemy.SQLError instead of sqlalchemy.NoSuchTableError. # this could probably be fixed at some point. diff --git a/test/orm/objectstore.py b/test/orm/objectstore.py index 6d3a712762..237c4b5545 100644 --- a/test/orm/objectstore.py +++ b/test/orm/objectstore.py @@ -230,7 +230,7 @@ class PKTest(SessionTest): @testbase.unsupported('mssql') def setUpAll(self): SessionTest.setUpAll(self) - db.echo = False + #db.echo = False global table global table2 global table3 @@ -266,6 +266,8 @@ class PKTest(SessionTest): db.echo = testbase.echo SessionTest.tearDownAll(self) + # not support on sqlite since sqlite's auto-pk generation only works with + # single column primary keys @testbase.unsupported('sqlite', 'mssql') def testprimarykey(self): class Entry(object): @@ -279,6 +281,8 @@ class PKTest(SessionTest): ctx.current.clear() e2 = Entry.mapper.get((e.multi_id, 2)) self.assert_(e is not e2 and e._instance_key == e2._instance_key) + + # this one works with sqlite since we are manually setting up pk values @testbase.unsupported('mssql') def testmanualpk(self): class Entry(object): @@ -289,6 +293,7 @@ class PKTest(SessionTest): e.pk_col_2 = 'pk1_related' e.data = 'im the data' ctx.current.flush() + @testbase.unsupported('mssql') def testkeypks(self): import datetime diff --git a/test/sql/indexes.py b/test/sql/indexes.py index ec72beda39..e9af301de6 100644 --- a/test/sql/indexes.py +++ b/test/sql/indexes.py @@ -5,16 +5,33 @@ import sys class IndexTest(testbase.AssertMixin): def setUp(self): - global metadata - metadata = BoundMetaData(testbase.db) + global metadata + metadata = BoundMetaData(testbase.db) self.echo = testbase.db.echo self.logger = testbase.db.logger def tearDown(self): testbase.db.echo = self.echo testbase.db.logger = testbase.db.engine.logger = self.logger - metadata.drop_all() + metadata.drop_all() + def test_constraint(self): + employees = Table('employees', metadata, + Column('id', Integer), + Column('soc', String(40)), + Column('name', String(30)), + PrimaryKeyConstraint('id', 'soc') + ) + elements = Table('elements', metadata, + Column('id', Integer), + Column('stuff', String(30)), + Column('emp_id', Integer), + Column('emp_soc', String(40)), + PrimaryKeyConstraint('id'), + ForeignKeyConstraint(['emp_id', 'emp_soc'], ['employees.id', 'employees.soc']) + ) + metadata.create_all() + def test_index_create(self): employees = Table('employees', metadata, Column('id', Integer, primary_key=True), -- 2.47.2