]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] SQL Server dialect can be given
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Aug 2012 20:51:14 +0000 (16:51 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Aug 2012 20:51:14 +0000 (16:51 -0400)
database-qualified schema names,
i.e. "schema='mydatabase.dbo'"; reflection
operations will detect this, split the schema
among the "." to get the owner separately,
and emit a "USE mydatabase" statement before
reflecting targets within the "dbo" owner;
the existing database returned from
DB_NAME() is then restored.

CHANGES
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/test_mssql.py

diff --git a/CHANGES b/CHANGES
index 70c20ee189e3c641d2da2a255a04e5edb79d2e46..068fcede56475b819fff0e8c6105cd7cf1403a96 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -390,6 +390,16 @@ underneath "0.7.xx".
     this.  [ticket:2363]
 
 - mssql
+  - [feature] SQL Server dialect can be given
+    database-qualified schema names,
+    i.e. "schema='mydatabase.dbo'"; reflection
+    operations will detect this, split the schema
+    among the "." to get the owner separately,
+    and emit a "USE mydatabase" statement before
+    reflecting targets within the "dbo" owner;
+    the existing database returned from
+    DB_NAME() is then restored.
+
   - [bug] removed legacy behavior whereby
     a column comparison to a scalar SELECT via
     == would coerce to an IN with the SQL server
index 668b32d14373b6f21ede3691c29942f032d03b35..91f396d23f50f6ba81eab459bdee015ca2d1d78c 100644 (file)
@@ -180,6 +180,7 @@ from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \
                                 FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\
                                 VARBINARY
 
+from ...util import update_wrapper
 from . import information_schema as ischema
 
 MS_2008_VERSION = (10,)
@@ -1063,6 +1064,38 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer):
         result = '.'.join([self.quote(x, force) for x in schema.split('.')])
         return result
 
+def _db_plus_owner_listing(fn):
+    def wrap(dialect, connection, schema=None, **kw):
+        dbname, owner = _owner_plus_db(dialect, schema)
+        return _switch_db(dbname, connection, fn, dialect, connection,
+                            dbname, owner, schema, **kw)
+    return update_wrapper(wrap, fn)
+
+def _db_plus_owner(fn):
+    def wrap(dialect, connection, tablename, schema=None, **kw):
+        dbname, owner = _owner_plus_db(dialect, schema)
+        return _switch_db(dbname, connection, fn, dialect, connection,
+                            tablename, dbname, owner, schema, **kw)
+    return update_wrapper(wrap, fn)
+
+def _switch_db(dbname, connection, fn, *arg, **kw):
+    if dbname:
+        current_db = connection.scalar("select db_name()")
+        connection.execute("use %s" % dbname)
+    try:
+        return fn(*arg, **kw)
+    finally:
+        if dbname:
+            connection.execute("use %s" % current_db)
+
+def _owner_plus_db(dialect, schema):
+    if not schema:
+        return None, dialect.default_schema_name
+    elif "." in schema:
+        return schema.split(".", 1)
+    else:
+        return None, schema
+
 class MSDialect(default.DefaultDialect):
     name = 'mssql'
     supports_default_values = True
@@ -1130,7 +1163,7 @@ class MSDialect(default.DefaultDialect):
             self.implicit_returning = True
 
     def _get_default_schema_name(self, connection):
