]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support comments on MSSQL
authorDaniel Hall <daniel.hall@moesol.com>
Sat, 30 Jul 2022 19:12:20 +0000 (15:12 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 10 Aug 2022 19:11:59 +0000 (21:11 +0200)
Added support table and column comments on MSSQL when
creating a table. Added support for reflecting table comments.
Thanks to Daniel Hall for the help in this pull request.

Fixes: #7844
Closes: #8225
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8225
Pull-request-sha: 540f4eb6395f9feed4b4240e3d22f539021948e9

Change-Id: I69f48c6dda4e00ec3d82fdeff13f3df9d735b7b0

doc/build/changelog/unreleased_20/7844.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/information_schema.py
test/dialect/mssql/test_reflection.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/7844.rst b/doc/build/changelog/unreleased_20/7844.rst
new file mode 100644 (file)
index 0000000..88fae45
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, mssql
+    :tickets: 7844
+
+    Added support table and column comments on MSSQL when
+    creating a table. Added support for reflecting table comments.
+    Thanks to Daniel Hall for the help in this pull request.
index f98df1c20ba458e4a572eb9cdaae778b397239fb..c85d21ef7aa238e4c45af8c6cb5646c37eb7afaf 100644 (file)
@@ -2596,6 +2596,62 @@ class MSDDLCompiler(compiler.DDLCompiler):
             text += " PERSISTED"
         return text
 
+    def visit_set_table_comment(self, create):
+        schema = self.preparer.schema_for_object(create.element)
+        schema_name = schema if schema else self.dialect.default_schema_name
+        return (
+            "execute sp_addextendedproperty 'MS_Description', "
+            "{0}, 'schema', {1}, 'table', {2}".format(
+                self.sql_compiler.render_literal_value(
+                    create.element.comment, sqltypes.NVARCHAR()
+                ),
+                self.preparer.quote_schema(schema_name),
+                self.preparer.format_table(create.element, use_schema=False),
+            )
+        )
+
+    def visit_drop_table_comment(self, drop):
+        schema = self.preparer.schema_for_object(drop.element)
+        schema_name = schema if schema else self.dialect.default_schema_name
+        return (
+            "execute sp_dropextendedproperty 'MS_Description', 'schema', "
+            "{0}, 'table', {1}".format(
+                self.preparer.quote_schema(schema_name),
+                self.preparer.format_table(drop.element, use_schema=False),
+            )
+        )
+
+    def visit_set_column_comment(self, create):
+        schema = self.preparer.schema_for_object(create.element.table)
+        schema_name = schema if schema else self.dialect.default_schema_name
+        return (
+            "execute sp_addextendedproperty 'MS_Description', "
+            "{0}, 'schema', {1}, 'table', {2}, 'column', {3}".format(
+                self.sql_compiler.render_literal_value(
+                    create.element.comment, sqltypes.NVARCHAR()
+                ),
+                self.preparer.quote_schema(schema_name),
+                self.preparer.format_table(
+                    create.element.table, use_schema=False
+                ),
+                self.preparer.format_column(create.element),
+            )
+        )
+
+    def visit_drop_column_comment(self, drop):
+        schema = self.preparer.schema_for_object(drop.element.table)
+        schema_name = schema if schema else self.dialect.default_schema_name
+        return (
+            "execute sp_dropextendedproperty 'MS_Description', 'schema', "
+            "{0}, 'table', {1}, 'column', {2}".format(
+                self.preparer.quote_schema(schema_name),
+                self.preparer.format_table(
+                    drop.element.table, use_schema=False
+                ),
+                self.preparer.format_column(drop.element),
+            )
+        )
+
     def visit_create_sequence(self, create, **kw):
         prefix = None
         if create.element.data_type is not None:
@@ -2789,6 +2845,8 @@ class MSDialect(default.DefaultDialect):
     supports_default_values = True
     supports_empty_insert = False
 
+    supports_comments = True
+
     # supports_native_uuid is partial here, so we implement our
     # own impl type
 
@@ -3254,6 +3312,33 @@ class MSDialect(default.DefaultDialect):
         else:
             raise exc.NoSuchTableError(f"{owner}.{viewname}")
 
+    @reflection.cache
+    def get_table_comment(self, connection, table_name, schema=None, **kw):
+        schema_name = schema if schema else self.default_schema_name
+        COMMENT_SQL = """
+            SELECT cast(com.value as nvarchar(max))
+            FROM fn_listextendedproperty('MS_Description',
+                'schema', :schema, 'table', :table, NULL, NULL
+            ) as com;
+        """
+
+        comment = connection.execute(
+            sql.text(COMMENT_SQL).bindparams(
+                sql.bindparam("schema", schema_name, ischema.CoerceUnicode()),
+                sql.bindparam("table", table_name, ischema.CoerceUnicode()),
+            )
+        ).scalar()
+        if comment:
+            return {"text": comment}
+        else:
+            return self._default_or_error(
+                connection,
+                table_name,
+                None,
+                ReflectionDefaults.table_comment,
+                **kw,
+            )
+
     def _temp_table_name_like_pattern(self, tablename):
         # LIKE uses '%' to match zero or more characters and '_' to match any
         # single character. We want to match literal underscores, so T-SQL
@@ -3314,24 +3399,6 @@ class MSDialect(default.DefaultDialect):
             whereclause = columns.c.table_name == tablename
             full_name = columns.c.table_name
 
-        join = columns.join(
-            computed_cols,
-            onclause=sql.and_(
-                computed_cols.c.object_id == func.object_id(full_name),
-                computed_cols.c.name
-                == columns.c.column_name.collate("DATABASE_DEFAULT"),
-            ),
-            isouter=True,
-        ).join(
-            identity_cols,
-            onclause=sql.and_(
-                identity_cols.c.object_id == func.object_id(full_name),
-                identity_cols.c.name
-                == columns.c.column_name.collate("DATABASE_DEFAULT"),
-            ),
-            isouter=True,
-        )
-
         if self._supports_nvarchar_max:
             computed_definition = computed_cols.c.definition
         else:
@@ -3340,17 +3407,53 @@ class MSDialect(default.DefaultDialect):
                 computed_cols.c.definition, NVARCHAR(4000)
             )
 
