From: Mike Bayer Date: Tue, 7 Aug 2012 20:51:14 +0000 (-0400) Subject: - [feature] SQL Server dialect can be given X-Git-Tag: rel_0_8_0b1~281 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=c94756cce81a940a6a6f09e1fdf8ccfe8d1c45c1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [feature] SQL Server dialect can be given database-qualified schema names, i.e. "schema='mydatabase.dbo'"; reflection operations will detect this, split the schema among the "." to get the owner separately, and emit a "USE mydatabase" statement before reflecting targets within the "dbo" owner; the existing database returned from DB_NAME() is then restored. --- diff --git a/CHANGES b/CHANGES index 70c20ee189..068fcede56 100644 --- a/CHANGES +++ b/CHANGES @@ -390,6 +390,16 @@ underneath "0.7.xx". this. [ticket:2363] - mssql + - [feature] SQL Server dialect can be given + database-qualified schema names, + i.e. "schema='mydatabase.dbo'"; reflection + operations will detect this, split the schema + among the "." to get the owner separately, + and emit a "USE mydatabase" statement before + reflecting targets within the "dbo" owner; + the existing database returned from + DB_NAME() is then restored. + - [bug] removed legacy behavior whereby a column comparison to a scalar SELECT via == would coerce to an IN with the SQL server diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 668b32d143..91f396d23f 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -180,6 +180,7 @@ from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ VARBINARY +from ...util import update_wrapper from . import information_schema as ischema MS_2008_VERSION = (10,) @@ -1063,6 +1064,38 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): result = '.'.join([self.quote(x, force) for x in schema.split('.')]) return result +def _db_plus_owner_listing(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db(dbname, connection, fn, dialect, connection, + dbname, owner, schema, **kw) + return update_wrapper(wrap, fn) + +def _db_plus_owner(fn): + def wrap(dialect, connection, tablename, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db(dbname, connection, fn, dialect, connection, + tablename, dbname, owner, schema, **kw) + return update_wrapper(wrap, fn) + +def _switch_db(dbname, connection, fn, *arg, **kw): + if dbname: + current_db = connection.scalar("select db_name()") + connection.execute("use %s" % dbname) + try: + return fn(*arg, **kw) + finally: + if dbname: + connection.execute("use %s" % current_db) + +def _owner_plus_db(dialect, schema): + if not schema: + return None, dialect.default_schema_name + elif "." in schema: + return schema.split(".", 1) + else: + return None, schema + class MSDialect(default.DefaultDialect): name = 'mssql' supports_default_values = True @@ -1130,7 +1163,7 @@ class MSDialect(default.DefaultDialect): self.implicit_returning = True def _get_default_schema_name(self, connection): - user_name = connection.scalar("SELECT user_name() as user_name;") + user_name = connection.scalar("SELECT user_name()") if user_name is not None: # now, get the default schema query = sql.text(""" @@ -1153,14 +1186,14 @@ class MSDialect(default.DefaultDialect): else: return column - def has_table(self, connection, tablename, schema=None): - current_schema = schema or self.default_schema_name + @_db_plus_owner + def has_table(self, connection, tablename, dbname, owner, schema): columns = ischema.columns - whereclause = self._unicode_cast(columns.c.table_name)==tablename - if current_schema: + whereclause = self._unicode_cast(columns.c.table_name) == tablename + if owner: whereclause = sql.and_(whereclause, - columns.c.table_schema==current_schema) + columns.c.table_schema == owner) s = sql.select([columns], whereclause) c = connection.execute(s) return c.first() is not None @@ -1174,12 +1207,12 @@ class MSDialect(default.DefaultDialect): return schema_names @reflection.cache - def get_table_names(self, connection, schema=None, **kw): - current_schema = schema or self.default_schema_name + @_db_plus_owner_listing + def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( - tables.c.table_schema == current_schema, + tables.c.table_schema == owner, tables.c.table_type == u'BASE TABLE' ), order_by=[tables.c.table_name] @@ -1188,12 +1221,12 @@ class MSDialect(default.DefaultDialect): return table_names @reflection.cache - def get_view_names(self, connection, schema=None, **kw): - current_schema = schema or self.default_schema_name + @_db_plus_owner_listing + def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( - tables.c.table_schema == current_schema, + tables.c.table_schema == owner, tables.c.table_type == u'VIEW' ), order_by=[tables.c.table_name] @@ -1202,15 +1235,13 @@ class MSDialect(default.DefaultDialect): return view_names @reflection.cache - def get_indexes(self, connection, tablename, schema=None, **kw): + @_db_plus_owner + def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): # using system catalogs, don't support index reflection # below MS 2005 if self.server_version_info < MS_2005_VERSION: return [] - current_schema = schema or self.default_schema_name - full_tname = "%s.%s" % (current_schema, tablename) - rp = connection.execute( sql.text("select ind.index_id, ind.is_unique, ind.name " "from sys.indexes as ind join sys.tables as tab on " @@ -1222,20 +1253,20 @@ class MSDialect(default.DefaultDialect): bindparams=[ sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', current_schema, + sql.bindparam('schname', owner, sqltypes.String(convert_unicode=True)) ], typemap = { - 'name':sqltypes.Unicode() + 'name': sqltypes.Unicode() } ) ) indexes = {} for row in rp: indexes[row['index_id']] = { - 'name':row['name'], - 'unique':row['is_unique'] == 1, - 'column_names':[] + 'name': row['name'], + 'unique': row['is_unique'] == 1, + 'column_names': [] } rp = connection.execute( sql.text( @@ -1251,11 +1282,11 @@ class MSDialect(default.DefaultDialect): bindparams=[ sql.bindparam('tabname', tablename, sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', current_schema, + sql.bindparam('schname', owner, sqltypes.String(convert_unicode=True)) ], typemap = { - 'name':sqltypes.Unicode() + 'name': sqltypes.Unicode() } ), ) @@ -1266,9 +1297,8 @@ class MSDialect(default.DefaultDialect): return indexes.values() @reflection.cache - def get_view_definition(self, connection, viewname, schema=None, **kw): - current_schema = schema or self.default_schema_name - + @_db_plus_owner + def get_view_definition(self, connection, viewname, dbname, owner, schema, **kw): rp = connection.execute( sql.text( "select definition from sys.sql_modules as mod, " @@ -1281,7 +1311,7 @@ class MSDialect(default.DefaultDialect): bindparams=[ sql.bindparam('viewname', viewname, sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', current_schema, + sql.bindparam('schname', owner, sqltypes.String(convert_unicode=True)) ] ) @@ -1292,15 +1322,15 @@ class MSDialect(default.DefaultDialect): return view_def @reflection.cache - def get_columns(self, connection, tablename, schema=None, **kw): + @_db_plus_owner + def get_columns(self, connection, tablename, dbname, owner, schema, **kw): # Get base columns - current_schema = schema or self.default_schema_name columns = ischema.columns - if current_schema: - whereclause = sql.and_(columns.c.table_name==tablename, - columns.c.table_schema==current_schema) + if owner: + whereclause = sql.and_(columns.c.table_name == tablename, + columns.c.table_schema == owner) else: - whereclause = columns.c.table_name==tablename + whereclause = columns.c.table_name == tablename s = sql.select([columns], whereclause, order_by=[columns.c.ordinal_position]) @@ -1361,7 +1391,7 @@ class MSDialect(default.DefaultDialect): # We also run an sp_columns to check for identity columns: cursor = connection.execute("sp_columns @table_name = '%s', " "@table_owner = '%s'" - % (tablename, current_schema)) + % (tablename, owner)) ic = None while True: row = cursor.fetchone() @@ -1377,7 +1407,7 @@ class MSDialect(default.DefaultDialect): cursor.close() if ic is not None and self.server_version_info >= MS_2005_VERSION: - table_fullname = "%s.%s" % (current_schema, tablename) + table_fullname = "%s.%s" % (owner, tablename) cursor = connection.execute( "select ident_seed('%s'), ident_incr('%s')" % (table_fullname, table_fullname) @@ -1386,52 +1416,36 @@ class MSDialect(default.DefaultDialect): row = cursor.first() if row is not None and row[0] is not None: colmap[ic]['sequence'].update({ - 'start' : int(row[0]), - 'increment' : int(row[1]) + 'start': int(row[0]), + 'increment': int(row[1]) }) return cols @reflection.cache - def get_pk_constraint(self, connection, tablename, schema=None, **kw): - current_schema = schema or self.default_schema_name + @_db_plus_owner + def get_pk_constraint(self, connection, tablename, dbname, owner, schema, **kw): pkeys = [] - # information_schema.referential_constraints - RR = ischema.ref_constraints - # information_schema.table_constraints TC = ischema.constraints - # information_schema.constraint_column_usage: - # the constrained column - C = ischema.key_constraints.alias('C') - # information_schema.constraint_column_usage: - # the referenced column - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias('C') # Primary key constraints s = sql.select([C.c.column_name, TC.c.constraint_type], sql.and_(TC.c.constraint_name == C.c.constraint_name, C.c.table_name == tablename, - C.c.table_schema == current_schema) + C.c.table_schema == owner) ) c = connection.execute(s) for row in c: if 'PRIMARY' in row[TC.c.constraint_type.name]: pkeys.append(row[0]) - return {'constrained_columns':pkeys, 'name':None} + return {'constrained_columns': pkeys, 'name': None} @reflection.cache - def get_foreign_keys(self, connection, tablename, schema=None, **kw): - current_schema = schema or self.default_schema_name - # Add constraints - #information_schema.referential_constraints + @_db_plus_owner + def get_foreign_keys(self, connection, tablename, dbname, owner, schema, **kw): RR = ischema.ref_constraints - # information_schema.table_constraints - TC = ischema.constraints - # information_schema.constraint_column_usage: - # the constrained column - C = ischema.key_constraints.alias('C') - # information_schema.constraint_column_usage: - # the referenced column - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias('C') + R = ischema.key_constraints.alias('R') # Foreign key constraints s = sql.select([C.c.column_name, @@ -1440,13 +1454,13 @@ class MSDialect(default.DefaultDialect): RR.c.update_rule, RR.c.delete_rule], sql.and_(C.c.table_name == tablename, - C.c.table_schema == current_schema, + C.c.table_schema == owner, C.c.constraint_name == RR.c.constraint_name, R.c.constraint_name == RR.c.unique_constraint_name, C.c.ordinal_position == R.c.ordinal_position ), - order_by = [ + order_by= [ RR.c.constraint_name, R.c.ordinal_position]) @@ -1457,11 +1471,11 @@ class MSDialect(default.DefaultDialect): def fkey_rec(): return { - 'name' : None, - 'constrained_columns' : [], - 'referred_schema' : None, - 'referred_table' : None, - 'referred_columns' : [] + 'name': None, + 'constrained_columns': [], + 'referred_schema': None, + 'referred_table': None, + 'referred_columns': [] } fkeys = util.defaultdict(fkey_rec) @@ -1473,8 +1487,9 @@ class MSDialect(default.DefaultDialect): rec['name'] = rfknm if not rec['referred_table']: rec['referred_table'] = rtbl - - if schema is not None or current_schema != rschema: + if schema is not None or owner != rschema: + if dbname: + rschema = dbname + "." + rschema rec['referred_schema'] = rschema local_cols, remote_cols = \ diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index b254a98e85..7627b65839 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -673,6 +673,41 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert isinstance(t1.c.id.type, Integer) assert isinstance(t1.c.data.type, types.NullType) + + @testing.provide_metadata + def test_db_qualified_items(self): + metadata = self.metadata + Table('foo', metadata, Column('id', Integer, primary_key=True)) + Table('bar', metadata, + Column('id', Integer, primary_key=True), + Column('foo_id', Integer, ForeignKey('foo.id', name="fkfoo")) + ) + metadata.create_all() + + dbname = testing.db.scalar("select db_name()") + owner = testing.db.scalar("SELECT user_name()") + + inspector = inspect(testing.db) + bar_via_db = inspector.get_foreign_keys( + "bar", schema="%s.%s" % (dbname, owner)) + eq_( + bar_via_db, + [{ + 'referred_table': 'foo', + 'referred_columns': ['id'], + 'referred_schema': 'test.dbo', + 'name': 'fkfoo', + 'constrained_columns': ['foo_id']}] + ) + + assert testing.db.has_table("bar", schema="test.dbo") + + m2 = MetaData() + Table('bar', m2, schema="test.dbo", autoload=True, + autoload_with=testing.db) + eq_(m2.tables["test.dbo.foo"].schema, "test.dbo") + + @testing.provide_metadata def test_indexes_cols(self): metadata = self.metadata