]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Handle sqlite get_unique_constraints() call for temporary tables
authorJohannes Erdfelt <johannes@erdfelt.com>
Wed, 17 Sep 2014 14:52:34 +0000 (07:52 -0700)
committerJohannes Erdfelt <johannes@erdfelt.com>
Wed, 17 Sep 2014 15:01:01 +0000 (08:01 -0700)
The sqlite get_unique_constraints() implementation did not do a union
against the sqlite_temp_master table like other code does. This could
result in an exception being raised if get_unique_constraints() was
called against a temporary table.

lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/test_sqlite.py

index af793d27575c8d39da36db1cd9a814ae54e53789..c76ef6afdc407155666883c60206f3cabaef224c 100644 (file)
@@ -1097,16 +1097,24 @@ class SQLiteDialect(default.DefaultDialect):
     @reflection.cache
     def get_unique_constraints(self, connection, table_name,
                                schema=None, **kw):
-        UNIQUE_SQL = """
-            SELECT sql
-            FROM
-                sqlite_master
-            WHERE
-                type='table' AND
-                name=:table_name
-        """
-        c = connection.execute(UNIQUE_SQL, table_name=table_name)
-        table_data = c.fetchone()[0]
+        try:
+            s = ("SELECT sql FROM "
+                 " (SELECT * FROM sqlite_master UNION ALL "
+                 "  SELECT * FROM sqlite_temp_master) "
+                 "WHERE name = '%s' "
+                 "AND type = 'table'") % table_name
+            rs = connection.execute(s)
+        except exc.DBAPIError:
+            s = ("SELECT sql FROM sqlite_master WHERE name = '%s' "
+                 "AND type = 'table'") % table_name
+            rs = connection.execute(s)
+        row = rs.fetchone()
+        if row is None:
+            # sqlite won't return the schema for the sqlite_master or
+            # sqlite_temp_master tables from this query. These tables
+            # don't have any unique constraints anyway.
+            return []
+        table_data = row[0]
 
         UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)'
         return [
index e77a039806b3274ddf5152ea5cb6e96985ca7363..6fc6446894774bcddb84ebf513cb9f7d57af421c 100644 (file)
@@ -575,6 +575,24 @@ class DialectTest(fixtures.TestBase, AssertsExecutionResults):
         finally:
             meta.drop_all()
 
+    def test_get_unique_constraints(self):
+        meta = MetaData(testing.db)
+        t1 = Table('foo', meta, Column('f', Integer),
+                   UniqueConstraint('f', name='foo_f'))
+        t2 = Table('bar', meta, Column('b', Integer),
+                   UniqueConstraint('b', name='bar_b'),
+                   prefixes=['TEMPORARY'])
+        meta.create_all()
+        from sqlalchemy.engine.reflection import Inspector
+        try:
+            inspector = Inspector(testing.db)
+            eq_(inspector.get_unique_constraints('foo'),
+                [{'column_names': [u'f'], 'name': u'foo_f'}])
+            eq_(inspector.get_unique_constraints('bar'),
+                [{'column_names': [u'b'], 'name': u'bar_b'}])
+        finally:
+            meta.drop_all()
+
 
 class SQLTest(fixtures.TestBase, AssertsCompiledSQL):