]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
essential refactoring complete - tests pass
authorRandall Smith <randall@tnr.cc>
Thu, 12 Feb 2009 06:32:03 +0000 (06:32 +0000)
committerRandall Smith <randall@tnr.cc>
Thu, 12 Feb 2009 06:32:03 +0000 (06:32 +0000)
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgres/base.py

index 8daf6404b7fce02c99fb74125457f38f37bf3a80..f9db033a0beec109a046f65f9846c3ce0103ad9c 100644 (file)
@@ -447,6 +447,9 @@ class OracleIdentifierPreparer(compiler.IdentifierPreparer):
         name = re.sub(r'^_+', '', savepoint.ident)
         return super(OracleIdentifierPreparer, self).format_savepoint(savepoint, name)
         
+class OracleInfoCache(default.DefaultInfoCache):
+    pass
+
 class OracleDialect(default.DefaultDialect):
     name = 'oracle'
     supports_alter = True
@@ -471,6 +474,7 @@ class OracleDialect(default.DefaultDialect):
     type_compiler = OracleTypeCompiler
     preparer = OracleIdentifierPreparer
     defaultrunner = OracleDefaultRunner
+    info_cache = OracleInfoCache
     
     def __init__(self, 
                 use_ansi=True, 
@@ -564,24 +568,15 @@ class OracleDialect(default.DefaultDialect):
             else:
                 return None, None, None, None
 
-    def reflecttable(self, connection, table, include_columns):
-        preparer = self.identifier_preparer
-
-        resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
-
-        if resolve_synonyms:
-            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
-        else:
-            actual_name, owner, dblink, synonym = None, None, None, None
+    def get_columns(self, connection, tablename, schemaname=None,
+                    info_cache=None, dblink=''):
 
-        if not actual_name:
-            actual_name = self._denormalize_name(table.name)
-        if not dblink:
-            dblink = ''
-        if not owner:
-            owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection))
-
-        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':actual_name, 'owner':owner})
+        if info_cache:
+            columns = info_cache.getColumns(tablename, schemaname)
+            if columns:
+                return columns
+        columns = []
+        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS%(dblink)s where TABLE_NAME = :table_name and OWNER = :owner" % {'dblink':dblink}, {'table_name':tablename, 'owner':schemaname})
 
         while True:
             row = c.fetchone()
@@ -590,9 +585,6 @@ class OracleDialect(default.DefaultDialect):
 
             (colname, coltype, length, precision, scale, nullable, default) = (self._normalize_name(row[0]), row[1], row[2], row[3], row[4], row[5]=='Y', row[6])
 
-            if include_columns and colname not in include_columns:
-                continue
-
             # INTEGER if the scale is 0 and precision is null
             # NUMBER if the scale and precision are both null
             # NUMBER(9,2) if the precision is 9 and the scale is 2
@@ -619,13 +611,26 @@ class OracleDialect(default.DefaultDialect):
             colargs = []
             if default is not None:
                 colargs.append(schema.DefaultClause(sql.text(default)))
