]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- backport unique constraints reflection to 0.8.4, thereby
authorRoman Podolyaka <roman.podolyaka@gmail.com>
Sun, 9 Jun 2013 16:07:00 +0000 (19:07 +0300)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 3 Dec 2013 19:59:50 +0000 (14:59 -0500)
assisting with alembic installations that have upgraded and are dealing with
PG index/unique constraint reflection.

Inspection API already supports reflection of table
indexes information and those also include unique
constraints (at least for PostgreSQL and MySQL).
But it could be actually useful to distinguish between
indexes and plain unique constraints (though both are
implemented in the same way internally in RDBMS).

This change adds a new method to Inspection API - get_unique_constraints()
and implements it for SQLite, PostgreSQL and MySQL dialects.

doc/build/changelog/changelog_08.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/reflection.py
test/engine/test_reflection.py

index 96f087224933d60c82455bfa5a0b79868037d337..a6583bcc2a36b128f59021926c933eaa472d8afe 100644 (file)
 .. changelog::
     :version: 0.8.4
 
+     .. change::
+        :tags: feature, sql
+        :tickets: 1443
+        :versions: 0.9.0b1
+
+        Added support for "unique constraint" reflection, via the
+        :meth:`.Inspector.get_unique_constraints` method.
+        Thanks for Roman Podolyaka for the patch.
+
     .. change::
         :tags: bug, oracle
         :tickets: 2864
index 901849bb2258d1f333b2ce3de7b6000da363b660..bd3298dab392f2b544ae251e0d5d1a391a8c4dab 100644 (file)
@@ -2284,6 +2284,21 @@ class MySQLDialect(default.DefaultDialect):
             indexes.append(index_d)
         return indexes
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        parsed_state = self._parsed_state_or_create(
+            connection, table_name, schema, **kw)
+
+        return [
+            {
+                'name': key['name'],
+                'column_names': [col[0] for col in key['columns']]
+            }
+            for key in parsed_state.keys
+            if key['type'] == 'UNIQUE'
+        ]
+
     @reflection.cache
     def get_view_definition(self, connection, view_name, schema=None, **kw):
 
index 1aaafac3963664950cb03efc960232855a669b78..c451d46764d643549f29b15bd60028812bf01728 100644 (file)
@@ -1990,6 +1990,36 @@ class PGDialect(default.DefaultDialect):
             for name, idx in indexes.items()
         ]
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
+
+        UNIQUE_SQL = """
+            SELECT
+                cons.conname as name,
+                ARRAY_AGG(a.attname) as column_names
+            FROM
+                pg_catalog.pg_constraint cons
+                left outer join pg_attribute a
+                    on cons.conrelid = a.attrelid and a.attnum = ANY(cons.conkey)
+            WHERE
+                cons.conrelid = :table_oid AND
+                cons.contype = 'u'
+            GROUP BY
+                cons.conname
+        """
+
+        t = sql.text(UNIQUE_SQL,
+                     typemap={'column_names': ARRAY(sqltypes.Unicode)})
+        c = connection.execute(t, table_oid=table_oid)
+
+        return [
+            {'name': row.name, 'column_names': row.column_names}
+            for row in c.fetchall()
+        ]
+
     def _load_enums(self, connection):
         if not self.supports_native_enum:
             return {}
index 2ea8c2494b143273b918680ea7502f3765f091fa..19a35b4954aa0995ba7e14f7908bb9e27b8e01bc 100644 (file)
@@ -924,6 +924,26 @@ class SQLiteDialect(default.DefaultDialect):
                 cols.append(row[2])
         return indexes
 
+    @reflection.cache
+    def get_unique_constraints(self, connection, table_name,
+                               schema=None, **kw):
+        UNIQUE_SQL = """
+            SELECT sql
+            FROM
+                sqlite_master
+            WHERE
+                type='table' AND
+                name=:table_name
+        """
+        c = connection.execute(UNIQUE_SQL, table_name=table_name)
+        table_data = c.fetchone()[0]
+
+        UNIQUE_PATTERN = 'CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)'
+        return [
+            {'name': name, 'column_names': [c.strip() for c in cols.split(',')]}
+            for name, cols in re.findall(UNIQUE_PATTERN, table_data)
+        ]
+
 
 def _pragma_cursor(cursor):
     """work around SQLite issue whereby cursor.description
index f623a2a61d0cdf7dfbfa2f6924e2b6a428dcf694..ac011559c8abcd6a98ae9e6879996347d1d7e6b6 100644 (file)
@@ -338,6 +338,23 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    def get_unique_constraints(self, table_name, schema=None, **kw):
+        """Return information about unique constraints in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        unique constraint information as a list of dicts with these keys:
+
+        name
+          the unique constraint's name
+
+        column_names
+          list of column names in order
+
+
+        """
+
+        raise NotImplementedError()
+
     def normalize_name(self, name):
         """convert the given name to lowercase if it is detected as
         case insensitive.
index 84e7e0432771754ed581dfd16efcc20574cdac5d..47c00f83d076a251491f4005cadca658f8389caf 100644 (file)
@@ -311,9 +311,6 @@ class Inspector(object):
         name
           optional name of the foreign key constraint.
 
-        \**kw
-          other options passed to the dialect's get_foreign_keys() method.
-
         """
 
         return self.dialect.get_foreign_keys(self.bind, table_name, schema,
@@ -335,14 +332,31 @@ class Inspector(object):
         unique
           boolean
 
-        \**kw
-          other options passed to the dialect's get_indexes() method.
         """
 
         return self.dialect.get_indexes(self.bind, table_name,
                                                   schema,
                                             info_cache=self.info_cache, **kw)
 
+    def get_unique_constraints(self, table_name, schema=None, **kw):
+        """Return information about unique constraints in `table_name`.
+
+        Given a string `table_name` and an optional string `schema`, return
+        unique constraint information as a list of dicts with these keys:
+
+        name
+          the unique constraint's name
+
+        column_names
+          list of column names in order
+
+        .. versionadded:: 0.8.4
+
+        """
+
+        return self.dialect.get_unique_constraints(
+            self.bind, table_name, schema, info_cache=self.info_cache, **kw)
+
     def reflecttable(self, table, include_columns, exclude_columns=()):
         """Given a Table object, load its internal constructs based on
         introspection.
index 3cceaca78d27150331c3b82e4d9601b32fba5ad8..93ed0898ffb288625f7ea00ad8a6018f5ad6537d 100644 (file)
@@ -1,3 +1,5 @@
+import operator
+
 import unicodedata
 import sqlalchemy as sa
 from sqlalchemy import schema, events, event, inspect
@@ -876,6 +878,41 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         assert set([t2.c.name, t2.c.id]) == set(r2.columns)
         assert set([t2.c.name]) == set(r3.columns)
 
+    @testing.provide_metadata
+    def test_unique_constraints_reflection(self):
+        uniques = sorted(
+            [
+                {'name': 'unique_a_b_c', 'column_names': ['a', 'b', 'c']},
+                {'name': 'unique_a_c', 'column_names': ['a', 'c']},
+                {'name': 'unique_b_c', 'column_names': ['b', 'c']},
+            ],
+            key=operator.itemgetter('name')
+        )
+
+        try:
+            orig_meta = sa.MetaData(bind=testing.db)
+            table = Table(
+                'testtbl', orig_meta,
+                Column('a', sa.String(20)),
+                Column('b', sa.String(30)),
+                Column('c', sa.Integer),
+            )
+            for uc in uniques:
+                table.append_constraint(
+                    sa.UniqueConstraint(*uc['column_names'], name=uc['name'])
+                )
+            orig_meta.create_all()
+
+            inspector = inspect(testing.db)
+            reflected = sorted(
+                inspector.get_unique_constraints('testtbl'),
+                key=operator.itemgetter('name')
+            )
+
+            assert uniques == reflected
+        finally:
+            testing.db.execute('drop table if exists testtbl;')
+
     @testing.requires.views
     @testing.provide_metadata
     def test_views(self):