]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Merge indexes [1047]:[1048] into trunk (for #6)
authorJason Pellerin <jpellerin@gmail.com>
Sun, 26 Feb 2006 22:57:46 +0000 (22:57 +0000)
committerJason Pellerin <jpellerin@gmail.com>
Sun, 26 Feb 2006 22:57:46 +0000 (22:57 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/indexes.py

index c25a55c7acf5938b28236c6e71689bc4c103cce2..1b600a4a847cdaa00177a972d977198842c4634c 100644 (file)
@@ -517,8 +517,11 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
             self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', '))
                     
         self.append("\n)%s\n\n" % self.post_create_table(table))
-        self.execute()
-
+        self.execute()        
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.visit_index(index)
+        
     def post_create_table(self, table):
         return ''
 
@@ -550,6 +553,8 @@ class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
         self.execute()
         
     def visit_table(self, table):
+        # NOTE: indexes on the table will be automatically dropped, so
+        # no need to drop them individually
         self.append("\nDROP TABLE " + table.fullname)
         self.execute()
 
index 240773043c46f25e5cf0451144e0d240f6811bde..bb46578a339da73d3bbd8732348653fb682abc58 100644 (file)
@@ -260,6 +260,9 @@ class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
             self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_key],', '))
 
         self.append("\n)\n\n")
-        self.execute()
+        self.execute()        
+        if hasattr(table, 'indexes'):
+            for index in table.indexes:
+                self.visit_index(index)
 
         
index 2e8e407709fc0becdb2529020446eae1d3f8723b..6152537dd1d90ed41a6d3ba3250fd65dec32d892 100644 (file)
@@ -160,7 +160,10 @@ class Table(sql.TableClause, SchemaItem):
             self.primary_key.append(column)
         column.table = self
         column.type = self.engine.type_descriptor(column.type)
-            
+
+    def append_index(self, index):
+        self.indexes[index.name] = index
+        
     def _set_parent(self, schema):
         schema.tables[self.name] = self
         self.schema = schema
@@ -170,6 +173,32 @@ class Table(sql.TableClause, SchemaItem):
         for c in self.columns:
             c.accept_schema_visitor(visitor)
         return visitor.visit_table(self)
+
+    def append_index_column(self, column, index=None, unique=None):
+        """Add an index or a column to an existing index of the same name.
+        """
+        if index is not None and unique is not None:
+            raise ValueError("index and unique may not both be specified")
+        if index:
+            if index is True:
+                name = 'ix_%s' % column.name
+            else:
+                name = index
+        elif unique:
+            if unique is True:
+                name = 'ux_%s' % column.name
+            else:
+                name = unique
+        # find this index in self.indexes
+        # add this column to it if found
+        # otherwise create new
+        try:
+            index = self.indexes[name]
+            index.append_column(column)
+        except KeyError:
+            index = Index(name, column, unique=unique)
+        return index
+    
     def deregister(self):
         """removes this table from it's engines table registry.  this does not
         issue a SQL DROP statement."""
@@ -224,9 +253,22 @@ class Column(sql.ColumnClause, SchemaItem):
         which will be invoked upon insert if this column is not present in the insert list or is given a value
         of None.
         
-        hidden=False : indicates this column should not be listed in the table's list of columns.  Used for the "oid" 
-        column, which generally isnt in column lists.
-        """
+        hidden=False : indicates this column should not be listed in the
+        table's list of columns.  Used for the "oid" column, which generally
+        isnt in column lists.
+
+        index=None : True or index name. Indicates that this column is
+        indexed. Pass true to autogenerate the index name. Pass a string to
+        specify the index name. Multiple columns that specify the same index
+        name will all be included in the index, in the order of their
+        creation.
+
+        unique=None : True or undex name. Indicates that this column is
+        indexed in a unique index . Pass true to autogenerate the index
+        name. Pass a string to specify the index name. Multiple columns that
+        specify the same index name will all be included in the index, in the
+        order of their creation.  """
+        
         name = str(name) # in case of incoming unicode
         super(Column, self).__init__(name, None, type)
         self.args = args
