]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
include mssql_clustered dialect_options when reflecting - issue #8288
authorJohn Lennox <john.lennox@comcast.net>
Sun, 21 Aug 2022 05:35:44 +0000 (01:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Aug 2022 21:52:46 +0000 (17:52 -0400)
Implemented reflection of the "clustered index" flag ``mssql_clustered``
for the SQL Server dialect. Pull request courtesy John Lennox.

Fixes: #8288
Closes: #8289
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8289
Pull-request-sha: 1bb57352e3e31d8fb7de69ab5e60e5464949f640

Change-Id: Ife367066328f9e47ad823e4098647964a18e21e8

doc/build/changelog/unreleased_20/8288.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/mssql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/8288.rst b/doc/build/changelog/unreleased_20/8288.rst
new file mode 100644 (file)
index 0000000..f2f775d
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, mssql
+    :tickets: 8288
+
+    Implemented reflection of the "clustered index" flag ``mssql_clustered``
+    for the SQL Server dialect. Pull request courtesy John Lennox.
index c85d21ef7aa238e4c45af8c6cb5646c37eb7afaf..94e0826118f9a5038790892db98ce25d10fb6df9 100644 (file)
@@ -3218,6 +3218,9 @@ class MSDialect(default.DefaultDialect):
         rp = connection.execution_options(future_result=True).execute(
             sql.text(
                 "select ind.index_id, ind.is_unique, ind.name, "
+                "case when ind.index_id = 1 "
+                "then cast(1 as bit) "
+                "else cast(0 as bit) end as is_clustered, "
                 f"{filter_definition} "
                 "from sys.indexes as ind join sys.tables as tab on "
                 "ind.object_id=tab.object_id "
@@ -3240,6 +3243,7 @@ class MSDialect(default.DefaultDialect):
                 "unique": row["is_unique"] == 1,
                 "column_names": [],
                 "include_columns": [],
+                "dialect_options": {"mssql_clustered": row["is_clustered"]},
             }
 
             if row["filter_definition"] is not None:
@@ -3566,7 +3570,15 @@ class MSDialect(default.DefaultDialect):
         # Primary key constraints
         s = (
             sql.select(
-                C.c.column_name, TC.c.constraint_type, C.c.constraint_name
+                C.c.column_name,
+                TC.c.constraint_type,
+                C.c.constraint_name,
+                func.objectproperty(
+                    func.object_id(
+                        C.c.table_schema + "." + C.c.constraint_name
+                    ),
+                    "CnstIsClustKey",
+                ).label("is_clustered"),
             )
             .where(
                 sql.and_(
@@ -3580,13 +3592,20 @@ class MSDialect(default.DefaultDialect):
         )
         c = connection.execution_options(future_result=True).execute(s)
         constraint_name = None
+        is_clustered = None
         for row in c.mappings():
             if "PRIMARY" in row[TC.c.constraint_type.name]:
                 pkeys.append(row["COLUMN_NAME"])
                 if constraint_name is None:
                     constraint_name = row[C.c.constraint_name.name]
+                if is_clustered is None:
+                    is_clustered = row["is_clustered"]
         if pkeys:
-            return {"constrained_columns": pkeys, "name": constraint_name}
+            return {
+                "constrained_columns": pkeys,
+                "name": constraint_name,
+                "dialect_options": {"mssql_clustered": is_clustered},
+            }
         else:
             return self._default_or_error(
                 connection,
index a3737a91a1463531e3618ff10d814da1aa7b8a8d..7e54ee57a0d222f6334af8b54e202f53de5b6c63 100644 (file)
@@ -2396,20 +2396,25 @@ class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase):
 
         insp = inspect(connection)
 
+        get_indexes = insp.get_indexes("t")
         eq_(
-            insp.get_indexes("t"),
+            get_indexes,
             [
                 {
                     "name": "t_idx",
                     "column_names": ["x"],
                     "include_columns": ["y"],
                     "unique": False,
-                    "dialect_options": {
-                        "%s_include" % connection.engine.name: ["y"]
-                    },
+                    "dialect_options": mock.ANY,
                 }
             ],
         )
+        eq_(
+            get_indexes[0]["dialect_options"][
+                "%s_include" % connection.engine.name
+            ],
+            ["y"],
+        )
 
         t2 = Table("t", MetaData(), autoload_with=connection)
         eq_(
index cd8742b9c8519040bc82f6b830276d077fda2efd..f682538b378131eb20c4cd57f865042a59c92ef3 100644 (file)
@@ -571,10 +571,166 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
 
         t2 = Table("t", MetaData(), autoload_with=connection)
         idx = list(sorted(t2.indexes, key=lambda idx: idx.name))[0]
+        self.assert_compile(
+            CreateIndex(idx),
+            "CREATE NONCLUSTERED INDEX idx_x ON t (x) WHERE ([x]='test')",
+        )
+
+    def test_index_reflection_clustered(self, metadata, connection):
+        """
+        when the result of get_indexes() is used to build an index it should
+        include the CLUSTERED keyword when appropriate
+        """
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
+        )
+        Index("idx_x", t1.c.x, mssql_clustered=True)
+        Index("idx_y", t1.c.y)
+        metadata.create_all(connection)
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
+
+        clustered_index = ""
+        for ix in ind:
+            if ix["dialect_options"]["mssql_clustered"]:
+                clustered_index = ix["name"]
+
+        eq_(clustered_index, "idx_x")
+
+        t2 = Table("t", MetaData(), autoload_with=connection)
+        idx = list(sorted(t2.indexes, key=lambda idx: idx.name))[0]
+
+        self.assert_compile(
+            CreateIndex(idx), "CREATE CLUSTERED INDEX idx_x ON t (x)"
+        )
+
+    def test_index_reflection_filtered_and_clustered(
+        self, metadata, connection
+    ):
+        """
+        table with one filtered index and one clustered index so each index
+        will have different dialect_options keys
+        """
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
+        )
+        Index("idx_x", t1.c.x, mssql_clustered=True)
+        Index("idx_y", t1.c.y, mssql_where=t1.c.y >= 5)
+        metadata.create_all(connection)
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
+
+        clustered_index = ""
+        for ix in ind:
+            if ix["dialect_options"]["mssql_clustered"]:
+                clustered_index = ix["name"]
+
+        eq_(clustered_index, "idx_x")
+
+        filtered_indexes = []
+        for ix in ind:
+            if "dialect_options" in ix:
+                if "mssql_where" in ix["dialect_options"]:
+                    filtered_indexes.append(
+                        ix["dialect_options"]["mssql_where"]
+                    )
+
+        eq_(sorted(filtered_indexes), ["([y]>=(5))"])
+
+        t2 = Table("t", MetaData(), autoload_with=connection)
+        clustered_idx = list(
+            sorted(t2.indexes, key=lambda clustered_idx: clustered_idx.name)
+        )[0]
+        filtered_idx = list(
+            sorted(t2.indexes, key=lambda filtered_idx: filtered_idx.name)
+        )[1]
+
+        self.assert_compile(
+            CreateIndex(clustered_idx), "CREATE CLUSTERED INDEX idx_x ON t (x)"
+        )
 
         self.assert_compile(
-            CreateIndex(idx), "CREATE INDEX idx_x ON t (x) WHERE ([x]='test')"
+            CreateIndex(filtered_idx),
+            "CREATE NONCLUSTERED INDEX idx_y ON t (y) WHERE ([y]>=(5))",
+        )
+
+    def test_index_reflection_nonclustered(self, metadata, connection):
+        """
+        one index created by specifying mssql_clustered=False
+        one created without specifying mssql_clustered property so it will
+        use default of NONCLUSTERED.
+        When reflected back mssql_clustered=False should be included in both
+        """
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
+        )
+        Index("idx_x", t1.c.x, mssql_clustered=False)
+        Index("idx_y", t1.c.y)
+        metadata.create_all(connection)
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
+
+        for ix in ind:
+            assert ix["dialect_options"]["mssql_clustered"] == False
+
+        t2 = Table("t", MetaData(), autoload_with=connection)
+        idx = list(sorted(t2.indexes, key=lambda idx: idx.name))[0]
+
+        self.assert_compile(
+            CreateIndex(idx), "CREATE NONCLUSTERED INDEX idx_x ON t (x)"
+        )
+
+    def test_primary_key_reflection_clustered(self, metadata, connection):
+        """
+        A primary key will be clustered by default if no other clustered index
+        exists.
+        When reflected back, mssql_clustered=True should be present.
+        """
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
         )
+        PrimaryKeyConstraint(t1.c.id, name="pk_t")
+
+        metadata.create_all(connection)
+        pk_reflect = testing.db.dialect.get_pk_constraint(
+            connection, "t", None
+        )
+
+        assert pk_reflect["dialect_options"]["mssql_clustered"] == True
+
+    def test_primary_key_reflection_nonclustered(self, metadata, connection):
+        """
+        Nonclustered primary key should include mssql_clustered=False
+        when reflected back
+        """
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
+        )
+        PrimaryKeyConstraint(t1.c.id, name="pk_t", mssql_clustered=False)
+
+        metadata.create_all(connection)
+        pk_reflect = testing.db.dialect.get_pk_constraint(
+            connection, "t", None
+        )
+
+        assert pk_reflect["dialect_options"]["mssql_clustered"] == False
 
     def test_max_ident_in_varchar_not_present(self, metadata, connection):
         """test [ticket:3504].