]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
revised to use PGInfoCache
authorRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 04:45:35 +0000 (04:45 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 04:45:35 +0000 (04:45 +0000)
lib/sqlalchemy/dialects/postgres/base.py

index d3dc9c0e3fcf32033913484c1462967f283acfc7..b4cf4350584f017bc02c404e5afea06ef82a13a0 100644 (file)
@@ -397,6 +397,21 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
             value = value[1:-1].replace('""','"')
         return value
 
+class PGInfoCache(default.DefaultInfoCache):
+    
+    def __init__(self):
+
+        default.DefaultInfoCache.__init__(self)
+
+    def getTableOID(self, tablename, schemaname=None):
+        table = self.getTable(tablename, schemaname)
+        if table:
+            return table.get('oid')
+
+    def setTableOID(self, oid, tablename, schemaname=None):
+        table = self.getTable(tablename, schemaname, create=True)
+        table['oid'] = oid
+
 class PGDialect(default.DefaultDialect):
     name = 'postgres'
     supports_alter = True
@@ -499,23 +514,6 @@ class PGDialect(default.DefaultDialect):
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
-    def _prepare_info_cache(self, info_cache, tablename, schemaname):
-        """Add index for schemaname.table_name if it does not exist.
-       
-        This is done so that certain keys can be assumed to be present.
-        
-        """
-        # First, make sure it has the keys we expect.
-        if info_cache is None: 
-            info_cache = dict(tables={})
-        elif 'tables' not in info_cache:
-            info_cache['tables'] = {}
-        # Add the table index if needed.
-        table_index = "%s.%s" % (schemaname, tablename)
-        if table_index not in info_cache['tables']:
-            info_cache['tables'][table_index] = {}
-        return info_cache
-
     def _get_table_oid(self, connection, tablename, schemaname=None,
                        info_cache=None):
         """Fetch the oid for schemaname.tablename.
@@ -525,10 +523,9 @@ class PGDialect(default.DefaultDialect):
         subsequent calls.
 
         """
-        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
-        # If it's in info_cache, juse use that.
-        table_index = "%s.%s" % (schemaname, tablename)
-        table_oid = info_cache['tables'][table_index].get('table_oid')
+        table_oid = None
+        if info_cache:
+            table_oid = info_cache.getTableOID(tablename, schemaname)
         if table_oid:
             return table_oid
         if schemaname is not None:
@@ -553,10 +550,15 @@ class PGDialect(default.DefaultDialect):
         if table_oid is None:
             raise exc.NoSuchTableError(table_name)
         # cache it
-        info_cache['tables'][table_index]['table_oid'] = table_oid
+        if info_cache:
+            info_cache.setTableOID(table_oid, tablename, schemaname)
         return table_oid
 
     def get_schema_names(self, connection, info_cache=None):
+        if info_cache:
+            schema_names = info_cache.getSchemaNames()
+            if schema_names is not None:
+                return schema_names
         s = """
         SELECT nspname
         FROM pg_namespace
@@ -564,28 +566,45 @@ class PGDialect(default.DefaultDialect):
         """
         rp = connection.execute(s)
         # what about system tables?
-        return [row[0].decode(self.encoding) for row in rp \
-                if not row[0].startswith('pg_')]
+        schema_names = [row[0].decode(self.encoding) for row in rp \
+                        if not row[0].startswith('pg_')]
+        if info_cache:
+            info_cache.addAllSchemas(schema_names)
+        return schema_names
 
     def get_table_names(self, connection, schemaname=None, info_cache=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
-        return self.table_names(connection, current_schema)
+        if info_cache:
+            table_names = info_cache.getTableNames(current_schema)
+            if table_names is not None:
+                return table_names
+        table_names = self.table_names(connection, current_schema)
+        if info_cache:
+            info_cache.addAllTables(table_names, current_schema)
+        return table_names
 
     def get_view_names(self, connection, schemaname=None, info_cache=None):
         if schemaname is not None:
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
+        if info_cache:
+            view_names = info_cache.getViewNames(current_schema)
+            if view_names is not None:
+                return view_names
         s = """
         SELECT relname
         FROM pg_class c
         WHERE relkind = 'v'
           AND '%(schema)s' = (select nspname from pg_namespace n where n.oid = c.relnamespace)
         """ % dict(schema=current_schema)
-        return [row[0].decode(self.encoding) for row in connection.execute(s)]
+        view_names = [row[0].decode(self.encoding) for row in connection.execute(s)]
+        if info_cache:
+            info_cache.addAllViews(view_names, schemaname)
+        return view_names
 
     def get_view_definition(self, connection, viewname, schemaname=None,
                                                             info_cache=None):
@@ -593,6 +612,10 @@ class PGDialect(default.DefaultDialect):
             current_schema = schemaname
         else:
             current_schema = self.get_default_schema_name(connection)
+        if info_cache:
+            view = info_cache.getView(viewname, current_schema)
+            if view.get('definition'):
+                return view['definition']
         s = """
         SELECT definition FROM pg_views
         WHERE schemaname = :schemaname
@@ -601,11 +624,19 @@ class PGDialect(default.DefaultDialect):
         rp = connection.execute(sql.text(s),
                                 viewname=viewname, schemaname=current_schema)
         if rp:
-            return rp.scalar().decode(self.encoding)
+            view_def = rp.scalar().decode(self.encoding)
+            if info_cache:
+                view = info_cache.getView(viewname, current_schema,
+                                          create=True)
+                view['definition'] = view_def
+            return view_def
 
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None):
-        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'columns' in table_cache.keys():
+                return table_cache.get('columns')
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         SQL_COLS = """
@@ -689,11 +720,18 @@ class PGDialect(default.DefaultDialect):
             column_info = dict(name=name, type=coltype, nullable=nullable,
                                default=default, colargs=colargs)
             columns.append(column_info)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname, 
+                                              create=True)
+            table_cache['columns'] = columns
         return columns
 
     def get_primary_keys(self, connection, tablename, schemaname=None,
                          info_cache=None):
-        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'primary_keys' in table_cache.keys():
+                return table_cache.get('primary_keys')
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         PK_SQL = """
@@ -706,19 +744,20 @@ class PGDialect(default.DefaultDialect):
         """
         t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
         c = connection.execute(t, table_oid=table_oid)
-        return [r[0] for r in c.fetchall()]
-        for row in c.fetchall():
-            pk = row[0]
-            if pk in table.c:
-                col = table.c[pk]
-                table.primary_key.add(col)
-                if col.default is None:
-                    col.autoincrement = False
+        primary_keys = [r[0] for r in c.fetchall()]
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname,
+                                              create=True)
+            table_cache['primary_keys'] = primary_keys
+        return primary_keys
 
     def get_foreign_keys(self, connection, tablename, schemaname=None,
                          info_cache=None):
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'foreign_keys' in table_cache.keys():
+                return table_cache.get('foreign_keys')
         preparer = self.identifier_preparer
-        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         FK_SQL = """
@@ -752,10 +791,17 @@ class PGDialect(default.DefaultDialect):
                 'referred_columns' : referred_columns
             }
             fkeys.append(fkey_d)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname,
+                                              create=True)
+            table_cache['foreign_keys'] = fkeys
         return fkeys
 
     def get_indexes(self, connection, tablename, schemaname, info_cache=None):
-        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'indexes' in table_cache.keys():
+                return table_cache.get('indexes')
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         IDX_SQL = """
@@ -793,6 +839,10 @@ class PGDialect(default.DefaultDialect):
             index_d['name'] = idx_name
             index_d['column_names'].append(col)
             index_d['unique'] = unique
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname,
+                                              create=True)
+            table_cache['indexes'] = indexes
         return indexes
 
     def reflecttable(self, connection, table, include_columns):
@@ -805,7 +855,7 @@ class PGDialect(default.DefaultDialect):
         if isinstance(tablename, str):
             tablename = tablename.decode(self.encoding)
         # end Py2K
-        info_cache = {}
+        info_cache = PGInfoCache()
         for col_d in self.get_columns(connection, tablename, schemaname,
                                                                 info_cache):
             name = col_d['name']