+        object_id = func.object_id(full_name)
+
         s = (
             sql.select(
-                columns,
+                columns.c.column_name,
+                columns.c.data_type,
+                columns.c.is_nullable,
+                columns.c.character_maximum_length,
+                columns.c.numeric_precision,
+                columns.c.numeric_scale,
+                columns.c.column_default,
+                columns.c.collation_name,
                 computed_definition,
                 computed_cols.c.is_persisted,
                 identity_cols.c.is_identity,
                 identity_cols.c.seed_value,
                 identity_cols.c.increment_value,
+                ischema.extended_properties.c.value.label("comment"),
+            )
+            .select_from(columns)
+            .outerjoin(
+                computed_cols,
+                onclause=sql.and_(
+                    computed_cols.c.object_id == object_id,
+                    computed_cols.c.name
+                    == columns.c.column_name.collate("DATABASE_DEFAULT"),
+                ),
+            )
+            .outerjoin(
+                identity_cols,
+                onclause=sql.and_(
+                    identity_cols.c.object_id == object_id,
+                    identity_cols.c.name
+                    == columns.c.column_name.collate("DATABASE_DEFAULT"),
+                ),
+            )
+            .outerjoin(
+                ischema.extended_properties,
+                onclause=sql.and_(
+                    ischema.extended_properties.c["class"] == 1,
+                    ischema.extended_properties.c.major_id == object_id,
+                    ischema.extended_properties.c.minor_id
+                    == columns.c.ordinal_position,
+                    ischema.extended_properties.c.name == "MS_Description",
+                ),
             )
             .where(whereclause)
-            .select_from(join)
             .order_by(columns.c.ordinal_position)
         )
 
@@ -3371,6 +3474,7 @@ class MSDialect(default.DefaultDialect):
             is_identity = row[identity_cols.c.is_identity]
             identity_start = row[identity_cols.c.seed_value]
             identity_increment = row[identity_cols.c.increment_value]
+            comment = row[ischema.extended_properties.c.value]
 
             coltype = self.ischema_names.get(type_, None)
 
@@ -3412,6 +3516,7 @@ class MSDialect(default.DefaultDialect):
                 "nullable": nullable,
                 "default": default,
                 "autoincrement": is_identity is not None,
+                "comment": comment,
             }
 
             if definition is not None and is_persisted is not None:
