]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added support for postgres_relkind.
authorRodrigo Menezes <rodrigo.menezes@moat.com>
Thu, 14 Aug 2014 18:47:23 +0000 (14:47 -0400)
committerRodrigo Menezes <rodrigo.menezes@moat.com>
Thu, 14 Aug 2014 18:47:23 +0000 (14:47 -0400)
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/reflection.py
setup.cfg
test/dialect/postgresql/test_reflection.py
test/requirements.py

index 5ff2f7c6121d992e6be37e3c1e7ff38ee44427ea..b3506f5d2586e12dd0ecb58c92923c02a107e455 100644 (file)
@@ -1669,11 +1669,12 @@ class PGDialect(default.DefaultDialect):
             "ops": {}
         }),
         (schema.Table, {
-            "ignore_search_path": False
+            "ignore_search_path": False,
+            "relkind": None
         })
     ]
 
-    reflection_options = ('postgresql_ignore_search_path', )
+    reflection_options = ('postgresql_ignore_search_path', 'postgresql_relkind')
 
     _backslash_escapes = True
 
@@ -1898,7 +1899,7 @@ class PGDialect(default.DefaultDialect):
         return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
 
     @reflection.cache
-    def get_table_oid(self, connection, table_name, schema=None, **kw):
+    def get_table_oid(self, connection, table_name, schema=None, postgresql_relkind=None, **kw):
         """Fetch the oid for schema.table_name.
 
         Several reflection methods require the table oid.  The idea for using
@@ -1911,13 +1912,28 @@ class PGDialect(default.DefaultDialect):
             schema_where_clause = "n.nspname = :schema"
         else:
             schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
+
+        RELKIND_SYNONYMS = {
+            'materialized': 'm',
+            'foreign': 'f'
+        }
+        ACCEPTED_RELKINDS = ('r','v','m','f')
+        if postgresql_relkind is None:
+            postgresql_relkind = 'r'
+        else:
+            postgresql_relkind = postgresql_relkind.lower()
+            if postgresql_relkind in RELKIND_SYNONYMS:
+                postgresql_relkind = RELKIND_SYNONYMS[postgresql_relkind.lower()]
+            if postgresql_relkind not in ACCEPTED_RELKINDS:
+                raise exc.SQLAlchemyError('Invalid postgresql_relkind: %s' % postgresql_relkind)
+
         query = """
             SELECT c.oid
             FROM pg_catalog.pg_class c
             LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
             WHERE (%s)
-            AND c.relname = :table_name AND c.relkind in ('r','v')
-        """ % schema_where_clause
+            AND c.relname = :table_name AND c.relkind in ('%s', 'v')
+        """ % (schema_where_clause, postgresql_relkind)
         # Since we're binding to unicode, table_name and schema_name must be
         # unicode.
         table_name = util.text_type(table_name)
@@ -2014,7 +2030,8 @@ class PGDialect(default.DefaultDialect):
     def get_columns(self, connection, table_name, schema=None, **kw):
 
         table_oid = self.get_table_oid(connection, table_name, schema,
-                                       info_cache=kw.get('info_cache'))
+                                       info_cache=kw.get('info_cache'),
+                                       postgresql_relkind=kw.get('postgresql_relkind'))
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -2164,7 +2181,8 @@ class PGDialect(default.DefaultDialect):
     @reflection.cache
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):
         table_oid = self.get_table_oid(connection, table_name, schema,
-                                       info_cache=kw.get('info_cache'))
+                                       info_cache=kw.get('info_cache'),
+                                       postgresql_relkind=kw.get('postgresql_relkind'))
 
         if self.server_version_info < (8, 4):
             PK_SQL = """
@@ -2214,7 +2232,8 @@ class PGDialect(default.DefaultDialect):
                          postgresql_ignore_search_path=False, **kw):
         preparer = self.identifier_preparer
         table_oid = self.get_table_oid(connection, table_name, schema,
-                                       info_cache=kw.get('info_cache'))
+                                       info_cache=kw.get('info_cache'),
+                                       postgresql_relkind=kw.get('postgresql_relkind'))
 
         FK_SQL = """
           SELECT r.conname,
