]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added get_pk_constraint() to reflection.Inspector, similar
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Apr 2010 23:05:41 +0000 (19:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Apr 2010 23:05:41 +0000 (19:05 -0400)
to get_primary_keys() except returns a dict that includes the
name of the constraint, for supported backends (PG so far).
[ticket:1769]
- Postgresql reflects the name of primary key constraints,
if one exists.  [ticket:1769]

CHANGES
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/test/requires.py
test/engine/test_reflection.py

diff --git a/CHANGES b/CHANGES
index 7edcc9b8a959bcceff4599ed58baa2a037a75293..e997eb56c87baf1d6caf068c80d030cdb8557eaa 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -116,6 +116,11 @@ CHANGES
     corresponding to the dialect, clause element, the column
     names within the VALUES or SET clause of an INSERT or UPDATE, 
     as well as the "batch" mode for an INSERT or UPDATE statement.
+  
+  - Added get_pk_constraint() to reflection.Inspector, similar
+    to get_primary_keys() except returns a dict that includes the
+    name of the constraint, for supported backends (PG so far).
+    [ticket:1769]
     
 - ext
   - the compiler extension now allows @compiles decorators
@@ -143,7 +148,10 @@ CHANGES
   - psycopg2/pg8000 dialects now aware of REAL[], FLOAT[], 
     DOUBLE_PRECISION[], NUMERIC[] return types without
     raising an exception.
-    
+
+  - Postgresql reflects the name of primary key constraints,
+    if one exists.  [ticket:1769]
+        
 - oracle
   - Now using cx_oracle output converters so that the
     DBAPI returns natively the kinds of values we prefer:
index 312ae9aa8a35b0fe2786edea42a1479353e35054..72251c8d5751604655803e925194ae1bc3c2ed68 100644 (file)
@@ -1005,6 +1005,27 @@ class PGDialect(default.DefaultDialect):
         primary_keys = [r[0] for r in c.fetchall()]
         return primary_keys
 
+    @reflection.cache
+    def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+        cols = self.get_primary_keys(connection, table_name, schema=schema, **kw)
+        
+        table_oid = self.get_table_oid(connection, table_name, schema,
+                                       info_cache=kw.get('info_cache'))
+
+        PK_CONS_SQL = """
+        SELECT conname
+           FROM  pg_catalog.pg_constraint r
+           WHERE r.conrelid = :table_oid AND r.contype = 'p'
+           ORDER BY 1
+        """
+        t = sql.text(PK_CONS_SQL, typemap={'conname':sqltypes.Unicode})
+        c = connection.execute(t, table_oid=table_oid)
+        name = c.scalar()
+        return {
+            'constrained_columns':cols,
+            'name':name
+        }
+
     @reflection.cache
     def get_foreign_keys(self, connection, table_name, schema=None, **kw):
         preparer = self.identifier_preparer
index 4c5a6a82b6e7d9babd4fee9178abfee57536ca9f..d3579606306bad2204db7e6590da9a504a1899ee 100644 (file)
@@ -260,8 +260,23 @@ class Dialect(object):
         Given a :class:`~sqlalchemy.engine.Connection`, a string
         `table_name`, and an optional string `schema`, return primary
         key information as a list of column names.
+
         """
+        raise NotImplementedError()
+
+    def get_pk_constraint(self, table_name, schema=None, **kw):
+        """Return information about the primary key constraint on `table_name`.
 
+        Given a string `table_name`, and an optional string `schema`, return
+        primary key information as a dictionary with these keys:
+
+        constrained_columns
+          a list of column names that make up the primary key
+
+        name
+          optional name of the primary key constraint.
+
+        """
         raise NotImplementedError()
 
     def get_foreign_keys(self, connection, table_name, schema=None, **kw):
index 6fb0a14a518d6d6b8f72afc6d7d086bf0247bf9d..fc49c62fac7bb28def478bccbab383e2fe847976 100644 (file)
@@ -205,6 +205,17 @@ class DefaultDialect(base.Dialect):
         insp = reflection.Inspector.from_engine(connection)
         return insp.reflecttable(table, include_columns)
 
+    def get_pk_constraint(self, conn, table_name, schema=None, **kw):
+        """Compatiblity method, adapts the result of get_primary_keys()
+        for those dialects which don't implement get_pk_constraint().
+        
+        """
+        return {
+            'constrained_columns':
+                        self.get_primary_keys(conn, table_name, 
+                                                schema=schema, **kw)
+        }
+        
     def validate_identifier(self, ident):
         if len(ident) > self.max_identifier_length:
             raise exc.IdentifierError(
index 57f2205c167bc6da40473f81a33bd68f434ef1c5..56b9eafd85ad0a9f5a9c74fd9783e98ccc186fde 100644 (file)
@@ -189,6 +189,26 @@ class Inspector(object):
 
         return pkeys
 
+    def get_pk_constraint(self, table_name, schema=None, **kw):
+        """Return information about primary key constraint on `table_name`.
+
+        Given a string `table_name`, and an optional string `schema`, return
+        primary key information as a dictionary with these keys:
+        
+        constrained_columns
+          a list of column names that make up the primary key
+        
+        name
+          optional name of the primary key constraint.
+
+        """
+        pkeys = self.dialect.get_pk_constraint(self.conn, table_name, schema,
+                                              info_cache=self.info_cache,
+                                              **kw)
+
+        return pkeys
+        
+
     def get_foreign_keys(self, table_name, schema=None, **kw):
         """Return information about foreign_keys in `table_name`.
 
@@ -208,6 +228,9 @@ class Inspector(object):
           a list of column names in the referred table that correspond to
           constrained_columns
 
+        name
+          optional name of the foreign key constraint.
+          
         \**kw
           other options passed to the dialect's get_foreign_keys() method.
 
@@ -318,12 +341,14 @@ class Inspector(object):
             raise exc.NoSuchTableError(table.name)
 
         # Primary keys
-        primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
-            table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw)
-            if pk in table.c
-        ])
-
-        table.append_constraint(primary_key_constraint)
+        pk_cons = self.get_pk_constraint(table_name, schema, **tblkw)
+        if pk_cons:
+            primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[
+                table.c[pk] for pk in pk_cons['constrained_columns']
+                if pk in table.c
+            ], name=pk_cons.get('name'))
+
+            table.append_constraint(primary_key_constraint)
 
         # Foreign keys
         fkeys = self.get_foreign_keys(table_name, schema, **tblkw)
