]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
completed refactoring of reflecttable
authorRandall Smith <randall@tnr.cc>
Wed, 4 Feb 2009 05:43:15 +0000 (05:43 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 4 Feb 2009 05:43:15 +0000 (05:43 +0000)
lib/sqlalchemy/dialects/postgres/base.py

index 2bdab3914f787ebab2f1528fb5f6d6c2a0a389e2..cb671dfa5f32cfa6774ea2404c80c4b7099a8d52 100644 (file)
@@ -516,7 +516,8 @@ class PGDialect(default.DefaultDialect):
             info_cache['tables'][table_index] = {}
         return info_cache
 
-    def _get_table_oid(self, connection, tablename, schemaname=None):
+    def _get_table_oid(self, connection, tablename, schemaname=None,
+                       info_cache=None):
         """Fetch the oid for schemaname.tablename.
 
         Several reflection methods require the table oid.  The idea for using
@@ -524,6 +525,12 @@ class PGDialect(default.DefaultDialect):
         subsequent calls.
 
         """
+        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        # If it's in info_cache, juse use that.
+        table_index = "%s.%s" % (schemaname, tablename)
+        table_oid = info_cache['tables'][table_index].get('table_oid')
+        if table_oid:
+            return table_oid
         if schemaname is not None:
             schema_where_clause = "n.nspname = :schema"
         else:
@@ -545,18 +552,15 @@ class PGDialect(default.DefaultDialect):
         table_oid = c.scalar()
         if table_oid is None:
             raise exc.NoSuchTableError(table_name)
+        # cache it
+        info_cache['tables'][table_index]['table_oid'] = table_oid
         return table_oid
 
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None):
         info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
-        # looked for cached table oid
-        table_index = "%s.%s" % (schemaname, tablename)
-        table_oid = info_cache['tables'][table_index].get('table_oid')
-        if table_oid is None:
-            table_oid = self._get_table_oid(connection, tablename, schemaname)
-            # cache it
-            info_cache['tables'][table_index]['table_oid'] = table_oid
+        table_oid = self._get_table_oid(connection, tablename, schemaname,
+                                        info_cache)
         SQL_COLS = """
             SELECT a.attname,
               pg_catalog.format_type(a.atttypid, a.atttypmod),
@@ -640,50 +644,22 @@ class PGDialect(default.DefaultDialect):
             columns.append(column_info)
         return columns
 
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-        schemaname = table.schema
-        tablename = table.name
-        # Py2K
-        if isinstance(schemaname, str):
-            schemaname = schemaname.decode(self.encoding)
-        if isinstance(tablename, str):
-            tablename = tablename.decode(self.encoding)
-        # end Py2K
-        info_cache = {}
-        for col_d in self.get_columns(connection, tablename, schemaname,
-                                                                info_cache):
-            name = col_d['name']
-            coltype = col_d['type']
-            nullable = col_d['nullable']
-            default = col_d['default']
-            colargs = col_d['colargs']
-            if include_columns and name not in include_columns:
-                continue
-            if default is not None:
-                match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
-                if match is not None:
-                    # the default is related to a Sequence
-                    sch = schemaname
-                    if '.' not in match.group(2) and sch is not None:
-                        # unconditionally quote the schema name.  this could
-                        # later be enhanced to obey quoting rules / "quote schema"
-                        default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
-                colargs.append(schema.DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
-        # Now we have the table oid cached.
-        table_oid = info_cache['tables']["%s.%s" % (schemaname, tablename)]['table_oid']
-        # Primary keys
+    def get_primary_keys(self, connection, tablename, schemaname=None,
+                         info_cache=None):
+        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        table_oid = self._get_table_oid(connection, tablename, schemaname,
+                                        info_cache)
         PK_SQL = """
           SELECT attname FROM pg_attribute
           WHERE attrelid = (
              SELECT indexrelid FROM pg_index i
-             WHERE i.indrelid = :table
+             WHERE i.indrelid = :table_oid
              AND i.indisprimary = 't')
           ORDER BY attnum
         """
         t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode})
-        c = connection.execute(t, table=table_oid)
+        c = connection.execute(t, table_oid=table_oid)
+        return [r[0] for r in c.fetchall()]
         for row in c.fetchall():
             pk = row[0]
             if pk in table.c:
@@ -692,7 +668,12 @@ class PGDialect(default.DefaultDialect):
                 if col.default is None:
                     col.autoincrement = False
 
-        # Foreign keys
+    def get_foreign_keys(self, connection, tablename, schemaname=None,
+                         info_cache=None):
+        preparer = self.identifier_preparer
+        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        table_oid = self._get_table_oid(connection, tablename, schemaname,
+                                        info_cache)
         FK_SQL = """
           SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef
           FROM  pg_catalog.pg_constraint r
@@ -702,49 +683,49 @@ class PGDialect(default.DefaultDialect):
 
         t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode})
         c = connection.execute(t, table=table_oid)
+        fkeys = []
         for conname, condef in c.fetchall():
             m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups()
             (constrained_columns, referred_schema, referred_table, referred_columns) = m
             constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)]
             if referred_schema:
                 referred_schema = preparer._unquote_identifier(referred_schema)
-            elif table.schema is not None and table.schema == self.get_default_schema_name(connection):
+            elif schemaname is not None and schemaname == self.get_default_schema_name(connection):
                 # no schema (i.e. its the default schema), and the table we're
                 # reflecting has the default schema explicit, then use that.
                 # i.e. try to use the user's conventions
