From: Mike Bayer Date: Tue, 15 Sep 2020 22:48:36 +0000 (-0400) Subject: Correct for SQL Server temp table owner X-Git-Tag: rel_1_4_0b1~94^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=07ba8e0a37daeb4304e8fede43b13e402b01dbeb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Correct for SQL Server temp table owner on my machine, the owner for a temp table comes out as dbo, and i am testing against a CI machine. im not sure what happens on a CI machine except perhaps that it provisions new databases is changing things. in any case, since we are searching the tempdb for the name, get the schema/owner also. Also refines the test to use a single connection and a transaction that rolls back, doesn't hang here but let's see what CI does. Change-Id: I522596ccc526cdab14c516b9a566ff666ac57dd6 --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 519d74d89e..7564536a53 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2938,45 +2938,54 @@ class MSDialect(default.DefaultDialect): return view_def def _get_internal_temp_table_name(self, connection, tablename): - result = connection.execute( - sql.text( - "select table_name " - "from tempdb.information_schema.tables " - "where table_name like :p1" - ), - { - "p1": tablename - + (("___%") if not tablename.startswith("##") else "") - }, - ).fetchall() - if len(result) > 1: - raise exc.UnreflectableTableError( - "Found more than one temporary table named '%s' in tempdb " - "at this time. Cannot reliably resolve that name to its " - "internal table name." % tablename + # it's likely that schema is always "dbo", but since we can + # get it here, let's get it. + # see https://stackoverflow.com/questions/8311959/ + # specifying-schema-for-temporary-tables + + try: + return connection.execute( + sql.text( + "select table_schema, table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + { + "p1": tablename + + (("___%") if not tablename.startswith("##") else "") + }, + ).one() + except exc.MultipleResultsFound as me: + util.raise_( + exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ), + replace_context=me, ) - elif len(result) == 0: - raise exc.NoSuchTableError( - "Unable to find a temporary table named '%s' in tempdb." - % tablename + except exc.NoResultFound as ne: + util.raise_( + exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ), + replace_context=ne, ) - else: - return result[0][0] @reflection.cache @_db_plus_owner def get_columns(self, connection, tablename, dbname, owner, schema, **kw): is_temp_table = tablename.startswith("#") if is_temp_table: - tablename = self._get_internal_temp_table_name( + owner, tablename = self._get_internal_temp_table_name( connection, tablename ) - # Get base columns - columns = ( - ischema.mssql_temp_table_columns - if is_temp_table - else ischema.columns - ) + + columns = ischema.mssql_temp_table_columns + else: + columns = ischema.columns + computed_cols = ischema.computed_columns if owner: whereclause = sql.and_( @@ -3016,6 +3025,7 @@ class MSDialect(default.DefaultDialect): ) c = connection.execution_options(future_result=True).execute(s) + cols = [] for row in c.mappings(): name = row[columns.c.column_name] diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 19b8c187cb..bd64bedcbc 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -25,6 +25,7 @@ from sqlalchemy.schema import CreateIndex from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import ComparesTables from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ @@ -255,45 +256,48 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): ("global_temp", "##tmp", True), ("nonexistent", "#no_es_bueno", False), id_="iaa", + argnames="table_name, exists", ) - def test_temporary_table(self, table_name, exists): + def test_temporary_table(self, connection, table_name, exists): metadata = self.metadata if exists: - # TODO: why this test hangs when using the connection fixture? - with testing.db.connect() as conn: - tran = conn.begin() - conn.execute( - ( - "CREATE TABLE %s " - "(id int primary key, txt nvarchar(50), dt2 datetime2)" # noqa - ) - % table_name - ) - conn.execute( - ( - "INSERT INTO %s (id, txt, dt2) VALUES " - "(1, N'foo', '2020-01-01 01:01:01'), " - "(2, N'bar', '2020-02-02 02:02:02') " - ) - % table_name + tt = Table( + table_name, + self.metadata, + Column("id", Integer, primary_key=True), + Column("txt", mssql.NVARCHAR(50)), + Column("dt2", mssql.DATETIME2), + ) + tt.create(connection) + connection.execute( + tt.insert(), + [ + { + "id": 1, + "txt": u"foo", + "dt2": datetime.datetime(2020, 1, 1, 1, 1, 1), + }, + { + "id": 2, + "txt": u"bar", + "dt2": datetime.datetime(2020, 2, 2, 2, 2, 2), + }, + ], + ) + + if not exists: + with expect_raises(exc.NoSuchTableError): + Table( + table_name, metadata, autoload_with=connection, ) - tran.commit() - tran = conn.begin() - try: - tmp_t = Table( - table_name, metadata, autoload_with=testing.db, - ) - tran.commit() - result = conn.execute( - tmp_t.select().where(tmp_t.c.id == 2) - ).fetchall() - eq_( - result, - [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))], - ) - except exc.NoSuchTableError: - if exists: - raise + else: + tmp_t = Table(table_name, metadata, autoload_with=connection) + result = connection.execute( + tmp_t.select().where(tmp_t.c.id == 2) + ).fetchall() + eq_( + result, [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))], + ) @testing.provide_metadata def test_db_qualified_items(self):