]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix has_table() for mssql temporary tables
authorGord Thompson <gord@gordthompson.com>
Fri, 18 Sep 2020 22:33:17 +0000 (16:33 -0600)
committerGord Thompson <gord@gordthompson.com>
Sat, 19 Sep 2020 16:18:52 +0000 (10:18 -0600)
Fixes: #5597
Fixes the issue where :meth:`_reflection.has_table` always returns
``False`` for temporary tables.

Change-Id: I03ab04c849a157ce8fd28c07ec3bf4407b0f2c94

doc/build/changelog/unreleased_14/5597.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_14/5597.rst b/doc/build/changelog/unreleased_14/5597.rst
new file mode 100644 (file)
index 0000000..ee9343b
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: mssql, bug, schema
+    :tickets: 5597
+
+    Fixed an issue where :meth:`_reflection.has_table` always returned
+    ``False`` for temporary tables.
\ No newline at end of file
index 7564536a537b731d38af5828c205d8fa5e4acfca..2cbdc19aacc24e25064143bcebf8e68d1def71b9 100644 (file)
@@ -2756,21 +2756,34 @@ class MSDialect(default.DefaultDialect):
 
     @_db_plus_owner
     def has_table(self, connection, tablename, dbname, owner, schema):
-        tables = ischema.tables
+        if tablename.startswith("#"):  # temporary table
+            tables = ischema.mssql_temp_table_columns
+            result = connection.execute(
+                sql.select(tables.c.table_name)
+                .where(
+                    tables.c.table_name.like(
+                        self._temp_table_name_like_pattern(tablename)
+                    )
+                )
+                .limit(1)
+            )
+            return result.scalar() is not None
+        else:
+            tables = ischema.tables
 
-        s = sql.select(tables.c.table_name).where(
-            sql.and_(
-                tables.c.table_type == "BASE TABLE",
-                tables.c.table_name == tablename,
+            s = sql.select(tables.c.table_name).where(
+                sql.and_(
+                    tables.c.table_type == "BASE TABLE",
+                    tables.c.table_name == tablename,
+                )
             )
-        )
 
-        if owner:
-            s = s.where(tables.c.table_schema == owner)
+            if owner:
+                s = s.where(tables.c.table_schema == owner)
 
-        c = connection.execute(s)
+            c = connection.execute(s)
 
-        return c.first() is not None
+            return c.first() is not None
 
     @_db_plus_owner
     def has_sequence(self, connection, sequencename, dbname, owner, schema):
@@ -2937,6 +2950,9 @@ class MSDialect(default.DefaultDialect):
             view_def = rp.scalar()
             return view_def
 
+    def _temp_table_name_like_pattern(self, tablename):
+        return tablename + (("___%") if not tablename.startswith("##") else "")
+
     def _get_internal_temp_table_name(self, connection, tablename):
         # it's likely that schema is always "dbo", but since we can
         # get it here, let's get it.
@@ -2950,10 +2966,7 @@ class MSDialect(default.DefaultDialect):
                     "from tempdb.information_schema.tables "
                     "where table_name like :p1"
                 ),
-                {
-                    "p1": tablename
-                    + (("___%") if not tablename.startswith("##") else "")
-                },
+                {"p1": self._temp_table_name_like_pattern(tablename)},
             ).one()
         except exc.MultipleResultsFound as me:
             util.raise_(
index bd64bedcbc88dcd36ed8ddbadcabc88458f362d6..c7d012f5bf1252d92557178b10b886ab5b40555e 100644 (file)
@@ -299,6 +299,22 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
                 result, [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))],
             )
 
+    @testing.provide_metadata
+    @testing.combinations(
+        ("local_temp", "#tmp", True),
+        ("global_temp", "##tmp", True),
+        ("nonexistent", "#no_es_bueno", False),
+        id_="iaa",
+        argnames="table_name, exists",
+    )
+    def test_has_table_temporary(self, connection, table_name, exists):
+        if exists:
+            tt = Table(table_name, self.metadata, Column("id", Integer),)
+            tt.create(connection)
+
+        found_it = testing.db.dialect.has_table(connection, table_name)
+        eq_(found_it, exists)
+
     @testing.provide_metadata
     def test_db_qualified_items(self):
         metadata = self.metadata