]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Correct for SQL Server temp table owner
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Sep 2020 22:48:36 +0000 (18:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 15 Sep 2020 23:25:13 +0000 (19:25 -0400)
on my machine, the owner for a temp table comes out as
dbo, and i am testing against a CI machine.  im not sure
what happens on a CI machine except perhaps that it provisions
new databases is changing things.   in any case, since we
are searching the tempdb for the name, get the schema/owner also.

Also refines the test to use a single connection and a transaction
that rolls back, doesn't hang here but let's see what CI does.

Change-Id: I522596ccc526cdab14c516b9a566ff666ac57dd6

lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_reflection.py

index 519d74d89eade3d667faff71f3e2a887241ad266..7564536a537b731d38af5828c205d8fa5e4acfca 100644 (file)
@@ -2938,45 +2938,54 @@ class MSDialect(default.DefaultDialect):
             return view_def
 
     def _get_internal_temp_table_name(self, connection, tablename):
-        result = connection.execute(
-            sql.text(
-                "select table_name "
-                "from tempdb.information_schema.tables "
-                "where table_name like :p1"
-            ),
-            {
-                "p1": tablename
-                + (("___%") if not tablename.startswith("##") else "")
-            },
-        ).fetchall()
-        if len(result) > 1:
-            raise exc.UnreflectableTableError(
-                "Found more than one temporary table named '%s' in tempdb "
-                "at this time. Cannot reliably resolve that name to its "
-                "internal table name." % tablename
+        # it's likely that schema is always "dbo", but since we can
+        # get it here, let's get it.
+        # see https://stackoverflow.com/questions/8311959/
+        # specifying-schema-for-temporary-tables
+
+        try:
+            return connection.execute(
+                sql.text(
+                    "select table_schema, table_name "
+                    "from tempdb.information_schema.tables "
+                    "where table_name like :p1"
+                ),
+                {
+                    "p1": tablename
+                    + (("___%") if not tablename.startswith("##") else "")
+                },
+            ).one()
+        except exc.MultipleResultsFound as me:
+            util.raise_(
+                exc.UnreflectableTableError(
+                    "Found more than one temporary table named '%s' in tempdb "
+                    "at this time. Cannot reliably resolve that name to its "
+                    "internal table name." % tablename
+                ),
+                replace_context=me,
             )
-        elif len(result) == 0:
-            raise exc.NoSuchTableError(
-                "Unable to find a temporary table named '%s' in tempdb."
-                % tablename
+        except exc.NoResultFound as ne:
+            util.raise_(
+                exc.NoSuchTableError(
+                    "Unable to find a temporary table named '%s' in tempdb."
+                    % tablename
+                ),
+                replace_context=ne,
             )
-        else:
-            return result[0][0]
 
     @reflection.cache
     @_db_plus_owner
     def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
         is_temp_table = tablename.startswith("#")
         if is_temp_table:
-            tablename = self._get_internal_temp_table_name(
+            owner, tablename = self._get_internal_temp_table_name(
                 connection, tablename
             )
-        # Get base columns
-        columns = (
-            ischema.mssql_temp_table_columns
-            if is_temp_table
-            else ischema.columns
-        )
+
+            columns = ischema.mssql_temp_table_columns
+        else:
+            columns = ischema.columns
+
         computed_cols = ischema.computed_columns
         if owner:
             whereclause = sql.and_(
@@ -3016,6 +3025,7 @@ class MSDialect(default.DefaultDialect):
         )
 
         c = connection.execution_options(future_result=True).execute(s)
+
         cols = []
         for row in c.mappings():
             name = row[columns.c.column_name]
index 19b8c187cb15bb18479b2ceab06fff3197703ac4..bd64bedcbc88dcd36ed8ddbadcabc88458f362d6 100644 (file)
@@ -25,6 +25,7 @@ from sqlalchemy.schema import CreateIndex
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import ComparesTables
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
@@ -255,45 +256,48 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
         ("global_temp", "##tmp", True),
         ("nonexistent", "#no_es_bueno", False),
         id_="iaa",
+        argnames="table_name, exists",
     )
-    def test_temporary_table(self, table_name, exists):
+    def test_temporary_table(self, connection, table_name, exists):
         metadata = self.metadata
         if exists:
-            # TODO: why this test hangs when using the connection fixture?
-            with testing.db.connect() as conn:
-                tran = conn.begin()
-                conn.execute(
-                    (
-                        "CREATE TABLE %s "
-                        "(id int primary key, txt nvarchar(50), dt2 datetime2)"  # noqa
-                    )
-                    % table_name
-                )
-                conn.execute(
-                    (
-                        "INSERT INTO %s (id, txt, dt2) VALUES "
-                        "(1, N'foo', '2020-01-01 01:01:01'), "
-                        "(2, N'bar', '2020-02-02 02:02:02') "
-                    )
-                    % table_name
+            tt = Table(
+                table_name,
+                self.metadata,
+                Column("id", Integer, primary_key=True),
+                Column("txt", mssql.NVARCHAR(50)),
+                Column("dt2", mssql.DATETIME2),
+            )
+            tt.create(connection)
+            connection.execute(
+                tt.insert(),
+                [
+                    {
+                        "id": 1,
+                        "txt": u"foo",
+                        "dt2": datetime.datetime(2020, 1, 1, 1, 1, 1),
+                    },
+                    {
+                        "id": 2,
+                        "txt": u"bar",
+                        "dt2": datetime.datetime(2020, 2, 2, 2, 2, 2),
+                    },
+                ],
+            )
+
+        if not exists:
+            with expect_raises(exc.NoSuchTableError):
+                Table(
+                    table_name, metadata, autoload_with=connection,
                 )
-                tran.commit()
-                tran = conn.begin()
-                try:
-                    tmp_t = Table(
-                        table_name, metadata, autoload_with=testing.db,
-                    )
-                    tran.commit()
-                    result = conn.execute(
-                        tmp_t.select().where(tmp_t.c.id == 2)
-                    ).fetchall()
-                    eq_(
-                        result,
-                        [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))],
-                    )
-                except exc.NoSuchTableError:
-                    if exists:
-                        raise
+        else:
+            tmp_t = Table(table_name, metadata, autoload_with=connection)
+            result = connection.execute(
+                tmp_t.select().where(tmp_t.c.id == 2)
+            ).fetchall()
+            eq_(
+                result, [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))],
+            )
 
     @testing.provide_metadata
     def test_db_qualified_items(self):