]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add basic support of unique constraints reflection
authorRoman Podolyaka <roman.podolyaka@gmail.com>
Sun, 9 Jun 2013 16:07:00 +0000 (19:07 +0300)
committerRoman Podolyaka <roman.podolyaka@gmail.com>
Sun, 9 Jun 2013 20:49:55 +0000 (23:49 +0300)
Inspection API already supports reflection of table
indexes information and those also include unique
constraints (at least for PostgreSQL and MySQL).
But it could be actually useful to distinguish between
indexes and plain unique constraints (though both are
implemented in the same way internally in RDBMS).

This change adds a new method to Inspection API - get_unique_constraints()
and implements it for SQLite, PostgreSQL and MySQL dialects.

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/reflection.py
test/engine/test_reflection.py

index 9d856e271c8a899a80e9bdaccf827647da662bfb..2642b5fdcecad328dff0ba97de61b5189069b7e7 100644 (file)
@@ -2211,6 +2211,21 @@ class MySQLDialect(default.DefaultDialect):
             indexes.append(index_d)
         return indexes
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        parsed_state = self._parsed_state_or_create(
+            connection, table_name, schema, **kw)
+
+        return [
+            {
+                'name': key['name'],
+                'column_names': [col[0] for col in key['columns']]
+            }
+            for key in parsed_state.keys
+            if key['type'] == 'UNIQUE'
+        ]
+
     @reflection.cache
     def get_view_definition(self, connection, view_name, schema=None, **kw):
 
index 00d0acc2c9e93ee25ce26e9509a840251229e395..0810e03849ef6bd41e1b36622bdb92d7d887fa7f 100644 (file)
@@ -1950,6 +1950,36 @@ class PGDialect(default.DefaultDialect):
             index_d['unique'] = unique
         return indexes
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
+
+        UNIQUE_SQL = """
+            SELECT
+                cons.conname as name,
+                ARRAY_AGG(a.attname) as column_names
+            FROM
+                pg_catalog.pg_constraint cons
+                left outer join pg_attribute a
+                    on cons.conrelid = a.attrelid and a.attnum = ANY(cons.conkey)
+            WHERE
+                cons.conrelid = :table_oid AND
+                cons.contype = 'u'
+            GROUP BY
+                cons.conname
+        """
+
+        t = sql.text(UNIQUE_SQL,
+                     typemap={'column_names': ARRAY(sqltypes.Unicode)})
+        c = connection.execute(t, table_oid=table_oid)
+
+        return [
+            {'name': row.name, 'column_names': row.column_names}
+            for row in c.fetchall()
+        ]
+
     def _load_enums(self, connection):
         if not self.supports_native_enum:
             return {}
index c7e09b164d71f4fc91d1fe0ed69d3a144715e53c..3e2a158a066437e79d562abf9c2fbe9ace702090 100644 (file)
@@ -917,6 +917,26 @@ class SQLiteDialect(default.DefaultDialect):
                 cols.append(row[2])
         return indexes
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        UNIQUE_SQL = """
+            SELECT sql
+            FROM
+                sqlite_master
+            WHERE
+                type='table' AND
+                name=:table_name
+        """
+        c = connection.execute(UNIQUE_SQL, table_name=table_name)
+        table_data = c.fetchone()[0]
+
+        UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)'
+        return [
+            {'name': name, 'column_names': [c.strip() for c in cols.split(',')]}
+            for name, cols in re.findall(UNIQUE_PATTERN, table_data)
+        ]
+
 
 def _pragma_cursor(cursor):
     """work around SQLite issue whereby cursor.description
index f623a2a61d0cdf7dfbfa2f6924e2b6a428dcf694..d5fe5c5e230e5f6740455c470047c3f14c563794 100644 (file)
@@ -338,6 +338,25 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def get_unique_constraints(self, table_name, schema=None, **kw):
+        """Return information about unique constraints in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        unique constraint information as a list of dicts with these keys:
+
+        name
+          the unique constraint's name
+
+        column_names
+          list of column names in order
+
+        \**kw
+          other options passed to the dialect's get_unique_constraints() method.
+
+        """
+
+        raise NotImplementedError()
+
     def normalize_name(self, name):
         """convert the given name to lowercase if it is detected as
         case insensitive.
index cf2caf679d910f5cb622c4b0b8520d4a28065661..1926e693a8ca4000573c01a66968eb746054fc35 100644 (file)
@@ -347,6 +347,26 @@ class Inspector(object):
                                                   schema,
                                             info_cache=self.info_cache, **kw)
 
+    def get_unique_constraints(self, table_name, schema=None, **kw):
+        """Return information about unique constraints in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        unique constraint information as a list of dicts with these keys:
+
+        name
+          the unique constraint's name
+
+        column_names
+          list of column names in order
+
+        \**kw
+          other options passed to the dialect's get_unique_constraints() method.
+
+        """
+
+        return self.dialect.get_unique_constraints(
+            self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+
     def reflecttable(self, table, include_columns, exclude_columns=()):
         """Given a Table object, load its internal constructs based on
         introspection.
index ac0fa515308de5186324e655376ba0031c3b764b..fd941187466fa3fbe460fc738dbd7245c59a2992 100644 (file)
@@ -1,3 +1,5 @@
+import operator
+
 import unicodedata
 import sqlalchemy as sa
 from sqlalchemy import schema, events, event, inspect
@@ -878,6 +880,41 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         assert set([t2.c.name, t2.c.id]) == set(r2.columns)
         assert set([t2.c.name]) == set(r3.columns)
 
+    @testing.provide_metadata
+    def test_unique_constraints_reflection(self):
+        uniques = sorted(
+            [
+                {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
+                {'name': 'unique_a_c', 'column_names': ['a', 'c']},
+                {'name': 'unique_b_c', 'column_names': ['b', 'c']},
+            ],
+            key=operator.itemgetter('name')
+        )
+
+        try:
+            orig_meta = sa.MetaData(bind=testing.db)
+            table = Table(
+                'testtbl', orig_meta,
+                Column('a', sa.String(20)),
+                Column('b', sa.String(30)),
+                Column('c', sa.Integer),
+            )
+            for uc in uniques:
+                table.append_constraint(
+                    sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
+                )
+            orig_meta.create_all()
+
+            inspector = inspect(testing.db)
+            reflected = sorted(
+                inspector.get_unique_constraints('testtbl'),
+                key=operator.itemgetter('name')
+            )
+
+            assert uniques == reflected
+        finally:
+            testing.db.execute('drop table if exists testtbl;')
+
     @testing.requires.views
     @testing.provide_metadata
     def test_views(self):