From c517f353d48cf71628b50e038098b563b026ce1d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 6 Jul 2009 00:18:57 +0000 Subject: [PATCH] - start moving get_default_schema_name to an initialized var - fix up MSSQL foreign key reflection ala oracle --- lib/sqlalchemy/dialects/mssql/base.py | 77 +++++++++++++------------ lib/sqlalchemy/dialects/mssql/pyodbc.py | 2 +- lib/sqlalchemy/engine/base.py | 6 +- lib/sqlalchemy/engine/default.py | 4 +- lib/sqlalchemy/test/testing.py | 2 +- test/engine/test_reflection.py | 2 +- 6 files changed, 51 insertions(+), 42 deletions(-) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index a5fdae2b73..fdec5741db 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1134,6 +1134,9 @@ class MSDialect(default.DefaultDialect): pass def get_default_schema_name(self, connection): + return self.default_schema_name + + def _get_default_schema_name(self, connection): user_name = connection.scalar("SELECT user_name() as user_name;") if user_name is not None: # now, get the default schema @@ -1157,7 +1160,7 @@ class MSDialect(default.DefaultDialect): def has_table(self, connection, tablename, schema=None): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name columns = ischema.columns s = sql.select([columns], current_schema @@ -1179,7 +1182,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_table_names(self, connection, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( @@ -1193,7 +1196,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name tables = ischema.tables s = sql.select([tables.c.table_name], sql.and_( @@ -1208,7 +1211,7 @@ class MSDialect(default.DefaultDialect): # The cursor reports it is closed after executing the sp. @reflection.cache def get_indexes(self, connection, tablename, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name full_tname = "%s.%s" % (current_schema, tablename) indexes = [] s = sql.text("exec sp_helpindex '%s'" % full_tname) @@ -1227,7 +1230,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_view_definition(self, connection, viewname, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name views = ischema.views s = sql.select([views.c.view_definition], sql.and_( @@ -1243,7 +1246,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, tablename, schema=None, **kw): # Get base columns - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name columns = ischema.columns s = sql.select([columns], current_schema @@ -1330,7 +1333,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_primary_keys(self, connection, tablename, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name pkeys = [] # Add constraints RR = ischema.ref_constraints #information_schema.referential_constraints @@ -1352,7 +1355,7 @@ class MSDialect(default.DefaultDialect): @reflection.cache def get_foreign_keys(self, connection, tablename, schema=None, **kw): - current_schema = schema or self.get_default_schema_name(connection) + current_schema = schema or self.default_schema_name # Add constraints RR = ischema.ref_constraints #information_schema.referential_constraints TC = ischema.constraints #information_schema.table_constraints @@ -1370,40 +1373,40 @@ class MSDialect(default.DefaultDialect): C.c.ordinal_position == R.c.ordinal_position ), order_by = [RR.c.constraint_name, R.c.ordinal_position]) - rows = connection.execute(s).fetchall() + # group rows by constraint ID, to handle multi-column FKs fkeys = [] fknm, scols, rcols = (None, [], []) - for r in rows: + + def fkey_rec(): + return { + 'name' : None, + 'constrained_columns' : [], + 'referred_schema' : None, + 'referred_table' : None, + 'referred_columns' : [] + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r - if rfknm != fknm: - if fknm: - fkeys.append({ - 'name' : fknm, - 'constrained_columns' : scols, - 'referred_schema' : rschema, - 'referred_table' : rtbl, - 'referred_columns' : rcols - }) - fknm, scols, rcols = (rfknm, [], []) - if not scol in scols: - scols.append(scol) - if not rcol in rcols: - rcols.append(rcol) - if fknm and scols: - # don't return the remote schema if no schema was specified and it - # is the default - if schema is None and current_schema == rschema: - rschema = None - fkeys.append({ - 'name' : fknm, - 'constrained_columns' : scols, - 'referred_schema' : rschema, - 'referred_table' : rtbl, - 'referred_columns' : rcols - }) - return fkeys + + rec = fkeys[rfknm] + rec['name'] = rfknm + if not rec['referred_table']: + rec['referred_table'] = rtbl + + if schema is not None or current_schema != rschema: + rec['referred_schema'] = rschema + + local_cols, remote_cols = rec['constrained_columns'], rec['referred_columns'] + + local_cols.append(scol) + remote_cols.append(rcol) + + return fkeys.values() def reflecttable(self, connection, table, include_columns): diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 1b754b2b49..550f26e676 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -62,8 +62,8 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') def initialize(self, connection): + super(MSDialect_pyodbc, self).initialize(connection) pyodbc = self.dbapi - self.server_version_info = self._get_server_version_info(connection) dbapi_con = connection.connection diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index dfe5992462..a570427969 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -311,7 +311,11 @@ class Dialect(object): raise NotImplementedError() def get_default_schema_name(self, connection): - """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`.""" + """Return the string name of the currently selected schema given a :class:`~sqlalchemy.engine.Connection`. + + DEPRECATED. moving this towards dialect.default_schema_name (not complete). + + """ raise NotImplementedError() diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index f3969eb57e..413de171a2 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -89,7 +89,9 @@ class DefaultDialect(base.Dialect): def initialize(self, connection): if hasattr(self, '_get_server_version_info'): self.server_version_info = self._get_server_version_info(connection) - + if hasattr(self, '_get_default_schema_name'): + self.default_schema_name = self._get_default_schema_name(connection) + @classmethod def type_descriptor(cls, typeobj): """Provide a database-specific ``TypeEngine`` object, given diff --git a/lib/sqlalchemy/test/testing.py b/lib/sqlalchemy/test/testing.py index 077acd3940..57b1405802 100644 --- a/lib/sqlalchemy/test/testing.py +++ b/lib/sqlalchemy/test/testing.py @@ -626,7 +626,7 @@ class ComparesTables(object): elif against(('mysql', '<', (5, 0))): # ignore reflection of bogus db-generated DefaultClause() pass - elif not c.primary_key or not against('postgres'): + elif not c.primary_key or not against('postgres', 'mssql'): #print repr(c) assert reflected_c.default is None, reflected_c.default diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 6b9760d38d..3d115b9a3d 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -595,7 +595,7 @@ class ReflectionTest(TestBase, ComparesTables): m9.reflect() self.assert_(not m9.tables) - @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle') + @testing.fails_on_everything_except('postgres', 'mysql', 'sqlite', 'oracle', 'mssql') def test_index_reflection(self): m1 = MetaData(testing.db) t1 = Table('party', m1, -- 2.47.3