From: Mike Bayer Date: Sat, 14 Oct 2006 21:58:04 +0000 (+0000) Subject: - a fair amount of cleanup to the schema package, removal of ambiguous X-Git-Tag: rel_0_3_0~51 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8340006dd7ed34cf32bbb7f856397d1c7f13d295;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - a fair amount of cleanup to the schema package, removal of ambiguous methods, methods that are no longer needed. slightly more constrained useage, greater emphasis on explicitness. - table_iterator signature fixup, includes fix for [ticket:288] - the "primary_key" attribute of Table and other selectables becomes a setlike ColumnCollection object; is no longer ordered or numerically indexed. a comparison clause between two pks that are derived from the same underlying tables (i.e. such as two Alias objects) can be generated via table1.primary_key==table2.primary_key - append_item() methods removed from Table and Column; preferably construct Table/Column/related objects inline, but if needed use append_column(), append_foreign_key(), append_constraint(), etc. - table.create() no longer returns the Table object, instead has no return value. the usual case is that tables are created via metadata, which is preferable since it will handle table dependencies. - added UniqueConstraint (goes at Table level), CheckConstraint (goes at Table or Column level) fixes [ticket:217] - index=False/unique=True on Column now creates a UniqueConstraint, index=True/unique=False creates a plain Index, index=True/unique=True on Column creates a unique Index. 'index' and 'unique' keyword arguments to column are now boolean only; for explcit names and groupings of indexes or unique constraints, use the UniqueConstraint/Index constructs explicitly. - relationship of Metadata/Table/SchemaGenerator/Dropper has been improved so that the schemavisitor receives the metadata object for greater control over groupings of creates/drops. - added "use_alter" argument to ForeignKey, ForeignKeyConstraint, but it doesnt do anything yet. will utilize new generator/dropper behavior to implement. --- diff --git a/CHANGES b/CHANGES index 721a5ef19c..b9ac78ee9a 100644 --- a/CHANGES +++ b/CHANGES @@ -32,6 +32,28 @@ - aliases do not use "AS" - correctly raises NoSuchTableError when reflecting non-existent table - Schema: + - a fair amount of cleanup to the schema package, removal of ambiguous + methods, methods that are no longer needed. slightly more constrained + useage, greater emphasis on explicitness + - the "primary_key" attribute of Table and other selectables becomes + a setlike ColumnCollection object; is no longer ordered or numerically + indexed. a comparison clause between two pks that are derived from the + same underlying tables (i.e. such as two Alias objects) can be generated + via table1.primary_key==table2.primary_key + - append_item() methods removed from Table and Column; preferably + construct Table/Column/related objects inline, but if needed use + append_column(), append_foreign_key(), append_constraint(), etc. + - table.create() no longer returns the Table object, instead has no + return value. the usual case is that tables are created via metadata, + which is preferable since it will handle table dependencies. + - added UniqueConstraint (goes at Table level), CheckConstraint + (goes at Table or Column level). + - index=False/unique=True on Column now creates a UniqueConstraint, + index=True/unique=False creates a plain Index, + index=True/unique=True on Column creates a unique Index. 'index' + and 'unique' keyword arguments to column are now boolean only; for + explcit names and groupings of indexes or unique constraints, use the + UniqueConstraint/Index constructs explicitly. - added autoincrement=True to Column; will disable schema generation of SERIAL/AUTO_INCREMENT/identity seq for postgres/mysql/mssql if explicitly set to False diff --git a/doc/build/content/metadata.txt b/doc/build/content/metadata.txt index bf5d78dce0..4cae58cb4b 100644 --- a/doc/build/content/metadata.txt +++ b/doc/build/content/metadata.txt @@ -470,41 +470,86 @@ A Sequence object can be defined on a Table that is then used for a non-sequence A sequence can also be specified with `optional=True` which indicates the Sequence should only be used on a database that requires an explicit sequence, and not those that supply some other method of providing integer values. At the moment, it essentially means "use this sequence only with Oracle and not Postgres". -### Defining Indexes {@name=indexes} +### Defining Constraints and Indexes {@name=constraints} -Indexes can be defined on table columns, including named indexes, non-unique or unique, multiple column. Indexes are included along with table create and drop statements. They are not used for any kind of run-time constraint checking; SQLAlchemy leaves that job to the expert on constraint checking, the database itself. +#### UNIQUE Constraint + +Unique constraints can be created anonymously on a single column using the `unique` keyword on `Column`. Explicitly named unique constraints and/or those with multiple columns are created via the `UniqueConstraint` table-level construct. {python} - boundmeta = BoundMetaData('postgres:///scott:tiger@localhost/test') - mytable = Table('mytable', boundmeta, - # define a unique index + meta = MetaData() + mytable = Table('mytable', meta, + + # per-column anonymous unique constraint Column('col1', Integer, unique=True), - # define a unique index with a specific name - Column('col2', Integer, unique='mytab_idx_1'), - - # define a non-unique index - Column('col3', Integer, index=True), + Column('col2', Integer), + Column('col3', Integer), - # define a non-unique index with a specific name - Column('col4', Integer, index='mytab_idx_2'), + # explicit/composite unique constraint. 'name' is optional. + UniqueConstraint('col2', 'col3', name='uix_1') + ) + +#### CHECK Constraint + +Check constraints can be named or unnamed and can be created at the Column or Table level, using the `CheckConstraint` construct. The text of the check constraint is passed directly through to the database, so there is limited "database independent" behavior. Column level check constraints generally should only refer to the column to which they are placed, while table level constraints can refer to any columns in the table. + +Note that some databases do not actively support check constraints such as MySQL and sqlite. + + {python} + meta = MetaData() + mytable = Table('mytable', meta, + + # per-column CHECK constraint + Column('col1', Integer, CheckConstraint('col1>5')), - # pass the same name to multiple columns to add them to the same index - Column('col5', Integer, index='mytab_idx_2'), + Column('col2', Integer), + Column('col3', Integer), - Column('col6', Integer), - Column('col7', Integer) - ) - - # create the table. all the indexes will be created along with it. - mytable.create() - - # indexes can also be specified standalone - i = Index('mytab_idx_3', mytable.c.col6, mytable.c.col7, unique=False) + # table level CHECK constraint. 'name' is optional. + CheckConstraint('col2 > col3 + 5', name='check1') + ) - # which can then be created separately (will also get created with table creates) +#### Indexes + +Indexes can be created anonymously (using an auto-generated name "ix_<column label>") for a single column using the inline `index` keyword on `Column`, which also modifies the usage of `unique` to apply the uniqueness to the index itself, instead of adding a separate UNIQUE constraint. For indexes with specific names or which encompass more than one column, use the `Index` construct, which requires a name. + +Note that the `Index` construct is created **externally** to the table which it corresponds, using `Column` objects and not strings. + + {python} + meta = MetaData() + mytable = Table('mytable', meta, + # an indexed column, with index "ix_mytable_col1" + Column('col1', Integer, index=True), + + # a uniquely indexed column with index "ix_mytable_col2" + Column('col2', Integer, index=True, unique=True), + + Column('col3', Integer), + Column('col4', Integer), + + Column('col5', Integer), + Column('col6', Integer), + ) + + # place an index on col3, col4 + Index('idx_col34', mytable.c.col3, mytable.c.col4) + + # place a unique index on col5, col6 + Index('myindex', mytable.c.col5, mytable.c.col6, unique=True) + +The `Index` objects will be created along with the CREATE statements for the table itself. An index can also be created on its own independently of the table: + + {python} + # create a table + sometable.create() + + # define an index + i = Index('someindex', sometable.c.col5) + + # create the index, will use the table's connectable, or specify the connectable keyword argument i.create() - + ### Adapting Tables to Alternate Metadata {@name=adapting} A `Table` object created against a specific `MetaData` object can be re-created against a new MetaData using the `tometadata` method: diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 2b0d7d17e5..208b2f603a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -7,7 +7,7 @@ """defines ANSI SQL operations. Contains default implementations for the abstract objects in the sql module.""" -from sqlalchemy import schema, sql, engine, util +from sqlalchemy import schema, sql, engine, util, sql_util import sqlalchemy.engine.default as default import string, re, sets, weakref @@ -28,9 +28,6 @@ RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', ' LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') -def create_engine(): - return engine.ComposedSQLEngine(None, ANSIDialect()) - class ANSIDialect(default.DefaultDialect): def __init__(self, cache_identifiers=True, **kwargs): super(ANSIDialect,self).__init__(**kwargs) @@ -174,7 +171,7 @@ class ANSICompiler(sql.Compiled): if n is not None: self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n) elif len(column.table.primary_key) != 0: - self.strings[column] = self.preparer.format_column_with_table(column.table.primary_key[0]) + self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0]) else: self.strings[column] = None else: @@ -611,22 +608,30 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaGenerator, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaGenerator, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables and util.Set(tables) or None self.connection = connection self.preparer = self.engine.dialect.preparer() - + self.dialect = self.engine.dialect + def get_column_specification(self, column, first_pk=False): raise NotImplementedError() - - def visit_table(self, table): - # the single whitespace before the "(" is significant - # as its MySQL's method of indicating a table name and not a reserved word. - # feel free to localize this logic to the mysql module - if self.checkfirst and self.engine.dialect.has_table(self.connection, table.name): - return + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=False, tables=self.tables): + if self.checkfirst and self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) + def visit_table(self, table): + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + #if column.onupdate is not None: + # column.onupdate.accept_schema_visitor(visitor, traverse=False) + self.append("\nCREATE TABLE " + self.preparer.format_table(table) + " (") separator = "\n" @@ -639,15 +644,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.append("\t" + self.get_column_specification(column, first_pk=column.primary_key and not first_pk)) if column.primary_key: first_pk = True - + for constraint in column.constraints: + constraint.accept_schema_visitor(self, traverse=False) + for constraint in table.constraints: - constraint.accept_schema_visitor(self) + constraint.accept_schema_visitor(self, traverse=False) self.append("\n)%s\n\n" % self.post_create_table(table)) - self.execute() + self.execute() if hasattr(table, 'indexes'): for index in table.indexes: - self.visit_index(index) + index.accept_schema_visitor(self, traverse=False) def post_create_table(self, table): return '' @@ -662,10 +669,17 @@ class ANSISchemaGenerator(engine.SchemaIterator): return None def _compile(self, tocompile, parameters): + """compile the given string/parameters using this SchemaGenerator's dialect.""" compiler = self.engine.dialect.compiler(tocompile, parameters) compiler.compile() return compiler + def visit_check_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" CHECK (%s)" % constraint.sqltext) + def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: return @@ -688,6 +702,13 @@ class ANSISchemaGenerator(engine.SchemaIterator): if constraint.onupdate is not None: self.append(" ON UPDATE %s" % constraint.onupdate) + def visit_unique_constraint(self, constraint): + self.append(", \n\t") + if constraint.name is not None: + self.append("CONSTRAINT %s " % constraint.name) + self.append(" UNIQUE ") + self.append("(%s)" % (string.join([self.preparer.format_column(c) for c in constraint],', '))) + def visit_column(self, column): pass @@ -701,21 +722,29 @@ class ANSISchemaGenerator(engine.SchemaIterator): self.execute() class ANSISchemaDropper(engine.SchemaIterator): - def __init__(self, engine, proxy, connection=None, checkfirst=False, **params): - super(ANSISchemaDropper, self).__init__(engine, proxy, **params) + def __init__(self, engine, proxy, connection, checkfirst=False, tables=None, **kwargs): + super(ANSISchemaDropper, self).__init__(engine, proxy, **kwargs) self.checkfirst = checkfirst + self.tables = tables self.connection = connection self.preparer = self.engine.dialect.preparer() + self.dialect = self.engine.dialect + + def visit_metadata(self, metadata): + for table in metadata.table_iterator(reverse=True, tables=self.tables): + if self.checkfirst and not self.dialect.has_table(self.connection, table.name): + continue + table.accept_schema_visitor(self, traverse=False) def visit_index(self, index): self.append("\nDROP INDEX " + index.name) self.execute() def visit_table(self, table): - # NOTE: indexes on the table will be automatically dropped, so - # no need to drop them individually - if self.checkfirst and not self.engine.dialect.has_table(self.connection, table.name): - return + for column in table.columns: + if column.default is not None: + column.default.accept_schema_visitor(self, traverse=False) + self.append("\nDROP TABLE " + self.preparer.format_table(table)) self.execute() diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index fa090a89e5..f38a24b1f8 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -253,7 +253,7 @@ class FireBirdDialect(ansisql.ANSIDialect): # is it a primary key? kw['primary_key'] = name in pkfields - table.append_item(schema.Column(*args, **kw)) + table.append_column(schema.Column(*args, **kw)) row = c.fetchone() # get the foreign keys @@ -276,7 +276,7 @@ class FireBirdDialect(ansisql.ANSIDialect): fk[1].append(refspec) for name,value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def last_inserted_ids(self): diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 291637e9e5..5a7369ccda 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -144,7 +144,7 @@ def reflecttable(connection, table, ischema_names): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -175,7 +175,7 @@ def reflecttable(connection, table, ischema_names): ) #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) if type=='PRIMARY KEY': - table.c[constrained_column]._set_primary_key() + table.primary_key.add(table.c[constrained_column]) elif type=='FOREIGN KEY': try: fk = fks[constraint_name] @@ -196,5 +196,5 @@ def reflecttable(connection, table, ischema_names): fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(ForeignKeyConstraint(value[0], value[1], name=name)) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 3d65abf0c5..d23c417306 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -446,7 +446,7 @@ class MSSQLDialect(ansisql.ANSIDialect): if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -478,7 +478,7 @@ class MSSQLDialect(ansisql.ANSIDialect): c = connection.execute(s) for row in c: if 'PRIMARY' in row[TC.c.constraint_type.name]: - table.c[row[0]]._set_primary_key() + table.primary_key.add(table.c[row[0]]) # Foreign key constraints @@ -498,13 +498,13 @@ class MSSQLDialect(ansisql.ANSIDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r if rfknm != fknm: if fknm: - table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) fknm, scols, rcols = (rfknm, [], []) if (not scol in scols): scols.append(scol) if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol)) if fknm and scols: - table.append_item(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm)) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 4443814c56..2fa7e9227f 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -353,7 +353,7 @@ class MySQLDialect(ansisql.ANSIDialect): colargs= [] if default: colargs.append(schema.PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, *colargs, + table.append_column(schema.Column(name, coltype, *colargs, **dict(primary_key=primary_key, nullable=nullable, ))) @@ -397,7 +397,7 @@ class MySQLDialect(ansisql.ANSIDialect): refcols = [match.group('reftable') + "." + x for x in re.findall(r'`(.+?)`', match.group('refcols'))] schema.Table(match.group('reftable'), table.metadata, autoload=True, autoload_with=connection) constraint = schema.ForeignKeyConstraint(columns, refcols, name=match.group('name')) - table.append_item(constraint) + table.append_constraint(constraint) return tabletype diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index db82e3dea8..b9aa096952 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -256,7 +256,7 @@ class OracleDialect(ansisql.ANSIDialect): if (name.upper() == name): name = name.lower() - table.append_item (schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) c = connection.execute(constraintSQL, {'table_name' : table.name.upper(), 'owner' : owner}) @@ -268,7 +268,7 @@ class OracleDialect(ansisql.ANSIDialect): #print "ROW:" , row (cons_name, cons_type, local_column, remote_table, remote_column) = row if cons_type == 'P': - table.c[local_column]._set_primary_key() + table.primary_key.add(table.c[local_column]) elif cons_type == 'R': try: fk = fks[cons_name] @@ -283,7 +283,7 @@ class OracleDialect(ansisql.ANSIDialect): fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1], name=name)) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def do_executemany(self, c, statement, parameters, context=None): rowcount = 0 diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a28a22cd62..dad2d3bff2 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -370,7 +370,7 @@ class PGDialect(ansisql.ANSIDialect): colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) - table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: @@ -392,7 +392,7 @@ class PGDialect(ansisql.ANSIDialect): if row is None: break pk = row[0] - table.c[pk]._set_primary_key() + table.primary_key.add(table.c[pk]) # Foreign keys FK_SQL = """ @@ -443,7 +443,7 @@ class PGDialect(ansisql.ANSIDialect): for column in referred_columns: refspec.append(".".join([referred_table, column])) - table.append_item(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) + table.append_constraint(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) class PGCompiler(ansisql.ANSICompiler): @@ -502,13 +502,13 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): return colspec def visit_sequence(self, sequence): - if not sequence.optional and not self.engine.dialect.has_sequence(self.connection, sequence.name): + if not sequence.optional and (not self.dialect.has_sequence(self.connection, sequence.name)): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): - if not sequence.optional and self.engine.dialect.has_sequence(self.connection, sequence.name): + if not sequence.optional and (self.dialect.has_sequence(self.connection, sequence.name)): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 80d5a7d2af..90cd66dd3c 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -199,7 +199,7 @@ class SQLiteDialect(ansisql.ANSIDialect): colargs= [] if has_default: colargs.append(PassiveDefault('?')) - table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) + table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) @@ -228,7 +228,7 @@ class SQLiteDialect(ansisql.ANSIDialect): if refspec not in fk[1]: fk[1].append(refspec) for name, value in fks.iteritems(): - table.append_item(schema.ForeignKeyConstraint(value[0], value[1])) + table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1])) # check for UNIQUE indexes c = connection.execute("PRAGMA index_list(" + table.name + ")", {}) unique_indexes = [] @@ -250,8 +250,7 @@ class SQLiteDialect(ansisql.ANSIDialect): col = table.columns[row[2]] # unique index that includes the pk is considered a multiple primary key for col in cols: - column = table.columns[col] - table.columns[col]._set_primary_key() + table.primary_key.add(table.columns[col]) class SQLiteCompiler(ansisql.ANSICompiler): def visit_cast(self, cast): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 6d0cf2eb36..4ba5e11158 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -421,7 +421,7 @@ class ComposedSQLEngine(sql.Engine, Connectable): else: conn = connection try: - element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs)) + element.accept_schema_visitor(visitorcallable(self, conn.proxy, connection=conn, **kwargs), traverse=False) finally: if connection is None: conn.close() diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5afd3e1b67..462c5e7991 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -366,10 +366,8 @@ class Query(object): if not distinct and order_by: s2.order_by(*util.to_list(order_by)) s3 = s2.alias('tbl_row_count') - crit = [] - for i in range(0, len(self.table.primary_key)): - crit.append(s3.primary_key[i] == self.table.primary_key[i]) - statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update) + crit = s3.primary_key==self.table.primary_key + statement = sql.select([], crit, from_obj=[self.table], use_labels=True, for_update=for_update) # now for the order by, convert the columns to their corresponding columns # in the "rowcount" query, and tack that new order by onto the "rowcount" query if order_by: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 05753e424f..5728d7c375 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,7 +19,7 @@ import sqlalchemy import copy, re, string __all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', 'ForeignKeyConstraint', - 'PrimaryKeyConstraint', + 'PrimaryKeyConstraint', 'CheckConstraint', 'UniqueConstraint', 'MetaData', 'BoundMetaData', 'DynamicMetaData', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): @@ -99,36 +99,33 @@ def _get_table_key(name, schema): class TableSingleton(type): """a metaclass used by the Table object to provide singleton behavior.""" def __call__(self, name, metadata, *args, **kwargs): + if isinstance(metadata, sql.Engine): + # backwards compatibility - get a BoundSchema associated with the engine + engine = metadata + if not hasattr(engine, '_legacy_metadata'): + engine._legacy_metadata = BoundMetaData(engine) + metadata = engine._legacy_metadata + elif metadata is not None and not isinstance(metadata, MetaData): + # they left MetaData out, so assume its another SchemaItem, add it to *args + args = list(args) + args.insert(0, metadata) + metadata = None + + if metadata is None: + metadata = default_metadata + + name = str(name) # in case of incoming unicode + schema = kwargs.get('schema', None) + autoload = kwargs.pop('autoload', False) + autoload_with = kwargs.pop('autoload_with', False) + mustexist = kwargs.pop('mustexist', False) + useexisting = kwargs.pop('useexisting', False) + key = _get_table_key(name, schema) try: - if isinstance(metadata, sql.Engine): - # backwards compatibility - get a BoundSchema associated with the engine - engine = metadata - if not hasattr(engine, '_legacy_metadata'): - engine._legacy_metadata = BoundMetaData(engine) - metadata = engine._legacy_metadata - elif metadata is not None and not isinstance(metadata, MetaData): - # they left MetaData out, so assume its another SchemaItem, add it to *args - args = list(args) - args.insert(0, metadata) - metadata = None - - if metadata is None: - metadata = default_metadata - - name = str(name) # in case of incoming unicode - schema = kwargs.get('schema', None) - autoload = kwargs.pop('autoload', False) - autoload_with = kwargs.pop('autoload_with', False) - redefine = kwargs.pop('redefine', False) - mustexist = kwargs.pop('mustexist', False) - useexisting = kwargs.pop('useexisting', False) - key = _get_table_key(name, schema) table = metadata.tables[key] if len(args): - if redefine: - table._reload_values(*args) - elif not useexisting: - raise exceptions.ArgumentError("Table '%s.%s' is already defined. specify 'redefine=True' to remap columns, or 'useexisting=True' to use the existing table" % (schema, name)) + if not useexisting: + raise exceptions.ArgumentError("Table '%s.%s' is already defined for this MetaData instance." % (schema, name)) return table except KeyError: if mustexist: @@ -145,7 +142,7 @@ class TableSingleton(type): else: metadata.get_engine().reflecttable(table) except exceptions.NoSuchTableError: - table.deregister() + del metadata.tables[key] raise # initialize all the column, etc. objects. done after # reflection to allow user-overrides @@ -210,8 +207,8 @@ class Table(SchemaItem, sql.TableClause): super(Table, self).__init__(name) self._metadata = metadata self.schema = kwargs.pop('schema', None) - self.indexes = util.OrderedProperties() - self.constraints = [] + self.indexes = util.Set() + self.constraints = util.Set() self.primary_key = PrimaryKeyConstraint() self.quote = kwargs.get('quote', False) self.quote_schema = kwargs.get('quote_schema', False) @@ -237,7 +234,7 @@ class Table(SchemaItem, sql.TableClause): if getattr(self, '_primary_key', None) in self.constraints: self.constraints.remove(self._primary_key) self._primary_key = pk - self.constraints.append(pk) + self.constraints.add(pk) primary_key = property(lambda s:s._primary_key, _set_primary_key) def _derived_metadata(self): @@ -251,93 +248,45 @@ class Table(SchemaItem, sql.TableClause): def __str__(self): return _get_table_key(self.name, self.schema) - - def _reload_values(self, *args): - """clear out the columns and other properties of this Table, and reload them from the - given argument list. This is used with the "redefine" keyword argument sent to the - metaclass constructor.""" - self._clear() - - self._init_items(*args) - def append_item(self, item): - """appends a Column item or other schema item to this Table.""" - self._init_items(item) - def append_column(self, column): - if not column.hidden: - self._columns[column.key] = column - if column.primary_key: - self.primary_key.append(column) - column.table = self + """append a Column to this Table.""" + column._set_parent(self) + def append_constraint(self, constraint): + """append a Constraint to this Table.""" + constraint._set_parent(self) - def append_index(self, index): - self.indexes[index.name] = index - def _get_parent(self): return self._metadata def _set_parent(self, metadata): metadata.tables[_get_table_key(self.name, self.schema)] = self self._metadata = metadata - def accept_schema_visitor(self, visitor): - """traverses the given visitor across the Column objects inside this Table, - then calls the visit_table method on the visitor.""" - for c in self.columns: - c.accept_schema_visitor(visitor) + def accept_schema_visitor(self, visitor, traverse=True): + if traverse: + for c in self.columns: + c.accept_schema_visitor(visitor, True) return visitor.visit_table(self) - def append_index_column(self, column, index=None, unique=None): - """Add an index or a column to an existing index of the same name. - """ - if index is not None and unique is not None: - raise ValueError("index and unique may not both be specified") - if index: - if index is True: - name = 'ix_%s' % column._label - else: - name = index - elif unique: - if unique is True: - name = 'ux_%s' % column._label - else: - name = unique - # find this index in self.indexes - # add this column to it if found - # otherwise create new - try: - index = self.indexes[name] - index.append_column(column) - except KeyError: - index = Index(name, column, unique=unique) - return index - - def deregister(self): - """remove this table from it's owning metadata. - - this does not issue a SQL DROP statement.""" - key = _get_table_key(self.name, self.schema) - del self.metadata.tables[key] - - def exists(self, engine=None): - if engine is None: - engine = self.get_engine() + def exists(self, connectable=None): + """return True if this table exists.""" + if connectable is None: + connectable = self.get_engine() def do(conn): e = conn.engine return e.dialect.has_table(conn, self.name) - return engine.run_callable(do) + return connectable.run_callable(do) def create(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.create(self, checkfirst=checkfirst) - else: - self.get_engine().create(self, checkfirst=checkfirst) - return self + """issue a CREATE statement for this table. + + see also metadata.create_all().""" + self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def drop(self, connectable=None, checkfirst=False): - if connectable is not None: - connectable.drop(self, checkfirst=checkfirst) - else: - self.get_engine().drop(self, checkfirst=checkfirst) + """issue a DROP statement for this table. + + see also metadata.drop_all().""" + self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self]) def tometadata(self, metadata, schema=None): """return a copy of this Table associated with a different MetaData.""" try: @@ -389,17 +338,16 @@ class Column(SchemaItem, sql.ColumnClause): table's list of columns. Used for the "oid" column, which generally isnt in column lists. - index=None : True or index name. Indicates that this column is - indexed. Pass true to autogenerate the index name. Pass a string to - specify the index name. Multiple columns that specify the same index - name will all be included in the index, in the order of their - creation. + index=False : Indicates that this column is + indexed. The name of the index is autogenerated. + to specify indexes with explicit names or indexes that contain multiple + columns, use the Index construct instead. - unique=None : True or index name. Indicates that this column is - indexed in a unique index . Pass true to autogenerate the index - name. Pass a string to specify the index name. Multiple columns that - specify the same index name will all be included in the index, in the - order of their creation. + unique=False : Indicates that this column + contains a unique constraint, or if index=True as well, indicates + that the Index should be created with the unique flag. + To specify multiple columns in the constraint/index or to specify an + explicit name, use the UniqueConstraint or Index constructs instead. autoincrement=True : Indicates that integer-based primary key columns should have autoincrementing behavior, if supported by the underlying database. This will affect CREATE TABLE statements such that they will @@ -430,9 +378,8 @@ class Column(SchemaItem, sql.ColumnClause): self._set_casing_strategy(name, kwargs) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) + self.constraints = util.Set() self.__originating_column = self - if self.index is not None and self.unique is not None: - raise exceptions.ArgumentError("Column may not define both index and unique") self._foreign_keys = util.Set() if len(kwargs): raise exceptions.ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys())) @@ -455,7 +402,10 @@ class Column(SchemaItem, sql.ColumnClause): return self.table.metadata def _get_engine(self): return self.table.engine - + + def append_foreign_key(self, fk): + fk._set_parent(self) + def __repr__(self): return "Column(%s)" % string.join( [repr(self.name)] + [repr(self.type)] + @@ -463,33 +413,33 @@ class Column(SchemaItem, sql.ColumnClause): ["%s=%s" % (k, repr(getattr(self, k))) for k in ['key', 'primary_key', 'nullable', 'hidden', 'default', 'onupdate']] , ',') - def append_item(self, item): - self._init_items(item) - - def _set_primary_key(self): - if self.primary_key: - return - self.primary_key = True - self.nullable = False - self.table.primary_key.append(self) - def _get_parent(self): return self.table + def _set_parent(self, table): if getattr(self, 'table', None) is not None: raise exceptions.ArgumentError("this Column already has a table!") - table.append_column(self) - if self.index or self.unique: - table.append_index_column(self, index=self.index, - unique=self.unique) - + if not self.hidden: + table._columns.add(self) + if self.primary_key: + table.primary_key.add(self) + self.table = table + + if self.index: + if isinstance(self.index, str): + raise exceptions.ArgumentError("The 'index' keyword argument on Column is boolean only. To create indexes with a specific name, append an explicit Index object to the Table's list of elements.") + Index('ix_%s' % self._label, self, unique=self.unique) + elif self.unique: + if isinstance(self.unique, str): + raise exceptions.ArgumentError("The 'unique' keyword argument on Column is boolean only. To create unique constraints or indexes with a specific name, append an explicit UniqueConstraint or Index object to the Table's list of elements.") + table.append_constraint(UniqueConstraint(self.key)) + + toinit = list(self.args) if self.default is not None: - self.default = ColumnDefault(self.default) - self._init_items(self.default) + toinit.append(ColumnDefault(self.default)) if self.onupdate is not None: - self.onupdate = ColumnDefault(self.onupdate, for_update=True) - self._init_items(self.onupdate) - self._init_items(*self.args) + toinit.append(ColumnDefault(self.onupdate, for_update=True)) + self._init_items(*toinit) self.args = None def copy(self): @@ -507,9 +457,9 @@ class Column(SchemaItem, sql.ColumnClause): c.orig_set = self.orig_set c.__originating_column = self.__originating_column if not c.hidden: - selectable.columns[c.key] = c + selectable.columns.add(c) if self.primary_key: - selectable.primary_key.append(c) + selectable.primary_key.add(c) [c._init_items(f) for f in fk] return c @@ -519,15 +469,18 @@ class Column(SchemaItem, sql.ColumnClause): return self.__originating_column._get_case_sensitive() case_sensitive = property(_case_sens) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """traverses the given visitor to this Column's default and foreign key object, then calls visit_column on the visitor.""" - if self.default is not None: - self.default.accept_schema_visitor(visitor) - if self.onupdate is not None: - self.onupdate.accept_schema_visitor(visitor) - for f in self.foreign_keys: - f.accept_schema_visitor(visitor) + if traverse: + if self.default is not None: + self.default.accept_schema_visitor(visitor, traverse=True) + if self.onupdate is not None: + self.onupdate.accept_schema_visitor(visitor, traverse=True) + for f in self.foreign_keys: + f.accept_schema_visitor(visitor, traverse=True) + for constraint in self.constraints: + constraint.accept_schema_visitor(visitor, traverse=True) visitor.visit_column(self) @@ -538,7 +491,7 @@ class ForeignKey(SchemaItem): One or more ForeignKey objects are used within a ForeignKeyConstraint object which represents the table-level constraint definition.""" - def __init__(self, column, constraint=None): + def __init__(self, column, constraint=None, use_alter=False): """Construct a new ForeignKey object. "column" can be a schema.Column object representing the relationship, @@ -553,6 +506,7 @@ class ForeignKey(SchemaItem): self._colspec = column self._column = None self.constraint = constraint + self.use_alter = use_alter def __repr__(self): return "ForeignKey(%s)" % repr(self._get_colspec()) @@ -611,7 +565,7 @@ class ForeignKey(SchemaItem): column = property(lambda s: s._init_column()) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_foreign_key method on the given visitor.""" visitor.visit_foreign_key(self) @@ -621,17 +575,13 @@ class ForeignKey(SchemaItem): self.parent = column if self.constraint is None and isinstance(self.parent.table, Table): - self.constraint = ForeignKeyConstraint([],[]) - self.parent.table.append_item(self.constraint) + self.constraint = ForeignKeyConstraint([],[], use_alter=self.use_alter) + self.parent.table.append_constraint(self.constraint) self.constraint._append_fk(self) - # if a foreign key was already set up for the parent column, replace it with - # this one - #if self.parent.foreign_key is not None: - # self.parent.table.foreign_keys.remove(self.parent.foreign_key) - #self.parent.foreign_key = self self.parent.foreign_keys.add(self) self.parent.table.foreign_keys.add(self) + class DefaultGenerator(SchemaItem): """Base class for column "default" values.""" def __init__(self, for_update=False, metadata=None): @@ -661,7 +611,7 @@ class PassiveDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(PassiveDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): return visitor.visit_passive_default(self) def __repr__(self): return "PassiveDefault(%s)" % repr(self.arg) @@ -672,7 +622,7 @@ class ColumnDefault(DefaultGenerator): def __init__(self, arg, **kwargs): super(ColumnDefault, self).__init__(**kwargs) self.arg = arg - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_column_default method on the given visitor.""" if self.for_update: return visitor.visit_column_onupdate(self) @@ -704,57 +654,66 @@ class Sequence(DefaultGenerator): return self def drop(self): self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) class Constraint(SchemaItem): """represents a table-level Constraint such as a composite primary key, foreign key, or unique constraint. - Also follows list behavior with regards to the underlying set of columns.""" + Implements a hybrid of dict/setlike behavior with regards to the list of underying columns""" def __init__(self, name=None): self.name = name - self.columns = [] + self.columns = sql.ColumnCollection() def __contains__(self, x): return x in self.columns + def keys(self): + return self.columns.keys() def __add__(self, other): return self.columns + other def __iter__(self): return iter(self.columns) def __len__(self): return len(self.columns) - def __getitem__(self, index): - return self.columns[index] - def __setitem__(self, index, item): - self.columns[index] = item def copy(self): raise NotImplementedError() def _get_parent(self): return getattr(self, 'table', None) - + +class CheckConstraint(Constraint): + def __init__(self, sqltext, name=None): + super(CheckConstraint, self).__init__(name) + self.sqltext = sqltext + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_check_constraint(self) + def _set_parent(self, parent): + self.parent = parent + parent.constraints.add(self) + class ForeignKeyConstraint(Constraint): """table-level foreign key constraint, represents a colleciton of ForeignKey objects.""" - def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None): + def __init__(self, columns, refcolumns, name=None, onupdate=None, ondelete=None, use_alter=False): super(ForeignKeyConstraint, self).__init__(name) self.__colnames = columns self.__refcolnames = refcolumns - self.elements = [] + self.elements = util.Set() self.onupdate = onupdate self.ondelete = ondelete + self.use_alter = use_alter def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append(c,r) - def accept_schema_visitor(self, visitor): + self.append_element(c,r) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_foreign_key_constraint(self) - def append(self, col, refcol): + def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self) fk._set_parent(self.table.c[col]) self._append_fk(fk) def _append_fk(self, fk): - self.columns.append(self.table.c[fk.parent.key]) - self.elements.append(fk) + self.columns.add(self.table.c[fk.parent.key]) + self.elements.add(fk) def copy(self): return ForeignKeyConstraint([x.parent.name for x in self.elements], [x._get_colspec() for x in self.elements], name=self.name, onupdate=self.onupdate, ondelete=self.ondelete) @@ -766,37 +725,37 @@ class PrimaryKeyConstraint(Constraint): self.table = table table.primary_key = self for c in self.__colnames: - self.append(table.c[c]) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_primary_key_constraint(self) - def append(self, col): - # TODO: change "columns" to a key-sensitive set ? - for c in self.columns: - if c.key == col.key: - self.columns.remove(c) - self.columns.append(col) + def add(self, col): + self.append_column(col) + def append_column(self, col): + self.columns.add(col) col.primary_key=True def copy(self): return PrimaryKeyConstraint(name=self.name, *[c.key for c in self]) - + def __eq__(self, other): + return self.columns == other + class UniqueConstraint(Constraint): - def __init__(self, name=None, *columns): - super(Constraint, self).__init__(name) + def __init__(self, *columns, **kwargs): + super(UniqueConstraint, self).__init__(name=kwargs.pop('name', None)) self.__colnames = list(columns) def _set_parent(self, table): self.table = table - table.constraints.append(self) + table.constraints.add(self) for c in self.__colnames: - self.append(table.c[c]) - def append(self, col): - self.columns.append(col) - def accept_schema_visitor(self, visitor): + self.append_column(table.c[c]) + def append_column(self, col): + self.columns.add(col) + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_unique_constraint(self) class Index(SchemaItem): """Represents an index of columns from a database table """ - def __init__(self, name, *columns, **kw): + def __init__(self, name, *columns, **kwargs): """Constructs an index object. Arguments are: name : the name of the index @@ -811,7 +770,7 @@ class Index(SchemaItem): self.name = name self.columns = [] self.table = None - self.unique = kw.pop('unique', False) + self.unique = kwargs.pop('unique', False) self._init_items(*columns) def _derived_metadata(self): @@ -821,12 +780,15 @@ class Index(SchemaItem): self.append_column(column) def _get_parent(self): return self.table + def _set_parent(self, table): + self.table = table + table.indexes.add(self) + def append_column(self, column): # make sure all columns are from the same table # and no column is repeated if self.table is None: - self.table = column.table - self.table.append_index(self) + self._set_parent(column.table) elif column.table != self.table: # all columns muse be from same table raise exceptions.ArgumentError("All index columns must be from same table. " @@ -850,7 +812,7 @@ class Index(SchemaItem): connectable.drop(self) else: self.get_engine().drop(self) - def accept_schema_visitor(self, visitor): + def accept_schema_visitor(self, visitor, traverse=True): visitor.visit_index(self) def __str__(self): return repr(self) @@ -863,7 +825,6 @@ class Index(SchemaItem): class MetaData(SchemaItem): """represents a collection of Tables and their associated schema constructs.""" def __init__(self, name=None, **kwargs): - # a dictionary that stores Table objects keyed off their name (and possibly schema name) self.tables = {} self.name = name self._set_casing_strategy(name, kwargs) @@ -871,11 +832,18 @@ class MetaData(SchemaItem): return False def clear(self): self.tables.clear() - def table_iterator(self, reverse=True): - return self._sort_tables(self.tables.values(), reverse=reverse) + + def table_iterator(self, reverse=True, tables=None): + import sqlalchemy.sql_util + if tables is None: + tables = self.tables.values() + else: + tables = util.Set(tables).intersection(self.tables.values()) + sorter = sqlalchemy.sql_util.TableCollection(list(tables)) + return iter(sorter.sort(reverse=reverse)) def _get_parent(self): return None - def create_all(self, connectable=None, tables=None, engine=None): + def create_all(self, connectable=None, tables=None, checkfirst=True): """create all tables stored in this metadata. This will conditionally create tables depending on if they do not yet @@ -884,28 +852,13 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to create - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored)""" if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables ) - for table in ts: - if e.dialect.has_table(conn, table.name): - continue - conn.create(table) - connectable.run_callable(do) + connectable.create(self, checkfirst=checkfirst, tables=tables) - def drop_all(self, connectable=None, tables=None, engine=None): + def drop_all(self, connectable=None, tables=None, checkfirst=True): """drop all tables stored in this metadata. This will conditionally drop tables depending on if they currently @@ -914,33 +867,17 @@ class MetaData(SchemaItem): connectable - a Connectable used to access the database; or use the engine bound to this MetaData. - tables - optional list of tables to drop - - engine - deprecated argument.""" - if not tables: - tables = self.tables.values() - - if connectable is None: - connectable = engine - + tables - optional list of tables, which is a subset of the total + tables in the MetaData (others are ignored) + """ if connectable is None: connectable = self.get_engine() - - def do(conn): - e = conn.engine - ts = self._sort_tables( tables, reverse=True ) - for table in ts: - if e.dialect.has_table(conn, table.name): - conn.drop(table) - connectable.run_callable(do) + connectable.drop(self, checkfirst=checkfirst, tables=tables) - def _sort_tables(self, tables, reverse=False): - import sqlalchemy.sql_util - sorter = sqlalchemy.sql_util.TableCollection() - for t in tables: - sorter.add(t) - return sorter.sort(reverse=reverse) - + + def accept_schema_visitor(self, visitor, traverse=True): + visitor.visit_metadata(self) + def _derived_metadata(self): return self def _get_engine(self): @@ -1029,6 +966,8 @@ class SchemaVisitor(sql.ClauseVisitor): pass def visit_unique_constraint(self, constraint): pass + def visit_check_constraint(self, constraint): + pass default_metadata = DynamicMetaData('default') diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index c113edaa32..6f51ccbe99 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -658,6 +658,17 @@ class ColumnElement(Selectable, CompareMixin): else: return self +class ColumnCollection(util.OrderedProperties): + def add(self, column): + self[column.key] = column + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c==local) + return and_(*l) + class FromClause(Selectable): """represents an element that can be used within the FROM clause of a SELECT statement.""" def __init__(self, name=None): @@ -671,7 +682,7 @@ class FromClause(Selectable): visitor.visit_fromclause(self) def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) @@ -735,8 +746,8 @@ class FromClause(Selectable): if hasattr(self, '_columns'): # TODO: put a mutex here ? this is a key place for threading probs return - self._columns = util.OrderedProperties() - self._primary_key = [] + self._columns = ColumnCollection() + self._primary_key = ColumnCollection() self._foreign_keys = util.Set() self._orig_cols = {} export = self._exportable_columns() @@ -1082,7 +1093,7 @@ class Join(FromClause): def _proxy_column(self, column): self._columns[column._label] = column if column.primary_key: - self._primary_key.append(column) + self._primary_key.add(column) for f in column.foreign_keys: self._foreign_keys.add(f) return column @@ -1257,9 +1268,9 @@ class TableClause(FromClause): def __init__(self, name, *columns): super(TableClause, self).__init__(name) self.name = self.fullname = name - self._columns = util.OrderedProperties() + self._columns = ColumnCollection() self._foreign_keys = util.Set() - self._primary_key = [] + self._primary_key = util.Set() for c in columns: self.append_column(c) self._oid_column = ColumnClause('oid', self, hidden=True) @@ -1282,16 +1293,6 @@ class TableClause(FromClause): return self._orig_cols original_columns = property(_orig_columns) - def _clear(self): - """clears all attributes on this TableClause so that new items can be added again""" - self.columns.clear() - self.foreign_keys[:] = [] - self.primary_key[:] = [] - try: - delattr(self, '_orig_cols') - except AttributeError: - pass - def accept_visitor(self, visitor): visitor.visit_table(self) def _exportable_columns(self): @@ -1305,7 +1306,7 @@ class TableClause(FromClause): data[self] = self def count(self, whereclause=None, **params): if len(self.primary_key): - col = self.primary_key[0] + col = list(self.primary_key)[0] else: col = list(self.columns)[0] return select([func.count(col).label('tbl_row_count')], whereclause, from_obj=[self], **params) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 5f243ae048..d5c6a3b92b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -87,6 +87,8 @@ class OrderedProperties(object): return len(self.__data) def __iter__(self): return self.__data.itervalues() + def __add__(self, other): + return list(self) + list(other) def __setitem__(self, key, object): self.__data[key] = object def __getitem__(self, key): diff --git a/test/engine/reflection.py b/test/engine/reflection.py index 45010111e4..469aab20eb 100644 --- a/test/engine/reflection.py +++ b/test/engine/reflection.py @@ -59,8 +59,6 @@ class ReflectionTest(PersistTest): mysql_engine='InnoDB' ) -# users.c.parent_user_id.set_foreign_key(ForeignKey(users.c.user_id)) - users.create() addresses.create() @@ -154,6 +152,7 @@ class ReflectionTest(PersistTest): autoload=True) u2 = Table('users', meta2, autoload=True) + print "ITS", list(a2.primary_key) assert list(a2.primary_key) == [a2.c.id] assert list(u2.primary_key) == [u2.c.id] assert u2.join(a2).onclause == u2.c.id==a2.c.id @@ -226,19 +225,19 @@ class ReflectionTest(PersistTest): def testmultipk(self): """test that creating a table checks for a sequence before creating it""" + meta = BoundMetaData(testbase.db) table = Table( - 'engine_multi', testbase.db, + 'engine_multi', meta, Column('multi_id', Integer, Sequence('multi_id_seq'), primary_key=True), Column('multi_rev', Integer, Sequence('multi_rev_seq'), primary_key=True), Column('name', String(50), nullable=False), Column('val', String(100)) ) table.create() - # clear out table registry - table.deregister() + meta2 = BoundMetaData(testbase.db) try: - table = Table('engine_multi', testbase.db, autoload=True) + table = Table('engine_multi', meta2, autoload=True) finally: table.drop() @@ -348,19 +347,20 @@ class ReflectionTest(PersistTest): testbase.db, autoload=True) def testoverride(self): + meta = BoundMetaData(testbase.db) table = Table( - 'override_test', testbase.db, + 'override_test', meta, Column('col1', Integer, primary_key=True), Column('col2', String(20)), Column('col3', Numeric) ) table.create() # clear out table registry - table.deregister() + meta2 = BoundMetaData(testbase.db) try: table = Table( - 'override_test', testbase.db, + 'override_test', meta2, Column('col2', Unicode()), Column('col4', String(30)), autoload=True) @@ -403,22 +403,22 @@ class CreateDropTest(PersistTest): ) def test_sorter( self ): - tables = metadata._sort_tables(metadata.tables.values()) + tables = metadata.table_iterator(reverse=False) table_names = [t.name for t in tables] self.assert_( table_names == ['users', 'orders', 'items', 'email_addresses'] or table_names == ['users', 'email_addresses', 'orders', 'items']) def test_createdrop(self): - metadata.create_all(engine=testbase.db) + metadata.create_all(connectable=testbase.db) self.assertEqual( testbase.db.has_table('items'), True ) self.assertEqual( testbase.db.has_table('email_addresses'), True ) - metadata.create_all(engine=testbase.db) + metadata.create_all(connectable=testbase.db) self.assertEqual( testbase.db.has_table('items'), True ) - metadata.drop_all(engine=testbase.db) + metadata.drop_all(connectable=testbase.db) self.assertEqual( testbase.db.has_table('items'), False ) self.assertEqual( testbase.db.has_table('email_addresses'), False ) - metadata.drop_all(engine=testbase.db) + metadata.drop_all(connectable=testbase.db) self.assertEqual( testbase.db.has_table('items'), False ) class SchemaTest(PersistTest): @@ -438,7 +438,7 @@ class SchemaTest(PersistTest): buf = StringIO.StringIO() def foo(s, p): buf.write(s) - gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo) + gen = testbase.db.dialect.schemagenerator(testbase.db.engine, foo, None) table1.accept_schema_visitor(gen) table2.accept_schema_visitor(gen) buf = buf.getvalue() diff --git a/test/orm/cycles.py b/test/orm/cycles.py index 63eb5b0f6f..eebe7af755 100644 --- a/test/orm/cycles.py +++ b/test/orm/cycles.py @@ -109,7 +109,7 @@ class BiDirectionalOneToManyTest(AssertMixin): Column('c2', Integer) ) metadata.create_all() - t2.c.c2.append_item(ForeignKey('t1.c1')) + t2.c.c2.append_foreign_key(ForeignKey('t1.c1')) def tearDownAll(self): t1.drop() t2.drop() @@ -153,7 +153,7 @@ class BiDirectionalOneToManyTest2(AssertMixin): ) t2.create() t1.create() - t2.c.c2.append_item(ForeignKey('t1.c1')) + t2.c.c2.append_foreign_key(ForeignKey('t1.c1')) t3 = Table('t1_data', metadata, Column('c1', Integer, primary_key=True), Column('t1id', Integer, ForeignKey('t1.c1')), @@ -225,8 +225,7 @@ class OneToManyManyToOneTest(AssertMixin): ball.create() person.create() -# person.c.favorite_ball_id.append_item(ForeignKey('ball.id')) - ball.c.person_id.append_item(ForeignKey('person.id')) + ball.c.person_id.append_foreign_key(ForeignKey('person.id')) # make the test more complete for postgres if db.engine.__module__.endswith('postgres'): diff --git a/test/orm/inheritance.py b/test/orm/inheritance.py index ce9a35479c..392e54407f 100644 --- a/test/orm/inheritance.py +++ b/test/orm/inheritance.py @@ -96,16 +96,16 @@ class InheritTest2(testbase.AssertMixin): foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), Column('data', String(20)), - ).create() + ) bar = Table('bar', metadata, Column('bid', Integer, ForeignKey('foo.id'), primary_key=True), #Column('fid', Integer, ForeignKey('foo.id'), ) - ).create() + ) foo_bar = Table('foo_bar', metadata, Column('foo_id', Integer, ForeignKey('foo.id')), - Column('bar_id', Integer, ForeignKey('bar.bid'))).create() + Column('bar_id', Integer, ForeignKey('bar.bid'))) metadata.create_all() def tearDownAll(self): metadata.drop_all() diff --git a/test/orm/manytomany.py b/test/orm/manytomany.py index 3966041b5c..dc343cb95d 100644 --- a/test/orm/manytomany.py +++ b/test/orm/manytomany.py @@ -28,7 +28,7 @@ class Transition(object): class M2MTest(testbase.AssertMixin): def setUpAll(self): - self.install_threadlocal() + global metadata metadata = testbase.metadata global place place = Table('place', metadata, @@ -68,28 +68,14 @@ class M2MTest(testbase.AssertMixin): Column('pl1_id', Integer, ForeignKey('place.place_id')), Column('pl2_id', Integer, ForeignKey('place.place_id')), ) - - place.create() - transition.create() - place_input.create() - place_output.create() - place_thingy.create() - place_place.create() + metadata.create_all() def tearDownAll(self): - place_place.drop() - place_input.drop() - place_output.drop() - place_thingy.drop() - place.drop() - transition.drop() - objectstore.clear() + metadata.drop_all() clear_mappers() #testbase.db.tables.clear() - self.uninstall_threadlocal() def setUp(self): - objectstore.clear() clear_mappers() def tearDown(self): @@ -111,6 +97,7 @@ class M2MTest(testbase.AssertMixin): lazy=True, )) + sess = create_session() p1 = Place('place1') p2 = Place('place2') p3 = Place('place3') @@ -118,7 +105,7 @@ class M2MTest(testbase.AssertMixin): p5 = Place('place5') p6 = Place('place6') p7 = Place('place7') - + [sess.save(x) for x in [p1,p2,p3,p4,p5,p6,p7]] p1.places.append(p2) p1.places.append(p3) p5.places.append(p6) @@ -127,10 +114,10 @@ class M2MTest(testbase.AssertMixin): p1.places.append(p5) p4.places.append(p3) p3.places.append(p4) - objectstore.flush() + sess.flush() - objectstore.clear() - l = Place.mapper.select(order_by=place.c.place_id) + sess.clear() + l = sess.query(Place).select(order_by=place.c.place_id) (p1, p2, p3, p4, p5, p6, p7) = l assert p1.places == [p2,p3,p5] assert p5.places == [p6] @@ -144,8 +131,8 @@ class M2MTest(testbase.AssertMixin): pp = p.places self.echo("Place " + str(p) +" places " + repr(pp)) - [objectstore.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] - objectstore.flush() + [sess.delete(p) for p in p1,p2,p3,p4,p5,p6,p7] + sess.flush() def testdouble(self): """tests that a mapper can have two eager relations to the same table, via @@ -165,10 +152,12 @@ class M2MTest(testbase.AssertMixin): tran.inputs.append(Place('place1')) tran.outputs.append(Place('place2')) tran.outputs.append(Place('place3')) - objectstore.flush() + sess = create_session() + sess.save(tran) + sess.flush() - objectstore.clear() - r = Transition.mapper.select() + sess.clear() + r = sess.query(Transition).select() self.assert_result(r, Transition, {'name':'transition1', 'inputs' : (Place, [{'name':'place1'}]), @@ -199,15 +188,15 @@ class M2MTest(testbase.AssertMixin): p2.inputs.append(t2) p3.inputs.append(t2) p1.outputs.append(t1) - - objectstore.flush() + sess = create_session() + [sess.save(x) for x in [t1,t2,t3,p1,p2,p3]] + sess.flush() self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) class M2MTest2(testbase.AssertMixin): def setUpAll(self): - self.install_threadlocal() metadata = testbase.metadata global studentTbl studentTbl = Table('student', metadata, Column('name', String(20), primary_key=True)) @@ -217,22 +206,13 @@ class M2MTest2(testbase.AssertMixin): enrolTbl = Table('enrol', metadata, Column('student_id', String(20), ForeignKey('student.name'),primary_key=True), Column('course_id', String(20), ForeignKey('course.name'), primary_key=True)) - - studentTbl.create() - courseTbl.create() - enrolTbl.create() + metadata.create_all() def tearDownAll(self): - enrolTbl.drop() - studentTbl.drop() - courseTbl.drop() - objectstore.clear() + metadata.drop_all() clear_mappers() - #testbase.db.tables.clear() - self.uninstall_threadlocal() def setUp(self): - objectstore.clear() clear_mappers() def tearDown(self): @@ -251,6 +231,7 @@ class M2MTest2(testbase.AssertMixin): Course.mapper = mapper(Course, courseTbl, properties = { 'students': relation(Student.mapper, enrolTbl, lazy=True, backref='courses') }) + sess = create_session() s1 = Student('Student1') c1 = Course('Course1') c2 = Course('Course2') @@ -260,55 +241,53 @@ class M2MTest2(testbase.AssertMixin): c3.students.append(s1) self.assert_(len(s1.courses) == 3) self.assert_(len(c1.students) == 1) - objectstore.flush() - objectstore.clear() - s = Student.mapper.get_by(name='Student1') - c = Course.mapper.get_by(name='Course3') + sess.save(s1) + sess.flush() + sess.clear() + s = sess.query(Student).get_by(name='Student1') + c = sess.query(Course).get_by(name='Course3') self.assert_(len(s.courses) == 3) del s.courses[1] self.assert_(len(s.courses) == 2) class M2MTest3(testbase.AssertMixin): def setUpAll(self): - self.install_threadlocal() metadata = testbase.metadata global c, c2a1, c2a2, b, a c = Table('c', metadata, Column('c1', Integer, primary_key = True), Column('c2', String(20)), - ).create() + ) a = Table('a', metadata, Column('a1', Integer, primary_key=True), Column('a2', String(20)), Column('c1', Integer, ForeignKey('c.c1')) - ).create() + ) c2a1 = Table('ctoaone', metadata, Column('c1', Integer, ForeignKey('c.c1')), Column('a1', Integer, ForeignKey('a.a1')) - ).create() + ) c2a2 = Table('ctoatwo', metadata, Column('c1', Integer, ForeignKey('c.c1')), Column('a1', Integer, ForeignKey('a.a1')) - ).create() + ) b = Table('b', metadata, Column('b1', Integer, primary_key=True), Column('a1', Integer, ForeignKey('a.a1')), Column('b2', Boolean) - ).create() - + ) + metadata.create_all() + def tearDownAll(self): b.drop() c2a2.drop() c2a1.drop() a.drop() c.drop() - objectstore.clear() clear_mappers() - #testbase.db.tables.clear() - self.uninstall_threadlocal() def testbasic(self): class C(object):pass diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 63d0904287..6cf2b4b494 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -91,7 +91,8 @@ class VersioningTest(UnitOfWorkTest): Column('id', Integer, Sequence('version_test_seq'), primary_key=True ), Column('version_id', Integer, nullable=False), Column('value', String(40), nullable=False) - ).create() + ) + version_table.create() def tearDownAll(self): version_table.drop() UnitOfWorkTest.tearDownAll(self) @@ -408,12 +409,14 @@ class PrivateAttrTest(UnitOfWorkTest): a_table = Table('a',testbase.db, Column('a_id', Integer, Sequence('next_a_id'), primary_key=True), Column('data', String(10)), - ).create() + ) b_table = Table('b',testbase.db, Column('b_id',Integer,Sequence('next_b_id'),primary_key=True), Column('a_id',Integer,ForeignKey('a.a_id')), - Column('data',String(10))).create() + Column('data',String(10))) + a_table.create() + b_table.create() def tearDownAll(self): b_table.drop() a_table.drop() diff --git a/test/sql/alltests.py b/test/sql/alltests.py index 29b638bb8e..c79d7b67e8 100644 --- a/test/sql/alltests.py +++ b/test/sql/alltests.py @@ -5,7 +5,7 @@ import unittest def suite(): modules_to_test = ( 'sql.testtypes', - 'sql.indexes', + 'sql.constraints', # SQL syntax 'sql.select', diff --git a/test/sql/indexes.py b/test/sql/constraints.py similarity index 60% rename from test/sql/indexes.py rename to test/sql/constraints.py index 5c46b63f2c..045d449687 100644 --- a/test/sql/indexes.py +++ b/test/sql/constraints.py @@ -2,7 +2,7 @@ import testbase from sqlalchemy import * import sys -class IndexTest(testbase.AssertMixin): +class ConstraintTest(testbase.AssertMixin): def setUp(self): global metadata @@ -27,6 +27,59 @@ class IndexTest(testbase.AssertMixin): ForeignKeyConstraint(['emp_id', 'emp_soc'], ['employees.id', 'employees.soc']) ) metadata.create_all() + + @testbase.unsupported('sqlite', 'mysql') + def test_check_constraint(self): + foo = Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + Column('y', Integer), + CheckConstraint('x>y')) + bar = Table('bar', metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer, CheckConstraint('x>7')), + ) + + metadata.create_all() + foo.insert().execute(id=1,x=9,y=5) + try: + foo.insert().execute(id=2,x=5,y=9) + assert False + except exceptions.SQLError: + assert True + + bar.insert().execute(id=1,x=10) + try: + bar.insert().execute(id=2,x=5) + assert False + except exceptions.SQLError: + assert True + + def test_unique_constraint(self): + foo = Table('foo', metadata, + Column('id', Integer, primary_key=True), + Column('value', String(30), unique=True)) + bar = Table('bar', metadata, + Column('id', Integer, primary_key=True), + Column('value', String(30)), + Column('value2', String(30)), + UniqueConstraint('value', 'value2', name='uix1') + ) + metadata.create_all() + foo.insert().execute(id=1, value='value1') + foo.insert().execute(id=2, value='value2') + bar.insert().execute(id=1, value='a', value2='a') + bar.insert().execute(id=2, value='a', value2='b') + try: + foo.insert().execute(id=3, value='value1') + assert False + except exceptions.SQLError: + assert True + try: + bar.insert().execute(id=3, value='a', value2='b') + assert False + except exceptions.SQLError: + assert True def test_index_create(self): employees = Table('employees', metadata, @@ -39,12 +92,12 @@ class IndexTest(testbase.AssertMixin): i = Index('employee_name_index', employees.c.last_name, employees.c.first_name) i.create() - assert employees.indexes['employee_name_index'] is i + assert i in employees.indexes i2 = Index('employee_email_index', employees.c.email_address, unique=True) i2.create() - assert employees.indexes['employee_email_index'] is i2 + assert i2 in employees.indexes def test_index_create_camelcase(self): """test that mixed-case index identifiers are legal""" @@ -76,16 +129,17 @@ class IndexTest(testbase.AssertMixin): events = Table('events', metadata, Column('id', Integer, primary_key=True), - Column('name', String(30), unique=True), + Column('name', String(30), index=True, unique=True), Column('location', String(30), index=True), - Column('sport', String(30), - unique='sport_announcer'), - Column('announcer', String(30), - unique='sport_announcer'), - Column('winner', String(30), index='idx_winners')) + Column('sport', String(30)), + Column('announcer', String(30)), + Column('winner', String(30))) + + Index('sport_announcer', events.c.sport, events.c.announcer, unique=True) + Index('idx_winners', events.c.winner) index_names = [ ix.name for ix in events.indexes ] - assert 'ux_events_name' in index_names + assert 'ix_events_name' in index_names assert 'ix_events_location' in index_names assert 'sport_announcer' in index_names assert 'idx_winners' in index_names @@ -97,19 +151,20 @@ class IndexTest(testbase.AssertMixin): capt.append(statement) capt.append(repr(parameters)) connection.proxy(statement, parameters) - schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy) + schemagen = testbase.db.dialect.schemagenerator(testbase.db, proxy, connection) events.accept_schema_visitor(schemagen) assert capt[0].strip().startswith('CREATE TABLE events') - assert capt[2].strip() == \ - 'CREATE UNIQUE INDEX ux_events_name ON events (name)' - assert capt[4].strip() == \ - 'CREATE INDEX ix_events_location ON events (location)' - assert capt[6].strip() == \ - 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)' - assert capt[8].strip() == \ + + s = set([capt[x].strip() for x in [2,4,6,8]]) + + assert s == set([ + 'CREATE UNIQUE INDEX ix_events_name ON events (name)', + 'CREATE INDEX ix_events_location ON events (location)', + 'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)', 'CREATE INDEX idx_winners ON events (winner)' - + ]) + # verify that the table is functional events.insert().execute(id=1, name='hockey finals', location='rink', sport='hockey', announcer='some canadian', diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index e08bdb89f1..ef851cf630 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -121,7 +121,7 @@ class ColumnsTest(AssertMixin): ) for aCol in testTable.c: - self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None).get_column_specification(aCol)) + self.assertEquals(expectedResults[aCol.name], db.dialect.schemagenerator(db, None, None).get_column_specification(aCol)) class UnicodeTest(AssertMixin): """tests the Unicode type. also tests the TypeDecorator with instances in the types package.""" diff --git a/test/zblog/tests.py b/test/zblog/tests.py index 7cec195901..e538cff9d8 100644 --- a/test/zblog/tests.py +++ b/test/zblog/tests.py @@ -12,9 +12,9 @@ from zblog.blog import * class ZBlogTest(AssertMixin): def create_tables(self): - tables.metadata.create_all(engine=db) + tables.metadata.create_all(connectable=db) def drop_tables(self): - tables.metadata.drop_all(engine=db) + tables.metadata.drop_all(connectable=db) def setUpAll(self): self.create_tables()