-
-            table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
-
-        if not table.columns:
-            raise AssertionError("Couldn't find any column information for table %s" % actual_name)
-
-        c = connection.execute("""SELECT
+            cdict = {
+                'name': colname,
+                'type': coltype,
+                'nullable': nullable,
+                'default': default,
+                'attrs': colargs
+            }
+            columns.append(cdict)
+        if info_cache:
+            info_cache.setColumns(columns, tablename, schemaname)
+        return columns
+
+    def _get_constraint_data(self, connection, tablename, schemaname=None,
+                             info_cache=None, dblink=''):
+
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and ['constraints'] in table_cache.keys():
+                return table_cache['constraints']
+        rp = connection.execute("""SELECT
              ac.constraint_name,
              ac.constraint_type,
              loc.column_name AS local_column,
@@ -644,19 +649,49 @@ class OracleDialect(default.DefaultDialect):
            AND ac.r_constraint_name = rem.constraint_name(+)
            -- order multiple primary keys correctly
            ORDER BY ac.constraint_name, loc.position, rem.position"""
-         % {'dblink':dblink}, {'table_name' : actual_name, 'owner' : owner})
+         % {'dblink':dblink}, {'table_name' : tablename, 'owner' : schemaname})
+        constraint_data = rp.fetchall()
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname,
+                                              create=True)
+            table_cache['constraints'] = constraint_data
+        return constraint_data
+
+    def get_primary_keys(self, connection, tablename, schemaname=None,
+                         info_cache=None, dblink=''):
+
+        if info_cache:
+            pkeys = info_cache.getPrimaryKeys(tablename, schemaname)
+            if pkeys is not None:
+                return pkeys
+        pkeys = []
+        constraint_data = self._get_constraint_data(connection, tablename,
+                                        schemaname, info_cache, dblink)
+        for row in constraint_data:
+            #print "ROW:" , row
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
+            if cons_type == 'P':
+                pkeys.append(local_column)
+        if info_cache:
+            info_cache.setPrimaryKeys(pkeys, tablename, schemaname)
+        return pkeys
+
+
+    def get_foreign_keys(self, connection, tablename, schemaname=None,
+                         info_cache=None, dblink='', resolve_synonyms=False):
 
+        if info_cache:
+            fkeys = info_cache.getForeignKeys(tablename, schemaname)
+            if fkeys is not None:
+                return fkeys
+
+        constraint_data = self._get_constraint_data(connection, tablename,
+                                                schemaname, info_cache, dblink)
+        fkeys = []
         fks = {}
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            #print "ROW:" , row
-            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \
-                    row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
-            if cons_type == 'P' and local_column in table.c:
-                table.primary_key.add(table.c[local_column])
-            elif cons_type == 'R':
+        for row in constraint_data:
+            (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = row[0:2] + tuple([self._normalize_name(x) for x in row[2:]])
+            if cons_type == 'R':
                 try:
                     fk = fks[cons_name]
                 except KeyError:
@@ -675,21 +710,78 @@ class OracleDialect(default.DefaultDialect):
                     if ref_synonym:
                         remote_table = self._normalize_name(ref_synonym)
                         remote_owner = self._normalize_name(ref_remote_owner)
-
-                if not table.schema and self._denormalize_name(remote_owner) == owner:
-                    refspec =  ".".join([remote_table, remote_column])
-                    t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
-                else:
-                    refspec =  ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
-                    t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, schema=remote_owner, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
-
                 if local_column not in fk[0]:
                     fk[0].append(local_column)
-                if refspec not in fk[1]:
-                    fk[1].append(refspec)
+                if remote_column not in fk[1]:
+                    fk[1].append(remote_column)
+        for (name, value) in fks.items():
+            if remote_table and value[1]:
+                fkeys.append((name, value[0], remote_owner, remote_table, value[1]))
+        if info_cache:
+            info_cache.setForeignKeys(fkeys, tablename, schemaname)
+        return fkeys
+
+    def reflecttable(self, connection, table, include_columns):
+        preparer = self.identifier_preparer
+        info_cache = OracleInfoCache()
+
+        resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
+
+        if resolve_synonyms:
+            actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
+        else:
+            actual_name, owner, dblink, synonym = None, None, None, None
 
-        for name, value in fks.iteritems():
-            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name, link_to_name=True))
+        if not actual_name:
+            actual_name = self._denormalize_name(table.name)
+        if not dblink:
+            dblink = ''
+        if not owner:
+            owner = self._denormalize_name(table.schema or self.get_default_schema_name(connection))
+
+        # columns
+        columns = self.get_columns(connection, actual_name, owner, info_cache,
+                                                                        dblink)
+        for cdict in columns:
+            colname = cdict['name']
+            coltype = cdict['type']
+            nullable = cdict['nullable']
+            colargs = cdict['attrs']
+            if include_columns and colname not in include_columns:
+                continue
+            table.append_column(schema.Column(colname, coltype,
+                                              nullable=nullable, *colargs))
+        if not table.columns:
+            raise AssertionError("Couldn't find any column information for table %s" % actual_name)
+
+        # primary keys
+        for pkcol in self.get_primary_keys(connection, actual_name, owner,
+                                                           info_cache, dblink):
+            if pkcol in table.c:
+                table.primary_key.add(table.c[pkcol])
+
+        # foreign keys
+        fks = {}
+        fkeys = []
+        fkeys = self.get_foreign_keys(connection, actual_name, owner,
+                                      info_cache, dblink, resolve_synonyms)
+        refspecs = []
+        for (conname, constrained_columns, referred_schema, referred_table,
+             referred_columns) in fkeys:
+            for (i, ref_col) in enumerate(referred_columns):
+                if not table.schema and self._denormalize_name(referred_schema) == self._denormalize_name(owner):
+                    t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+
+                    refspec =  ".".join([referred_table, ref_col])
+                else:
+                    refspec = '.'.join([x for x in [referred_schema,
+                                    referred_table, ref_col] if x is not None])
+
+                    t = schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection, schema=referred_schema, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
+                refspecs.append(refspec)
+            table.append_constraint(
+                schema.ForeignKeyConstraint(constrained_columns, refspecs,
+                                        name=conname, link_to_name=True))
 
 
 class _OuterJoinColumn(sql.ClauseElement):
index 9277fb124969078ed65d5b6467ce2c7850956a30..7db0dd882333ee9e33f92db624731d153f1f5efb 100644 (file)
@@ -640,9 +640,9 @@ class PGDialect(default.DefaultDialect):
     def get_columns(self, connection, tablename, schemaname=None,
                     info_cache=None):
         if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname)
-            if table_cache and 'columns' in table_cache.keys():
-                return table_cache.get('columns')
+            columns = info_cache.getColumns(tablename, schemaname)
+            if columns is not None:
+                return columns
         table_oid = self._get_table_oid(connection, tablename, schemaname,
                                         info_cache)
         SQL_COLS = """
@@ -727,9 +727,7 @@ class PGDialect(default.DefaultDialect):
                                default=default, colargs=colargs)
             columns.append(column_info)
         if info_cache:
-            table_cache = info_cache.getTable(tablename, schemaname, 
-                                              create=True)
-            table_cache['columns'] = columns
+            info_cache.setColumns(columns, tablename, schemaname)
         return columns
 
     def get_primary_keys(self, connection, tablename, schemaname=None,