]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
factored out column reflection from reflecttable
authorRandall Smith <randall@tnr.cc>
Mon, 2 Feb 2009 06:25:14 +0000 (06:25 +0000)
committerRandall Smith <randall@tnr.cc>
Mon, 2 Feb 2009 06:25:14 +0000 (06:25 +0000)
lib/sqlalchemy/dialects/postgres/base.py

index bab875d6867070df21b2187d4349a3f0bb8af561..2bdab3914f787ebab2f1528fb5f6d6c2a0a389e2 100644 (file)
@@ -499,20 +499,64 @@ class PGDialect(default.DefaultDialect):
             raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-        if table.schema is not None:
+    def _prepare_info_cache(self, info_cache, tablename, schemaname):
+        """Add index for schemaname.table_name if it does not exist.
+       
+        This is done so that certain keys can be assumed to be present.
+        
+        """
+        # First, make sure it has the keys we expect.
+        if info_cache is None: 
+            info_cache = dict(tables={})
+        elif 'tables' not in info_cache:
+            info_cache['tables'] = {}
+        # Add the table index if needed.
+        table_index = "%s.%s" % (schemaname, tablename)
+        if table_index not in info_cache['tables']:
+            info_cache['tables'][table_index] = {}
+        return info_cache
+
+    def _get_table_oid(self, connection, tablename, schemaname=None):
+        """Fetch the oid for schemaname.tablename.
+
+        Several reflection methods require the table oid.  The idea for using
+        this method is that it can be fetched one time and cached for
+        subsequent calls.
+
+        """
+        if schemaname is not None:
             schema_where_clause = "n.nspname = :schema"
-            schemaname = table.schema
-            
-            # Py2K
-            if isinstance(schemaname, str):
-                schemaname = schemaname.decode(self.encoding)
-            # end Py2K
         else:
             schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
-            schemaname = None
-
+        query = """
+            SELECT c.oid
+            FROM pg_catalog.pg_class c
+            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
+            WHERE (%s)
+            AND c.relname = :table_name AND c.relkind in ('r','v')
+        """ % schema_where_clause
+        s = sql.text(query, bindparams=[
+            sql.bindparam('table_name', type_=sqltypes.Unicode),
+            sql.bindparam('schema', type_=sqltypes.Unicode)
+            ],
+            typemap={'oid':sqltypes.Integer}
+        )
+        c = connection.execute(s, table_name=tablename, schema=schemaname)
+        table_oid = c.scalar()
+        if table_oid is None:
+            raise exc.NoSuchTableError(table_name)
+        return table_oid
+
+    def get_columns(self, connection, tablename, schemaname=None,
+                    info_cache=None):
+        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        # looked for cached table oid
+        table_index = "%s.%s" % (schemaname, tablename)
+        table_oid = info_cache['tables'][table_index].get('table_oid')
+        if table_oid is None:
+            table_oid = self._get_table_oid(connection, tablename, schemaname)
+            # cache it
+            info_cache['tables'][table_index]['table_oid'] = table_oid
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -521,44 +565,28 @@ class PGDialect(default.DefaultDialect):
               AS DEFAULT,
               a.attnotnull, a.attnum, a.attrelid as table_oid
             FROM pg_catalog.pg_attribute a
-            WHERE a.attrelid = (
-                SELECT c.oid
-                FROM pg_catalog.pg_class c
-                     LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
-                     WHERE (%s)
-                     AND c.relname = :table_name AND c.relkind in ('r','v')
-            ) AND a.attnum > 0 AND NOT a.attisdropped
+            WHERE a.attrelid = :table_oid
+            AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
-        """ % schema_where_clause
-
-        s = sql.text(SQL_COLS, bindparams=[sql.bindparam('table_name', type_=sqltypes.Unicode), sql.bindparam('schema', type_=sqltypes.Unicode)], typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode})
-        tablename = table.name
-        # Py2K
-        if isinstance(tablename, str):
-            tablename = tablename.decode(self.encoding)
-        # end Py2K
-        c = connection.execute(s, table_name=tablename, schema=schemaname)
+        """
+        s = sql.text(SQL_COLS, 
+            bindparams=[sql.bindparam('table_oid', type_=sqltypes.Integer)], 
+            typemap={'attname':sqltypes.Unicode, 'default':sqltypes.Unicode}
+        )
+        c = connection.execute(s, table_oid=table_oid)
         rows = c.fetchall()
