]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
moved reflecttable to inspector for mssql
authorRandall Smith <randall@tnr.cc>
Sat, 25 Apr 2009 07:35:03 +0000 (07:35 +0000)
committerRandall Smith <randall@tnr.cc>
Sat, 25 Apr 2009 07:35:03 +0000 (07:35 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/engine/reflection.py

index 8df54bcafe280de67d6148518587d1f4b7087214..fee22ad11b290f685ba46685fcf33a557cd955e0 100644 (file)
@@ -1218,9 +1218,8 @@ class MSDialect(default.DefaultDialect):
         s = sql.select([columns],
                    current_schema
                        and sql.and_(columns.c.table_name==tablename, columns.c.table_schema==current_schema)
-                       or columns.c.table_name==table_name,
+                       or columns.c.table_name==tablename,
                    order_by=[columns.c.ordinal_position])
-
         c = connection.execute(s)
         cols = []
         while True:
@@ -1257,8 +1256,6 @@ class MSDialect(default.DefaultDialect):
 
             coltype = coltype(**kwargs)
             colargs = []
-            if default is not None:
-                colargs.append(sa_schema.DefaultClause(sql.text(default)))
             cdict = {
                 'name' : name,
                 'type' : coltype,
@@ -1267,6 +1264,40 @@ class MSDialect(default.DefaultDialect):
                 'attrs' : colargs
             }
             cols.append(cdict)
+        # autoincrement and identity
+        colmap = {}
+        for col in cols:
+            colmap[col['name']] = col
+        # We also run an sp_columns to check for identity columns:
+        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (tablename, current_schema))
+        ic = None
+        while True:
+            row = cursor.fetchone()
+            if row is None:
+                break
+            (col_name, type_name) = row[3], row[5]
+            if type_name.endswith("identity") and col_name in colmap:
+                ic = col_name
+                colmap[col_name]['autoincrement'] = True
+                colmap[col_name]['sequence'] = dict(
+                                    name='%s_identity' % col_name)
+                break
+        cursor.close()
+        if not ic is None:
+            try:
+                # is this table_fullname reliable?
+                table_fullname = "%s.%s" % (current_schema, tablename)
+                cursor = connection.execute("select ident_seed(?), ident_incr(?)", table_fullname, table_fullname)
+                row = cursor.fetchone()
+                cursor.close()
+                if not row is None:
+                    colmap[ic]['sequence'].update({
+                        'start' : int(row[0]),
+                        'increment' : int(row[1])
+                    })
+            except:
+                # ignoring it, works just like before
+                pass
         return cols
 
     @reflection.cache
@@ -1333,6 +1364,10 @@ class MSDialect(default.DefaultDialect):
             if not rcol in rcols:
                 rcols.append(rcol)
         if fknm and scols:
