From: Mike Bayer Date: Sat, 18 Feb 2006 20:33:20 +0000 (+0000) Subject: added indexes to schema/ansisql/engine X-Git-Tag: rel_0_1_1~29 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=55ce6851e08daeba3be8e8c32d9e4618e53a8d5e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added indexes to schema/ansisql/engine slightly different index syntax for mysql fixed mysql Time type to convert from a timedelta to time tweaks to date unit tests for mysql --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index a300dc6392..ed0f829fbc 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -168,6 +168,9 @@ class ANSICompiler(sql.Compiled): def visit_fromclause(self, fromclause): self.froms[fromclause] = fromclause.from_name + def visit_index(self, index): + self.strings[index] = index.name + def visit_textclause(self, textclause): if textclause.parens and len(textclause.text): self.strings[textclause] = "(" + textclause.text + ")" @@ -200,7 +203,7 @@ class ANSICompiler(sql.Compiled): def visit_function(self, func): self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" - + def visit_compound_select(self, cs): text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") for tup in cs.clauses: @@ -531,12 +534,26 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): def visit_column(self, column): pass + + def visit_index(self, index): + self.append('CREATE ') + if index.unique: + self.append('UNIQUE ') + self.append('INDEX %s ON %s (%s)' \ + % (index.name, index.table.name, + string.join([c.name for c in index.columns], ', '))) + self.execute() + class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator): + def visit_index(self, index): + self.append("\nDROP INDEX " + index.name) + self.execute() + def visit_table(self, table): self.append("\nDROP TABLE " + table.fullname) self.execute() class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner): - pass \ No newline at end of file + pass diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 0afac7df39..04bdc24fa4 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import sys, StringIO, string, types, re +import sys, StringIO, string, types, re, datetime import sqlalchemy.sql as sql import sqlalchemy.engine as engine @@ -40,6 +40,13 @@ class MSDate(sqltypes.Date): class MSTime(sqltypes.Time): def get_col_spec(self): return "TIME" + def convert_result_value(self, value, engine): + # convert from a timedelta value + if value is not None: + return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60)) + else: + return None + class MSText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" @@ -135,6 +142,9 @@ class MySQLEngine(ansisql.ANSISQLEngine): def schemagenerator(self, **params): return MySQLSchemaGenerator(self, **params) + def schemadropper(self, **params): + return MySQLSchemaDropper(self, **params) + def get_default_schema_name(self): if not hasattr(self, '_default_schema_name'): self._default_schema_name = text("select database()", self).scalar() @@ -276,3 +286,7 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): else: return "" +class MySQLSchemaDropper(ansisql.ANSISchemaDropper): + def visit_index(self, index): + self.append("\nDROP INDEX " + index.name + " ON " + index.table.name) + self.execute() diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index ac6df0f9e2..d553eee2db 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -294,13 +294,13 @@ class SQLEngine(schema.SchemaEngine): for the "rowcount" function on a statement handle. """ return True - def create(self, table, **params): - """creates a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemagenerator(**params)) + def create(self, entity, **params): + """creates a table or index within this engine's database connection given a schema.Table object.""" + entity.accept_visitor(self.schemagenerator(**params)) - def drop(self, table, **params): - """drops a table within this engine's database connection given a schema.Table object.""" - table.accept_visitor(self.schemadropper(**params)) + def drop(self, entity, **params): + """drops a table or index within this engine's database connection given a schema.Table object.""" + entity.accept_visitor(self.schemadropper(**params)) def compile(self, statement, parameters, **kwargs): """given a sql.ClauseElement statement plus optional bind parameters, creates a new @@ -329,6 +329,14 @@ class SQLEngine(schema.SchemaEngine): database-specific behavior.""" return sql.ColumnImpl(column) + def indeximpl(self, index): + """returns a new sql.IndexImpl object to correspond to the given Index + object. An IndexImpl provides SQL statement builder operations on an + Index metadata object, and a subclass of this object may be provided + by a SQLEngine subclass to provide database-specific behavior. + """ + return sql.IndexImpl(index) + def get_default_schema_name(self): """returns the currently selected schema in the current connection.""" return None diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index bb926053d6..a9c6a4d968 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -19,8 +19,8 @@ from sqlalchemy.util import * from sqlalchemy.types import * import copy, re, string -__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] - +__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index', + 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault'] class SchemaItem(object): """base class for items that define a database schema.""" @@ -455,6 +455,53 @@ class Sequence(DefaultGenerator): """calls the visit_seauence method on the given visitor.""" return visitor.visit_sequence(self) +class Index(SchemaItem): + """Represents an index of columns from a database table + """ + + def __init__(self, name, *columns, **kw): + """Constructs an index object. Arguments are: + + name : the name of the index + + *columns : columns to include in the index. All columns must belong to + the same table, and no column may appear more than once. + + **kw : keyword arguments include: + + unique=True : create a unique index + """ + self.name = name + self.columns = columns + self.unique = kw.pop('unique', False) + self._init_items() + + def _init_items(self): + # make sure all columns are from the same table + # FIXME: and no column is repeated + self.table = None + for column in self.columns: + if self.table is None: + self.table = column.table + elif column.table != self.table: + # all columns muse be from same table + raise ValueError("All index columns must be from same table. " + "%s is from %s not %s" % (column, + column.table, + self.table)) + # set my _impl from col.table.engine + self._impl = self.table.engine.indeximpl(self) + + def accept_visitor(self, visitor): + visitor.visit_index(self) + def __str__(self): + return repr(self) + def __repr__(self): + return 'Index("%s", %s%s)' % (self.name, + ', '.join([repr(c) + for c in self.columns]), + (self.unique and ', unique=True') or '') + class SchemaEngine(object): """a factory object used to create implementations for schema objects. This object is the ultimate base class for the engine.SQLEngine class.""" @@ -464,6 +511,11 @@ class SchemaEngine(object): def columnimpl(self, column): """returns a new implementation object for a Column (usually sql.ColumnImpl)""" raise NotImplementedError() + def indeximpl(self, index): + """returns a new implementation object for an Index (usually + sql.IndexImpl) + """ + raise NotImplementedError() def reflecttable(self, table): """given a table, will query the database and populate its Column and ForeignKey objects.""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 293880bf3d..03c94c5e33 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1430,4 +1430,25 @@ class Delete(UpdateBase): self.whereclause.accept_visitor(visitor) visitor.visit_delete(self) +class IndexImpl(ClauseElement): + + def __init__(self, index): + self.index = index + self.name = index.name + self._engine = self.index.table.engine + + table = property(lambda s: s.index.table) + columns = property(lambda s: s.index.columns) + def hash_key(self): + return self.index.hash_key() + def accept_visitor(self, visitor): + visitor.visit_index(self.index) + def compare(self, other): + return self.index is other + def create(self): + self._engine.create(self.index) + def drop(self): + self._engine.drop(self.index) + def execute(self): + self.create() diff --git a/test/alltests.py b/test/alltests.py index 3fe704e93f..d5697e8832 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -16,7 +16,8 @@ def suite(): # schema/tables 'engines', 'testtypes', - + 'indexes', + # SQL syntax 'select', 'selectable', diff --git a/test/indexes.py b/test/indexes.py new file mode 100644 index 0000000000..3fde8828cd --- /dev/null +++ b/test/indexes.py @@ -0,0 +1,36 @@ +from sqlalchemy import * +import sys +import testbase + +class IndexTest(testbase.AssertMixin): + + def setUp(self): + self.created = [] + + def tearDown(self): + if self.created: + self.created.reverse() + for entity in self.created: + entity.drop() + + def test_index_create(self): + employees = Table('employees', testbase.db, + Column('id', Integer, primary_key=True), + Column('first_name', String(30)), + Column('last_name', String(30)), + Column('email_address', String(30))) + employees.create() + self.created.append(employees) + + i = Index('employee_name_index', + employees.c.last_name, employees.c.first_name) + i.create() + self.created.append(i) + + i = Index('employee_email_index', + employees.c.email_address, unique=True) + i.create() + self.created.append(i) + +if __name__ == "__main__": + testbase.main() diff --git a/test/testtypes.py b/test/testtypes.py index 3ea868ac90..4c37f64bbf 100644 --- a/test/testtypes.py +++ b/test/testtypes.py @@ -143,37 +143,36 @@ class DateTest(AssertMixin): def setUpAll(self): global users_with_date, insert_data - insert_data = [[7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)], + insert_data = [ + [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)], [8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.date(2005,10,10), datetime.time(0,0,0)], [9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 54839), datetime.date(1970,4,1), datetime.time(23,59,59,999)], - [10, 'colber', None, None, None]] + [10, 'colber', None, None, None] + ] fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time'] collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime), Column('user_date', Date), Column('user_time', Time)] - - if db.engine.__module__.endswith('mysql'): # strip microseconds -- not supported by this engine (should be an easier way to detect this) for d in insert_data: - d[2] = d[2].replace(microsecond=0) - d[4] = d[4].replace(microsecond=0) + if d[2] is not None: + d[2] = d[2].replace(microsecond=0) + if d[4] is not None: + d[4] = d[4].replace(microsecond=0) try: db.type_descriptor(types.TIME).get_col_spec() - print "HI" except: # don't test TIME type -- not supported by this engine insert_data = [d[:-1] for d in insert_data] fnames = fnames[:-1] collist = collist[:-1] - users_with_date = Table('query_users_with_date', db, redefine = True, *collist) users_with_date.create() - insert_dicts = [dict(zip(fnames, d)) for d in insert_data] for idict in insert_dicts: users_with_date.insert().execute(**idict) # insert the data