-                referred_schema = table.schema
+                referred_schema = schemaname
             referred_table = preparer._unquote_identifier(referred_table)
             referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)]
-
-            refspec = []
-            if referred_schema is not None:
-                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
-                            autoload_with=connection)
-                for column in referred_columns:
-                    refspec.append(".".join([referred_schema, referred_table, column]))
-            else:
-                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
-                for column in referred_columns:
-                    refspec.append(".".join([referred_table, column]))
-
-            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
-
-        # Indexes 
+            fkey_d = {
+                'name' : conname,
+                'constrained_columns' : constrained_columns,
+                'referred_schema' : referred_schema,
+                'referred_table' : referred_table,
+                'referred_columns' : referred_columns
+            }
+            fkeys.append(fkey_d)
+        return fkeys
+
+    def get_indexes(self, connection, tablename, schemaname, info_cache=None):
+        info_cache = self._prepare_info_cache(info_cache, tablename, schemaname)
+        table_oid = self._get_table_oid(connection, tablename, schemaname,
+                                        info_cache)
         IDX_SQL = """
           SELECT c.relname, i.indisunique, i.indexprs, i.indpred,
             a.attname
           FROM pg_index i, pg_class c, pg_attribute a
-          WHERE i.indrelid = :table AND i.indexrelid = c.oid
+          WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid
             AND a.attrelid = i.indexrelid AND i.indisprimary = 'f'
           ORDER BY c.relname, a.attnum
         """
         t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode})
-        c = connection.execute(t, table=table_oid)
-        indexes = {}
+        c = connection.execute(t, table_oid=table_oid)
+        index_names = {}
+        indexes = []
         sv_idx_name = None
         for row in c.fetchall():
             idx_name, unique, expr, prd, col = row
-
             if expr and not idx_name == sv_idx_name:
                 util.warn(
                   "Skipped unsupported reflection of expression-based index %s"
@@ -756,16 +737,90 @@ class PGDialect(default.DefaultDialect):
                    "Predicate of partial index %s ignored during reflection"
                    % idx_name)
                 sv_idx_name = idx_name
+            if idx_name in index_names:
+                index_d = index_names[idx_name]
+            else:
+                index_d = {'column_names':[]}
+                indexes.append(index_d)
+                index_names[idx_name] = index_d
+            index_d['name'] = idx_name
+            index_d['column_names'].append(col)
+            index_d['unique'] = unique
+        return indexes
+
+    def reflecttable(self, connection, table, include_columns):
+        preparer = self.identifier_preparer
+        schemaname = table.schema
+        tablename = table.name
+        # Py2K
+        if isinstance(schemaname, str):
+            schemaname = schemaname.decode(self.encoding)
+        if isinstance(tablename, str):
+            tablename = tablename.decode(self.encoding)
+        # end Py2K
+        info_cache = {}
+        for col_d in self.get_columns(connection, tablename, schemaname,
+                                                                info_cache):
+            name = col_d['name']
+            coltype = col_d['type']
+            nullable = col_d['nullable']
+            default = col_d['default']
+            colargs = col_d['colargs']
+            if include_columns and name not in include_columns:
+                continue
+            if default is not None:
+                match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+                if match is not None:
+                    # the default is related to a Sequence
+                    sch = schemaname
+                    if '.' not in match.group(2) and sch is not None:
+                        # unconditionally quote the schema name.  this could
+                        # later be enhanced to obey quoting rules / "quote schema"
+                        default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3)
+                colargs.append(schema.DefaultClause(sql.text(default)))
+            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))
+        # Now we have the table oid cached.
+        table_oid = self._get_table_oid(connection, tablename, schemaname,
+                                        info_cache)
+        # Primary keys
+        for pk in self.get_primary_keys(connection, tablename, schemaname,
+                                                                    info_cache):
+            if pk in table.c:
+                col = table.c[pk]
+                table.primary_key.add(col)
+                if col.default is None:
+                    col.autoincrement = False
+        # Foreign keys
+        fkeys = self.get_foreign_keys(connection, tablename, schemaname,
+                                      info_cache)
+        for fkey_d in fkeys:
+            conname = fkey_d['name']
+            constrained_columns = fkey_d['constrained_columns']
+            referred_schema = fkey_d['referred_schema']
+            referred_table = fkey_d['referred_table']
+            referred_columns = fkey_d['referred_columns']
+            refspec = []
+            if referred_schema is not None:
+                schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema,
+                            autoload_with=connection)
+                for column in referred_columns:
+                    refspec.append(".".join([referred_schema, referred_table, column]))
+            else:
+                schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection)
+                for column in referred_columns:
+                    refspec.append(".".join([referred_table, column]))
 
-            if not indexes.has_key(idx_name):
-                indexes[idx_name] = [unique, []]
-            indexes[idx_name][1].append(col)
+            table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True))
 
-        for name, (unique, columns) in indexes.items():
+        # Indexes 
+        indexes = self.get_indexes(connection, tablename, schemaname,
+                                   info_cache)
+        for index_d in indexes:
+            name = index_d['name']
+            columns = index_d['column_names']
+            unique = index_d['unique']
             schema.Index(name, *[table.columns[c] for c in columns], 
                          **dict(unique=unique))
-
 
     def _load_domains(self, connection):
         ## Load data types for domains: