class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
- def get_column_specification(self, column):
+ def get_column_specification(self, column, override_pk=False):
raise NotImplementedError()
def visit_table(self, table):
separator = "\n"
+ # if only one primary key, specify it along with the column
+ pks = table.primary_keys
for column in table.columns:
self.append(separator)
separator = ", \n"
- self.append("\t" + self.get_column_specification(column))
-
+ self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1))
+
+ # if multiple primary keys, specify it at the bottom
+ if len(pks) > 1:
+ self.append(", \n")
+ self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', '))
+
self.append("\n)\n\n")
self.execute()
return ansisql.ANSICompiler.visit_insert(self, insert)
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
- def get_column_specification(self, column):
+ def get_column_specification(self, column, override_pk=False):
colspec = column.name
colspec += " " + column.type.get_col_spec()
if not column.nullable:
colspec += " NOT NULL"
- if column.primary_key:
+ if column.primary_key and not override_pk:
colspec += " PRIMARY KEY"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
return ansisql.ANSICompiler.visit_insert(self, insert)
class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
- def get_column_specification(self, column):
+ def get_column_specification(self, column, override_pk=False):
colspec = column.name
if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
colspec += " SERIAL"
if not column.nullable:
colspec += " NOT NULL"
- if column.primary_key:
+ if column.primary_key and not override_pk:
colspec += " PRIMARY KEY"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name)
#print "row! " + repr(row)
remotetable = Table(tablename, self, autoload = True)
table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol])
-
+ # check for UNIQUE indexes
+ c = self.execute("PRAGMA index_list(" + table.name + ")", {})
+ unique_indexes = []
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ if (row[2] == 1):
+ unique_indexes.append(row[1])
+ # loop thru unique indexes for one that includes the primary key
+ for idx in unique_indexes:
+ c = self.execute("PRAGMA index_info(" + idx + ")", {})
+ cols = []
+ includes_primary=False
+ while True:
+ row = c.fetchone()
+ if row is None:
+ break
+ cols.append(row[2])
+ col = table.columns[row[2]]
+ if col.primary_key:
+ includes_primary= True
+ if includes_primary:
+ # unique index that includes the pk is considered a multiple primary key
+ for col in cols:
+ column = table.columns[col]
+ table.columns[col]._set_primary_key()
+
class SQLiteCompiler(ansisql.ANSICompiler):
def __init__(self, *args, **params):
params.setdefault('paramstyle', 'named')
ansisql.ANSICompiler.__init__(self, *args, **params)
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
- def get_column_specification(self, column):
+ def get_column_specification(self, column, override_pk=False):
colspec = column.name + " " + column.type.get_col_spec()
if not column.nullable:
colspec += " NOT NULL"
- if column.primary_key:
+ if column.primary_key and not override_pk:
colspec += " PRIMARY KEY"
if column.foreign_key:
colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name)
return colspec
+ def visit_table(self, table):
+ """sqlite is going to create multi-primary keys as a single PK plus a UNIQUE index. otherwise
+ its autoincrement functionality seems to get lost"""
+ self.append("\nCREATE TABLE " + table.fullname + "(")
+
+ separator = "\n"
+
+ have_pk = False
+ for column in table.columns:
+ self.append(separator)
+ separator = ", \n"
+ # specify PRIMARY KEY for just the first primary key
+ self.append("\t" + self.get_column_specification(column, override_pk=have_pk))
+ if column.primary_key:
+ have_pk = True
+
+ if len(table.primary_keys) > 1:
+ self.append(", \n")
+ # put all PRIMARY KEYS in a UNIQUE index
+ self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_keys],', '))
+
+ self.append("\n)\n\n")
+ self.execute()