@@ -2318,11 +2337,11 @@ class PGDialect(default.DefaultDialect):
     @reflection.cache
     def get_indexes(self, connection, table_name, schema, **kw):
         table_oid = self.get_table_oid(connection, table_name, schema,
-                                       info_cache=kw.get('info_cache'))
+                                       info_cache=kw.get('info_cache'),
+                                       postgresql_relkind=kw.get('postgresql_relkind'))
 
         # cast indkey as varchar since it's an int2vector,
         # returned as a list by some drivers such as pypostgresql
-
         IDX_SQL = """
           SELECT
               i.relname as relname,
@@ -2336,7 +2355,7 @@ class PGDialect(default.DefaultDialect):
                         pg_attribute a
                         on t.oid=a.attrelid and %s
           WHERE
-              t.relkind = 'r'
+              t.relkind IN ('r', 'v', 'f', 'm')
               and t.oid = :table_oid
               and ix.indisprimary = 'f'
           ORDER BY
@@ -2391,7 +2410,8 @@ class PGDialect(default.DefaultDialect):
     def get_unique_constraints(self, connection, table_name,
                                schema=None, **kw):
         table_oid = self.get_table_oid(connection, table_name, schema,
-                                       info_cache=kw.get('info_cache'))
+                                       info_cache=kw.get('info_cache'),
+                                       postgresql_relkind=kw.get('postgresql_relkind'))
 
         UNIQUE_SQL = """
             SELECT
index 012d1d35ddb60dc8b4ecd8ad9df0dd728ea842a9..afe9a8b3e298f9b457ff82c58cae184b574d4b50 100644 (file)
@@ -378,7 +378,6 @@ class Inspector(object):
          use :class:`.quoted_name`.
 
         """
-
         return self.dialect.get_indexes(self.bind, table_name,
                                         schema,
                                         info_cache=self.info_cache, **kw)
@@ -405,7 +404,6 @@ class Inspector(object):
         .. versionadded:: 0.8.4
 
         """
-
         return self.dialect.get_unique_constraints(
             self.bind, table_name, schema, info_cache=self.info_cache, **kw)
 
@@ -573,7 +571,7 @@ class Inspector(object):
                                                conname, link_to_name=True,
                                                **options))
         # Indexes
-        indexes = self.get_indexes(table_name, schema)
+        indexes = self.get_indexes(table_name, schema, **table.dialect_kwargs)
         for index_d in indexes:
             name = index_d['name']
             columns = index_d['column_names']
index 7517220a669c80353119b3b7e4b58ae19458c7d6..4ec4b0837c2a09505b53b3a671bb964186715f64 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -26,6 +26,13 @@ profile_file=test/profiles.txt
 # create database link test_link connect to scott identified by tiger using 'xe';
 oracle_db_link = test_link
 
+# host name of a postgres database that has the postgres_fdw extension.
+# to create this run:
+# CREATE EXTENSION postgres_fdw;
+# GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO public;
+# this can be localhost to create a loopback foreign table
+postgres_test_db_link = localhost
+
 
 [db]
 default=sqlite:///:memory:
index 1d6a4176563116dcb730ce847371f1a8ac410df1..313be0b377a0f05d688b428cedb1e2744392a3ef 100644 (file)
@@ -12,8 +12,108 @@ import sqlalchemy as sa
 from sqlalchemy.dialects.postgresql import base as postgresql
 
 
-class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
+class RelKindReflectionTest(fixtures.TestBase, AssertsExecutionResults):
+    """Test postgresql_relkind reflection option"""
+
+    __requires__ = 'postgresql_test_dblink',
+    __only_on__ = 'postgresql >= 9.3'
+    __backend__ = True
+
+    @classmethod
+    def setup_class(cls):
+        from sqlalchemy.testing import config
+        cls.dblink = config.file_config.get('sqla_testing', 'postgres_test_db_link')
+
+        metadata = MetaData(testing.db)
+        testtable = Table(
+            'testtable', metadata,
+            Column(
+                'id', Integer, primary_key=True),
+            Column(
+                'data', String(30)))
+        metadata.create_all()
+        testtable.insert().execute({'id': 89, 'data': 'd1'})
+
+        con = testing.db.connect()
 