-
-        if not rows:
-            raise exc.NoSuchTableError(table.name)
-
         domains = self._load_domains(connection)
-
+        # format columns
+        columns = []
         for name, format_type, default, notnull, attnum, table_oid in rows:
-            if include_columns and name not in include_columns:
-                continue
-
             ## strip (30) from character varying(30)
             attype = re.search('([^\([]+)', format_type).group(1)
             nullable = not notnull
             is_array = format_type.endswith('[]')
-
             try:
                 charlen = re.search('\(([\d,]+)\)', format_type).group(1)
             except:
                 charlen = False
-
             numericprec = False
             numericscale = False
             if attype == 'numeric':
@@ -573,20 +601,17 @@ class PGDialect(default.DefaultDialect):
             if attype == 'integer':
                 numericprec, numericscale = (32, 0)
                 charlen = False
-
             args = []
             for a in (charlen, numericprec, numericscale):
                 if a is None:
                     args.append(None)
                 elif a is not False:
                     args.append(int(a))
-
             kwargs = {}
             if attype == 'timestamp with time zone':
                 kwargs['timezone'] = True
             elif attype == 'timestamp without time zone':
                 kwargs['timezone'] = False
-
             if attype in self.ischema_names:
                 coltype = self.ischema_names[attype]
             else:
@@ -595,14 +620,12 @@ class PGDialect(default.DefaultDialect):
                     if domain['attype'] in self.ischema_names:
                         # A table can't override whether the domain is nullable.
                         nullable = domain['nullable']
-
                         if domain['default'] and not default:
                             # It can, however, override the default value, but can't set it to null.
                             default = domain['default']
                         coltype = self.ischema_names[domain['attype']]
                 else:
                     coltype = None
-
             if coltype:
                 coltype = coltype(*args, **kwargs)
                 if is_array:
@@ -611,21 +634,45 @@ class PGDialect(default.DefaultDialect):
                 util.warn("Did not recognize type '%s' of column '%s'" %
                           (attype, name))
                 coltype = sqltypes.NULLTYPE
-
             colargs = []
+            column_info = dict(name=name, type=coltype, nullable=nullable,
+                               default=default, colargs=colargs)
+            columns.append(column_info)
+        return columns
+
+    def reflecttable(self, connection, table, include_columns):
+        preparer = self.identifier_preparer
+        schemaname = table.schema
+        tablename = table.name
+        # Py2K
+        if isinstance(schemaname, str):
+            schemaname = schemaname.decode(self.encoding)
+        if isinstance(tablename, str):
+            tablename = tablename.decode(self.encoding)
+        # end Py2K
+        info_cache = {}
+        for col_d in self.get_columns(connection, tablename, schemaname,
+                                                                info_cache):
+            name = col_d['name']
+            coltype = col_d['type']
+            nullable = col_d['nullable']
+            default = col_d['default']
+            colargs = col_d['colargs']
+            if include_columns and name not in include_columns:
+                continue
             if default is not None:
                 match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
                 if match is not None:
                     # the default is related to a Sequence
-                    sch = table.schema
+                    sch = schemaname
                     if '.' not in match.group(2) and sch is not None:
                         # unconditionally quote the schema name.  this could
                         # later be enhanced to obey quoting rules / "quote schema"
                         default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
                 colargs.append(schema.DefaultClause(sql.text(default)))
             table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-
-
+        # Now we have the table oid cached.
+        table_oid = info_cache['tables']["%s.%s" % (schemaname, tablename)]['table_oid']
         # Primary keys
         PK_SQL = """
           SELECT attname FROM pg_attribute