From f6f9e5232a6496086add3a7aeeb2195d4b537d14 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Wed, 4 Feb 2009 05:43:15 +0000 Subject: [PATCH] completed refactoring of reflecttable --- lib/sqlalchemy/dialects/postgres/base.py | 199 +++++++++++++++-------- 1 file changed, 127 insertions(+), 72 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 2bdab3914f..cb671dfa5f 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -516,7 +516,8 @@ class PGDialect(default.DefaultDialect): info_cache['tables'][table_index] = {} return info_cache - def _get_table_oid(self, connection, tablename, schemaname=None): + def _get_table_oid(self, connection, tablename, schemaname=None, + info_cache=None): """Fetch the oid for schemaname.tablename. Several reflection methods require the table oid. The idea for using @@ -524,6 +525,12 @@ class PGDialect(default.DefaultDialect): subsequent calls. """ + info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + # If it's in info_cache, juse use that. + table_index = "%s.%s" % (schemaname, tablename) + table_oid = info_cache['tables'][table_index].get('table_oid') + if table_oid: + return table_oid if schemaname is not None: schema_where_clause = "n.nspname = :schema" else: @@ -545,18 +552,15 @@ class PGDialect(default.DefaultDialect): table_oid = c.scalar() if table_oid is None: raise exc.NoSuchTableError(table_name) + # cache it + info_cache['tables'][table_index]['table_oid'] = table_oid return table_oid def get_columns(self, connection, tablename, schemaname=None, info_cache=None): info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) - # looked for cached table oid - table_index = "%s.%s" % (schemaname, tablename) - table_oid = info_cache['tables'][table_index].get('table_oid') - if table_oid is None: - table_oid = self._get_table_oid(connection, tablename, schemaname) - # cache it - info_cache['tables'][table_index]['table_oid'] = table_oid + table_oid = self._get_table_oid(connection, tablename, schemaname, + info_cache) SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -640,50 +644,22 @@ class PGDialect(default.DefaultDialect): columns.append(column_info) return columns - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - schemaname = table.schema - tablename = table.name - # Py2K - if isinstance(schemaname, str): - schemaname = schemaname.decode(self.encoding) - if isinstance(tablename, str): - tablename = tablename.decode(self.encoding) - # end Py2K - info_cache = {} - for col_d in self.get_columns(connection, tablename, schemaname, - info_cache): - name = col_d['name'] - coltype = col_d['type'] - nullable = col_d['nullable'] - default = col_d['default'] - colargs = col_d['colargs'] - if include_columns and name not in include_columns: - continue - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - # the default is related to a Sequence - sch = schemaname - if '.' not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / "quote schema" - default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) - colargs.append(schema.DefaultClause(sql.text(default))) - table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) - # Now we have the table oid cached. - table_oid = info_cache['tables']["%s.%s" % (schemaname, tablename)]['table_oid'] - # Primary keys + def get_primary_keys(self, connection, tablename, schemaname=None, + info_cache=None): + info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + table_oid = self._get_table_oid(connection, tablename, schemaname, + info_cache) PK_SQL = """ SELECT attname FROM pg_attribute WHERE attrelid = ( SELECT indexrelid FROM pg_index i - WHERE i.indrelid = :table + WHERE i.indrelid = :table_oid AND i.indisprimary = 't') ORDER BY attnum """ t = sql.text(PK_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) + c = connection.execute(t, table_oid=table_oid) + return [r[0] for r in c.fetchall()] for row in c.fetchall(): pk = row[0] if pk in table.c: @@ -692,7 +668,12 @@ class PGDialect(default.DefaultDialect): if col.default is None: col.autoincrement = False - # Foreign keys + def get_foreign_keys(self, connection, tablename, schemaname=None, + info_cache=None): + preparer = self.identifier_preparer + info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + table_oid = self._get_table_oid(connection, tablename, schemaname, + info_cache) FK_SQL = """ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef FROM pg_catalog.pg_constraint r @@ -702,49 +683,49 @@ class PGDialect(default.DefaultDialect): t = sql.text(FK_SQL, typemap={'conname':sqltypes.Unicode, 'condef':sqltypes.Unicode}) c = connection.execute(t, table=table_oid) + fkeys = [] for conname, condef in c.fetchall(): m = re.search('FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)', condef).groups() (constrained_columns, referred_schema, referred_table, referred_columns) = m constrained_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s*', constrained_columns)] if referred_schema: referred_schema = preparer._unquote_identifier(referred_schema) - elif table.schema is not None and table.schema == self.get_default_schema_name(connection): + elif schemaname is not None and schemaname == self.get_default_schema_name(connection): # no schema (i.e. its the default schema), and the table we're # reflecting has the default schema explicit, then use that. # i.e. try to use the user's conventions - referred_schema = table.schema + referred_schema = schemaname referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) for x in re.split(r'\s*,\s', referred_columns)] - - refspec = [] - if referred_schema is not None: - schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, - autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_schema, referred_table, column])) - else: - schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) - for column in referred_columns: - refspec.append(".".join([referred_table, column])) - - table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) - - # Indexes + fkey_d = { + 'name' : conname, + 'constrained_columns' : constrained_columns, + 'referred_schema' : referred_schema, + 'referred_table' : referred_table, + 'referred_columns' : referred_columns + } + fkeys.append(fkey_d) + return fkeys + + def get_indexes(self, connection, tablename, schemaname, info_cache=None): + info_cache = self._prepare_info_cache(info_cache, tablename, schemaname) + table_oid = self._get_table_oid(connection, tablename, schemaname, + info_cache) IDX_SQL = """ SELECT c.relname, i.indisunique, i.indexprs, i.indpred, a.attname FROM pg_index i, pg_class c, pg_attribute a - WHERE i.indrelid = :table AND i.indexrelid = c.oid + WHERE i.indrelid = :table_oid AND i.indexrelid = c.oid AND a.attrelid = i.indexrelid AND i.indisprimary = 'f' ORDER BY c.relname, a.attnum """ t = sql.text(IDX_SQL, typemap={'attname':sqltypes.Unicode}) - c = connection.execute(t, table=table_oid) - indexes = {} + c = connection.execute(t, table_oid=table_oid) + index_names = {} + indexes = [] sv_idx_name = None for row in c.fetchall(): idx_name, unique, expr, prd, col = row - if expr and not idx_name == sv_idx_name: util.warn( "Skipped unsupported reflection of expression-based index %s" @@ -756,16 +737,90 @@ class PGDialect(default.DefaultDialect): "Predicate of partial index %s ignored during reflection" % idx_name) sv_idx_name = idx_name + if idx_name in index_names: + index_d = index_names[idx_name] + else: + index_d = {'column_names':[]} + indexes.append(index_d) + index_names[idx_name] = index_d + index_d['name'] = idx_name + index_d['column_names'].append(col) + index_d['unique'] = unique + return indexes + + def reflecttable(self, connection, table, include_columns): + preparer = self.identifier_preparer + schemaname = table.schema + tablename = table.name + # Py2K + if isinstance(schemaname, str): + schemaname = schemaname.decode(self.encoding) + if isinstance(tablename, str): + tablename = tablename.decode(self.encoding) + # end Py2K + info_cache = {} + for col_d in self.get_columns(connection, tablename, schemaname, + info_cache): + name = col_d['name'] + coltype = col_d['type'] + nullable = col_d['nullable'] + default = col_d['default'] + colargs = col_d['colargs'] + if include_columns and name not in include_columns: + continue + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + # the default is related to a Sequence + sch = schemaname + if '.' not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / "quote schema" + default = match.group(1) + ('"%s"' % sch) + '.' + match.group(2) + match.group(3) + colargs.append(schema.DefaultClause(sql.text(default))) + table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs)) + # Now we have the table oid cached. + table_oid = self._get_table_oid(connection, tablename, schemaname, + info_cache) + # Primary keys + for pk in self.get_primary_keys(connection, tablename, schemaname, + info_cache): + if pk in table.c: + col = table.c[pk] + table.primary_key.add(col) + if col.default is None: + col.autoincrement = False + # Foreign keys + fkeys = self.get_foreign_keys(connection, tablename, schemaname, + info_cache) + for fkey_d in fkeys: + conname = fkey_d['name'] + constrained_columns = fkey_d['constrained_columns'] + referred_schema = fkey_d['referred_schema'] + referred_table = fkey_d['referred_table'] + referred_columns = fkey_d['referred_columns'] + refspec = [] + if referred_schema is not None: + schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, + autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_schema, referred_table, column])) + else: + schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) - if not indexes.has_key(idx_name): - indexes[idx_name] = [unique, []] - indexes[idx_name][1].append(col) + table.append_constraint(schema.ForeignKeyConstraint(constrained_columns, refspec, conname, link_to_name=True)) - for name, (unique, columns) in indexes.items(): + # Indexes + indexes = self.get_indexes(connection, tablename, schemaname, + info_cache) + for index_d in indexes: + name = index_d['name'] + columns = index_d['column_names'] + unique = index_d['unique'] schema.Index(name, *[table.columns[c] for c in columns], **dict(unique=unique)) - - def _load_domains(self, connection): ## Load data types for domains: -- 2.47.3