@@ -235,6 +277,10 @@ class Column(sql.ColumnClause, SchemaItem):
         self.nullable = kwargs.pop('nullable', not self.primary_key)
         self.hidden = kwargs.pop('hidden', False)
         self.default = kwargs.pop('default', None)
+        self.index = kwargs.pop('index', None)
+        self.unique = kwargs.pop('unique', None)
+        if self.index is not None and self.unique is not None:
+            raise ArgumentError("Column may not define both index and unique")
         self._foreign_key = None
         self._orig = None
         self._parent = None
@@ -269,6 +315,10 @@ class Column(sql.ColumnClause, SchemaItem):
         if getattr(self, 'table', None) is not None:
             raise ArgumentError("this Column already has a table!")
         table.append_column(self)
+        if self.index or self.unique:
+            table.append_index_column(self, index=self.index,
+                                      unique=self.unique)
+        
         if self.default is not None:
             self.default = ColumnDefault(self.default)
             self._init_items(self.default)
@@ -429,7 +479,6 @@ class Sequence(DefaultGenerator):
 class Index(SchemaItem):
     """Represents an index of columns from a database table
     """
-
     def __init__(self, name, *columns, **kw):
         """Constructs an index object. Arguments are:
 
@@ -443,24 +492,34 @@ class Index(SchemaItem):
         unique=True : create a unique index
         """
         self.name = name
-        self.columns = columns
+        self.columns = []
+        self.table = None
         self.unique = kw.pop('unique', False)
-        self._init_items()
+        self._init_items(*columns)
 
     engine = property(lambda s:s.table.engine)
-    def _init_items(self):
+    def _init_items(self, *args):
+        for column in args:
+            self.append_column(column)
+            
+    def append_column(self, column):
         # make sure all columns are from the same table
-        # FIXME: and no column is repeated
-        self.table = None
-        for column in self.columns:
-            if self.table is None:
-                self.table = column.table
-            elif column.table != self.table:
-                # all columns muse be from same table
-                raise ArgumentError("All index columns must be from same table. "
-                                 "%s is from %s not %s" % (column,
-                                                           column.table,
-                                                           self.table))
+        # and no column is repeated
+        if self.table is None:
+            self.table = column.table
+            self.table.append_index(self)
+        elif column.table != self.table:
+            # all columns muse be from same table
+            raise ArgumentError("All index columns must be from same table. "
+                                "%s is from %s not %s" % (column,
+                                                          column.table,
+                                                          self.table))
+        elif column.name in [ c.name for c in self.columns ]:
+            raise ArgumentError("A column may not appear twice in the "
+                                "same index (%s already has column %s)"
+                                % (self.name, column))
+        self.columns.append(column)
+        
     def create(self):
        self.engine.create(self)
        return self
@@ -501,7 +560,7 @@ class SchemaVisitor(sql.ClauseVisitor):
         """visit a ForeignKey."""
         pass
     def visit_index(self, index):
-        """visit an Index (not implemented yet)."""
+        """visit an Index."""
         pass
     def visit_passive_default(self, default):
         """visit a passive default"""
index f88b3118fb394167387d1733c6690aede5b4e491..f905da583494dc0a5a4a5351972667112ad4455b 100644 (file)
@@ -982,11 +982,14 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.id = self.fullname = name
         self._columns = util.OrderedProperties()
+        self._indexes = util.OrderedProperties()
         self._foreign_keys = []
         self._primary_key = []
         for c in columns:
             self.append_column(c)
 
+    indexes = property(lambda s:s._indexes)
+    
     def append_column(self, c):
         self._columns[c.text] = c
         c.table = self
index 3fde8828cda34fc75d40318a6dcda54a38c3bc3e..d0cb1a131e52f4eabc6fcda5a3c71edec3f38bad 100644 (file)
@@ -6,8 +6,12 @@ class IndexTest(testbase.AssertMixin):
     
     def setUp(self):
         self.created = []
-
+        self.echo = testbase.db.echo
+        self.logger = testbase.db.logger
+        
     def tearDown(self):
+        testbase.db.echo = self.echo
+        testbase.db.logger = testbase.db.engine.logger = self.logger
         if self.created:
             self.created.reverse()
             for entity in self.created:
@@ -26,11 +30,87 @@ class IndexTest(testbase.AssertMixin):
                   employees.c.last_name, employees.c.first_name)
         i.create()
         self.created.append(i)
+        assert employees.indexes['employee_name_index'] is i
         
-        i = Index('employee_email_index',
-                  employees.c.email_address, unique=True)        
+        i2 = Index('employee_email_index',
+                   employees.c.email_address, unique=True)        
+        i2.create()
+        self.created.append(i2)
+        assert employees.indexes['employee_email_index'] is i2
+
+    def test_index_create_camelcase(self):
+        """test that mixed-case index identifiers are legal"""
+        employees = Table('companyEmployees', testbase.db,
+                          Column('id', Integer, primary_key=True),
+                          Column('firstName', String),
+                          Column('lastName', String),
+                          Column('emailAddress', String))        
+        employees.create()
+        self.created.append(employees)
+        
+        i = Index('employeeNameIndex',
+                  employees.c.lastName, employees.c.firstName)
         i.create()
         self.created.append(i)
         
+        i = Index('employeeEmailIndex',
+                  employees.c.emailAddress, unique=True)        
+        i.create()
+        self.created.append(i)
+
+        # Check that the table is useable. This is mostly for pg,
+        # which can be somewhat sticky with mixed-case identifiers
+        employees.insert().execute(firstName='Joe', lastName='Smith')
+        ss = employees.select().execute().fetchall()
+        assert ss[0].firstName == 'Joe'
+        assert ss[0].lastName == 'Smith'
+
+    def test_index_create_inline(self):
+        """Test indexes defined with tables"""
+
+        testbase.db.echo = True
+        capt = []
+        class dummy:
+            pass
+        stream = dummy()
+        stream.write = capt.append
+        testbase.db.logger = testbase.db.engine.logger = stream
+        
+        events = Table('events', testbase.db,
+                       Column('id', Integer, primary_key=True),
+                       Column('name', String(30), unique=True),
+                       Column('location', String(30), index=True),
+                       Column('sport', String(30),
+                              unique='sport_announcer'),
+                       Column('announcer', String(30),
+                              unique='sport_announcer'),
+                       Column('winner', String(30), index='idx_winners'))
+        
+        index_names = [ ix.name for ix in events.indexes ]
+        assert 'ux_name' in index_names
+        assert 'ix_location' in index_names
+        assert 'sport_announcer' in index_names
+        assert 'idx_winners' in index_names
+        assert len(index_names) == 4
+
+        events.create()
+        self.created.append(events)
+
+        # verify that the table is functional
+        events.insert().execute(id=1, name='hockey finals', location='rink',
+                                sport='hockey', announcer='some canadian',
+                                winner='sweden')
+        ss = events.select().execute().fetchall()
+        
+        assert capt[0].strip().startswith('CREATE TABLE events')
+        assert capt[2].strip() == \
+            'CREATE UNIQUE INDEX ux_name ON events (name)'
+        assert capt[4].strip() == \
+            'CREATE INDEX ix_location ON events (location)'
+        assert capt[6].strip() == \
+            'CREATE UNIQUE INDEX sport_announcer ON events (sport, announcer)'
+        assert capt[8].strip() == \
+            'CREATE INDEX idx_winners ON events (winner)'
+            
 if __name__ == "__main__":    
     testbase.main()