]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactored mssql for reflection (tests pass/fail same)
authorRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 06:30:27 +0000 (06:30 +0000)
committerRandall Smith <randall@tnr.cc>
Wed, 11 Feb 2009 06:30:27 +0000 (06:30 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/engine/reflection.py

index bbde62bbac8d2797d6a6262eabaac9dab6cf3484..43fe4b5d57c0ff609f7cab8dd9bcd9ea3928d929 100644 (file)
@@ -1043,6 +1043,13 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
         #TODO: determine MSSQL's escaping rules
         return value
 
+
+class MSInfoCache(default.DefaultInfoCache):
+    
+    def __init__(self, *args, **kwargs):
+        default.DefaultInfoCache.__init__(self, *args, **kwargs)
+
+
 class MSDialect(default.DefaultDialect):
     name = 'mssql'
     supports_default_values = True
@@ -1064,6 +1071,7 @@ class MSDialect(default.DefaultDialect):
     ddl_compiler = MSDDLCompiler
     type_compiler = MSTypeCompiler
     preparer = MSIdentifierPreparer
+    info_cache = MSInfoCache
 
     def __init__(self,
                  auto_identity_insert=True, query_timeout=None,
@@ -1139,28 +1147,130 @@ class MSDialect(default.DefaultDialect):
         row  = c.fetchone()
         return row is not None
 
-    def reflecttable(self, connection, table, include_columns):
-        import sqlalchemy.dialects.information_schema as ischema
+    def get_schema_names(self, connection, info_cache=None):
+        if info_cache:
+            schema_names = info_cache.getSchemaNames()
+            if schema_names is not None:
+                return schema_names
+        import sqlalchemy.databases.information_schema as ischema
+        s = sql.select([self.uppercase_table(ischema.schemata).c.schema_name],
+            order_by=[ischema.schemata.c.schema_name]
+        )
+        schema_names = [r[0] for r in connection.execute(s)]
+        if info_cache:
+            info_cache.addAllSchemas(schema_names)
+        return schema_names
+
+    def get_table_names(self, connection, schemaname, info_cache=None):
+        import sqlalchemy.databases.information_schema as ischema
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            table_names = info_cache.getTableNames(current_schema)
+            if table_names is not None:
+                return table_names
+        tables = self.uppercase_table(ischema.tables)
+        s = sql.select([tables.c.table_name],
+            sql.and_(
+                tables.c.table_schema == current_schema,
+                tables.c.table_type == 'BASE TABLE'
+            ),
+            order_by=[tables.c.table_name]
+        )
+        table_names = [r[0] for r in connection.execute(s)]
+        if info_cache:
+            info_cache.addAllTables(table_names, current_schema)
+        return table_names
+
+    def get_view_names(self, connection, schemaname=None, info_cache=None):
+        import sqlalchemy.databases.information_schema as ischema
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            view_names = info_cache.getViewNames(current_schema)
+            if view_names is not None:
+                return view_names
+        tables = self.uppercase_table(ischema.tables)
+        s = sql.select([tables.c.table_name],
+            sql.and_(
+                tables.c.table_schema == current_schema,
+                tables.c.table_type == 'VIEW'
+            ),
+            order_by=[tables.c.table_name]
+        )
+        view_names = [r[0] for r in connection.execute(s)]
+        if info_cache:
+            info_cache.addAllViews(view_names, schemaname)
+        return view_names
+
+    def get_indexes(self, connection, tablename, schemaname=None,
+                                                            info_cache=None):
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema)
+            if table_cache and 'indexes' in table_cache:
+                return table_cache.get('indexes')
+        full_tname = "%s.%s" % (current_schema, tablename)
+        indexes = []
+        s = sql.text("exec sp_helpindex '%s'" % full_tname)
+        rp = connection.execute(s)
+        for row in rp:
+            if 'primary key' not in row['index_description']:
+                indexes.append({
+                    'name' : row['index_name'],
+                    'column_names' : row['index_keys'].split(','),
+                    'unique': 'unique' in row['index_description']
+                })
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema,
+                                              create=True)
+            table_cache['indexes'] = indexes
+        return indexes
+
+    def get_view_definition(self, connection, viewname, schemaname=None,
+                            info_cache=None):
+        import sqlalchemy.databases.information_schema as ischema
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            view_cache = info_cache.getView(viewname, current_schema)
+            if view_cache and 'definition' in view_cache.keys():
+                return view_cache.get('definition')
+        views = self.uppercase_table(ischema.views)
+        s = sql.select([views.c.view_definition],
+            sql.and_(
+                views.c.table_schema == current_schema,
+                views.c.table_name == viewname
+            ),
+        )
+        rp = connection.execute(s)
+        if rp:
+            view_def = rp.scalar()
+            if info_cache:
+                view_cache = info_cache.getView(viewname, current_schema,
+                                                create=True)
+                view_cache['definition'] = view_def
+            return view_def
+
+    def get_columns(self, connection, tablename, schemaname=None,
+                                                            info_cache=None):
         # Get base columns
