]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- start moving get_default_schema_name to an initialized var
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Jul 2009 00:18:57 +0000 (00:18 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 6 Jul 2009 00:18:57 +0000 (00:18 +0000)
- fix up MSSQL foreign key reflection ala oracle

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/test/testing.py
test/engine/test_reflection.py

index a5fdae2b7306e5ad041ce17fb18d70e29f1499b6..fdec5741db79833a774e358ecc5a9ce0a5834eaa 100644 (file)
@@ -1134,6 +1134,9 @@ class MSDialect(default.DefaultDialect):
         pass
 
     def get_default_schema_name(self, connection):
+        return self.default_schema_name
+        
+    def _get_default_schema_name(self, connection):
         user_name = connection.scalar("SELECT user_name() as user_name;")
         if user_name is not None:
             # now, get the default schema
@@ -1157,7 +1160,7 @@ class MSDialect(default.DefaultDialect):
 
 
     def has_table(self, connection, tablename, schema=None):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         columns = ischema.columns
         s = sql.select([columns],
                    current_schema
@@ -1179,7 +1182,7 @@ class MSDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_table_names(self, connection, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         tables = ischema.tables
         s = sql.select([tables.c.table_name],
             sql.and_(
@@ -1193,7 +1196,7 @@ class MSDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_view_names(self, connection, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         tables = ischema.tables
         s = sql.select([tables.c.table_name],
             sql.and_(
@@ -1208,7 +1211,7 @@ class MSDialect(default.DefaultDialect):
     # The cursor reports it is closed after executing the sp.
     @reflection.cache
     def get_indexes(self, connection, tablename, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         full_tname = "%s.%s" % (current_schema, tablename)
         indexes = []
         s = sql.text("exec sp_helpindex '%s'" % full_tname)
@@ -1227,7 +1230,7 @@ class MSDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_view_definition(self, connection, viewname, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         views = ischema.views
         s = sql.select([views.c.view_definition],
             sql.and_(
@@ -1243,7 +1246,7 @@ class MSDialect(default.DefaultDialect):
     @reflection.cache
     def get_columns(self, connection, tablename, schema=None, **kw):
         # Get base columns
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         columns = ischema.columns
         s = sql.select([columns],
                    current_schema
@@ -1330,7 +1333,7 @@ class MSDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_primary_keys(self, connection, tablename, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         pkeys = []
         # Add constraints
         RR = ischema.ref_constraints    #information_schema.referential_constraints
@@ -1352,7 +1355,7 @@ class MSDialect(default.DefaultDialect):
 
     @reflection.cache
     def get_foreign_keys(self, connection, tablename, schema=None, **kw):
-        current_schema = schema or self.get_default_schema_name(connection)
+        current_schema = schema or self.default_schema_name
         # Add constraints
         RR = ischema.ref_constraints    #information_schema.referential_constraints
         TC = ischema.constraints        #information_schema.table_constraints
@@ -1370,40 +1373,40 @@ class MSDialect(default.DefaultDialect):
                                 C.c.ordinal_position == R.c.ordinal_position
                                 ),
                        order_by = [RR.c.constraint_name, R.c.ordinal_position])
-        rows = connection.execute(s).fetchall()
+        
 
         # group rows by constraint ID, to handle multi-column FKs
         fkeys = []
         fknm, scols, rcols = (None, [], [])
-        for r in rows:
+        
+        def fkey_rec():
+            return {
+                'name' : None,
+                'constrained_columns' : [],
+                'referred_schema' : None,
+                'referred_table' : None,
+                'referred_columns' : []
+            }
+
+        fkeys = util.defaultdict(fkey_rec)
+        
+        for r in connection.execute(s).fetchall():
             scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r
-            if rfknm != fknm:
-                if fknm:
-                    fkeys.append({
-                        'name' : fknm,
-                        'constrained_columns' : scols,
-                        'referred_schema' : rschema,
-                        'referred_table' : rtbl,
-                        'referred_columns' : rcols
-                    })
-                fknm, scols, rcols = (rfknm, [], [])
-            if not scol in scols:
-                scols.append(scol)
-            if not rcol in rcols:
-                rcols.append(rcol)
-        if fknm and scols:
-            # don't return the remote schema if no schema was specified and it
-            # is the default
-            if schema is None and current_schema == rschema:
-                rschema = None
-            fkeys.append({
-                'name' : fknm,
-                'constrained_columns' : scols,
-                'referred_schema' : rschema,
-                'referred_table' : rtbl,
-                'referred_columns' : rcols
-            })
-        return fkeys
+
+            rec = fkeys[rfknm]
+            rec['name'] = rfknm
+            if not rec['referred_table']:
+                rec['referred_table'] = rtbl
+
+                if schema is not None or current_schema != rschema:
+                    rec['referred_schema'] = rschema
+            
+            local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns']
+            
+            local_cols.append(scol)
+            remote_cols.append(rcol)
+
+        return fkeys.values()
 
     def reflecttable(self, connection, table, include_columns):
 
index 1b754b2b49fbc6d84b846bf2759a8ae0521b6a61..550f26e6762e032657f0c855fa895a6a974aa4b7 100644 (file)
@@ -62,8 +62,8 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect):
         self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
         
     def initialize(self, connection):
+        super(MSDialect_pyodbc, self).initialize(connection)
         pyodbc = self.dbapi
-        self.server_version_info = self._get_server_version_info(connection)
         
         dbapi_con = connection.connection
         
index dfe59924622e1f9600ee9d8813df9ea736d1684a..a57042796909d47299e093c9a3b56aa614708a7b 100644 (file)
@@ -311,7 +311,11 @@ class Dialect(object):
         raise NotImplementedError()
 
     def get_default_schema_name(self, connection):
-        """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`."""
+        """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.
+        
+        DEPRECATED.  moving this towards dialect.default_schema_name (not complete).
+        
+        """
 
         raise NotImplementedError()
 
index f3969eb57e0fb5b1fdf3afb51d3fcb6aa436a3a6..413de171a20a711a756fab47cd4b83121cbc4545 100644 (file)
@@ -89,7 +89,9 @@ class DefaultDialect(base.Dialect):
     def initialize(self, connection):
         if hasattr(self, '_get_server_version_info'):
             self.server_version_info = self._get_server_version_info(connection)
-
+        if hasattr(self, '_get_default_schema_name'):
+            self.default_schema_name = self._get_default_schema_name(connection)
+        
     @classmethod
     def type_descriptor(cls, typeobj):
         """Provide a database-specific ``TypeEngine`` object, given
index 077acd3940c454935c50858eec39c24ba2dcd54f..57b140580239fe915b2df1619f5dea37ab68ae5f 100644 (file)
@@ -626,7 +626,7 @@ class ComparesTables(object):
             elif against(('mysql', '<', (5, 0))):
                 # ignore reflection of bogus db-generated DefaultClause()
                 pass
-            elif not c.primary_key or not against('postgres'):
+            elif not c.primary_key or not against('postgres', 'mssql'):
                 #print repr(c)
                 assert reflected_c.default is None, reflected_c.default
 
index 6b9760d38d7946eea9d920741632b8d2f7a769de..3d115b9a3d188fe5bd52ff450f7235900fd442f4 100644 (file)
@@ -595,7 +595,7 @@ class ReflectionTest(TestBase, ComparesTables):
             m9.reflect()
             self.assert_(not m9.tables)
 
-    @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle')
+    @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle', 'mssql')
     def test_index_reflection(self):
         m1 = MetaData(testing.db)
         t1 = Table('party', m1,