]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve reflection for mssql temporary tables
authorGord Thompson <gord@gordthompson.com>
Wed, 12 Aug 2020 20:46:59 +0000 (14:46 -0600)
committerGord Thompson <gord@gordthompson.com>
Tue, 1 Sep 2020 14:05:51 +0000 (08:05 -0600)
Fixes: #5506
Change-Id: I718474d76e3c630a1b71e07eaa20cefb104d11de

doc/build/changelog/unreleased_14/5506.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/information_schema.py
lib/sqlalchemy/dialects/mssql/provision.py
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/mssql/test_reflection.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_14/5506.rst b/doc/build/changelog/unreleased_14/5506.rst
new file mode 100644 (file)
index 0000000..71b5732
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, mssql
+    :tickets: 5506
+
+    Added support for reflection of temporary tables with the SQL Server dialect.
+    Table names that are prefixed by a pound sign "#" are now introspected from
+    the MSSQL "tempdb" system catalog.
index f38c537fdcf167fc89db62953fdc00ae8e5f2fa8..ed17fb8631e5a4a96c358c32bb9babf819a46b02 100644 (file)
@@ -2913,11 +2913,46 @@ class MSDialect(default.DefaultDialect):
             view_def = rp.scalar()
             return view_def
 
+    def _get_internal_temp_table_name(self, connection, tablename):
+        result = connection.execute(
+            sql.text(
+                "select table_name "
+                "from tempdb.information_schema.tables "
+                "where table_name like :p1"
+            ),
+            {
+                "p1": tablename
+                + (("___%") if not tablename.startswith("##") else "")
+            },
+        ).fetchall()
+        if len(result) > 1:
+            raise exc.UnreflectableTableError(
+                "Found more than one temporary table named '%s' in tempdb "
+                "at this time. Cannot reliably resolve that name to its "
+                "internal table name." % tablename
+            )
+        elif len(result) == 0:
+            raise exc.NoSuchTableError(
+                "Unable to find a temporary table named '%s' in tempdb."
+                % tablename
+            )
+        else:
+            return result[0][0]
+
     @reflection.cache
     @_db_plus_owner
     def get_columns(self, connection, tablename, dbname, owner, schema, **kw):
+        is_temp_table = tablename.startswith("#")
+        if is_temp_table:
+            tablename = self._get_internal_temp_table_name(
+                connection, tablename
+            )
         # Get base columns
