From 10718c7bb5f01e844efc31826cacba0e5b1a24f8 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 12 Feb 2009 06:32:03 +0000 Subject: [PATCH] essential refactoring complete - tests pass --- lib/sqlalchemy/dialects/oracle/base.py | 192 +++++++++++++++++------ lib/sqlalchemy/dialects/postgres/base.py | 10 +- 2 files changed, 146 insertions(+), 56 deletions(-) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 8daf6404b7..f9db033a0b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -447,6 +447,9 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer): name = re.sub(r'^_+', '', savepoint.ident) return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name) +class OracleInfoCache(default.DefaultInfoCache): + pass + class OracleDialect(default.DefaultDialect): name = 'oracle' supports_alter = True @@ -471,6 +474,7 @@ class OracleDialect(default.DefaultDialect): type_compiler = OracleTypeCompiler preparer = OracleIdentifierPreparer defaultrunner = OracleDefaultRunner + info_cache = OracleInfoCache def __init__(self, use_ansi=True, @@ -564,24 +568,15 @@ class OracleDialect(default.DefaultDialect): else: return None, None, None, None - def reflecttable(self, connection, table, include_columns): - preparer = self.identifier_preparer - - resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) - - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) - else: - actual_name, owner, dblink, synonym = None, None, None, None + def get_columns(self, connection, tablename, schemaname=None, + info_cache=None, dblink=''): - if not actual_name: - actual_name = self._denormalize_name(table.name) - if not dblink: - dblink = '' - if not owner: - owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection)) - - c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner}) + if info_cache: + columns = info_cache.getColumns(tablename, schemaname) + if columns: + return columns + columns = [] + c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname}) while True: row = c.fetchone() @@ -590,9 +585,6 @@ class OracleDialect(default.DefaultDialect): (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6]) - if include_columns and colname not in include_columns: - continue - # INTEGER if the scale is 0 and precision is null # NUMBER if the scale and precision are both null # NUMBER(9,2) if the precision is 9 and the scale is 2 @@ -619,13 +611,26 @@ class OracleDialect(default.DefaultDialect): colargs = [] if default is not None: colargs.append(schema.DefaultClause(sql.text(default))) - - table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) - - if not table.columns: - raise AssertionError("Couldn't find any column information for table %s" % actual_name) - - c = connection.execute("""SELECT + cdict = { + 'name': colname, + 'type': coltype, + 'nullable': nullable, + 'default': default, + 'attrs': colargs + } + columns.append(cdict) + if info_cache: + info_cache.setColumns(columns, tablename, schemaname) + return columns + + def _get_constraint_data(self, connection, tablename, schemaname=None, + info_cache=None, dblink=''): + + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname) + if table_cache and ['constraints'] in table_cache.keys(): + return table_cache['constraints'] + rp = connection.execute("""SELECT ac.constraint_name, ac.constraint_type, loc.column_name AS local_column, @@ -644,19 +649,49 @@ class OracleDialect(default.DefaultDialect): AND ac.r_constraint_name = rem.constraint_name(+) -- order multiple primary keys correctly ORDER BY ac.constraint_name, loc.position, rem.position""" - % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner}) + % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname}) + constraint_data = rp.fetchall() + if info_cache: + table_cache = info_cache.getTable(tablename, schemaname, + create=True) + table_cache['constraints'] = constraint_data + return constraint_data + + def get_primary_keys(self, connection, tablename, schemaname=None, + info_cache=None, dblink=''): + + if info_cache: + pkeys = info_cache.getPrimaryKeys(tablename, schemaname) + if pkeys is not None: + return pkeys + pkeys = [] + constraint_data = self._get_constraint_data(connection, tablename, + schemaname, info_cache, dblink) + for row in constraint_data: + #print "ROW:" , row + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) + if cons_type == 'P': + pkeys.append(local_column) + if info_cache: + info_cache.setPrimaryKeys(pkeys, tablename, schemaname) + return pkeys + + + def get_foreign_keys(self, connection, tablename, schemaname=None, + info_cache=None, dblink='', resolve_synonyms=False): + if info_cache: + fkeys = info_cache.getForeignKeys(tablename, schemaname) + if fkeys is not None: + return fkeys + + constraint_data = self._get_constraint_data(connection, tablename, + schemaname, info_cache, dblink) + fkeys = [] fks = {} - while True: - row = c.fetchone() - if row is None: - break - #print "ROW:" , row - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) - if cons_type == 'P' and local_column in table.c: - table.primary_key.add(table.c[local_column]) - elif cons_type == 'R': + for row in constraint_data: + (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]]) + if cons_type == 'R': try: fk = fks[cons_name] except KeyError: @@ -675,21 +710,78 @@ class OracleDialect(default.DefaultDialect): if ref_synonym: remote_table = self._normalize_name(ref_synonym) remote_owner = self._normalize_name(ref_remote_owner) - - if not table.schema and self._denormalize_name(remote_owner) == owner: - refspec = ".".join([remote_table, remote_column]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - else: - refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x]) - t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) - if local_column not in fk[0]: fk[0].append(local_column) - if refspec not in fk[1]: - fk[1].append(refspec) + if remote_column not in fk[1]: + fk[1].append(remote_column) + for (name, value) in fks.items(): + if remote_table and value[1]: + fkeys.append((name, value[0], remote_owner, remote_table, value[1])) + if info_cache: + info_cache.setForeignKeys(fkeys, tablename, schemaname) + return fkeys + + def reflecttable(self, connection, table, include_columns): + preparer = self.identifier_preparer + info_cache = OracleInfoCache() + + resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) + else: + actual_name, owner, dblink, synonym = None, None, None, None - for name, value in fks.iteritems(): - table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True)) + if not actual_name: + actual_name = self._denormalize_name(table.name) + if not dblink: + dblink = '' + if not owner: + owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection)) + + # columns + columns = self.get_columns(connection, actual_name, owner, info_cache, + dblink) + for cdict in columns: + colname = cdict['name'] + coltype = cdict['type'] + nullable = cdict['nullable'] + colargs = cdict['attrs'] + if include_columns and colname not in include_columns: + continue + table.append_column(schema.Column(colname, coltype, + nullable=nullable, *colargs)) + if not table.columns: + raise AssertionError("Couldn't find any column information for table %s" % actual_name) + + # primary keys + for pkcol in self.get_primary_keys(connection, actual_name, owner, + info_cache, dblink): + if pkcol in table.c: + table.primary_key.add(table.c[pkcol]) + + # foreign keys + fks = {} + fkeys = [] + fkeys = self.get_foreign_keys(connection, actual_name, owner, + info_cache, dblink, resolve_synonyms) + refspecs = [] + for (conname, constrained_columns, referred_schema, referred_table, + referred_columns) in fkeys: + for (i, ref_col) in enumerate(referred_columns): + if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner): + t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + + refspec = ".".join([referred_table, ref_col]) + else: + refspec = '.'.join([x for x in [referred_schema, + referred_table, ref_col] if x is not None]) + + t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) + refspecs.append(refspec) + table.append_constraint( + schema.ForeignKeyConstraint(constrained_columns, refspecs, + name=conname, link_to_name=True)) class _OuterJoinColumn(sql.ClauseElement): diff --git a/lib/sqlalchemy/dialects/postgres/base.py b/lib/sqlalchemy/dialects/postgres/base.py index 9277fb1249..7db0dd8823 100644 --- a/lib/sqlalchemy/dialects/postgres/base.py +++ b/lib/sqlalchemy/dialects/postgres/base.py @@ -640,9 +640,9 @@ class PGDialect(default.DefaultDialect): def get_columns(self, connection, tablename, schemaname=None, info_cache=None): if info_cache: - table_cache = info_cache.getTable(tablename, schemaname) - if table_cache and 'columns' in table_cache.keys(): - return table_cache.get('columns') + columns = info_cache.getColumns(tablename, schemaname) + if columns is not None: + return columns table_oid = self._get_table_oid(connection, tablename, schemaname, info_cache) SQL_COLS = """ @@ -727,9 +727,7 @@ class PGDialect(default.DefaultDialect): default=default, colargs=colargs) columns.append(column_info) if info_cache: - table_cache = info_cache.getTable(tablename, schemaname, - create=True) - table_cache['columns'] = columns + info_cache.setColumns(columns, tablename, schemaname) return columns def get_primary_keys(self, connection, tablename, schemaname=None, -- 2.47.3