From f81f3f8bd4d496dcc5fc142f6dbad987ce1ea45c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 26 Nov 2005 20:14:03 +0000 Subject: [PATCH] got round trip for multiple priamry keys to work with table create/reflection (postgres, sqlite) small fix to ORM get with multiple primary keys --- lib/sqlalchemy/ansisql.py | 13 +++++-- lib/sqlalchemy/databases/oracle.py | 4 +- lib/sqlalchemy/databases/postgres.py | 4 +- lib/sqlalchemy/databases/sqlite.py | 56 ++++++++++++++++++++++++++-- lib/sqlalchemy/mapper.py | 2 +- 5 files changed, 68 insertions(+), 11 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 8aff72c47a..109a7deb52 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -308,7 +308,7 @@ class ANSICompiler(sql.Compiled): class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): - def get_column_specification(self, column): + def get_column_specification(self, column, override_pk=False): raise NotImplementedError() def visit_table(self, table): @@ -316,11 +316,18 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator): separator = "\n" + # if only one primary key, specify it along with the column + pks = table.primary_keys for column in table.columns: self.append(separator) separator = ", \n" - self.append("\t" + self.get_column_specification(column)) - + self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1)) + + # if multiple primary keys, specify it at the bottom + if len(pks) > 1: + self.append(", \n") + self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', ')) + self.append("\n)\n\n") self.execute() diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index aa25ffec62..3fa14b01b0 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -196,13 +196,13 @@ class OracleCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.visit_insert(self, insert) class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column): + def get_column_specification(self, column, override_pk=False): colspec = column.name colspec += " " + column.type.get_col_spec() if not column.nullable: colspec += " NOT NULL" - if column.primary_key: + if column.primary_key and not override_pk: colspec += " PRIMARY KEY" if column.foreign_key: colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index fd468a8785..0a41fb4c8e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -331,7 +331,7 @@ class PGCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.visit_insert(self, insert) class PGSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column): + def get_column_specification(self, column, override_pk=False): colspec = column.name if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional): colspec += " SERIAL" @@ -340,7 +340,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator): if not column.nullable: colspec += " NOT NULL" - if column.primary_key: + if column.primary_key and not override_pk: colspec += " PRIMARY KEY" if column.foreign_key: colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 860e82c2cd..8f6bedff69 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -149,20 +149,70 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): #print "row! " + repr(row) remotetable = Table(tablename, self, autoload = True) table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol]) - + # check for UNIQUE indexes + c = self.execute("PRAGMA index_list(" + table.name + ")", {}) + unique_indexes = [] + while True: + row = c.fetchone() + if row is None: + break + if (row[2] == 1): + unique_indexes.append(row[1]) + # loop thru unique indexes for one that includes the primary key + for idx in unique_indexes: + c = self.execute("PRAGMA index_info(" + idx + ")", {}) + cols = [] + includes_primary=False + while True: + row = c.fetchone() + if row is None: + break + cols.append(row[2]) + col = table.columns[row[2]] + if col.primary_key: + includes_primary= True + if includes_primary: + # unique index that includes the pk is considered a multiple primary key + for col in cols: + column = table.columns[col] + table.columns[col]._set_primary_key() + class SQLiteCompiler(ansisql.ANSICompiler): def __init__(self, *args, **params): params.setdefault('paramstyle', 'named') ansisql.ANSICompiler.__init__(self, *args, **params) class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): - def get_column_specification(self, column): + def get_column_specification(self, column, override_pk=False): colspec = column.name + " " + column.type.get_col_spec() if not column.nullable: colspec += " NOT NULL" - if column.primary_key: + if column.primary_key and not override_pk: colspec += " PRIMARY KEY" if column.foreign_key: colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec + def visit_table(self, table): + """sqlite is going to create multi-primary keys as a single PK plus a UNIQUE index. otherwise + its autoincrement functionality seems to get lost""" + self.append("\nCREATE TABLE " + table.fullname + "(") + + separator = "\n" + + have_pk = False + for column in table.columns: + self.append(separator) + separator = ", \n" + # specify PRIMARY KEY for just the first primary key + self.append("\t" + self.get_column_specification(column, override_pk=have_pk)) + if column.primary_key: + have_pk = True + + if len(table.primary_keys) > 1: + self.append(", \n") + # put all PRIMARY KEYS in a UNIQUE index + self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_keys],', ')) + + self.append("\n)\n\n") + self.execute() diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 186da5af9f..5203199425 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -371,7 +371,7 @@ class Mapper(object): # appending to the and_'s clause list directly to skip # typechecks etc. clause.clauses.append(primary_key == ident[i]) - i += 2 + i += 1 try: return self.select(clause)[0] except IndexError: -- 2.47.2