def visit_fromclause(self, fromclause):
self.froms[fromclause] = fromclause.from_name
+ def visit_index(self, index):
+ self.strings[index] = index.name
+
def visit_textclause(self, textclause):
if textclause.parens and len(textclause.text):
self.strings[textclause] = "(" + textclause.text + ")"
def visit_function(self, func):
self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
-
+
def visit_compound_select(self, cs):
text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
for tup in cs.clauses:
def visit_column(self, column):
pass
+
+ def visit_index(self, index):
+ self.append('CREATE ')
+ if index.unique:
+ self.append('UNIQUE ')
+ self.append('INDEX %s ON %s (%s)' \
+ % (index.name, index.table.name,
+ string.join([c.name for c in index.columns], ', ')))
+ self.execute()
+
class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+ def visit_index(self, index):
+ self.append("\nDROP INDEX " + index.name)
+ self.execute()
+
def visit_table(self, table):
self.append("\nDROP TABLE " + table.fullname)
self.execute()
class ANSIDefaultRunner(sqlalchemy.engine.DefaultRunner):
- pass
\ No newline at end of file
+ pass
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import sys, StringIO, string, types, re
+import sys, StringIO, string, types, re, datetime
import sqlalchemy.sql as sql
import sqlalchemy.engine as engine
class MSTime(sqltypes.Time):
def get_col_spec(self):
return "TIME"
+ def convert_result_value(self, value, engine):
+ # convert from a timedelta value
+ if value is not None:
+ return datetime.time(value.seconds/60/60, value.seconds/60%60, value.seconds - (value.seconds/60*60))
+ else:
+ return None
+
class MSText(sqltypes.TEXT):
def get_col_spec(self):
return "TEXT"
def schemagenerator(self, **params):
return MySQLSchemaGenerator(self, **params)
+ def schemadropper(self, **params):
+ return MySQLSchemaDropper(self, **params)
+
def get_default_schema_name(self):
if not hasattr(self, '_default_schema_name'):
self._default_schema_name = text("select database()", self).scalar()
else:
return ""
+class MySQLSchemaDropper(ansisql.ANSISchemaDropper):
+ def visit_index(self, index):
+ self.append("\nDROP INDEX " + index.name + " ON " + index.table.name)
+ self.execute()
for the "rowcount" function on a statement handle. """
return True
- def create(self, table, **params):
- """creates a table within this engine's database connection given a schema.Table object."""
- table.accept_visitor(self.schemagenerator(**params))
+ def create(self, entity, **params):
+ """creates a table or index within this engine's database connection given a schema.Table object."""
+ entity.accept_visitor(self.schemagenerator(**params))
- def drop(self, table, **params):
- """drops a table within this engine's database connection given a schema.Table object."""
- table.accept_visitor(self.schemadropper(**params))
+ def drop(self, entity, **params):
+ """drops a table or index within this engine's database connection given a schema.Table object."""
+ entity.accept_visitor(self.schemadropper(**params))
def compile(self, statement, parameters, **kwargs):
"""given a sql.ClauseElement statement plus optional bind parameters, creates a new
database-specific behavior."""
return sql.ColumnImpl(column)
+ def indeximpl(self, index):
+ """returns a new sql.IndexImpl object to correspond to the given Index
+ object. An IndexImpl provides SQL statement builder operations on an
+ Index metadata object, and a subclass of this object may be provided
+ by a SQLEngine subclass to provide database-specific behavior.
+ """
+ return sql.IndexImpl(index)
+
def get_default_schema_name(self):
"""returns the currently selected schema in the current connection."""
return None
from sqlalchemy.types import *
import copy, re, string
-__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
-
+__all__ = ['SchemaItem', 'Table', 'Column', 'ForeignKey', 'Sequence', 'Index',
+ 'SchemaEngine', 'SchemaVisitor', 'PassiveDefault', 'ColumnDefault']
class SchemaItem(object):
"""base class for items that define a database schema."""
"""calls the visit_seauence method on the given visitor."""
return visitor.visit_sequence(self)
+class Index(SchemaItem):
+ """Represents an index of columns from a database table
+ """
+
+ def __init__(self, name, *columns, **kw):
+ """Constructs an index object. Arguments are:
+
+ name : the name of the index
+
+ *columns : columns to include in the index. All columns must belong to
+ the same table, and no column may appear more than once.
+
+ **kw : keyword arguments include:
+
+ unique=True : create a unique index
+ """
+ self.name = name
+ self.columns = columns
+ self.unique = kw.pop('unique', False)
+ self._init_items()
+
+ def _init_items(self):
+ # make sure all columns are from the same table
+ # FIXME: and no column is repeated
+ self.table = None
+ for column in self.columns:
+ if self.table is None:
+ self.table = column.table
+ elif column.table != self.table:
+ # all columns muse be from same table
+ raise ValueError("All index columns must be from same table. "
+ "%s is from %s not %s" % (column,
+ column.table,
+ self.table))
+ # set my _impl from col.table.engine
+ self._impl = self.table.engine.indeximpl(self)
+
+ def accept_visitor(self, visitor):
+ visitor.visit_index(self)
+ def __str__(self):
+ return repr(self)
+ def __repr__(self):
+ return 'Index("%s", %s%s)' % (self.name,
+ ', '.join([repr(c)
+ for c in self.columns]),
+ (self.unique and ', unique=True') or '')
+
class SchemaEngine(object):
"""a factory object used to create implementations for schema objects. This object
is the ultimate base class for the engine.SQLEngine class."""
def columnimpl(self, column):
"""returns a new implementation object for a Column (usually sql.ColumnImpl)"""
raise NotImplementedError()
+ def indeximpl(self, index):
+ """returns a new implementation object for an Index (usually
+ sql.IndexImpl)
+ """
+ raise NotImplementedError()
def reflecttable(self, table):
"""given a table, will query the database and populate its Column and ForeignKey
objects."""
self.whereclause.accept_visitor(visitor)
visitor.visit_delete(self)
+class IndexImpl(ClauseElement):
+
+ def __init__(self, index):
+ self.index = index
+ self.name = index.name
+ self._engine = self.index.table.engine
+
+ table = property(lambda s: s.index.table)
+ columns = property(lambda s: s.index.columns)
+ def hash_key(self):
+ return self.index.hash_key()
+ def accept_visitor(self, visitor):
+ visitor.visit_index(self.index)
+ def compare(self, other):
+ return self.index is other
+ def create(self):
+ self._engine.create(self.index)
+ def drop(self):
+ self._engine.drop(self.index)
+ def execute(self):
+ self.create()
# schema/tables
'engines',
'testtypes',
-
+ 'indexes',
+
# SQL syntax
'select',
'selectable',
--- /dev/null
+from sqlalchemy import *
+import sys
+import testbase
+
+class IndexTest(testbase.AssertMixin):
+
+ def setUp(self):
+ self.created = []
+
+ def tearDown(self):
+ if self.created:
+ self.created.reverse()
+ for entity in self.created:
+ entity.drop()
+
+ def test_index_create(self):
+ employees = Table('employees', testbase.db,
+ Column('id', Integer, primary_key=True),
+ Column('first_name', String(30)),
+ Column('last_name', String(30)),
+ Column('email_address', String(30)))
+ employees.create()
+ self.created.append(employees)
+
+ i = Index('employee_name_index',
+ employees.c.last_name, employees.c.first_name)
+ i.create()
+ self.created.append(i)
+
+ i = Index('employee_email_index',
+ employees.c.email_address, unique=True)
+ i.create()
+ self.created.append(i)
+
+if __name__ == "__main__":
+ testbase.main()
def setUpAll(self):
global users_with_date, insert_data
- insert_data = [[7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)],
+ insert_data = [
+ [7, 'jack', datetime.datetime(2005, 11, 10, 0, 0), datetime.date(2005,11,10), datetime.time(12,20,2)],
[8, 'roy', datetime.datetime(2005, 11, 10, 11, 52, 35), datetime.date(2005,10,10), datetime.time(0,0,0)],
[9, 'foo', datetime.datetime(2005, 11, 10, 11, 52, 35, 54839), datetime.date(1970,4,1), datetime.time(23,59,59,999)],
- [10, 'colber', None, None, None]]
+ [10, 'colber', None, None, None]
+ ]
fnames = ['user_id', 'user_name', 'user_datetime', 'user_date', 'user_time']
collist = [Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), Column('user_datetime', DateTime),
Column('user_date', Date), Column('user_time', Time)]
-
-
if db.engine.__module__.endswith('mysql'):
# strip microseconds -- not supported by this engine (should be an easier way to detect this)
for d in insert_data:
- d[2] = d[2].replace(microsecond=0)
- d[4] = d[4].replace(microsecond=0)
+ if d[2] is not None:
+ d[2] = d[2].replace(microsecond=0)
+ if d[4] is not None:
+ d[4] = d[4].replace(microsecond=0)
try:
db.type_descriptor(types.TIME).get_col_spec()
- print "HI"
except:
# don't test TIME type -- not supported by this engine
insert_data = [d[:-1] for d in insert_data]
fnames = fnames[:-1]
collist = collist[:-1]
-
users_with_date = Table('query_users_with_date', db, redefine = True, *collist)
users_with_date.create()
-
insert_dicts = [dict(zip(fnames, d)) for d in insert_data]
for idict in insert_dicts:
users_with_date.insert().execute(**idict) # insert the data