-        columns = ischema.columns
+        columns = (
+            ischema.mssql_temp_table_columns
+            if is_temp_table
+            else ischema.columns
+        )
         computed_cols = ischema.computed_columns
         if owner:
             whereclause = sql.and_(
index 6cdde83865dc7e1285a93bbd3546ee3384d11f3e..f80110b7d46cff704749aaaab711e583a1852155 100644 (file)
@@ -5,9 +5,6 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-# TODO: should be using the sys. catalog with SQL Server, not information
-# schema
-
 from ... import cast
 from ... import Column
 from ... import MetaData
@@ -93,6 +90,25 @@ columns = Table(
     schema="INFORMATION_SCHEMA",
 )
 
+mssql_temp_table_columns = Table(
+    "COLUMNS",
+    ischema,
+    Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
+    Column("TABLE_NAME", CoerceUnicode, key="table_name"),
+    Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
+    Column("IS_NULLABLE", Integer, key="is_nullable"),
+    Column("DATA_TYPE", String, key="data_type"),
+    Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
+    Column(
+        "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
+    ),
+    Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
+    Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
+    Column("COLUMN_DEFAULT", Integer, key="column_default"),
+    Column("COLLATION_NAME", String, key="collation_name"),
+    schema="tempdb.INFORMATION_SCHEMA",
+)
+
 constraints = Table(
     "TABLE_CONSTRAINTS",
     ischema,
index a5131eae6afc9e2fbde28277b4879542e3313c8b..269eb164f70117a8b01f137c56e9a862b35eb965 100644 (file)
@@ -2,8 +2,10 @@ from ... import create_engine
 from ... import exc
 from ...testing.provision import create_db
 from ...testing.provision import drop_db
+from ...testing.provision import get_temp_table_name
 from ...testing.provision import log
 from ...testing.provision import run_reap_dbs
+from ...testing.provision import temp_table_keyword_args
 
 
 @create_db.for_db("mssql")
@@ -72,3 +74,13 @@ def _reap_mssql_dbs(url, idents):
         log.info(
             "Dropped %d out of %d stale databases detected", dropped, total
         )
+
+
+@temp_table_keyword_args.for_db("mssql")
+def _mssql_temp_table_keyword_args(cfg, eng):
+    return {}
+
+
+@get_temp_table_name.for_db("mssql")
+def _mssql_get_temp_table_name(cfg, eng, base_name):
+    return "#" + base_name
index 0edaae4909c17fac478e95a116382f7c95e8ee3a..8bdad357c14ab8bf9b8b79cde53fc78f6e425850 100644 (file)
@@ -296,3 +296,18 @@ def temp_table_keyword_args(cfg, eng):
     raise NotImplementedError(
         "no temp table keyword args routine for cfg: %s" % eng.url
     )
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+    """Specify table name for creating a temporary Table.
+
+    Dialect-specific implementations of this method will return the
+    name to use when creating a temporary table for testing,
+    e.g., in the define_temp_tables method of the
+    ComponentReflectionTest class in suite/test_reflection.py
+
+    Default to just the base name since that's what most dialects will
+    use. The mssql dialect's implementation will need a "#" prepended.
+    """
+    return base_name
index 151be757aa4cb042c0d87e055f7cacc359994bcf..94ec22c1e6bac3e3feb59601ac49af1a4f335e10 100644 (file)
@@ -8,6 +8,7 @@ from .. import eq_
 from .. import expect_warnings
 from .. import fixtures
 from .. import is_
+from ..provision import get_temp_table_name
 from ..provision import temp_table_keyword_args
 from ..schema import Column
 from ..schema import Table
@@ -442,8 +443,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
     @classmethod
     def define_temp_tables(cls, metadata):
         kw = temp_table_keyword_args(config, config.db)
+        table_name = get_temp_table_name(config, config.db, "user_tmp")
         user_tmp = Table(
-            "user_tmp",
+            table_name,
             metadata,
             Column("id", sa.INT, primary_key=True),
             Column("name", sa.VARCHAR(50)),
@@ -736,10 +738,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
 
     @testing.requires.temp_table_reflection
     def test_get_temp_table_columns(self):
+        table_name = get_temp_table_name(config, config.db, "user_tmp")
         meta = MetaData(self.bind)
-        user_tmp = self.tables.user_tmp
+        user_tmp = self.tables[table_name]
         insp = inspect(meta.bind)
-        cols = insp.get_columns("user_tmp")
+        cols = insp.get_columns(table_name)
         self.assert_(len(cols) > 0, len(cols))
 
         for i, col in enumerate(user_tmp.columns):
@@ -1051,10 +1054,11 @@ class ComponentReflectionTest(fixtures.TablesTest):
             refl.pop("duplicates_index", None)
         eq_(reflected, [{"column_names": ["name"], "name": "user_tmp_uq"}])
 
-    @testing.requires.temp_table_reflection
+    @testing.requires.temp_table_reflect_indexes
     def test_get_temp_table_indexes(self):
         insp = inspect(self.bind)
-        indexes = insp.get_indexes("user_tmp")
+        table_name = get_temp_table_name(config, config.db, "user_tmp")
+        indexes = insp.get_indexes(table_name)
         for ind in indexes:
             ind.pop("dialect_options", None)
         eq_(
index 0bd8f7a5a022691405c428ab4a7c5266b1797b36..67bde6fb378e78da5083b857d5cdc483cfacdfa7 100644 (file)
@@ -1,7 +1,10 @@
 # -*- encoding: utf-8
+import datetime
+
 from sqlalchemy import Column
 from sqlalchemy import DDL
 from sqlalchemy import event
+from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Index
 from sqlalchemy import inspect
@@ -245,6 +248,52 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
         )
         eq_(t.name, "ABCDEFGHIJKLMNOPQRSTUVWXYZ")
 
+    @testing.provide_metadata
+    @testing.combinations(
+        ("local_temp", "#tmp", True),
+        ("global_temp", "##tmp", True),
+        ("nonexistent", "#no_es_bueno", False),
+        id_="iaa",
+    )
+    def test_temporary_table(self, table_name, exists):
+        metadata = self.metadata
+        if exists:
+            # TODO: why this test hangs when using the connection fixture?
+            with testing.db.connect() as conn:
+                tran = conn.begin()
+                conn.execute(
+                    (
+                        "CREATE TABLE %s "
+                        "(id int primary key, txt nvarchar(50), dt2 datetime2)"  # noqa
+                    )
+                    % table_name
+                )
+                conn.execute(
+                    (
+                        "INSERT INTO %s (id, txt, dt2) VALUES "
+                        "(1, N'foo', '2020-01-01 01:01:01'), "
+                        "(2, N'bar', '2020-02-02 02:02:02') "
+                    )
+                    % table_name
+                )
+                tran.commit()
+                tran = conn.begin()
+                try:
+                    tmp_t = Table(
+                        table_name, metadata, autoload_with=testing.db,
+                    )
+                    tran.commit()
+                    result = conn.execute(
+                        tmp_t.select().where(tmp_t.c.id == 2)
+                    ).fetchall()
+                    eq_(
+                        result,
+                        [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))],
+                    )
+                except exc.NoSuchTableError:
+                    if exists:
+                        raise
+
     @testing.provide_metadata
     def test_db_qualified_items(self):
         metadata = self.metadata
index 1c2561bbbe495b7f401224b176b1229ce60f5aa9..145d87d753c4342cd973eadf0bec85adde666e73 100644 (file)
@@ -257,14 +257,18 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def temporary_tables(self):
         """target database supports temporary tables"""
-        return skip_if(
-            ["mssql", "firebird", self._sqlite_file_db], "not supported (?)"
-        )
+        return skip_if(["firebird", self._sqlite_file_db], "not supported (?)")
 
     @property
     def temp_table_reflection(self):
         return self.temporary_tables
 
+    @property
+    def temp_table_reflect_indexes(self):
+        return skip_if(
+            ["mssql", "firebird", self._sqlite_file_db], "not supported (?)"
+        )
+
     @property
     def reflectable_autoincrement(self):
         """Target database must support tables that can automatically generate