+            # don't return the remote schema if no schema was specified and it
+            # is the default
+            if schema is None and current_schema == rschema:
+                rschema = None
             fkeys.append({
                 'name' : fknm,
                 'constrained_columns' : scols,
@@ -1343,87 +1378,9 @@ class MSDialect(default.DefaultDialect):
         return fkeys
 
     def reflecttable(self, connection, table, include_columns):
-        info_cache = {}
-
-        # Get base columns
-        if table.schema is not None:
-            current_schema = table.schema
-        else:
-            current_schema = self.get_default_schema_name(connection)
-        columns = self.get_columns(connection, table.name, current_schema, info_cache=info_cache)
-
-        found_table = False
-        for cdict in columns:
-            name = cdict['name']
-            coltype = cdict['type']
-            nullable = cdict['nullable']
-            default = cdict['default']
-            colargs = cdict['attrs']
-            found_table = True
-            if include_columns and name not in include_columns:
-                continue
-            table.append_column(sa_schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
-        if not found_table:
-            raise exc.NoSuchTableError(table.name)
 
-        # We also run an sp_columns to check for identity columns:
-        cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
-        ic = None
-        while True:
-            row = cursor.fetchone()
-            if row is None:
-                break
-            col_name, type_name = row[3], row[5]
-            if type_name.endswith("identity") and col_name in table.c:
-                ic = table.c[col_name]
-                ic.autoincrement = True
-                # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
-                ic.sequence = sa_schema.Sequence(ic.name + '_identity', 1, 1)
-                # MSSQL: only one identity per table allowed
-                cursor.close()
-                break
-        if not ic is None:
-            try:
-                cursor = connection.execute("select ident_seed(?), ident_incr(?)", table.fullname, table.fullname)
-                row = cursor.fetchone()
-                cursor.close()
-                if not row is None:
-                    ic.sequence.start = int(row[0])
-                    ic.sequence.increment = int(row[1])
-            except:
-                # ignoring it, works just like before
-                pass
-
-        # Primary key constraints
-        pkeys = self.get_primary_keys(connection, table.name,
-                                      current_schema, info_cache=info_cache)
-        for pkey in pkeys:
-            if pkey in table.c:
-                table.primary_key.add(table.c[pkey])
-
-        # Foreign key constraints
-        def _gen_fkref(table, rschema, rtbl, rcol):
-            if rschema == current_schema and not table.schema:
-                return '.'.join([rtbl, rcol])
-            else:
-                return '.'.join([rschema, rtbl, rcol])
-
-        fkeys = self.get_foreign_keys(connection, table.name, current_schema, info_cache=info_cache)
-        for fkey_d in fkeys:
-            fknm = fkey_d['name']
-            scols = fkey_d['constrained_columns']
-            rschema = fkey_d['referred_schema']
-            rtbl = fkey_d['referred_table']
-            rcols = fkey_d['referred_columns']
-            # if the reflected schema is the default schema then don't set it because this will
-            # play into the metadata key causing duplicates.
-            if rschema == current_schema and not table.schema:
-                sa_schema.Table(rtbl, table.metadata, autoload=True,
-                             autoload_with=connection)
-            else:
-                sa_schema.Table(rtbl, table.metadata, schema=rschema,
-                             autoload=True, autoload_with=connection)
-            table.append_constraint(sa_schema.ForeignKeyConstraint(scols, [_gen_fkref(table, rschema, rtbl, c) for c in rcols], fknm, link_to_name=True))
+        insp = reflection.Inspector.from_engine(connection)
+        return insp.reflecttable(table, include_columns)
 
 # fixme.  I added this for the tests to run. -Randall
 MSSQLDialect = MSDialect
index 372d0e3f49a57796db8fbba558ca1ef6b1a300d8..a21332cd8aa14286c18ec3971aaecf2570ecf984 100644 (file)
@@ -275,12 +275,15 @@ class Inspector(object):
             del tblkw[k]
             tblkw[str(k)] = v
 
-        # Py2K
-        if isinstance(schema, str):
-            schema = schema.decode(self.dialect.encoding)
-        if isinstance(table_name, str):
-            table_name = table_name.decode(self.dialect.encoding)
+        ### Py2K
+        # fixme
+        # This is breaking mssql, which can't bind unicode.
+        ##if isinstance(schema, str):
+        ##    schema = schema.decode(dialect.encoding)
+        ##if isinstance(table_name, str):
+        ##    table_name = table_name.decode(dialect.encoding)
         # end Py2K
+
         # columns
         found_table = False
         for col_d in self.get_columns(table_name, schema, **tblkw):
@@ -305,11 +308,19 @@ class Inspector(object):
                     colargs.append(sa_schema.DefaultClause(default))
                 else:
                     colargs.append(sa_schema.DefaultClause(sql.text(default)))
-            table.append_column(sa_schema.Column(name, coltype,
-                                nullable=nullable, *colargs, **col_kw))
+            col = sa_schema.Column(name, coltype,nullable=nullable, *colargs, **col_kw)
+            if 'sequence' in col_d:
+                seq = col_d['sequence']
+                col.sequence = sa_schema.Sequence(seq['name'], 1, 1)
+                if 'start' in seq:
+                    col.sequence.start = seq['start']
+                if 'increment' in seq:
+                    col.sequence.increment = seq['increment']
+            table.append_column(col)
 
         if not found_table:
             raise exc.NoSuchTableError(table.name)
+
         # Primary keys
         for pk in self.get_primary_keys(table_name, schema, **tblkw):
             if pk in table.c: