From ddbd9dafffdedf6fb464947394c81c8b02153e14 Mon Sep 17 00:00:00 2001 From: Daniel Hall Date: Sat, 30 Jul 2022 15:12:20 -0400 Subject: [PATCH] Support comments on MSSQL Added support table and column comments on MSSQL when creating a table. Added support for reflecting table comments. Thanks to Daniel Hall for the help in this pull request. Fixes: #7844 Closes: #8225 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8225 Pull-request-sha: 540f4eb6395f9feed4b4240e3d22f539021948e9 Change-Id: I69f48c6dda4e00ec3d82fdeff13f3df9d735b7b0 --- doc/build/changelog/unreleased_20/7844.rst | 7 + lib/sqlalchemy/dialects/mssql/base.py | 145 +++++++++++++++--- .../dialects/mssql/information_schema.py | 35 ++++- test/dialect/mssql/test_reflection.py | 41 +++++ test/requirements.py | 2 +- 5 files changed, 204 insertions(+), 26 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7844.rst diff --git a/doc/build/changelog/unreleased_20/7844.rst b/doc/build/changelog/unreleased_20/7844.rst new file mode 100644 index 0000000000..88fae454b6 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7844.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: usecase, mssql + :tickets: 7844 + + Added support table and column comments on MSSQL when + creating a table. Added support for reflecting table comments. + Thanks to Daniel Hall for the help in this pull request. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f98df1c20b..c85d21ef7a 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2596,6 +2596,62 @@ class MSDDLCompiler(compiler.DDLCompiler): text += " PERSISTED" return text + def visit_set_table_comment(self, create): + schema = self.preparer.schema_for_object(create.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{0}, 'schema', {1}, 'table', {2}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table(create.element, use_schema=False), + ) + ) + + def visit_drop_table_comment(self, drop): + schema = self.preparer.schema_for_object(drop.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{0}, 'table', {1}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table(drop.element, use_schema=False), + ) + ) + + def visit_set_column_comment(self, create): + schema = self.preparer.schema_for_object(create.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{0}, 'schema', {1}, 'table', {2}, 'column', {3}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + create.element.table, use_schema=False + ), + self.preparer.format_column(create.element), + ) + ) + + def visit_drop_column_comment(self, drop): + schema = self.preparer.schema_for_object(drop.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{0}, 'table', {1}, 'column', {2}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + drop.element.table, use_schema=False + ), + self.preparer.format_column(drop.element), + ) + ) + def visit_create_sequence(self, create, **kw): prefix = None if create.element.data_type is not None: @@ -2789,6 +2845,8 @@ class MSDialect(default.DefaultDialect): supports_default_values = True supports_empty_insert = False + supports_comments = True + # supports_native_uuid is partial here, so we implement our # own impl type @@ -3254,6 +3312,33 @@ class MSDialect(default.DefaultDialect): else: raise exc.NoSuchTableError(f"{owner}.{viewname}") + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + schema_name = schema if schema else self.default_schema_name + COMMENT_SQL = """ + SELECT cast(com.value as nvarchar(max)) + FROM fn_listextendedproperty('MS_Description', + 'schema', :schema, 'table', :table, NULL, NULL + ) as com; + """ + + comment = connection.execute( + sql.text(COMMENT_SQL).bindparams( + sql.bindparam("schema", schema_name, ischema.CoerceUnicode()), + sql.bindparam("table", table_name, ischema.CoerceUnicode()), + ) + ).scalar() + if comment: + return {"text": comment} + else: + return self._default_or_error( + connection, + table_name, + None, + ReflectionDefaults.table_comment, + **kw, + ) + def _temp_table_name_like_pattern(self, tablename): # LIKE uses '%' to match zero or more characters and '_' to match any # single character. We want to match literal underscores, so T-SQL @@ -3314,24 +3399,6 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name == tablename full_name = columns.c.table_name - join = columns.join( - computed_cols, - onclause=sql.and_( - computed_cols.c.object_id == func.object_id(full_name), - computed_cols.c.name - == columns.c.column_name.collate("DATABASE_DEFAULT"), - ), - isouter=True, - ).join( - identity_cols, - onclause=sql.and_( - identity_cols.c.object_id == func.object_id(full_name), - identity_cols.c.name - == columns.c.column_name.collate("DATABASE_DEFAULT"), - ), - isouter=True, - ) - if self._supports_nvarchar_max: computed_definition = computed_cols.c.definition else: @@ -3340,17 +3407,53 @@ class MSDialect(default.DefaultDialect): computed_cols.c.definition, NVARCHAR(4000) ) + object_id = func.object_id(full_name) + s = ( sql.select( - columns, + columns.c.column_name, + columns.c.data_type, + columns.c.is_nullable, + columns.c.character_maximum_length, + columns.c.numeric_precision, + columns.c.numeric_scale, + columns.c.column_default, + columns.c.collation_name, computed_definition, computed_cols.c.is_persisted, identity_cols.c.is_identity, identity_cols.c.seed_value, identity_cols.c.increment_value, + ischema.extended_properties.c.value.label("comment"), + ) + .select_from(columns) + .outerjoin( + computed_cols, + onclause=sql.and_( + computed_cols.c.object_id == object_id, + computed_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + identity_cols, + onclause=sql.and_( + identity_cols.c.object_id == object_id, + identity_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + ischema.extended_properties, + onclause=sql.and_( + ischema.extended_properties.c["class"] == 1, + ischema.extended_properties.c.major_id == object_id, + ischema.extended_properties.c.minor_id + == columns.c.ordinal_position, + ischema.extended_properties.c.name == "MS_Description", + ), ) .where(whereclause) - .select_from(join) .order_by(columns.c.ordinal_position) ) @@ -3371,6 +3474,7 @@ class MSDialect(default.DefaultDialect): is_identity = row[identity_cols.c.is_identity] identity_start = row[identity_cols.c.seed_value] identity_increment = row[identity_cols.c.increment_value] + comment = row[ischema.extended_properties.c.value] coltype = self.ischema_names.get(type_, None) @@ -3412,6 +3516,7 @@ class MSDialect(default.DefaultDialect): "nullable": nullable, "default": default, "autoincrement": is_identity is not None, + "comment": comment, } if definition is not None and is_persisted is not None: diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index b7e560bf14..33ab1f9929 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -6,7 +6,6 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors - from ... import cast from ... import Column from ... import MetaData @@ -16,6 +15,7 @@ from ...sql import expression from ...types import Boolean from ...types import Integer from ...types import Numeric +from ...types import NVARCHAR from ...types import String from ...types import TypeDecorator from ...types import Unicode @@ -198,7 +198,7 @@ sequences = Table( ) -class IdentitySqlVariant(TypeDecorator): +class NumericSqlVariant(TypeDecorator): r"""This type casts sql_variant columns in the identity_columns view to numeric. This is required because: @@ -220,9 +220,34 @@ identity_columns = Table( Column("object_id", Integer), Column("name", CoerceUnicode), Column("is_identity", Boolean), - Column("seed_value", IdentitySqlVariant), - Column("increment_value", IdentitySqlVariant), - Column("last_value", IdentitySqlVariant), + Column("seed_value", NumericSqlVariant), + Column("increment_value", NumericSqlVariant), + Column("last_value", NumericSqlVariant), Column("is_not_for_replication", Boolean), schema="sys", ) + + +class NVarcharSqlVariant(TypeDecorator): + """This type casts sql_variant columns in the extended_properties view + to nvarchar. This is required because pyodbc does not support sql_variant + """ + + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, NVARCHAR) + + +extended_properties = Table( + "extended_properties", + ischema, + Column("class", Integer), # TINYINT + Column("class_desc", CoerceUnicode), + Column("major_id", Integer), + Column("minor_id", Integer), + Column("name", CoerceUnicode), + Column("value", NVarcharSqlVariant), + schema="sys", +) diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 73fd6bf1ca..cd8742b9c8 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -161,6 +161,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): "nullable": False, "default": None, "autoincrement": False, + "comment": None, }, { "name": "data", @@ -170,6 +171,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): "nullable": True, "default": None, "autoincrement": False, + "comment": None, }, ], ) @@ -403,6 +405,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): "nullable": False, "default": None, "autoincrement": False, + "comment": None, } ], ) @@ -597,6 +600,44 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): is_(col["type"].length, None) in_("max", str(col["type"].compile(dialect=connection.dialect))) + def test_comments(self, metadata, connection): + Table( + "tbl_with_comments", + metadata, + Column( + "id", + types.Integer, + primary_key=True, + comment="pk comment 🔑", + ), + Column("no_comment", types.Integer), + Column( + "has_comment", + types.String(20), + comment="has the comment § méil 📧", + ), + comment="table comment çòé 🐍", + ) + metadata.create_all(connection) + insp = inspect(connection) + eq_( + insp.get_table_comment("tbl_with_comments"), + {"text": "table comment çòé 🐍"}, + ) + + cols = { + col["name"]: col["comment"] + for col in insp.get_columns("tbl_with_comments") + } + eq_( + cols, + { + "id": "pk comment 🔑", + "no_comment": None, + "has_comment": "has the comment § méil 📧", + }, + ) + class InfoCoerceUnicodeTest(fixtures.TestBase, AssertsCompiledSQL): def test_info_unicode_cast_no_2000(self): diff --git a/test/requirements.py b/test/requirements.py index 4475784440..9001b52367 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -168,7 +168,7 @@ class DefaultRequirements(SuiteRequirements): @property def comment_reflection(self): - return only_on(["postgresql", "mysql", "mariadb", "oracle"]) + return only_on(["postgresql", "mysql", "mariadb", "oracle", "mssql"]) @property def constraint_comment_reflection(self): -- 2.47.2