]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Added MSSQL support for introspecting the default schema name for the logged in user...
authorMichael Trier <mtrier@gmail.com>
Tue, 23 Dec 2008 06:01:09 +0000 (06:01 +0000)
committerMichael Trier <mtrier@gmail.com>
Tue, 23 Dec 2008 06:01:09 +0000 (06:01 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py

diff --git a/CHANGES b/CHANGES
index 020459071d89755f5000f25025d82143b7e5f31b..c278e62d17a72224b92a913ef9571003984454dc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -202,6 +202,10 @@ CHANGES
       new doc section "Custom Comparators".
     
 - mssql
+    - ``get_default_schema_name`` is now reflected from the
+      database based on the user's default schema. This only works
+      with MSSQL 2005 and later. [ticket:1258]
+
     - Added collation support through the use of a new collation
       argument. This is supported on the following types: char,
       nchar, varchar, nvarchar, text, ntext. [ticket:1248]
index 7d2cae5b906271174db786ab1dc122b20b87e665..ecf8b246282780eb6412d952979552f16317ccd1 100644 (file)
@@ -955,7 +955,25 @@ class MSSQLDialect(default.DefaultDialect):
             newobj.dialect = self
         return newobj
 
+    @base.connection_memoize(('dialect', 'default_schema_name'))
     def get_default_schema_name(self, connection):
+        query = "SELECT user_name() as user_name;"
+        user_name = connection.scalar(sql.text(query))
+        if user_name is not None:
+            # now, get the default schema
+            query = """
+            SELECT default_schema_name FROM
+            sys.database_principals
+            WHERE name = :user_name
+            AND type = 'S'
+            """
+            try:
+                default_schema_name = connection.scalar(sql.text(query),
+                                                    user_name=user_name)
+                if default_schema_name is not None:
+                    return default_schema_name
+            except:
+                pass
         return self.schema_name
 
     def table_names(self, connection, schema):