index bf911c2c2201425b98d38d0057fa9d4adae00e8f..1b9052fd8b03f7f6e35dc7e908828133208ebf57 100644 (file)
@@ -11,7 +11,8 @@ from testing import \
      exclude, \
      emits_warning_on,\
      skip_if,\
-     fails_on
+     fails_on,\
+     fails_on_everything_except
 
 import testing
 import sys
@@ -245,6 +246,13 @@ def sane_rowcount(fn):
         fn,
         skip_if(lambda: not testing.db.dialect.supports_sane_rowcount)
     )
+
+def reflects_pk_names(fn):
+    """Target driver reflects the name of primary key constraints."""
+    return _chain_decorators_on(
+        fn,
+        fails_on_everything_except('postgresql')
+    )
     
 def python2(fn):
     return _chain_decorators_on(
index 18074337f205cbe6c54c482086e1a5141ffb7d33..4b1cb4652d796636470119aebe75998a7f6e76cf 100644 (file)
@@ -4,8 +4,7 @@ from sqlalchemy import types as sql_types
 from sqlalchemy import schema
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy import MetaData
-from sqlalchemy.test.schema import Table
-from sqlalchemy.test.schema import Column
+from sqlalchemy.test.schema import Table, Column
 import sqlalchemy as sa
 from sqlalchemy.test import TestBase, ComparesTables, \
                             testing, engines, AssertsCompiledSQL
@@ -966,10 +965,11 @@ def createTables(meta, schema=None):
         test_needs_fk=True,
     )
     addresses = Table('email_addresses', meta,
-        Column('address_id', sa.Integer, primary_key = True),
+        Column('address_id', sa.Integer),
         Column('remote_user_id', sa.Integer,
                sa.ForeignKey(users.c.user_id)),
         Column('email_address', sa.String(20)),
+        sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'),
         schema=schema,
         test_needs_fk=True,
     )
@@ -1148,10 +1148,17 @@ class ComponentReflectionTest(TestBase):
             users_pkeys = insp.get_primary_keys(users.name,
                                                 schema=schema)
             eq_(users_pkeys,  ['user_id'])
-            addr_pkeys = insp.get_primary_keys(addresses.name,
-                                               schema=schema)
+            addr_cons = insp.get_pk_constraint(addresses.name,
+                                                schema=schema)
+                                                
+            addr_pkeys = addr_cons['constrained_columns']
             eq_(addr_pkeys,  ['address_id'])
-
+            
+            @testing.requires.reflects_pk_names
+            def go():
+                eq_(addr_cons['name'], 'email_ad_pk')
+            go()
+            
         finally:
             addresses.drop()
             users.drop()