]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got round trip for multiple priamry keys to work with table create/reflection (postgr...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2005 20:14:03 +0000 (20:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Nov 2005 20:14:03 +0000 (20:14 +0000)
small fix to ORM get with multiple primary keys

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/mapper.py

index 8aff72c47a1442e3971a6cdb45c68034a0a1a757..109a7deb52519ca6c5c06df4c9b73dab5f6783e1 100644 (file)
@@ -308,7 +308,7 @@ class ANSICompiler(sql.Compiled):
 
 class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
 
-    def get_column_specification(self, column):
+    def get_column_specification(self, column, override_pk=False):
         raise NotImplementedError()
         
     def visit_table(self, table):
@@ -316,11 +316,18 @@ class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
         
         separator = "\n"
         
+        # if only one primary key, specify it along with the column
+        pks = table.primary_keys
         for column in table.columns:
             self.append(separator)
             separator = ", \n"
-            self.append("\t" + self.get_column_specification(column))
-            
+            self.append("\t" + self.get_column_specification(column, override_pk=len(pks)>1))
+        
+        # if multiple primary keys, specify it at the bottom
+        if len(pks) > 1:
+            self.append(", \n")
+            self.append("\tPRIMARY KEY (%s)" % string.join([c.name for c in pks],', '))
+                    
         self.append("\n)\n\n")
         self.execute()
 
index aa25ffec62965603523f8186b10c1ce890172d4a..3fa14b01b05d6f7b078bb0edad50e376cecfd5f8 100644 (file)
@@ -196,13 +196,13 @@ class OracleCompiler(ansisql.ANSICompiler):
         return ansisql.ANSICompiler.visit_insert(self, insert)
 
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column):
+    def get_column_specification(self, column, override_pk=False):
         colspec = column.name
         colspec += " " + column.type.get_col_spec()
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key:
+        if column.primary_key and not override_pk:
             colspec += " PRIMARY KEY"
         if column.foreign_key:
             colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
index fd468a87851c43abc85083d04fa624eece9442a6..0a41fb4c8e95daacc124a503f402591b6e72b314 100644 (file)
@@ -331,7 +331,7 @@ class PGCompiler(ansisql.ANSICompiler):
         return ansisql.ANSICompiler.visit_insert(self, insert)
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column):
+    def get_column_specification(self, column, override_pk=False):
         colspec = column.name
         if column.primary_key and isinstance(column.type, types.Integer) and (column.sequence is None or column.sequence.optional):
             colspec += " SERIAL"
@@ -340,7 +340,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
 
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key:
+        if column.primary_key and not override_pk:
             colspec += " PRIMARY KEY"
         if column.foreign_key:
             colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
index 860e82c2cde213c309eedec3bbb6ecc582ff9697..8f6bedff697eb7ba131b4384ec4142567774da2e 100644 (file)
@@ -149,20 +149,70 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
             #print "row! " + repr(row)
             remotetable = Table(tablename, self, autoload = True)
             table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol])
-            
+        # check for UNIQUE indexes
+        c = self.execute("PRAGMA index_list(" + table.name + ")", {})
+        unique_indexes = []
+        while True:
+            row = c.fetchone()
+            if row is None:
+                break
+            if (row[2] == 1):
+                unique_indexes.append(row[1])
+        # loop thru unique indexes for one that includes the primary key
+        for idx in unique_indexes:
+            c = self.execute("PRAGMA index_info(" + idx + ")", {})
+            cols = []
+            includes_primary=False
+            while True:
+                row = c.fetchone()
+                if row is None:
+                    break
+                cols.append(row[2])
+                col = table.columns[row[2]]
+                if col.primary_key:
+                    includes_primary= True
+            if includes_primary:
+                # unique index that includes the pk is considered a multiple primary key
+                for col in cols:
+                    column = table.columns[col]
+                    table.columns[col]._set_primary_key()
+                    
 class SQLiteCompiler(ansisql.ANSICompiler):
     def __init__(self, *args, **params):
         params.setdefault('paramstyle', 'named')
         ansisql.ANSICompiler.__init__(self, *args, **params)
 
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
-    def get_column_specification(self, column):
+    def get_column_specification(self, column, override_pk=False):
         colspec = column.name + " " + column.type.get_col_spec()
         if not column.nullable:
             colspec += " NOT NULL"
-        if column.primary_key:
+        if column.primary_key and not override_pk:
             colspec += " PRIMARY KEY"
         if column.foreign_key:
             colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
+    def visit_table(self, table):
+        """sqlite is going to create multi-primary keys as a single PK plus a UNIQUE index.  otherwise
+        its autoincrement functionality seems to get lost"""
+        self.append("\nCREATE TABLE " + table.fullname + "(")
+
+        separator = "\n"
+
+        have_pk = False
+        for column in table.columns:
+            self.append(separator)
+            separator = ", \n"
+            # specify PRIMARY KEY for just the first primary key
+            self.append("\t" + self.get_column_specification(column, override_pk=have_pk))
+            if column.primary_key:
+                have_pk = True
+                
+        if len(table.primary_keys) > 1:
+            self.append(", \n")
+            # put all PRIMARY KEYS in a UNIQUE index
+            self.append("\tUNIQUE (%s)" % string.join([c.name for c in table.primary_keys],', '))
+
+        self.append("\n)\n\n")
+        self.execute()
 
index 186da5af9f4c03b1f8a9c4d4f181896e6bb0be6e..5203199425d3feda2d1fdd8810ec610a93c27393 100644 (file)
@@ -371,7 +371,7 @@ class Mapper(object):
                 # appending to the and_'s clause list directly to skip
                 # typechecks etc.
                 clause.clauses.append(primary_key == ident[i])
-                i += 2
+                i += 1
             try:
                 return self.select(clause)[0]
             except IndexError: