]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow auto_increment on any pk column, not just the first.
authorJason Kirtland <jek@discorporate.us>
Sun, 12 Aug 2007 01:11:44 +0000 (01:11 +0000)
committerJason Kirtland <jek@discorporate.us>
Sun, 12 Aug 2007 01:11:44 +0000 (01:11 +0000)
lib/sqlalchemy/databases/mysql.py
test/dialect/mysql.py

index f0b18d3acca45897d4072fdf05928ba2e6e28901..2d9c3af4c34e2defbda92bf06755ea63f5c0abf0 100644 (file)
@@ -1729,13 +1729,16 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
         if not column.nullable:
             colspec.append('NOT NULL')
 
-        # FIXME: #649 ASAP
-        if column.primary_key:
-            if (len(column.foreign_keys)==0
-                and first_pk
-                and column.autoincrement
-                and isinstance(column.type, sqltypes.Integer)):
-                colspec.append('AUTO_INCREMENT')
+        if column.primary_key and column.autoincrement:
+            try:
+                first = [c for c in column.table.primary_key.columns
+                         if (c.autoincrement and
+                             isinstance(c.type, sqltypes.Integer) and
+                             not c.foreign_keys)].pop(0)
+                if column is first:
+                    colspec.append('AUTO_INCREMENT')
+            except IndexError:
+                pass
 
         return ' '.join(colspec)
 
@@ -1909,6 +1912,8 @@ class MySQLSchemaReflector(object):
         # AUTO_INCREMENT
         if spec.get('autoincr', False):
             col_kw['autoincrement'] = True
+        elif issubclass(col_type, sqltypes.Integer):
+            col_kw['autoincrement'] = False
 
         # DEFAULT
         default = spec.get('default', None)
index 03a87a0ba3b82db5dd606e9df5905025e77d8677..ab3e49f93c9c29106d848d936bf856eb10c1889d 100644 (file)
@@ -600,6 +600,68 @@ class TypesTest(AssertMixin):
 
         m.drop_all()
 
+    @testing.supported('mysql')
+    def test_autoincrement(self):
+        meta = MetaData(testbase.db)
+        try:
+            Table('ai_1', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True))
+            Table('ai_2', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True))
+            Table('ai_3', meta,
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_4', meta,
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False),
+                  Column('int_n2', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False))
+            Table('ai_5', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False))
+            Table('ai_6', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_7', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('o2', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_8', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('o2', String(1), PassiveDefault('x'),
+                         primary_key=True))
+            meta.create_all()
+
+            table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+                           'ai_5', 'ai_6', 'ai_7', 'ai_8']
+            mr = MetaData(testbase.db)
+            mr.reflect(only=table_names)
+
+            for tbl in [mr.tables[name] for name in table_names]:
+                for c in tbl.c:
+                    if c.name.startswith('int_y'):
+                        assert c.autoincrement
+                    elif c.name.startswith('int_n'):
+                        assert not c.autoincrement
+                tbl.insert().execute()
+                if 'int_y' in tbl.c:
+                    assert select([tbl.c.int_y]).scalar() == 1
+                    assert list(tbl.select().execute().fetchone()).count(1) == 1
+                else:
+                    assert 1 not in list(tbl.select().execute().fetchone())
+        finally:
+            meta.drop_all()
+
     def assert_eq(self, got, wanted):
         if got != wanted:
             print "Expected %s" % wanted