]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sqlite: reflect primary key constraint names, fixes #3629
authorDiana Clarke <diana.joan.clarke@gmail.com>
Thu, 28 Jan 2016 03:54:05 +0000 (22:54 -0500)
committerDiana Clarke <diana.joan.clarke@gmail.com>
Thu, 28 Jan 2016 03:54:05 +0000 (22:54 -0500)
lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/test_sqlite.py
test/requirements.py

index 0e048aefff1cc9a608175531424828dcb4fcb1ae..3ab9022cc06f4bd45dfa6dbfe0ffe123ba422c49 100644 (file)
@@ -1297,12 +1297,21 @@ class SQLiteDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+        table_data = self._get_table_sql(connection, table_name, schema=schema)
+
+        def parse_pk():
+            PK_PATTERN = 'CONSTRAINT (\w+) PRIMARY KEY'
+            result = re.search(PK_PATTERN, table_data, re.I)
+            return result.group(1) if result else None
+
         cols = self.get_columns(connection, table_name, schema, **kw)
         pkeys = []
         for col in cols:
             if col['primary_key']:
                 pkeys.append(col['name'])
-        return {'constrained_columns': pkeys, 'name': None}
+
+        constraint_name = parse_pk() if table_data else None
+        return {'constrained_columns': pkeys, 'name': constraint_name}
 
     @reflection.cache
     def get_foreign_keys(self, connection, table_name, schema=None, **kw):
index 33903ff89826f225f35de9d09cc9d6da05a77c4e..580950b12f2aa0f68b0d99f6f3d949be575a155d 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.testing import eq_, assert_raises, \
     assert_raises_message, is_
 from sqlalchemy import Table, select, bindparam, Column,\
     MetaData, func, extract, ForeignKey, text, DefaultClause, and_, \
-    create_engine, UniqueConstraint, Index
+    create_engine, UniqueConstraint, Index, PrimaryKeyConstraint
 from sqlalchemy.types import Integer, String, Boolean, DateTime, Date, Time
 from sqlalchemy import types as sqltypes
 from sqlalchemy import event, inspect
@@ -1130,6 +1130,18 @@ class ConstraintReflectionTest(fixtures.TestBase):
                 prefixes=['TEMPORARY']
             )
 
+            Table(
+                'p', meta,
+                Column('id', Integer),
+                PrimaryKeyConstraint('id', name='pk_name'),
+            )
+
+            Table(
+                'q', meta,
+                Column('id', Integer),
+                PrimaryKeyConstraint('id'),
+            )
+
             meta.create_all(conn)
 
             # will contain an "autoindex"
@@ -1223,8 +1235,6 @@ class ConstraintReflectionTest(fixtures.TestBase):
         )
 
     def test_unnamed_inline_foreign_key_quoted(self):
-        inspector = Inspector(testing.db)
-
         inspector = Inspector(testing.db)
         fks = inspector.get_foreign_keys('e1')
         eq_(
@@ -1342,6 +1352,27 @@ class ConstraintReflectionTest(fixtures.TestBase):
             [{'column_names': ['x'], 'name': None}]
         )
 
+    def test_primary_key_constraint_named(self):
+        inspector = Inspector(testing.db)
+        eq_(
+            inspector.get_pk_constraint("p"),
+            {'constrained_columns': ['id'], 'name': 'pk_name'}
+        )
+
+    def test_primary_key_constraint_unnamed(self):
+        inspector = Inspector(testing.db)
+        eq_(
+            inspector.get_pk_constraint("q"),
+            {'constrained_columns': ['id'], 'name': None}
+        )
+
+    def test_primary_key_constraint_no_pk(self):
+        inspector = Inspector(testing.db)
+        eq_(
+            inspector.get_pk_constraint("d"),
+            {'constrained_columns': [], 'name': None}
+        )
+
 
 class SavepointTest(fixtures.TablesTest):
 
index 522a376e00c7f215fd6ccdde3d3889fcb6a64520..658ffab017ca27650671a7f766cedcbd65ff6a26 100644 (file)
@@ -528,7 +528,7 @@ class DefaultRequirements(SuiteRequirements):
         """Target driver reflects the name of primary key constraints."""
 
         return fails_on_everything_except('postgresql', 'oracle', 'mssql',
-                    'sybase')
+                    'sybase', 'sqlite')
 
     @property
     def json_type(self):