From: Gaurav Sharma Date: Wed, 3 Jun 2026 12:52:09 +0000 (-0400) Subject: Implement native multi-table reflection API for the mssql dialect X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=c84c7b2ffccecb5249fcc8ec01c400ca2867dbe5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement native multi-table reflection API for the mssql dialect ### Description Adds 5 native `get_multi_*` reflection methods (columns, pk, fk, indexes, table_comment) for the MSSQL dialect, replacing the per-table loop in `_default_multi_reflect`. Single-table methods now delegate to the multi versions (PG/Oracle pattern); legacy per-table SQL is retained as `_internal_get_*` helpers, used only for tempdb reflection. Not implemented here: `get_multi_unique_constraints`, `get_multi_check_constraints`, `get_multi_table_options` -MSSQL has no single-table counterparts to delegate from. Happy to add as a follow-up. ### Performance Measured with `test/perf/many_table_reflection.py` against SQL Server 2022 (Docker, localhost, pyodbc + ODBC Driver 18) on a 250-table fixture, 15-50 cols, with PKs/FKs/indexes/comments: | | single | multi | speedup | |---|---|---|---| | `get_columns` | 1.33s | 0.35s | 3.8x | | `get_pk_constraint` | 1.41s | 0.08s | 18x | | `get_foreign_keys` | 7.07s | 0.19s | 37x | | `get_indexes` | 0.80s | 0.09s | 8.6x | | `get_table_comment` | 0.71s | 0.05s | 14x | | **MetaData.reflect** | **12.62s** | **1.15s** | **11x** | ### Checklist This pull request is: - [x] A new feature implementation - Fixes: #8430 - Tests added in `test/dialect/mssql/test_reflection.py` - Changelog entry: `doc/build/changelog/unreleased_21/8430.rst` Closes: #13297 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13297 Pull-request-sha: 2c6c69f159225f85312e39d99222279ef8cbadd1 Change-Id: I525c60fc5ece94dd250f376b05b64b09e65ca0d7 --- diff --git a/doc/build/changelog/unreleased_21/8430.rst b/doc/build/changelog/unreleased_21/8430.rst new file mode 100644 index 0000000000..9467c8a92a --- /dev/null +++ b/doc/build/changelog/unreleased_21/8430.rst @@ -0,0 +1,17 @@ +.. change:: + :tags: performance, mssql, reflection + :tickets: 8430 + + Implemented native multi-table reflection methods for the SQL Server + dialect, providing :meth:`.MSDialect.get_multi_columns`, + :meth:`.MSDialect.get_multi_pk_constraint`, + :meth:`.MSDialect.get_multi_foreign_keys`, + :meth:`.MSDialect.get_multi_indexes` and + :meth:`.MSDialect.get_multi_table_comment`. Previously the SQL Server + dialect relied on the default dialect default implementation + which calls the per-table methods in a loop; the new implementations + issue a single bulk query per object type against the ``sys.*`` + catalog views, avoiding the per-table round trips. The single-table + reflection methods are now thin wrappers over the multi-table ones, + matching the pattern used by the PostgreSQL and Oracle dialects. + Pull request courtesy Gaurav Sharma. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index e6c7ceda70..83ab7a78fd 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -961,6 +961,7 @@ from __future__ import annotations import codecs import datetime +from functools import lru_cache import operator import re from typing import Any @@ -982,13 +983,14 @@ from ... import text from ... import util from ...engine import cursor as _cursor from ...engine import default +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements from ...sql import expression -from ...sql import func from ...sql import quoted_name from ...sql import roles from ...sql import sqltypes @@ -2967,6 +2969,24 @@ def _db_plus_owner(fn): return update_wrapper(wrap, fn) +def _db_plus_owner_multi(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw, + ) + + return update_wrapper(wrap, fn) + + def _switch_db(dbname, connection, fn, *arg, **kw): if dbname: current_db = connection.exec_driver_sql("select db_name()").scalar() @@ -3464,126 +3484,6 @@ class MSDialect(default.DefaultDialect): return c.first() is not None - def _default_or_error(self, connection, tablename, owner, method, **kw): - # TODO: try to avoid having to run a separate query here - if self._internal_has_table(connection, tablename, owner, **kw): - return method() - else: - raise exc.NoSuchTableError(f"{owner}.{tablename}") - - @reflection.cache - @_db_plus_owner - def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): - filter_definition = ( - "ind.filter_definition" - if self.server_version_info >= MS_2008_VERSION - else "NULL as filter_definition" - ) - rp = connection.execution_options(future_result=True).execute( - sql.text(f""" -select - ind.index_id, - ind.is_unique, - ind.name, - ind.type, - {filter_definition} -from - sys.indexes as ind -join sys.tables as tab on - ind.object_id = tab.object_id -join sys.schemas as sch on - sch.schema_id = tab.schema_id -where - tab.name = :tabname - and sch.name = :schname - and ind.is_primary_key = 0 - and ind.type != 0 -order by - ind.name - """) - .bindparams( - sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), - sql.bindparam("schname", owner, ischema.CoerceUnicode()), - ) - .columns(name=sqltypes.Unicode()) - ) - indexes = {} - for row in rp.mappings(): - indexes[row["index_id"]] = current = { - "name": row["name"], - "unique": row["is_unique"] == 1, - "column_names": [], - "include_columns": [], - "dialect_options": {}, - } - - do = current["dialect_options"] - index_type = row["type"] - if index_type in {1, 2}: - do["mssql_clustered"] = index_type == 1 - if index_type in {5, 6}: - do["mssql_clustered"] = index_type == 5 - do["mssql_columnstore"] = True - if row["filter_definition"] is not None: - do["mssql_where"] = row["filter_definition"] - - rp = connection.execution_options(future_result=True).execute( - sql.text(""" -select - ind_col.index_id, - col.name, - ind_col.is_included_column -from - sys.columns as col -join sys.tables as tab on - tab.object_id = col.object_id -join sys.index_columns as ind_col on - ind_col.column_id = col.column_id - and ind_col.object_id = tab.object_id -join sys.schemas as sch on - sch.schema_id = tab.schema_id -where - tab.name = :tabname - and sch.name = :schname -order by - ind_col.index_id, - ind_col.key_ordinal - """) - .bindparams( - sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), - sql.bindparam("schname", owner, ischema.CoerceUnicode()), - ) - .columns(name=sqltypes.Unicode()) - ) - for row in rp.mappings(): - if row["index_id"] not in indexes: - continue - index_def = indexes[row["index_id"]] - is_colstore = index_def["dialect_options"].get("mssql_columnstore") - is_clustered = index_def["dialect_options"].get("mssql_clustered") - if not (is_colstore and is_clustered): - # a clustered columnstore index includes all columns but does - # not want them in the index definition - if row["is_included_column"] and not is_colstore: - # a noncludsted columnstore index reports that includes - # columns but requires that are listed as normal columns - index_def["include_columns"].append(row["name"]) - else: - index_def["column_names"].append(row["name"]) - for index_info in indexes.values(): - # NOTE: "root level" include_columns is legacy, now part of - # dialect_options (issue #7382) - index_info["dialect_options"]["mssql_include"] = index_info[ - "include_columns" - ] - - if indexes: - return list(indexes.values()) - else: - return self._default_or_error( - connection, tablename, owner, ReflectionDefaults.indexes, **kw - ) - @reflection.cache @_db_plus_owner def get_view_definition( @@ -3606,38 +3506,6 @@ order by else: raise exc.NoSuchTableError(f"{owner}.{viewname}") - @reflection.cache - def get_table_comment(self, connection, table_name, schema=None, **kw): - if not self.supports_comments: - raise NotImplementedError( - "Can't get table comments on current SQL Server version in use" - ) - - schema_name = schema if schema else self.default_schema_name - COMMENT_SQL = """ - SELECT cast(com.value as nvarchar(max)) - FROM fn_listextendedproperty('MS_Description', - 'schema', :schema, 'table', :table, NULL, NULL - ) as com; - """ - - comment = connection.execute( - sql.text(COMMENT_SQL).bindparams( - sql.bindparam("schema", schema_name, ischema.CoerceUnicode()), - sql.bindparam("table", table_name, ischema.CoerceUnicode()), - ) - ).scalar() - if comment: - return {"text": comment} - else: - return self._default_or_error( - connection, - table_name, - None, - ReflectionDefaults.table_comment, - **kw, - ) - def _temp_table_name_like_pattern(self, tablename): # LIKE uses '%' to match zero or more characters and '_' to match any # single character. We want to match literal underscores, so T-SQL @@ -3673,9 +3541,110 @@ order by % tablename ) from ne - @reflection.cache - @_db_plus_owner - def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + def _parse_column_info( + self, + name, + type_, + nullable, + maxlen, + numericprec, + numericscale, + default, + collation, + definition, + is_persisted, + is_identity, + identity_start, + identity_increment, + comment, + base_type=None, + ): + # Try to resolve the user type first (e.g., "sysname"), + # then fall back to the base type (e.g., "nvarchar"). + # base_type may be None for CLR types (geography, geometry, + # hierarchyid) which have no corresponding base type. + coltype = self.ischema_names.get(type_, None) + if coltype is None and base_type is not None and base_type != type_: + coltype = self.ischema_names.get(base_type, None) + kwargs = {} + + if coltype in (MSBinary, MSVarBinary, sqltypes.LargeBinary): + kwargs["length"] = maxlen if maxlen != -1 else None + elif coltype in (MSString, MSChar, MSText): + kwargs["length"] = maxlen if maxlen != -1 else None + if collation: + kwargs["collation"] = collation + elif coltype in (MSNVarchar, MSNChar, MSNText): + kwargs["length"] = maxlen // 2 if maxlen != -1 else None + if collation: + kwargs["collation"] = collation + + if coltype is None: + if base_type is not None and base_type != type_: + util.warn( + "Did not recognize type '%s' (user type) or '%s' " + "(base type) of column '%s'" % (type_, base_type, name) + ) + else: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (type_, name) + ) + coltype = sqltypes.NULLTYPE + else: + if issubclass(coltype, sqltypes.NumericCommon): + kwargs["precision"] = numericprec + if not issubclass(coltype, sqltypes.Float): + kwargs["scale"] = numericscale + coltype = coltype(**kwargs) + + cdict = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": is_identity is not None, + "comment": comment, + } + + if definition is not None and is_persisted is not None: + cdict["computed"] = { + "sqltext": definition, + "persisted": is_persisted, + } + + if is_identity is not None: + if identity_start is None or identity_increment is None: + cdict["identity"] = {} + else: + if isinstance( + coltype, (sqltypes.BigInteger, sqltypes.Integer) + ): + start = int(identity_start) + increment = int(identity_increment) + else: + start = identity_start + increment = identity_increment + cdict["identity"] = { + "start": start, + "increment": increment, + } + + return cdict + + @lru_cache() + def _columns_select(self): + """Build the unified Core sys.* select for column reflection. + + Returns a ``select()`` that includes ``table_name`` and ``owner`` + in the result so callers can group rows by table. The caller + applies the appropriate WHERE clause for either the single-table + case (``sys_columns.c.object_id == X``) or the multi-table case + (``sys_schemas.c.name == owner AND sys_objects.c.name.in_(names)``). + + Used by :meth:`.get_multi_columns` (and indirectly by + :meth:`.get_columns` via the multi delegation). + """ sys_columns = ischema.sys_columns sys_types = ischema.sys_types sys_base_types = ischema.sys_types.alias("base_types") @@ -3683,30 +3652,8 @@ order by computed_cols = ischema.computed_columns identity_cols = ischema.identity_columns extended_properties = ischema.extended_properties - - # to access sys tables, need an object_id. - # object_id() can normally match to the unquoted name even if it - # has special characters. however it also accepts quoted names, - # which means for the special case that the name itself has - # "quotes" (e.g. brackets for SQL Server) we need to "quote" (e.g. - # bracket) that name anyway. Fixed as part of #12654 - - is_temp_table = tablename.startswith("#") - if is_temp_table: - owner, tablename = self._get_internal_temp_table_name( - connection, tablename - ) - - object_id_tokens = [self.identifier_preparer.quote(tablename)] - if owner: - object_id_tokens.insert(0, self.identifier_preparer.quote(owner)) - - if is_temp_table: - object_id_tokens.insert(0, "tempdb") - - object_id = func.object_id(".".join(object_id_tokens)) - - whereclause = sys_columns.c.object_id == object_id + sys_objects = ischema.sys_objects + sys_schemas = ischema.sys_schemas if self._supports_nvarchar_max: computed_definition = computed_cols.c.definition @@ -3718,23 +3665,34 @@ order by s = ( sql.select( - sys_columns.c.name, - sys_types.c.name, + sys_objects.c.name.label("table_name"), + sys_schemas.c.name.label("owner"), + sys_columns.c.name.label("column_name"), + sys_types.c.name.label("type_name"), sys_base_types.c.name.label("base_type"), sys_columns.c.is_nullable, sys_columns.c.max_length, sys_columns.c.precision, sys_columns.c.scale, - sys_default_constraints.c.definition, + sys_default_constraints.c.definition.label("default_value"), sys_columns.c.collation_name, - computed_definition, + computed_definition.label("computed_definition"), computed_cols.c.is_persisted, identity_cols.c.is_identity, identity_cols.c.seed_value, identity_cols.c.increment_value, extended_properties.c.value.label("comment"), + sys_columns.c.column_id, ) .select_from(sys_columns) + .join( + sys_objects, + onclause=sys_columns.c.object_id == sys_objects.c.object_id, + ) + .join( + sys_schemas, + onclause=sys_objects.c.schema_id == sys_schemas.c.schema_id, + ) .join( sys_types, onclause=sys_columns.c.user_type_id @@ -3781,200 +3739,13 @@ order by sys_columns.c.column_id == extended_properties.c.minor_id, ), ) - .where(whereclause) - .order_by(sys_columns.c.column_id) + .order_by(sys_objects.c.name, sys_columns.c.column_id) ) + return s - if is_temp_table: - exec_opts = {"schema_translate_map": {"sys": "tempdb.sys"}} - else: - exec_opts = {"schema_translate_map": {}} - c = connection.execution_options(**exec_opts).execute(s) - - cols = [] - for row in c.mappings(): - name = row[sys_columns.c.name] - type_ = row[sys_types.c.name] - base_type = row["base_type"] - nullable = row[sys_columns.c.is_nullable] == 1 - maxlen = row[sys_columns.c.max_length] - numericprec = row[sys_columns.c.precision] - numericscale = row[sys_columns.c.scale] - default = row[sys_default_constraints.c.definition] - collation = row[sys_columns.c.collation_name] - definition = row[computed_definition] - is_persisted = row[computed_cols.c.is_persisted] - is_identity = row[identity_cols.c.is_identity] - identity_start = row[identity_cols.c.seed_value] - identity_increment = row[identity_cols.c.increment_value] - comment = row[extended_properties.c.value] - - # Try to resolve the user type first (e.g., "sysname"), - # then fall back to the base type (e.g., "nvarchar"). - # base_type may be None for CLR types (geography, geometry, - # hierarchyid) which have no corresponding base type. - coltype = self.ischema_names.get(type_, None) - if ( - coltype is None - and base_type is not None - and base_type != type_ - ): - coltype = self.ischema_names.get(base_type, None) - - kwargs = {} - - if coltype in ( - MSBinary, - MSVarBinary, - sqltypes.LargeBinary, - ): - kwargs["length"] = maxlen if maxlen != -1 else None - elif coltype in ( - MSString, - MSChar, - MSText, - ): - kwargs["length"] = maxlen if maxlen != -1 else None - if collation: - kwargs["collation"] = collation - elif coltype in ( - MSNVarchar, - MSNChar, - MSNText, - ): - kwargs["length"] = maxlen // 2 if maxlen != -1 else None - if collation: - kwargs["collation"] = collation - - if coltype is None: - if base_type is not None and base_type != type_: - util.warn( - "Did not recognize type '%s' (user type) or '%s' " - "(base type) of column '%s'" % (type_, base_type, name) - ) - else: - util.warn( - "Did not recognize type '%s' of column '%s'" - % (type_, name) - ) - coltype = sqltypes.NULLTYPE - else: - if issubclass(coltype, sqltypes.NumericCommon): - kwargs["precision"] = numericprec - - if not issubclass(coltype, sqltypes.Float): - kwargs["scale"] = numericscale - - coltype = coltype(**kwargs) - cdict = { - "name": name, - "type": coltype, - "nullable": nullable, - "default": default, - "autoincrement": is_identity is not None, - "comment": comment, - } - - if definition is not None and is_persisted is not None: - cdict["computed"] = { - "sqltext": definition, - "persisted": is_persisted, - } - - if is_identity is not None: - # identity_start and identity_increment are Decimal or None - if identity_start is None or identity_increment is None: - cdict["identity"] = {} - else: - if isinstance(coltype, sqltypes.BigInteger): - start = int(identity_start) - increment = int(identity_increment) - elif isinstance(coltype, sqltypes.Integer): - start = int(identity_start) - increment = int(identity_increment) - else: - start = identity_start - increment = identity_increment - - cdict["identity"] = { - "start": start, - "increment": increment, - } - - cols.append(cdict) - - if cols: - return cols - else: - return self._default_or_error( - connection, tablename, owner, ReflectionDefaults.columns, **kw - ) - - @reflection.cache - @_db_plus_owner - def get_pk_constraint( - self, connection, tablename, dbname, owner, schema, **kw - ): - pkeys = [] - TC = ischema.constraints - C = ischema.key_constraints.alias("C") - - # Primary key constraints - s = ( - sql.select( - C.c.column_name, - TC.c.constraint_type, - C.c.constraint_name, - func.objectproperty( - func.object_id( - C.c.table_schema + "." + C.c.constraint_name - ), - "CnstIsClustKey", - ).label("is_clustered"), - ) - .where( - sql.and_( - TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner, - ), - ) - .order_by(TC.c.constraint_name, C.c.ordinal_position) - ) - c = connection.execution_options(future_result=True).execute(s) - constraint_name = None - is_clustered = None - for row in c.mappings(): - if "PRIMARY" in row[TC.c.constraint_type.name]: - pkeys.append(row["COLUMN_NAME"]) - if constraint_name is None: - constraint_name = row[C.c.constraint_name.name] - if is_clustered is None: - is_clustered = row["is_clustered"] - if pkeys: - return { - "constrained_columns": pkeys, - "name": constraint_name, - "dialect_options": {"mssql_clustered": is_clustered}, - } - else: - return self._default_or_error( - connection, - tablename, - owner, - ReflectionDefaults.pk_constraint, - **kw, - ) - - @reflection.cache - @_db_plus_owner - def get_foreign_keys( - self, connection, tablename, dbname, owner, schema, **kw - ): - # Foreign key constraints - s = ( - text("""\ + @staticmethod + def _fk_query_sql(fk_info_where, extra_cols=""): + return """\ WITH fk_info AS ( SELECT ischema_ref_con.constraint_schema, @@ -3992,11 +3763,11 @@ WITH fk_info AS ( INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON - ischema_key_col.table_schema = ischema_ref_con.constraint_schema + ischema_key_col.table_schema = + ischema_ref_con.constraint_schema AND ischema_key_col.constraint_name = - ischema_ref_con.constraint_name - WHERE ischema_key_col.table_name = :tablename - AND ischema_key_col.table_schema = :owner + ischema_ref_con.constraint_name + WHERE %(fk_info_where)s ), constraint_info AS ( SELECT @@ -4019,22 +3790,19 @@ index_info AS ( sys.columns.name AS column_name FROM sys.indexes - INNER JOIN - sys.objects ON - sys.objects.object_id = sys.indexes.object_id - INNER JOIN - sys.schemas ON - sys.schemas.schema_id = sys.objects.schema_id - INNER JOIN - sys.index_columns ON - sys.index_columns.object_id = sys.objects.object_id + INNER JOIN sys.objects + ON sys.objects.object_id = sys.indexes.object_id + INNER JOIN sys.schemas + ON sys.schemas.schema_id = sys.objects.schema_id + INNER JOIN sys.index_columns + ON sys.index_columns.object_id = sys.objects.object_id AND sys.index_columns.index_id = sys.indexes.index_id - INNER JOIN - sys.columns ON - sys.columns.object_id = sys.indexes.object_id + INNER JOIN sys.columns + ON sys.columns.object_id = sys.indexes.object_id AND sys.columns.column_id = sys.index_columns.column_id ) SELECT + %(extra_cols)s fk_info.constraint_schema, fk_info.constraint_name, fk_info.ordinal_position, @@ -4054,6 +3822,7 @@ index_info AS ( AND constraint_info.ordinal_position = fk_info.ordinal_position UNION SELECT + %(extra_cols)s fk_info.constraint_schema, fk_info.constraint_name, fk_info.ordinal_position, @@ -4074,15 +3843,533 @@ index_info AS ( ORDER BY fk_info.constraint_schema, fk_info.constraint_name, fk_info.ordinal_position -""") +""" % { + "fk_info_where": fk_info_where, + "extra_cols": extra_cols, + } + + # --- multi-reflection API --- + + def _partition_filter_names( + self, connection, owner, filter_names, scope, kind + ): + """Split caller-supplied ``filter_names`` between the bulk + ``sys.*`` query path (regular objects in the user database) and + the temp object path (objects living in ``tempdb``), and resolve + the existing-in-catalog names for the bulk path in one shot. + + Returns the tuple + ``(run_bulk, regular_names, multi_object_names, temp_names)``: + + * ``run_bulk`` -- whether the multi method should run its bulk + ``sys.*`` query at all. + * ``regular_names`` -- list of caller-supplied names (with the + caller's casing) for the bulk path, used to build the + case-insensitive name map. Empty list when no regular names + were requested. + * ``multi_object_names`` -- list of names (with server-side + casing) that actually exist in the catalog and match the + requested ``kind``. This is what the bulk SQL's IN clause + gets filtered with. Empty list when no matches. + * ``temp_names`` -- list of ``#``-prefixed temp object names + that the caller asked for and that the requested + ``scope``/``kind`` allows. + + ``scope=ObjectScope.DEFAULT`` excludes temp objects entirely. + ``scope=ObjectScope.TEMPORARY`` excludes regular objects. + ``ObjectKind.VIEW`` (without TABLE) excludes temp because SQL + Server does not have temp views in the same form. SQL Server + also has no ``get_temp_table_names`` implementation, so temp + objects can only be reflected when the caller supplies their + names explicitly. + """ + include_default = scope is not ObjectScope.TEMPORARY + include_temp = ( + scope is not ObjectScope.DEFAULT and ObjectKind.TABLE in kind + ) + + if filter_names: + if include_temp: + temp_names = [n for n in filter_names if n.startswith("#")] + else: + temp_names = [] + + if include_default: + regular_names = [ + n for n in filter_names if not n.startswith("#") + ] + run_bulk = bool(regular_names) + else: + regular_names = [] + run_bulk = False + else: + temp_names = [] + regular_names = [] + run_bulk = include_default + + if not run_bulk: + return (run_bulk, regular_names, [], temp_names) + + type_filter = [] + if ObjectKind.TABLE in kind: + type_filter.append("'U'") + if ObjectKind.VIEW in kind: + type_filter.append("'V'") + # SQL Server does not support materialized views, so ignore them + if not type_filter: + return (run_bulk, regular_names, [], temp_names) + + query = ( + "SELECT o.name " + "FROM sys.objects o " + "JOIN sys.schemas s ON o.schema_id = s.schema_id " + "WHERE s.name = :owner " + "AND o.type IN (%s)" % ", ".join(type_filter) + ) + params = [sql.bindparam("owner", owner, ischema.CoerceUnicode())] + if regular_names: + query += " AND o.name IN :filter_names" + params.append( + sql.bindparam("filter_names", regular_names, expanding=True) + ) + + rp = connection.execute(sql.text(query).bindparams(*params)) + multi_object_names = [row[0] for row in rp] + + return (run_bulk, regular_names, multi_object_names, temp_names) + + @staticmethod + def _multi_name_map(filter_names): + """Build a case-insensitive ``server_name -> user_name`` mapper. + + MSSQL object names are case-insensitive under the default + collation, so a user may pass ``"sOmEtAbLe"`` for a table + physically stored as ``"SomeTable"``. The bulk SQL returns the + server-side casing, but result keys must use the user-supplied + casing so that downstream code (e.g. + :meth:`.Inspector.reflect_table`) can find the entry. + + Returns a callable ``(server_name) -> user_name``. When + ``filter_names`` is empty/None, returns identity. + """ + if not filter_names: + return lambda n: n + lookup = {n.lower(): n for n in filter_names} + return lambda n: lookup.get(n.lower(), n) + + @staticmethod + def _value_or_raise(data, table, schema): + """Unwrap a single ``(schema, table)`` entry from a multi-method + result, raising :exc:`.NoSuchTableError` when missing. + + Mirrors PostgreSQL's helper of the same name. Used by the + single-table reflection wrappers that delegate to the multi + implementation. + """ + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + @_db_plus_owner_multi + def get_multi_columns( + self, + connection, + dbname, + owner, + schema, + filter_names, + scope, + kind, + **kw, + ): + ( + run_bulk, + regular_names, + multi_object_names, + temp_names, + ) = self._partition_filter_names( + connection, owner, filter_names, scope, kind + ) + + result = {} + + if run_bulk and multi_object_names: + name_map = self._multi_name_map(regular_names) + self._fetch_multi_columns( + connection, + owner=owner, + names=multi_object_names, + schema=schema, + name_map=name_map, + result=result, + exec_opts={"schema_translate_map": {}}, + ) + + if temp_names: + self._fetch_multi_columns_temp( + connection, + temp_names=temp_names, + schema=schema, + result=result, + ) + + return result.items() + + def _fetch_multi_columns( + self, connection, owner, names, schema, name_map, result, exec_opts + ): + """Execute the unified columns select for a single catalog pass. + + Used by :meth:`.get_multi_columns` for both the main-database pass + (no ``schema_translate_map``) and the tempdb pass (with + ``schema_translate_map={"sys": "tempdb.sys"}`` applied). + """ + s = self._columns_select().where( + ischema.sys_schemas.c.name == owner, + ischema.sys_objects.c.name.in_( + sql.bindparam("filter_names", expanding=True) + ), + ) + + rp = connection.execute( + s, + {"filter_names": names}, + execution_options=exec_opts, + ) + + for row in rp.mappings(): + table_name = name_map(row["table_name"]) + cdict = self._parse_column_info( + name=row["column_name"], + type_=row["type_name"], + base_type=row["base_type"], + nullable=row["is_nullable"] == 1, + maxlen=row["max_length"], + numericprec=row["precision"], + numericscale=row["scale"], + default=row["default_value"], + collation=row["collation_name"], + definition=row["computed_definition"], + is_persisted=row["is_persisted"], + is_identity=row["is_identity"], + identity_start=row["seed_value"], + identity_increment=row["increment_value"], + comment=row["comment"], + ) + result.setdefault((schema, table_name), []).append(cdict) + + for n in names: + key = (schema, name_map(n)) + if key not in result: + result[key] = ReflectionDefaults.columns() + + def _fetch_multi_columns_temp( + self, connection, temp_names, schema, result + ): + """Run the unified columns select against tempdb for temp names. + + Resolves each ``#name`` to its mangled tempdb name (one round + trip per temp via :meth:`._get_internal_temp_table_name`), then + runs a single bulk select against ``tempdb.sys.*`` via + ``schema_translate_map``. The result rows (keyed by mangled + name) are mapped back to the original ``#name`` for the caller. + """ + resolved_by_owner, original_by_mangled = self._resolve_temp_names( + connection, temp_names + ) + + for temp_owner, mangled_names in resolved_by_owner.items(): + + def temp_name_map(server_name, _by=original_by_mangled): + return _by.get(server_name, server_name) + + self._fetch_multi_columns( + connection, + owner=temp_owner, + names=mangled_names, + schema=schema, + name_map=temp_name_map, + result=result, + exec_opts={"schema_translate_map": {"sys": "tempdb.sys"}}, + ) + + def _resolve_temp_names(self, connection, temp_names): + """Resolve user-facing temp names (``#foo``) to their mangled + tempdb names, grouped by owner. + + Returns ``(resolved_by_owner, original_by_mangled)``: + + * ``resolved_by_owner`` -- ``{owner: [mangled_name, ...]}``, + one bulk query per owner (typically just ``'dbo'``). + * ``original_by_mangled`` -- ``{mangled_name: original_#name}`` + so caller can map result rows back to the user-supplied name. + + Names that don't resolve (no matching temp table in tempdb) are + silently dropped. The single-table reflection wrappers handle + the resulting absence by raising :exc:`.NoSuchTableError` via + :meth:`._value_or_raise`. + """ + resolved_by_owner = {} + original_by_mangled = {} + for name in temp_names: + try: + temp_owner, mangled = self._get_internal_temp_table_name( + connection, name + ) + except exc.NoSuchTableError: + continue + resolved_by_owner.setdefault(temp_owner, []).append(mangled) + original_by_mangled[mangled] = name + return resolved_by_owner, original_by_mangled + + @_db_plus_owner_multi + def get_multi_pk_constraint( + self, + connection, + dbname, + owner, + schema, + filter_names, + scope, + kind, + **kw, + ): + ( + run_bulk, + regular_names, + multi_object_names, + temp_names, + ) = self._partition_filter_names( + connection, owner, filter_names, scope, kind + ) + + result = {} + + if run_bulk and multi_object_names: + name_map = self._multi_name_map(regular_names) + self._fetch_multi_pk_constraint( + connection, + owner=owner, + names=multi_object_names, + schema=schema, + name_map=name_map, + result=result, + exec_opts={"schema_translate_map": {}}, + ) + + if temp_names: + self._fetch_multi_pk_constraint_temp( + connection, + temp_names=temp_names, + schema=schema, + result=result, + ) + + return result.items() + + @lru_cache() + def _pk_constraint_select(self): + """Build the unified Core sys.* select for PK constraint reflection. + + Returns a ``select()`` that includes ``table_name`` and ``owner`` + in the result so callers can group rows by table. The clustered + flag comes directly from ``sys.indexes.type``(1 = clustered). + """ + sys_key_constraints = ischema.sys_key_constraints + sys_indexes = ischema.sys_indexes + sys_index_columns = ischema.sys_index_columns + sys_columns = ischema.sys_columns + sys_objects = ischema.sys_objects + sys_schemas = ischema.sys_schemas + + s = ( + sql.select( + sys_objects.c.name.label("table_name"), + sys_schemas.c.name.label("owner"), + sys_key_constraints.c.name.label("constraint_name"), + sys_columns.c.name.label("column_name"), + sys_index_columns.c.key_ordinal, + sys_indexes.c.type.label("index_type"), + ) + .select_from(sys_key_constraints) + .join( + sys_objects, + onclause=sys_key_constraints.c.parent_object_id + == sys_objects.c.object_id, + ) + .join( + sys_schemas, + onclause=sys_objects.c.schema_id == sys_schemas.c.schema_id, + ) + .join( + sys_index_columns, + onclause=sql.and_( + sys_index_columns.c.object_id + == sys_key_constraints.c.parent_object_id, + sys_index_columns.c.index_id + == sys_key_constraints.c.unique_index_id, + ), + ) + .join( + sys_columns, + onclause=sql.and_( + sys_columns.c.object_id == sys_index_columns.c.object_id, + sys_columns.c.column_id == sys_index_columns.c.column_id, + ), + ) + .join( + sys_indexes, + onclause=sql.and_( + sys_indexes.c.object_id + == sys_key_constraints.c.parent_object_id, + sys_indexes.c.index_id + == sys_key_constraints.c.unique_index_id, + ), + ) + .where(sys_key_constraints.c.type == "PK") + .order_by(sys_objects.c.name, sys_index_columns.c.key_ordinal) + ) + return s + + def _fetch_multi_pk_constraint( + self, connection, owner, names, schema, name_map, result, exec_opts + ): + """Execute the unified pk_constraint select for one catalog pass. + + Used by :meth:`.get_multi_pk_constraint` for both the main DB + pass and the tempdb pass (with ``schema_translate_map`` set). + """ + s = self._pk_constraint_select().where( + ischema.sys_schemas.c.name == owner, + ischema.sys_objects.c.name.in_( + sql.bindparam("filter_names", expanding=True) + ), + ) + + rp = connection.execute( + s, + {"filter_names": names}, + execution_options=exec_opts, + ) + + for row in rp.mappings(): + table_name = name_map(row["table_name"]) + key = (schema, table_name) + if key not in result: + result[key] = { + "constrained_columns": [], + "name": row["constraint_name"], + "dialect_options": { + "mssql_clustered": row["index_type"] == 1 + }, + } + result[key]["constrained_columns"].append(row["column_name"]) + + for n in names: + key = (schema, name_map(n)) + if key not in result: + result[key] = ReflectionDefaults.pk_constraint() + + def _fetch_multi_pk_constraint_temp( + self, connection, temp_names, schema, result + ): + """Run the unified pk_constraint select against tempdb.""" + resolved_by_owner, original_by_mangled = self._resolve_temp_names( + connection, temp_names + ) + + for temp_owner, mangled_names in resolved_by_owner.items(): + + def temp_name_map(server_name, _by=original_by_mangled): + return _by.get(server_name, server_name) + + self._fetch_multi_pk_constraint( + connection, + owner=temp_owner, + names=mangled_names, + schema=schema, + name_map=temp_name_map, + result=result, + exec_opts={"schema_translate_map": {"sys": "tempdb.sys"}}, + ) + + @_db_plus_owner_multi + def get_multi_foreign_keys( + self, + connection, + dbname, + owner, + schema, + filter_names, + scope, + kind, + **kw, + ): + ( + run_bulk, + regular_names, + multi_object_names, + temp_names, + ) = self._partition_filter_names( + connection, owner, filter_names, scope, kind + ) + + final = {} + + if run_bulk and multi_object_names: + name_map = self._multi_name_map(regular_names) + self._fetch_multi_foreign_keys( + connection, + dbname=dbname, + owner=owner, + names=multi_object_names, + schema=schema, + name_map=name_map, + final=final, + ) + + # FK reflection for temp tables: the INFORMATION_SCHEMA query in + # _fk_query_sql does not see tempdb objects (it queries the main + # DB's catalog). The pre-existing single-table behavior was to + # return empty foreign_keys for temp tables that exist. We + # preserve that here by checking existence and returning the + # default (empty) reflection for each temp that is reachable. + for name in temp_names: + if self._internal_has_table(connection, name, owner="dbo"): + final[(schema, name)] = ReflectionDefaults.foreign_keys() + + return final.items() + + def _fetch_multi_foreign_keys( + self, connection, dbname, owner, names, schema, name_map, final + ): + """Execute the unified FK query for the main-DB pass. + + Uses the _fk_query_sql sql text, restricted to the provided owner and + IN-list of names. Groups result rows by table and constraint + name into the final dict. + """ + rp = connection.execute( + sql.text( + self._fk_query_sql( + fk_info_where=( + "ischema_key_col.table_schema = :owner" + "\n AND ischema_key_col.table_name" + " IN :filter_names" + ), + extra_cols="fk_info.table_name,", + ) + ) .bindparams( - sql.bindparam("tablename", tablename, ischema.CoerceUnicode()), sql.bindparam("owner", owner, ischema.CoerceUnicode()), + sql.bindparam("filter_names", names, expanding=True), ) .columns( constraint_schema=sqltypes.Unicode(), constraint_name=sqltypes.Unicode(), - table_schema=sqltypes.Unicode(), table_name=sqltypes.Unicode(), constrained_column=sqltypes.Unicode(), referred_table_schema=sqltypes.Unicode(), @@ -4091,41 +4378,40 @@ index_info AS ( ) ) - # group rows by constraint ID, to handle multi-column FKs - fkeys = util.defaultdict( - lambda: { - "name": None, - "constrained_columns": [], - "referred_schema": None, - "referred_table": None, - "referred_columns": [], - "options": {}, - } - ) - - for r in connection.execute(s).all(): + grouped = {} + for r in rp.all(): ( - _, # constraint schema + table_name, + _, # constraint_schema rfknm, - _, # ordinal position + _, # ordinal scol, rschema, rtbl, rcol, - # TODO: we support match= for foreign keys so - # we can support this also, PG has match=FULL for example - # but this seems to not be a valid value for SQL Server - _, # match rule + _, # match fkuprule, fkdelrule, ) = r - rec = fkeys[rfknm] + key = (schema, name_map(table_name)) + if key not in grouped: + grouped[key] = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, + } + ) + + rec = grouped[key][rfknm] rec["name"] = rfknm if fkuprule != "NO ACTION": rec["options"]["onupdate"] = fkuprule - if fkdelrule != "NO ACTION": rec["options"]["ondelete"] = fkdelrule @@ -4136,21 +4422,459 @@ index_info AS ( rschema = dbname + "." + rschema rec["referred_schema"] = rschema - local_cols, remote_cols = ( - rec["constrained_columns"], - rec["referred_columns"], + rec["constrained_columns"].append(scol) + rec["referred_columns"].append(rcol) + + for key, fk_dict in grouped.items(): + final[key] = list(fk_dict.values()) + + for n in names: + key = (schema, name_map(n)) + if key not in final: + final[key] = ReflectionDefaults.foreign_keys() + + @_db_plus_owner_multi + def get_multi_indexes( + self, + connection, + dbname, + owner, + schema, + filter_names, + scope, + kind, + **kw, + ): + ( + run_bulk, + regular_names, + multi_object_names, + temp_names, + ) = self._partition_filter_names( + connection, owner, filter_names, scope, kind + ) + + result = {} + + if run_bulk and multi_object_names: + name_map = self._multi_name_map(regular_names) + self._fetch_multi_indexes( + connection, + owner=owner, + names=multi_object_names, + schema=schema, + name_map=name_map, + result=result, + exec_opts={"schema_translate_map": {}}, ) - local_cols.append(scol) - remote_cols.append(rcol) + if temp_names: + self._fetch_multi_indexes_temp( + connection, + temp_names=temp_names, + schema=schema, + result=result, + ) - if fkeys: - return list(fkeys.values()) + return result.items() + + @lru_cache() + def _indexes_metadata_select(self): + """Build the Core sys.* select for index metadata (one row per + index per table). + + Used by :meth:`._fetch_multi_indexes` for both the main-DB pass + and the tempdb pass (via ``schema_translate_map``). Replaces a + previous ``sql.text()`` body that did not honor + ``schema_translate_map`` (which only rewrites Core + schema-bearing objects, not literal SQL text). + """ + sys_indexes = ischema.sys_indexes + sys_objects = ischema.sys_objects + sys_schemas = ischema.sys_schemas + + if self.server_version_info >= MS_2008_VERSION: + filter_definition = sys_indexes.c.filter_definition else: - return self._default_or_error( + filter_definition = sql.null() + + s = ( + sql.select( + sys_objects.c.name.label("table_name"), + sys_indexes.c.index_id, + sys_indexes.c.is_unique, + sys_indexes.c.name, + sys_indexes.c.type, + filter_definition.label("filter_definition"), + ) + .select_from(sys_indexes) + .join( + sys_objects, + onclause=sys_indexes.c.object_id == sys_objects.c.object_id, + ) + .join( + sys_schemas, + onclause=sys_schemas.c.schema_id == sys_objects.c.schema_id, + ) + .where(sys_indexes.c.is_primary_key == 0) + .where(sys_indexes.c.type != 0) + .order_by(sys_objects.c.name, sys_indexes.c.name) + ) + return s + + @lru_cache() + def _indexes_columns_select(self): + """Build the Core sys.* select for index columns (one row per + index column per index per table).""" + sys_index_columns = ischema.sys_index_columns + sys_columns = ischema.sys_columns + sys_objects = ischema.sys_objects + sys_schemas = ischema.sys_schemas + + s = ( + sql.select( + sys_objects.c.name.label("table_name"), + sys_index_columns.c.index_id, + sys_columns.c.name, + sys_index_columns.c.is_included_column, + ) + .select_from(sys_index_columns) + .join( + sys_columns, + onclause=sql.and_( + sys_columns.c.object_id == sys_index_columns.c.object_id, + sys_columns.c.column_id == sys_index_columns.c.column_id, + ), + ) + .join( + sys_objects, + onclause=sys_objects.c.object_id + == sys_index_columns.c.object_id, + ) + .join( + sys_schemas, + onclause=sys_schemas.c.schema_id == sys_objects.c.schema_id, + ) + .order_by( + sys_objects.c.name, + sys_index_columns.c.index_id, + sys_index_columns.c.key_ordinal, + ) + ) + return s + + def _fetch_multi_indexes( + self, connection, owner, names, schema, name_map, result, exec_opts + ): + """Execute the unified indexes queries for one catalog pass. + + Runs two Core sys.* selects (index metadata + index columns), + grouped by table. Used for both the main-DB pass and the + tempdb pass (with ``schema_translate_map={"sys": "tempdb.sys"}`` + applied). + """ + meta_q = self._indexes_metadata_select().where( + ischema.sys_schemas.c.name == owner, + ischema.sys_objects.c.name.in_( + sql.bindparam("filter_names", expanding=True) + ), + ) + + rp = connection.execute( + meta_q, + {"filter_names": names}, + execution_options=exec_opts, + ) + + # {table_name: {index_id: index_dict}} + indexes_by_table = {} + for row in rp.mappings(): + tname = name_map(row["table_name"]) + if tname not in indexes_by_table: + indexes_by_table[tname] = {} + + current = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + "include_columns": [], + "dialect_options": {}, + } + do = current["dialect_options"] + index_type = row["type"] + if index_type in {1, 2}: + do["mssql_clustered"] = index_type == 1 + if index_type in {5, 6}: + do["mssql_clustered"] = index_type == 5 + do["mssql_columnstore"] = True + if row["filter_definition"] is not None: + do["mssql_where"] = row["filter_definition"] + + indexes_by_table[tname][row["index_id"]] = current + + cols_q = self._indexes_columns_select().where( + ischema.sys_schemas.c.name == owner, + ischema.sys_objects.c.name.in_( + sql.bindparam("filter_names", expanding=True) + ), + ) + + rp2 = connection.execute( + cols_q, + {"filter_names": names}, + execution_options=exec_opts, + ) + + for row in rp2.mappings(): + tname = name_map(row["table_name"]) + idx_id = row["index_id"] + if tname not in indexes_by_table: + continue + if idx_id not in indexes_by_table[tname]: + continue + index_def = indexes_by_table[tname][idx_id] + is_colstore = index_def["dialect_options"].get("mssql_columnstore") + is_clustered = index_def["dialect_options"].get("mssql_clustered") + if not (is_colstore and is_clustered): + # a clustered columnstore index includes all columns but does + # not want them in the index definition + if row["is_included_column"] and not is_colstore: + # a noncludsted columnstore index reports that includes + # columns but requires that are listed as normal columns + index_def["include_columns"].append(row["name"]) + else: + index_def["column_names"].append(row["name"]) + + for tname, idx_dict in indexes_by_table.items(): + for index_info in idx_dict.values(): + index_info["dialect_options"]["mssql_include"] = index_info[ + "include_columns" + ] + result[(schema, tname)] = list(idx_dict.values()) + + for n in names: + key = (schema, name_map(n)) + if key not in result: + result[key] = ReflectionDefaults.indexes() + + def _fetch_multi_indexes_temp( + self, connection, temp_names, schema, result + ): + """Run the unified indexes queries against tempdb.""" + resolved_by_owner, original_by_mangled = self._resolve_temp_names( + connection, temp_names + ) + + for temp_owner, mangled_names in resolved_by_owner.items(): + + def temp_name_map(server_name, _by=original_by_mangled): + return _by.get(server_name, server_name) + + self._fetch_multi_indexes( connection, - tablename, - owner, - ReflectionDefaults.foreign_keys, - **kw, + owner=temp_owner, + names=mangled_names, + schema=schema, + name_map=temp_name_map, + result=result, + exec_opts={"schema_translate_map": {"sys": "tempdb.sys"}}, + ) + + @_db_plus_owner_multi + def get_multi_table_comment( + self, + connection, + dbname, + owner, + schema, + filter_names, + scope, + kind, + **kw, + ): + if not self.supports_comments: + raise NotImplementedError( + "Can't get table comments on current SQL Server " + "version in use" + ) + + ( + run_bulk, + regular_names, + multi_object_names, + temp_names, + ) = self._partition_filter_names( + connection, owner, filter_names, scope, kind + ) + + result = {} + + if run_bulk and multi_object_names: + name_map = self._multi_name_map(regular_names) + self._fetch_multi_table_comment( + connection, + owner=owner, + names=multi_object_names, + schema=schema, + name_map=name_map, + result=result, + exec_opts={"schema_translate_map": {}}, + ) + + if temp_names: + self._fetch_multi_table_comment_temp( + connection, + temp_names=temp_names, + schema=schema, + result=result, + ) + + return result.items() + + @lru_cache() + def _table_comment_select(self): + """Build the Core sys.* select for table comment reflection.""" + sys_objects = ischema.sys_objects + sys_schemas = ischema.sys_schemas + extended_properties = ischema.extended_properties + + s = ( + sql.select( + sys_objects.c.name.label("table_name"), + extended_properties.c.value.label("comment"), + ) + .select_from(sys_objects) + .join( + sys_schemas, + onclause=sys_objects.c.schema_id == sys_schemas.c.schema_id, + ) + .outerjoin( + extended_properties, + onclause=sql.and_( + extended_properties.c["class"] == 1, + extended_properties.c.name == "MS_Description", + extended_properties.c.major_id == sys_objects.c.object_id, + extended_properties.c.minor_id == 0, + ), + ) + .where(sys_objects.c.type.in_(("U", "V"))) + ) + return s + + def _fetch_multi_table_comment( + self, connection, owner, names, schema, name_map, result, exec_opts + ): + """Execute the unified table comment query for one catalog pass. + + Core sys.* select so ``schema_translate_map`` actually rewrites + the catalog reference for the tempdb pass. + """ + q = self._table_comment_select().where( + ischema.sys_schemas.c.name == owner, + ischema.sys_objects.c.name.in_( + sql.bindparam("filter_names", expanding=True) + ), + ) + + rp = connection.execute( + q, {"filter_names": names}, execution_options=exec_opts + ) + + for row in rp.mappings(): + table_name = name_map(row["table_name"]) + comment = row["comment"] + result[(schema, table_name)] = ( + {"text": comment} if comment else {"text": None} + ) + + for n in names: + key = (schema, name_map(n)) + if key not in result: + result[key] = ReflectionDefaults.table_comment() + + def _fetch_multi_table_comment_temp( + self, connection, temp_names, schema, result + ): + """Run the unified table comment query against tempdb.""" + resolved_by_owner, original_by_mangled = self._resolve_temp_names( + connection, temp_names + ) + + for temp_owner, mangled_names in resolved_by_owner.items(): + + def temp_name_map(server_name, _by=original_by_mangled): + return _by.get(server_name, server_name) + + self._fetch_multi_table_comment( + connection, + owner=temp_owner, + names=mangled_names, + schema=schema, + name_map=temp_name_map, + result=result, + exec_opts={"schema_translate_map": {"sys": "tempdb.sys"}}, ) + + # --- Single-table reflection wrappers (delegate to multi) --- + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 5249f3e39f..0f1a56f559 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -282,3 +282,64 @@ extended_properties = Table( Column("value", NVarcharSqlVariant), schema="sys", ) + +sys_schemas = Table( + "schemas", + ischema, + Column("schema_id", Integer), + Column("name", CoerceUnicode), + Column("principal_id", Integer), + schema="sys", +) + +sys_objects = Table( + "objects", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("schema_id", Integer), + Column("parent_object_id", Integer), + Column("type", String), # CHAR(2) + Column("type_desc", CoerceUnicode), + schema="sys", +) + +sys_key_constraints = Table( + "key_constraints", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("schema_id", Integer), + Column("parent_object_id", Integer), + Column("type", String), # CHAR(2) ('PK', 'UQ') + Column("unique_index_id", Integer), + schema="sys", +) + +sys_indexes = Table( + "indexes", + ischema, + Column("object_id", Integer), + Column("index_id", Integer), + Column("name", CoerceUnicode), + Column("type", Integer), # TINYINT: 1=clustered, 2=nonclustered + Column("type_desc", CoerceUnicode), + Column("is_unique", Boolean), + Column("is_primary_key", Boolean), + Column("is_unique_constraint", Boolean), + Column("filter_definition", CoerceUnicode), + schema="sys", +) + +sys_index_columns = Table( + "index_columns", + ischema, + Column("object_id", Integer), + Column("index_id", Integer), + Column("index_column_id", Integer), + Column("column_id", Integer), + Column("key_ordinal", Integer), + Column("is_descending_key", Boolean), + Column("is_included_column", Boolean), + schema="sys", +) diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 0c9f1485a0..59f5245e72 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -22,6 +22,8 @@ from sqlalchemy import types as sqltypes from sqlalchemy.dialects import mssql from sqlalchemy.dialects.mssql import base from sqlalchemy.dialects.mssql.information_schema import tables +from sqlalchemy.engine import ObjectKind +from sqlalchemy.engine import ObjectScope from sqlalchemy.pool import NullPool from sqlalchemy.schema import CreateIndex from sqlalchemy.testing import AssertsCompiledSQL @@ -293,6 +295,177 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): [(2, "bar", datetime.datetime(2020, 2, 2, 2, 2, 2))], ) + def test_get_multi_columns_temp_table(self, metadata, connection): + """Direct ``get_multi_columns`` API works for temp tables across + scope/kind combinations. + + Regression coverage for the case where a ``#``-prefixed name was + passed to a multi method outside the autoload-shaped + ``scope=ANY/kind=ANY`` call. + """ + tt = Table( + "#mr_tmp", + metadata, + Column("id", Integer, primary_key=True), + Column("name", mssql.NVARCHAR(50)), + ) + tt.create(connection) + + insp = inspect(connection) + + # ANY scope, TABLE kind: should return the temp table + r = dict( + insp.get_multi_columns( + filter_names=["#mr_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + ) + eq_(set(r.keys()), {(None, "#mr_tmp")}) + eq_( + [c["name"] for c in r[(None, "#mr_tmp")]], + ["id", "name"], + ) + + # TEMPORARY scope: should also return the temp table + r = dict( + insp.get_multi_columns( + filter_names=["#mr_tmp"], + scope=ObjectScope.TEMPORARY, + kind=ObjectKind.TABLE, + ) + ) + eq_(set(r.keys()), {(None, "#mr_tmp")}) + + # DEFAULT scope: must EXCLUDE the temp table + r = dict( + insp.get_multi_columns( + filter_names=["#mr_tmp"], + scope=ObjectScope.DEFAULT, + kind=ObjectKind.TABLE, + ) + ) + eq_(r, {}) + + # VIEW kind: must EXCLUDE the temp table (no temp views on mssql) + r = dict( + insp.get_multi_columns( + filter_names=["#mr_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.VIEW, + ) + ) + eq_(r, {}) + + def test_get_multi_pk_constraint_temp_table(self, metadata, connection): + tt = Table( + "#mr_pk_tmp", + metadata, + Column("id", Integer, primary_key=True), + Column("val", Integer), + ) + tt.create(connection) + + insp = inspect(connection) + r = dict( + insp.get_multi_pk_constraint( + filter_names=["#mr_pk_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + ) + in_((None, "#mr_pk_tmp"), r) + + def test_temp_reflection_does_not_leak_translate_map( + self, metadata, connection + ): + """Reflecting a temp table must not leave ``schema_translate_map`` + applied on the caller's connection. + + The unified multi reflection path runs the tempdb pass with + ``schema_translate_map={"sys": "tempdb.sys"}`` as an + execute-level option. Applying it on the connection (e.g. via + ``connection.execution_options(...)``) mutates the connection + in place and poisons subsequent Core sys.* queries. + """ + tt = Table( + "#mr_leak_tmp", + metadata, + Column("id", Integer, primary_key=True), + Column("data", mssql.NVARCHAR(50)), + ) + tt.create(connection) + + before = dict(connection.get_execution_options()) + + insp = inspect(connection) + # Each of these would route via the tempdb pass. + insp.get_multi_columns( + filter_names=["#mr_leak_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + insp.get_multi_pk_constraint( + filter_names=["#mr_leak_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + insp.get_multi_indexes( + filter_names=["#mr_leak_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + insp.get_multi_table_comment( + filter_names=["#mr_leak_tmp"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + + after = dict(connection.get_execution_options()) + eq_(before, after) + + def test_temp_reflection_with_caller_translate_map( + self, metadata, connection + ): + """Temp table reflection must work even when the caller has a + ``schema_translate_map`` set on the connection that contains a + ``"sys"`` key. + + Statement-level execution options DO NOT override connection + options for ``schema_translate_map`` (the connection wins + during option merge). Execute-level options DO override, so the + tempdb pass must pass its map at execute time, not on the + statement or the connection. + """ + tt = Table( + "#mr_hostile_map", + metadata, + Column("id", Integer, primary_key=True), + Column("val", Integer), + ) + tt.create(connection) + + # Hostile setup: caller has set a translate_map with a "sys" + # key for their own purposes. Our temp pass must still reach + # tempdb.sys.* despite this. + hostile = connection.execution_options( + schema_translate_map={"sys": "INFORMATION_SCHEMA"} + ) + + insp = inspect(hostile) + r = dict( + insp.get_multi_columns( + filter_names=["#mr_hostile_map"], + scope=ObjectScope.ANY, + kind=ObjectKind.TABLE, + ) + ) + eq_(set(r.keys()), {(None, "#mr_hostile_map")}) + eq_( + [c["name"] for c in r[(None, "#mr_hostile_map")]], + ["id", "val"], + ) + @testing.combinations( ("local_temp", "#tmp", True), ("global_temp", "##tmp", True),