-        user_name = connection.scalar("SELECT user_name() as user_name;")
+        user_name = connection.scalar("SELECT user_name()")
         if user_name is not None:
             # now, get the default schema
             query = sql.text("""
@@ -1153,14 +1186,14 @@ class MSDialect(default.DefaultDialect):
         else:
             return column
 
-    def has_table(self, connection, tablename, schema=None):
-        current_schema = schema or self.default_schema_name
+    @_db_plus_owner
+    def has_table(self, connection, tablename, dbname, owner, schema):
         columns = ischema.columns
 
-        whereclause = self._unicode_cast(columns.c.table_name)==tablename
-        if current_schema:
+        whereclause = self._unicode_cast(columns.c.table_name) == tablename
+        if owner:
             whereclause = sql.and_(whereclause,
-                                   columns.c.table_schema==current_schema)
+                                   columns.c.table_schema == owner)
         s = sql.select([columns], whereclause)
         c = connection.execute(s)
         return c.first() is not None
@@ -1174,12 +1207,12 @@ class MSDialect(default.DefaultDialect):
         return schema_names
 
     @reflection.cache
-    def get_table_names(self, connection, schema=None, **kw):
-        current_schema = schema or self.default_schema_name
+    @_db_plus_owner_listing
+    def get_table_names(self, connection, dbname, owner, schema, **kw):
         tables = ischema.tables
         s = sql.select([tables.c.table_name],
             sql.and_(
-                tables.c.table_schema == current_schema,
+                tables.c.table_schema == owner,
                 tables.c.table_type == u'BASE TABLE'
             ),
             order_by=[tables.c.table_name]
@@ -1188,12 +1221,12 @@ class MSDialect(default.DefaultDialect):
         return table_names
 
     @reflection.cache
-    def get_view_names(self, connection, schema=None, **kw):
-        current_schema = schema or self.default_schema_name
+    @_db_plus_owner_listing
+    def get_view_names(self, connection, dbname, owner, schema, **kw):
         tables = ischema.tables
         s = sql.select([tables.c.table_name],
             sql.and_(
-                tables.c.table_schema == current_schema,
+                tables.c.table_schema == owner,
                 tables.c.table_type == u'VIEW'
             ),
             order_by=[tables.c.table_name]
@@ -1202,15 +1235,13 @@ class MSDialect(default.DefaultDialect):
         return view_names
 
     @reflection.cache
-    def get_indexes(self, connection, tablename, schema=None, **kw):
+    @_db_plus_owner
+    def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
         # using system catalogs, don't support index reflection
         # below MS 2005
         if self.server_version_info < MS_2005_VERSION:
             return []
 
-        current_schema = schema or self.default_schema_name
-        full_tname = "%s.%s" % (current_schema, tablename)
-
         rp = connection.execute(
             sql.text("select ind.index_id, ind.is_unique, ind.name "
                 "from sys.indexes as ind join sys.tables as tab on "
@@ -1222,20 +1253,20 @@ class MSDialect(default.DefaultDialect):
                 bindparams=[
                     sql.bindparam('tabname', tablename,
                                     sqltypes.String(convert_unicode=True)),
-                    sql.bindparam('schname', current_schema,
+                    sql.bindparam('schname', owner,
                                     sqltypes.String(convert_unicode=True))
                 ],
                 typemap = {
-                    'name':sqltypes.Unicode()
+                    'name': sqltypes.Unicode()
                 }
             )
         )
         indexes = {}
         for row in rp:
             indexes[row['index_id']] = {
-                'name':row['name'],
-                'unique':row['is_unique'] == 1,
-                'column_names':[]
+                'name': row['name'],
+                'unique': row['is_unique'] == 1,
+                'column_names': []
             }
         rp = connection.execute(
             sql.text(
@@ -1251,11 +1282,11 @@ class MSDialect(default.DefaultDialect):
                         bindparams=[
                             sql.bindparam('tabname', tablename,
                                     sqltypes.String(convert_unicode=True)),
-                            sql.bindparam('schname', current_schema,
+                            sql.bindparam('schname', owner,
                                     sqltypes.String(convert_unicode=True))
                         ],
                         typemap = {
-                            'name':sqltypes.Unicode()
+                            'name': sqltypes.Unicode()
                         }
                         ),
             )
@@ -1266,9 +1297,8 @@ class MSDialect(default.DefaultDialect):
         return indexes.values()
 
     @reflection.cache
-    def get_view_definition(self, connection, viewname, schema=None, **kw):
-        current_schema = schema or self.default_schema_name
-
+    @_db_plus_owner
+    def get_view_definition(self, connection, viewname, dbname, owner, schema, **kw):
         rp = connection.execute(
             sql.text(
                 "select definition from sys.sql_modules as mod, "
@@ -1281,7 +1311,7 @@ class MSDialect(default.DefaultDialect):
                 bindparams=[
                     sql.bindparam('viewname', viewname,
                             sqltypes.String(convert_unicode=True)),
-                    sql.bindparam('schname', current_schema,
+                    sql.bindparam('schname', owner,
                             sqltypes.String(convert_unicode=True))
                 ]
             )
@@ -1292,15 +1322,15 @@ class MSDialect(default.DefaultDialect):
             return view_def
 
     @reflection.cache
-    def get_columns(self, connection, tablename, schema=None, **kw):
+    @_db_plus_owner
+    def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
         # Get base columns
-        current_schema = schema or self.default_schema_name
         columns = ischema.columns
-        if current_schema:
-            whereclause = sql.and_(columns.c.table_name==tablename,
-                                   columns.c.table_schema==current_schema)
+        if owner:
+            whereclause = sql.and_(columns.c.table_name == tablename,
+                                   columns.c.table_schema == owner)
         else:
-            whereclause = columns.c.table_name==tablename
+            whereclause = columns.c.table_name == tablename
         s = sql.select([columns], whereclause,
                         order_by=[columns.c.ordinal_position])
 
@@ -1361,7 +1391,7 @@ class MSDialect(default.DefaultDialect):
         # We also run an sp_columns to check for identity columns:
         cursor = connection.execute("sp_columns @table_name = '%s', "
                                     "@table_owner = '%s'"
-                                    % (tablename, current_schema))
+                                    % (tablename, owner))
         ic = None
         while True:
             row = cursor.fetchone()
@@ -1377,7 +1407,7 @@ class MSDialect(default.DefaultDialect):
         cursor.close()
 
         if ic is not None and self.server_version_info >= MS_2005_VERSION:
-            table_fullname = "%s.%s" % (current_schema, tablename)
+            table_fullname = "%s.%s" % (owner, tablename)
             cursor = connection.execute(
                 "select ident_seed('%s'), ident_incr('%s')"
                 % (table_fullname, table_fullname)
@@ -1386,52 +1416,36 @@ class MSDialect(default.DefaultDialect):
             row = cursor.first()
             if row is not None and row[0] is not None:
                 colmap[ic]['sequence'].update({
-                    'start' : int(row[0]),
-                    'increment' : int(row[1])
+                    'start': int(row[0]),
+                    'increment': int(row[1])
                 })
         return cols
 
     @reflection.cache
-    def get_pk_constraint(self, connection, tablename, schema=None, **kw):
-        current_schema = schema or self.default_schema_name
+    @_db_plus_owner
+    def get_pk_constraint(self, connection, tablename, dbname, owner, schema, **kw):
         pkeys = []
-        # information_schema.referential_constraints
-        RR = ischema.ref_constraints
-        # information_schema.table_constraints
         TC = ischema.constraints
-        # information_schema.constraint_column_usage:
-        # the constrained column
-        C  = ischema.key_constraints.alias('C')
-        # information_schema.constraint_column_usage:
-        # the referenced column
-        R  = ischema.key_constraints.alias('R')
+        C = ischema.key_constraints.alias('C')
 
         # Primary key constraints
         s = sql.select([C.c.column_name, TC.c.constraint_type],
             sql.and_(TC.c.constraint_name == C.c.constraint_name,
                      C.c.table_name == tablename,
-                     C.c.table_schema == current_schema)
+                     C.c.table_schema == owner)
         )
         c = connection.execute(s)
         for row in c:
             if 'PRIMARY' in row[TC.c.constraint_type.name]:
                 pkeys.append(row[0])
-        return {'constrained_columns':pkeys, 'name':None}
+        return {'constrained_columns': pkeys, 'name': None}
 
     @reflection.cache
-    def get_foreign_keys(self, connection, tablename, schema=None, **kw):
-        current_schema = schema or self.default_schema_name
-        # Add constraints
-        #information_schema.referential_constraints
+    @_db_plus_owner
+    def get_foreign_keys(self, connection, tablename, dbname, owner, schema, **kw):
         RR = ischema.ref_constraints
-        # information_schema.table_constraints
-        TC = ischema.constraints
-        # information_schema.constraint_column_usage:
-        # the constrained column
-        C  = ischema.key_constraints.alias('C')
-        # information_schema.constraint_column_usage:
-        # the referenced column
-        R  = ischema.key_constraints.alias('R')
+        C = ischema.key_constraints.alias('C')
+        R = ischema.key_constraints.alias('R')
 
         # Foreign key constraints
         s = sql.select([C.c.column_name,
@@ -1440,13 +1454,13 @@ class MSDialect(default.DefaultDialect):
                         RR.c.update_rule,
                         RR.c.delete_rule],
                        sql.and_(C.c.table_name == tablename,
-                                C.c.table_schema == current_schema,
+                                C.c.table_schema == owner,
                                 C.c.constraint_name == RR.c.constraint_name,
                                 R.c.constraint_name ==
                                                 RR.c.unique_constraint_name,
                                 C.c.ordinal_position == R.c.ordinal_position
                                 ),
-                       order_by = [
+                       order_by= [
                                     RR.c.constraint_name,
                                     R.c.ordinal_position])
 
@@ -1457,11 +1471,11 @@ class MSDialect(default.DefaultDialect):
 
         def fkey_rec():
             return {
-                'name' : None,
-                'constrained_columns' : [],
-                'referred_schema' : None,
-                'referred_table' : None,
-                'referred_columns' : []
+                'name': None,
+                'constrained_columns': [],
+                'referred_schema': None,
+                'referred_table': None,
+                'referred_columns': []
             }
 
         fkeys = util.defaultdict(fkey_rec)
@@ -1473,8 +1487,9 @@ class MSDialect(default.DefaultDialect):
             rec['name'] = rfknm
             if not rec['referred_table']:
                 rec['referred_table'] = rtbl
-
-                if schema is not None or current_schema != rschema:
+                if schema is not None or owner != rschema:
+                    if dbname:
+                        rschema = dbname + "." + rschema
                     rec['referred_schema'] = rschema
 
             local_cols, remote_cols = \
index b254a98e857c0b7ccfb71cb1131656ee1b754287..7627b65839a92aac458bbeff03a378f14b50e726 100644 (file)
@@ -673,6 +673,41 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         assert isinstance(t1.c.id.type, Integer)
         assert isinstance(t1.c.data.type, types.NullType)
 
+
+    @testing.provide_metadata
+    def test_db_qualified_items(self):
+        metadata = self.metadata
+        Table('foo', metadata, Column('id', Integer, primary_key=True))
+        Table('bar', metadata,
+                Column('id', Integer, primary_key=True),
+                Column('foo_id', Integer, ForeignKey('foo.id', name="fkfoo"))
+            )
+        metadata.create_all()
+
+        dbname = testing.db.scalar("select db_name()")
+        owner = testing.db.scalar("SELECT user_name()")
+
+        inspector = inspect(testing.db)
+        bar_via_db = inspector.get_foreign_keys(
+                            "bar", schema="%s.%s" % (dbname, owner))
+        eq_(
+            bar_via_db,
+            [{
+                'referred_table': 'foo',
+                'referred_columns': ['id'],
+                'referred_schema': 'test.dbo',
+                'name': 'fkfoo',
+                'constrained_columns': ['foo_id']}]
+        )
+
+        assert testing.db.has_table("bar", schema="test.dbo")
+
+        m2 = MetaData()
+        Table('bar', m2, schema="test.dbo", autoload=True,
+                                autoload_with=testing.db)
+        eq_(m2.tables["test.dbo.foo"].schema, "test.dbo")
+
+
     @testing.provide_metadata
     def test_indexes_cols(self):
         metadata = self.metadata