+        for ddl in \
+                "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable;", \
+                "CREATE SERVER test_server FOREIGN DATA WRAPPER postgres_fdw OPTIONS (dbname 'test', host '%s');" % cls.dblink, \
+                "CREATE USER MAPPING FOR public SERVER test_server options (user 'scott', password 'tiger');", \
+                "CREATE FOREIGN TABLE test_foreigntable ( \
+                    id          INT, \
+                    data        VARCHAR(30) \
+                ) SERVER test_server OPTIONS (table_name 'testtable');":
+            try:
+                con.execute(ddl)
+            except exc.DBAPIError as e:
+                if 'already exists' not in str(e):
+                    raise e
+
+    @classmethod
+    def teardown_class(cls):
+        con = testing.db.connect()
+        con.execute('DROP FOREIGN TABLE test_foreigntable;')
+        con.execute('DROP USER MAPPING FOR public SERVER test_server;')
+        con.execute('DROP SERVER test_server;')
+        con.execute('DROP MATERIALIZED VIEW test_mview;')
+        con.execute('DROP TABLE testtable;')
+
+    def test_mview_is_reflected(self):
+        mview_relkind_names = ('m', 'materialized')
+        for mview_relkind_name in mview_relkind_names:
+            metadata = MetaData(testing.db)
+            table = Table('test_mview', metadata, autoload=True, postgresql_relkind=mview_relkind_name)
+            eq_(set(table.columns.keys()), set(['id', 'data']), "Columns of reflected mview didn't equal expected columns")
+
+    def test_mview_select(self):
+        metadata = MetaData(testing.db)
+        table = Table('test_mview', metadata, autoload=True, postgresql_relkind='m')
+        assert table.select().execute().fetchall() == [
+            (89, 'd1',)
+        ]
+
+    def test_foreign_table_is_reflected(self):
+        foreign_table_relkind_names = ('f', 'foreign')
+        for foreign_table_relkind_name in foreign_table_relkind_names:
+            metadata = MetaData(testing.db)
+            table = Table('test_foreigntable', metadata, autoload=True, postgresql_relkind=foreign_table_relkind_name)
+            eq_(set(table.columns.keys()), set(['id', 'data']), "Columns of reflected foreign table didn't equal expected columns")
+
+    def test_foreign_table_select(self):
+        metadata = MetaData(testing.db)
+        table = Table('test_foreigntable', metadata, autoload=True, postgresql_relkind='f')
+        assert table.select().execute().fetchall() == [
+            (89, 'd1',)
+        ]
+
+    def test_foreign_table_roundtrip(self):
+        metadata = MetaData(testing.db)
+        table = Table('test_foreigntable', metadata, autoload=True, postgresql_relkind='f')
+
+        connection = testing.db.connect()
+        trans = connection.begin()
+        try:
+            table.delete().execute()
+            table.insert().execute({'id': 89, 'data': 'd1'})
+            trans.commit()
+        except:
+            trans.rollback()
+            raise
+
+        assert table.select().execute().fetchall() == [
+            (89, 'd1',)
+        ]
+
+    def test_invalid_relkind(self):
+        metadata = MetaData(testing.db)
+        def create_bad_table():
+            return Table('test_foreigntable', metadata, autoload=True, postgresql_relkind='nope')
+
+        assert_raises(exc.SQLAlchemyError, create_bad_table)
+
+
+class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
     """Test PostgreSQL domains"""
 
     __only_on__ = 'postgresql > 8.3'
index e8705d145231b0fad6a23b025f00edeb0a68dfca..927c94bfbfe48682219cb692485061d16409a0fe 100644 (file)
@@ -716,6 +716,14 @@ class DefaultRequirements(SuiteRequirements):
                     "oracle_db_link option not specified in config"
                 )
 
+    @property
+    def postgresql_test_dblink(self):
+        return skip_if(
+                    lambda config: not config.file_config.has_option(
+                        'sqla_testing', 'postgres_test_db_link'),
+                    "postgres_test_db_link option not specified in config"
+                )
+    
     @property
     def percent_schema_names(self):
         return skip_if(