index b7e560bf144357a1b4f7dba8fd58f6a120562f98..33ab1f99290b0984555d8492d02195e96083dfa1 100644 (file)
@@ -6,7 +6,6 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
 
-
 from ... import cast
 from ... import Column
 from ... import MetaData
@@ -16,6 +15,7 @@ from ...sql import expression
 from ...types import Boolean
 from ...types import Integer
 from ...types import Numeric
+from ...types import NVARCHAR
 from ...types import String
 from ...types import TypeDecorator
 from ...types import Unicode
@@ -198,7 +198,7 @@ sequences = Table(
 )
 
 
-class IdentitySqlVariant(TypeDecorator):
+class NumericSqlVariant(TypeDecorator):
     r"""This type casts sql_variant columns in the identity_columns view
     to numeric. This is required because:
 
@@ -220,9 +220,34 @@ identity_columns = Table(
     Column("object_id", Integer),
     Column("name", CoerceUnicode),
     Column("is_identity", Boolean),
-    Column("seed_value", IdentitySqlVariant),
-    Column("increment_value", IdentitySqlVariant),
-    Column("last_value", IdentitySqlVariant),
+    Column("seed_value", NumericSqlVariant),
+    Column("increment_value", NumericSqlVariant),
+    Column("last_value", NumericSqlVariant),
     Column("is_not_for_replication", Boolean),
     schema="sys",
 )
+
+
+class NVarcharSqlVariant(TypeDecorator):
+    """This type casts sql_variant columns in the extended_properties view
+    to nvarchar. This is required because pyodbc does not support sql_variant
+    """
+
+    impl = Unicode
+    cache_ok = True
+
+    def column_expression(self, colexpr):
+        return cast(colexpr, NVARCHAR)
+
+
+extended_properties = Table(
+    "extended_properties",
+    ischema,
+    Column("class", Integer),  # TINYINT
+    Column("class_desc", CoerceUnicode),
+    Column("major_id", Integer),
+    Column("minor_id", Integer),
+    Column("name", CoerceUnicode),
+    Column("value", NVarcharSqlVariant),
+    schema="sys",
+)
index 73fd6bf1caa277368346f8f94f4846350596970c..cd8742b9c8519040bc82f6b830276d077fda2efd 100644 (file)
@@ -161,6 +161,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
                             "nullable": False,
                             "default": None,
                             "autoincrement": False,
+                            "comment": None,
                         },
                         {
                             "name": "data",
@@ -170,6 +171,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
                             "nullable": True,
                             "default": None,
                             "autoincrement": False,
+                            "comment": None,
                         },
                     ],
                 )
@@ -403,6 +405,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
                         "nullable": False,
                         "default": None,
                         "autoincrement": False,
+                        "comment": None,
                     }
                 ],
             )
@@ -597,6 +600,44 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
             is_(col["type"].length, None)
             in_("max", str(col["type"].compile(dialect=connection.dialect)))
 
+    def test_comments(self, metadata, connection):
+        Table(
+            "tbl_with_comments",
+            metadata,
+            Column(
+                "id",
+                types.Integer,
+                primary_key=True,
+                comment="pk comment 🔑",
+            ),
+            Column("no_comment", types.Integer),
+            Column(
+                "has_comment",
+                types.String(20),
+                comment="has the comment § méil 📧",
+            ),
+            comment="table comment çòé 🐍",
+        )
+        metadata.create_all(connection)
+        insp = inspect(connection)
+        eq_(
+            insp.get_table_comment("tbl_with_comments"),
+            {"text": "table comment çòé 🐍"},
+        )
+
+        cols = {
+            col["name"]: col["comment"]
+            for col in insp.get_columns("tbl_with_comments")
+        }
+        eq_(
+            cols,
+            {
+                "id": "pk comment 🔑",
+                "no_comment": None,
+                "has_comment": "has the comment § méil 📧",
+            },
+        )
+
 
 class InfoCoerceUnicodeTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_info_unicode_cast_no_2000(self):
index 4475784440523eb2ffab3be0e126ce456fe15972..9001b52367ff52d50cc8ce519be2574953279f23 100644 (file)
@@ -168,7 +168,7 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def comment_reflection(self):
-        return only_on(["postgresql", "mysql", "mariadb", "oracle"])
+        return only_on(["postgresql", "mysql", "mariadb", "oracle", "mssql"])
 
     @property
     def constraint_comment_reflection(self):