]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixed compile for mssql dialect
authorAdiorz <adiorz90@gmail.com>
Mon, 7 Dec 2020 15:03:53 +0000 (15:03 +0000)
committerAdiorz <adiorz90@gmail.com>
Mon, 7 Dec 2020 17:08:30 +0000 (17:08 +0000)
Fixed string compilation when both mssql_include and mssql_where are used

Fixes: #5751
lib/sqlalchemy/dialects/mssql/base.py
test/dialect/mssql/test_compiler.py

index 9addbf31fd6564e5d2d607814273294a1da9b90f..911e1791aea85691d90522d9e86bf54819a465e4 100644 (file)
@@ -2286,18 +2286,6 @@ class MSDDLCompiler(compiler.DDLCompiler):
             ),
         )
 
-        whereclause = index.dialect_options["mssql"]["where"]
-
-        if whereclause is not None:
-            whereclause = coercions.expect(
-                roles.DDLExpressionRole, whereclause
-            )
-
-            where_compiled = self.sql_compiler.process(
-                whereclause, include_table=False, literal_binds=True
-            )
-            text += " WHERE " + where_compiled
-
         # handle other included columns
         if index.dialect_options["mssql"]["include"]:
             inclusions = [
@@ -2311,6 +2299,18 @@ class MSDDLCompiler(compiler.DDLCompiler):
                 [preparer.quote(c.name) for c in inclusions]
             )
 
+        whereclause = index.dialect_options["mssql"]["where"]
+
+        if whereclause is not None:
+            whereclause = coercions.expect(
+                roles.DDLExpressionRole, whereclause
+            )
+
+            where_compiled = self.sql_compiler.process(
+                whereclause, include_table=False, literal_binds=True
+            )
+            text += " WHERE " + where_compiled
+
         return text
 
     def visit_drop_index(self, drop):
index 568d346f5c91086f11b47271d5d41aae29679e44..eea401189903c5b226d2c5b4cf6f1460505098ca 100644 (file)
@@ -1281,6 +1281,29 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
         )
 
+    def test_index_include_where(self):
+        metadata = MetaData()
+        tbl = Table(
+            "test",
+            metadata,
+            Column("x", Integer),
+            Column("y", Integer),
+            Column("z", Integer),
+        )
+        idx = Index("foo", tbl.c.x, mssql_include=[tbl.c.y],
+                    mssql_where=tbl.c.y > 1)
+        self.assert_compile(
+            schema.CreateIndex(idx),
+            "CREATE INDEX foo ON test (x) INCLUDE (y) WHERE y > 1"
+        )
+
+        idx = Index("foo", tbl.c.x, mssql_include=[tbl.c.y],
+                    mssql_where="y > 1")
+        self.assert_compile(
+            schema.CreateIndex(idx),
+            "CREATE INDEX foo ON test (x) INCLUDE (y) WHERE y > 1"
+        )
+
     def test_try_cast(self):
         metadata = MetaData()
         t1 = Table("t1", metadata, Column("id", Integer, primary_key=True))