From: Mike Bayer Date: Thu, 15 Apr 2010 23:05:41 +0000 (-0400) Subject: - Added get_pk_constraint() to reflection.Inspector, similar X-Git-Tag: rel_0_6_0~15 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=82d194c9a65b09fef8d52318cbe38e2c84dfd2ca;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Added get_pk_constraint() to reflection.Inspector, similar to get_primary_keys() except returns a dict that includes the name of the constraint, for supported backends (PG so far). [ticket:1769] - Postgresql reflects the name of primary key constraints, if one exists. [ticket:1769] --- diff --git a/CHANGES b/CHANGES index 7edcc9b8a9..e997eb56c8 100644 --- a/CHANGES +++ b/CHANGES @@ -116,6 +116,11 @@ CHANGES corresponding to the dialect, clause element, the column names within the VALUES or SET clause of an INSERT or UPDATE, as well as the "batch" mode for an INSERT or UPDATE statement. + + - Added get_pk_constraint() to reflection.Inspector, similar + to get_primary_keys() except returns a dict that includes the + name of the constraint, for supported backends (PG so far). + [ticket:1769] - ext - the compiler extension now allows @compiles decorators @@ -143,7 +148,10 @@ CHANGES - psycopg2/pg8000 dialects now aware of REAL[], FLOAT[], DOUBLE_PRECISION[], NUMERIC[] return types without raising an exception. - + + - Postgresql reflects the name of primary key constraints, + if one exists. [ticket:1769] + - oracle - Now using cx_oracle output converters so that the DBAPI returns natively the kinds of values we prefer: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 312ae9aa8a..72251c8d57 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1005,6 +1005,27 @@ class PGDialect(default.DefaultDialect): primary_keys = [r[0] for r in c.fetchall()] return primary_keys + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + cols = self.get_primary_keys(connection, table_name, schema=schema, **kw) + + table_oid = self.get_table_oid(connection, table_name, schema, + info_cache=kw.get('info_cache')) + + PK_CONS_SQL = """ + SELECT conname + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table_oid AND r.contype = 'p' + ORDER BY 1 + """ + t = sql.text(PK_CONS_SQL, typemap={'conname':sqltypes.Unicode}) + c = connection.execute(t, table_oid=table_oid) + name = c.scalar() + return { + 'constrained_columns':cols, + 'name':name + } + @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): preparer = self.identifier_preparer diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4c5a6a82b6..d357960630 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -260,8 +260,23 @@ class Dialect(object): Given a :class:`~sqlalchemy.engine.Connection`, a string `table_name`, and an optional string `schema`, return primary key information as a list of column names. + """ + raise NotImplementedError() + + def get_pk_constraint(self, table_name, schema=None, **kw): + """Return information about the primary key constraint on `table_name`. + Given a string `table_name`, and an optional string `schema`, return + primary key information as a dictionary with these keys: + + constrained_columns + a list of column names that make up the primary key + + name + optional name of the primary key constraint. + + """ raise NotImplementedError() def get_foreign_keys(self, connection, table_name, schema=None, **kw): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 6fb0a14a51..fc49c62fac 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -205,6 +205,17 @@ class DefaultDialect(base.Dialect): insp = reflection.Inspector.from_engine(connection) return insp.reflecttable(table, include_columns) + def get_pk_constraint(self, conn, table_name, schema=None, **kw): + """Compatiblity method, adapts the result of get_primary_keys() + for those dialects which don't implement get_pk_constraint(). + + """ + return { + 'constrained_columns': + self.get_primary_keys(conn, table_name, + schema=schema, **kw) + } + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 57f2205c16..56b9eafd85 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -189,6 +189,26 @@ class Inspector(object): return pkeys + def get_pk_constraint(self, table_name, schema=None, **kw): + """Return information about primary key constraint on `table_name`. + + Given a string `table_name`, and an optional string `schema`, return + primary key information as a dictionary with these keys: + + constrained_columns + a list of column names that make up the primary key + + name + optional name of the primary key constraint. + + """ + pkeys = self.dialect.get_pk_constraint(self.conn, table_name, schema, + info_cache=self.info_cache, + **kw) + + return pkeys + + def get_foreign_keys(self, table_name, schema=None, **kw): """Return information about foreign_keys in `table_name`. @@ -208,6 +228,9 @@ class Inspector(object): a list of column names in the referred table that correspond to constrained_columns + name + optional name of the foreign key constraint. + \**kw other options passed to the dialect's get_foreign_keys() method. @@ -318,12 +341,14 @@ class Inspector(object): raise exc.NoSuchTableError(table.name) # Primary keys - primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ - table.c[pk] for pk in self.get_primary_keys(table_name, schema, **tblkw) - if pk in table.c - ]) - - table.append_constraint(primary_key_constraint) + pk_cons = self.get_pk_constraint(table_name, schema, **tblkw) + if pk_cons: + primary_key_constraint = sa_schema.PrimaryKeyConstraint(*[ + table.c[pk] for pk in pk_cons['constrained_columns'] + if pk in table.c + ], name=pk_cons.get('name')) + + table.append_constraint(primary_key_constraint) # Foreign keys fkeys = self.get_foreign_keys(table_name, schema, **tblkw) diff --git a/lib/sqlalchemy/test/requires.py b/lib/sqlalchemy/test/requires.py index bf911c2c22..1b9052fd8b 100644 --- a/lib/sqlalchemy/test/requires.py +++ b/lib/sqlalchemy/test/requires.py @@ -11,7 +11,8 @@ from testing import \ exclude, \ emits_warning_on,\ skip_if,\ - fails_on + fails_on,\ + fails_on_everything_except import testing import sys @@ -245,6 +246,13 @@ def sane_rowcount(fn): fn, skip_if(lambda: not testing.db.dialect.supports_sane_rowcount) ) + +def reflects_pk_names(fn): + """Target driver reflects the name of primary key constraints.""" + return _chain_decorators_on( + fn, + fails_on_everything_except('postgresql') + ) def python2(fn): return _chain_decorators_on( diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 18074337f2..4b1cb4652d 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -4,8 +4,7 @@ from sqlalchemy import types as sql_types from sqlalchemy import schema from sqlalchemy.engine.reflection import Inspector from sqlalchemy import MetaData -from sqlalchemy.test.schema import Table -from sqlalchemy.test.schema import Column +from sqlalchemy.test.schema import Table, Column import sqlalchemy as sa from sqlalchemy.test import TestBase, ComparesTables, \ testing, engines, AssertsCompiledSQL @@ -966,10 +965,11 @@ def createTables(meta, schema=None): test_needs_fk=True, ) addresses = Table('email_addresses', meta, - Column('address_id', sa.Integer, primary_key = True), + Column('address_id', sa.Integer), Column('remote_user_id', sa.Integer, sa.ForeignKey(users.c.user_id)), Column('email_address', sa.String(20)), + sa.PrimaryKeyConstraint('address_id', name='email_ad_pk'), schema=schema, test_needs_fk=True, ) @@ -1148,10 +1148,17 @@ class ComponentReflectionTest(TestBase): users_pkeys = insp.get_primary_keys(users.name, schema=schema) eq_(users_pkeys, ['user_id']) - addr_pkeys = insp.get_primary_keys(addresses.name, - schema=schema) + addr_cons = insp.get_pk_constraint(addresses.name, + schema=schema) + + addr_pkeys = addr_cons['constrained_columns'] eq_(addr_pkeys, ['address_id']) - + + @testing.requires.reflects_pk_names + def go(): + eq_(addr_cons['name'], 'email_ad_pk') + go() + finally: addresses.drop() users.drop()