-        if table.schema is not None:
-            current_schema = table.schema
-        else:
-            current_schema = self.get_default_schema_name(connection)
-
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema)
+            if table_cache and 'columns' in table_cache.keys():
+                return table_cache.get('columns')
+        import sqlalchemy.dialects.information_schema as ischema
         columns = self.uppercase_table(ischema.columns)
         s = sql.select([columns],
                    current_schema
-                       and sql.and_(columns.c.table_name==table.name, columns.c.table_schema==current_schema)
-                       or columns.c.table_name==table.name,
+                       and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
+                       or columns.c.table_name==table_name,
                    order_by=[columns.c.ordinal_position])
 
         c = connection.execute(s)
-        found_table = False
+        cols = []
         while True:
             row = c.fetchone()
             if row is None:
                 break
-            found_table = True
             (name, type, nullable, charlen, numericprec, numericscale, default, collation) = (
                 row[columns.c.column_name],
                 row[columns.c.data_type],
@@ -1171,9 +1281,6 @@ class MSDialect(default.DefaultDialect):
                 row[columns.c.column_default],
                 row[columns.c.collation_name]
             )
-            if include_columns and name not in include_columns:
-                continue
-
             coltype = self.ischema_names.get(type, None)
 
             kwargs = {}
@@ -1196,8 +1303,131 @@ class MSDialect(default.DefaultDialect):
             colargs = []
             if default is not None:
                 colargs.append(schema.DefaultClause(sql.text(default)))
-            table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
+            cdict = {
+                'name' : name,
+                'type' : coltype,
+                'nullable' : nullable,
+                'default' : default,
+                'attrs' : colargs
+            }
+            cols.append(cdict)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema,
+                                              create=True)
+            table_cache['columns'] = cols
+        return cols
+
+    def get_primary_keys(self, connection, tablename, schemaname=None,
+                                                            info_cache=None):
+        import sqlalchemy.dialects.information_schema as ischema
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'primary_keys' in table_cache.keys():
+                return table_cache.get('primary_keys')
+        pkeys = []
+        # Add constraints
+        RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
+        TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
+        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
+        R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
+
+        # Primary key constraints
+        s = sql.select([C.c.column_name, TC.c.constraint_type],
+            sql.and_(TC.c.constraint_name == C.c.constraint_name,
+                     C.c.table_name == tablename,
+                     C.c.table_schema == current_schema)
+        )
+        c = connection.execute(s)
+        for row in c:
+            if 'PRIMARY' in row[TC.c.constraint_type.name]:
+                pkeys.append(row[0])
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema,
+                                              create=True)
+            table_cache['primary_keys'] = pkeys
+        return pkeys
+
+    def get_foreign_keys(self, connection, tablename, schemaname=None,
+                                                            info_cache=None):
+        import sqlalchemy.dialects.information_schema as ischema
+        current_schema = schemaname or self.get_default_schema_name(connection)
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, schemaname)
+            if table_cache and 'foreign_keys' in table_cache.keys():
+                return table_cache.get('foreign_keys')
+        # Add constraints
+        RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
+        TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
+        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
+        R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
+
+        # Foreign key constraints
+        s = sql.select([C.c.column_name,
+                        R.c.table_schema, R.c.table_name, R.c.column_name,
+                        RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
+                       sql.and_(C.c.table_name == tablename,
+                                C.c.table_schema == current_schema,
+                                C.c.constraint_name == RR.c.constraint_name,
+                                R.c.constraint_name == RR.c.unique_constraint_name,
+                                C.c.ordinal_position == R.c.ordinal_position
+                                ),
+                       order_by = [RR.c.constraint_name, R.c.ordinal_position])
+        rows = connection.execute(s).fetchall()
+
+        # group rows by constraint ID, to handle multi-column FKs
+        fkeys = []
+        fknm, scols, rcols = (None, [], [])
+        for r in rows:
+            scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
+            if rfknm != fknm:
+                if fknm:
+                    fkeys.append({
+                        'name' : fknm,
+                        'constrained_columns' : scols,
+                        'referred_schema' : rschema,
+                        'referred_table' : rtbl,
+                        'referred_columns' : rcols
+                    })
+                fknm, scols, rcols = (rfknm, [], [])
+            if not scol in scols:
+                scols.append(scol)
+            if not (rschema, rtbl, rcol) in rcols:
+                rcols.append((rschema, rtbl, rcol))
+        if fknm and scols:
+            fkeys.append({
+                'name' : fknm,
+                'constrained_columns' : scols,
+                'referred_schema' : rschema,
+                'referred_table' : rtbl,
+                'referred_columns' : rcols
+            })
+        if info_cache:
+            table_cache = info_cache.getTable(tablename, current_schema)
+            table_cache['foreign_keys'] = fkeys
+        return fkeys
 
+    def reflecttable(self, connection, table, include_columns):
+        import sqlalchemy.dialects.information_schema as ischema
+        # Get base columns
+        if table.schema is not None:
+            current_schema = table.schema
+        else:
+            current_schema = self.get_default_schema_name(connection)
+        info_cache = MSInfoCache()
+        columns = self.get_columns(connection, table.name, current_schema,
+                                   info_cache)
+        found_table = False
+        for cdict in columns:
+            name = cdict['name']
+            coltype = cdict['type']
+            nullable = cdict['nullable']
+            default = cdict['default']
+            colargs = cdict['attrs']
+            found_table = True
+            if include_columns and name not in include_columns:
+                continue
+            table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
         if not found_table:
             raise exc.NoSuchTableError(table.name)
 
@@ -1229,59 +1459,34 @@ class MSDialect(default.DefaultDialect):
                 # ignoring it, works just like before
                 pass
 
-        # Add constraints
-        RR = self.uppercase_table(ischema.ref_constraints)    #information_schema.referential_constraints
-        TC = self.uppercase_table(ischema.constraints)        #information_schema.table_constraints
-        C  = self.uppercase_table(ischema.pg_key_constraints).alias('C') #information_schema.constraint_column_usage: the constrained column
-        R  = self.uppercase_table(ischema.pg_key_constraints).alias('R') #information_schema.constraint_column_usage: the referenced column
-
         # Primary key constraints
-        s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name,
-                                                                         C.c.table_name == table.name,
-                                                                         C.c.table_schema == (table.schema or current_schema)))
-        c = connection.execute(s)
-        for row in c:
-            if 'PRIMARY' in row[TC.c.constraint_type.name] and row[0] in table.c:
-                table.primary_key.add(table.c[row[0]])
+        pkeys = self.get_primary_keys(connection, table.name,
+                                      current_schema, info_cache)
+        for pkey in pkeys:
+            if pkey in table.c:
+                table.primary_key.add(table.c[pkey])
 
         # Foreign key constraints
-        s = sql.select([C.c.column_name,
-                        R.c.table_schema, R.c.table_name, R.c.column_name,
-                        RR.c.constraint_name, RR.c.match_option, RR.c.update_rule, RR.c.delete_rule],
-                       sql.and_(C.c.table_name == table.name,
-                                C.c.table_schema == (table.schema or current_schema),
-                                C.c.constraint_name == RR.c.constraint_name,
-                                R.c.constraint_name == RR.c.unique_constraint_name,
-                                C.c.ordinal_position == R.c.ordinal_position
-                                ),
-                       order_by = [RR.c.constraint_name, R.c.ordinal_position])
-        rows = connection.execute(s).fetchall()
-
         def _gen_fkref(table, rschema, rtbl, rcol):
             if rschema == current_schema and not table.schema:
                 return '.'.join([rtbl, rcol])
             else:
                 return '.'.join([rschema, rtbl, rcol])
 
-        # group rows by constraint ID, to handle multi-column FKs
-        fknm, scols, rcols = (None, [], [])
-        for r in rows:
-            scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
+        fkeys = self.get_foreign_keys(connection, table.name, current_schema,
+                                      info_cache)
+        for fkey_d in fkeys:
+            fknm = fkey_d['name']
+            scols = fkey_d['constrained_columns']
+            rschema = fkey_d['referred_schema']
+            rtbl = fkey_d['referred_table']
+            rcols = fkey_d['referred_columns']
             # if the reflected schema is the default schema then don't set it because this will
             # play into the metadata key causing duplicates.
             if rschema == current_schema and not table.schema:
-                schema.Table(rtbl, table.metadata, autoload=True, autoload_with=connection)
+                schema.Table(rtbl, table.metadata, autoload=True,
+                             autoload_with=connection)
             else:
-                schema.Table(rtbl, table.metadata, schema=rschema, autoload=True, autoload_with=connection)
-            if rfknm != fknm:
-                if fknm:
-                    table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-                fknm, scols, rcols = (rfknm, [], [])
-            if not scol in scols:
-                scols.append(scol)
-            if not (rschema, rtbl, rcol) in rcols:
-                rcols.append((rschema, rtbl, rcol))
-
-        if fknm and scols:
+                schema.Table(rtbl, table.metadata, schema=rschema,
+                             autoload=True, autoload_with=connection)
             table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-
index 746e2e94b1c602270c63b7270f3029b1837f5215..677cafee9585aea17b1a49659c0e45223a7350db 100644 (file)
@@ -32,13 +32,13 @@ class Inspector(object):
           [sqlalchemy.engine.base.#Connectable]
 
         """
-        self.info_cache = {}
         self.conn = conn
         # set the engine
         if hasattr(conn, 'engine'):
             self.engine = conn.engine
         else:
             self.engine = conn
+        self.info_cache = self.engine.dialect.info_cache()
 
     def default_schema_name(self):
         return self.engine.dialect.get_default_schema_name(self.conn)