]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Reflect mssql/postgresql filtered/partial indexes
authorRamonWill <ramonwilliams@hotmail.co.uk>
Thu, 20 Aug 2020 19:05:39 +0000 (15:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Sep 2020 14:30:43 +0000 (10:30 -0400)
Added support for inspection / reflection of partial indexes / filtered
indexes, i.e. those which use the ``mssql_where`` or ``postgresql_where``
parameters, with :class:`_schema.Index`.   The entry is both part of the
dictionary returned by :meth:`.Inspector.get_indexes` as well as part of a
reflected :class:`_schema.Index` construct that was reflected.  Pull
request courtesy Ramon Williams.

**Have a nice day!**
Fixes: #4966
Closes: #5504
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5504
Pull-request-sha: b3018bac987081193b2e65cfdb6aeb7d5d270fcd

Change-Id: Icbb2f93d1545700718ccb5222097185b815f5dbc

doc/build/changelog/unreleased_14/4966.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_reflection.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_14/4966.rst b/doc/build/changelog/unreleased_14/4966.rst
new file mode 100644 (file)
index 0000000..e24ccbf
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, mssql, postgresql
+    :tickets: 4966
+
+    Added support for inspection / reflection of partial indexes / filtered
+    indexes, i.e. those which use the ``mssql_where`` or ``postgresql_where``
+    parameters, with :class:`_schema.Index`.   The entry is both part of the
+    dictionary returned by :meth:`.Inspector.get_indexes` as well as part of a
+    reflected :class:`_schema.Index` construct that was reflected.  Pull
+    request courtesy Ramon Williams.
index ab6e19cf4152a13e166fea50a86d02be4bcfeaef..e3d16f3f34f1af3c657048f6660e6e8e7b742469 100644 (file)
@@ -728,11 +728,13 @@ from ... import util
 from ...engine import cursor as _cursor
 from ...engine import default
 from ...engine import reflection
+from ...sql import coercions
 from ...sql import compiler
 from ...sql import elements
 from ...sql import expression
 from ...sql import func
 from ...sql import quoted_name
+from ...sql import roles
 from ...sql import util as sql_util
 from ...sql.type_api import to_instance
 from ...types import BIGINT
@@ -2205,6 +2207,10 @@ class MSDDLCompiler(compiler.DDLCompiler):
         whereclause = index.dialect_options["mssql"]["where"]
 
         if whereclause is not None:
+            whereclause = coercions.expect(
+                roles.DDLExpressionRole, whereclause
+            )
+
             where_compiled = self.sql_compiler.process(
                 whereclause, include_table=False, literal_binds=True
             )
@@ -2785,7 +2791,8 @@ class MSDialect(default.DefaultDialect):
 
         rp = connection.execution_options(future_result=True).execute(
             sql.text(
-                "select ind.index_id, ind.is_unique, ind.name "
+                "select ind.index_id, ind.is_unique, ind.name, "
+                "ind.filter_definition "
                 "from sys.indexes as ind join sys.tables as tab on "
                 "ind.object_id=tab.object_id "
                 "join sys.schemas as sch on sch.schema_id=tab.schema_id "
@@ -2806,6 +2813,12 @@ class MSDialect(default.DefaultDialect):
                 "unique": row["is_unique"] == 1,
                 "column_names": [],
             }
+
+            if row["filter_definition"] is not None:
+                indexes[row["index_id"]].setdefault("dialect_options", {})[
+                    "mssql_where"
+                ] = row["filter_definition"]
+
         rp = connection.execution_options(future_result=True).execute(
             sql.text(
                 "select ind_col.index_id, ind_col.object_id, col.name "
index c56cccd8dbbcc6ac8e151b48707dcdb88ea99847..4095114168b57158517c39c17f8c165ed0faadb5 100644 (file)
@@ -2131,6 +2131,10 @@ class PGDDLCompiler(compiler.DDLCompiler):
         whereclause = index.dialect_options["postgresql"]["where"]
 
         if whereclause is not None:
+            whereclause = coercions.expect(
+                roles.DDLExpressionRole, whereclause
+            )
+
             where_compiled = self.sql_compiler.process(
                 whereclause, include_table=False, literal_binds=True
             )
@@ -3459,9 +3463,10 @@ class PGDialect(default.DefaultDialect):
             IDX_SQL = """
               SELECT
                   i.relname as relname,
-                  ix.indisunique, ix.indexprs, ix.indpred,
+                  ix.indisunique, ix.indexprs,
                   a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
                   ix.indoption::varchar, i.reloptions, am.amname,
+                  pg_get_expr(ix.indpred, ix.indrelid),
                   %s as indnkeyatts
               FROM
                   pg_class t
@@ -3504,7 +3509,6 @@ class PGDialect(default.DefaultDialect):
                 idx_name,
                 unique,
                 expr,
-                prd,
                 col,
                 col_num,
                 conrelid,
@@ -3512,6 +3516,7 @@ class PGDialect(default.DefaultDialect):
                 idx_option,
                 options,
                 amname,
+                filter_definition,
                 indnkeyatts,
             ) = row
 
@@ -3524,13 +3529,6 @@ class PGDialect(default.DefaultDialect):
                 sv_idx_name = idx_name
                 continue
 
-            if prd and not idx_name == sv_idx_name:
-                util.warn(
-                    "Predicate of partial index %s ignored during reflection"
-                    % idx_name
-                )
-                sv_idx_name = idx_name
-
             has_idx = idx_name in indexes
             index = indexes[idx_name]
             if col is not None:
@@ -3586,6 +3584,9 @@ class PGDialect(default.DefaultDialect):
                 if amname and amname != "btree":
                     index["amname"] = amname
 
+                if filter_definition:
+                    index["postgresql_where"] = filter_definition
+
         result = []
         for name, idx in indexes.items():
             entry = {
@@ -3608,6 +3609,10 @@ class PGDialect(default.DefaultDialect):
                 entry.setdefault("dialect_options", {})[
                     "postgresql_using"
                 ] = idx["amname"]
+            if "postgresql_where" in idx:
+                entry.setdefault("dialect_options", {})[
+                    "postgresql_where"
+                ] = idx["postgresql_where"]
             result.append(entry)
         return result
 
index 83a6108882eb8fd926787a0c5f2d73e5e93b915c..bedbe31c19e1cb841fb81a8e9c845a8774f19196 100644 (file)
@@ -1299,6 +1299,12 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1",
         )
 
+        idx = Index("test_idx_data_1", tbl.c.data, mssql_where="data > 1")
+        self.assert_compile(
+            schema.CreateIndex(idx),
+            "CREATE INDEX test_idx_data_1 ON test (data) WHERE data > 1",
+        )
+
     def test_index_ordering(self):
         metadata = MetaData()
         tbl = Table(
index 176d3d2ecbaa79daac0189c2988889258898869d..0e1ee89661d4b126ec993a2abda6b2735d5dafcd 100644 (file)
@@ -17,6 +17,7 @@ from sqlalchemy.dialects import mssql
 from sqlalchemy.dialects.mssql import base
 from sqlalchemy.dialects.mssql.information_schema import CoerceUnicode
 from sqlalchemy.dialects.mssql.information_schema import tables
+from sqlalchemy.schema import CreateIndex
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import ComparesTables
 from sqlalchemy.testing import eq_
@@ -319,6 +320,36 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL):
 
         eq_(set(list(t2.indexes)[0].columns), set([t2.c["x col"], t2.c.y]))
 
+    @testing.provide_metadata
+    def test_indexes_with_filtered(self, connection):
+        metadata = self.metadata
+
+        t1 = Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("x", types.String(20)),
+            Column("y", types.Integer),
+        )
+        Index("idx_x", t1.c.x, mssql_where=t1.c.x == "test")
+        Index("idx_y", t1.c.y, mssql_where=t1.c.y >= 5)
+        metadata.create_all(connection)
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
+
+        filtered_indexes = []
+        for ix in ind:
+            if "dialect_options" in ix:
+                filtered_indexes.append(ix["dialect_options"]["mssql_where"])
+
+        eq_(sorted(filtered_indexes), ["([x]='test')", "([y]>=(5))"])
+
+        t2 = Table("t", MetaData(), autoload_with=connection)
+        idx = list(sorted(t2.indexes, key=lambda idx: idx.name))[0]
+
+        self.assert_compile(
+            CreateIndex(idx), "CREATE INDEX idx_x ON t (x) WHERE ([x]='test')"
+        )
+
     @testing.provide_metadata
     def test_max_ident_in_varchar_not_present(self):
         """test [ticket:3504].
index 708dbe147a10570bf9c1b5b74d1abebbd3cd7a95..517d570c947b1c9cfcfef7510e0c67c12d20d169 100644 (file)
@@ -452,7 +452,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
+        idx3 = Index(
+            "test_idx2",
+            tbl.c.data,
+            postgresql_where="data > 'a' AND data < 'b''s'",
+        )
+        self.assert_compile(
+            schema.CreateIndex(idx3),
+            "CREATE INDEX test_idx2 ON testtbl (data) "
+            "WHERE data > 'a' AND data < 'b''s'",
+            dialect=postgresql.dialect(),
+        )
+
     def test_create_index_with_ops(self):
+
         m = MetaData()
         tbl = Table(
             "testtbl",
index ec9328c2fb1ea45a9fd2809329cfdbf9dd8bbb5c..2a214c7666ce6e7b51d1038de1ba07a333ffe5f7 100644 (file)
@@ -25,7 +25,9 @@ from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import INTEGER
 from sqlalchemy.dialects.postgresql import INTERVAL
 from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.schema import CreateIndex
 from sqlalchemy.sql.schema import CheckConstraint
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import assert_raises
@@ -419,7 +421,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             base.PGDialect.ischema_names = ischema_names
 
 
-class ReflectionTest(fixtures.TestBase):
+class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
     __only_on__ = "postgresql"
     __backend__ = True
 
@@ -880,7 +882,7 @@ class ReflectionTest(fixtures.TestBase):
 
     @testing.provide_metadata
     def test_index_reflection(self):
-        """ Reflecting partial & expression-based indexes should warn
+        """ Reflecting expression-based indexes should warn
         """
 
         metadata = self.metadata
@@ -926,12 +928,53 @@ class ReflectionTest(fixtures.TestBase):
             [
                 "Skipped unsupported reflection of "
                 "expression-based index idx1",
-                "Predicate of partial index idx2 ignored during " "reflection",
                 "Skipped unsupported reflection of "
                 "expression-based index idx3",
             ],
         )
 
+    @testing.provide_metadata
+    def test_index_reflection_partial(self, connection):
+        """Reflect the filter defintion on partial indexes
+        """
+
+        metadata = self.metadata
+
+        t1 = Table(
+            "table1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(20)),
+            Column("x", Integer),
+        )
+        Index("idx1", t1.c.id, postgresql_where=t1.c.name == "test")
+        Index("idx2", t1.c.id, postgresql_where=t1.c.x >= 5)
+
+        metadata.create_all(connection)
+
+        ind = testing.db.dialect.get_indexes(connection, t1, None)
+
+        partial_definitions = []
+        for ix in ind:
+            if "dialect_options" in ix:
+                partial_definitions.append(
+                    ix["dialect_options"]["postgresql_where"]
+                )
+
+        eq_(
+            sorted(partial_definitions),
+            ["((name)::text = 'test'::text)", "(x >= 5)"],
+        )
+
+        t2 = Table("table1", MetaData(), autoload_with=connection)
+        idx = list(sorted(t2.indexes, key=lambda idx: idx.name))[0]
+
+        self.assert_compile(
+            CreateIndex(idx),
+            "CREATE INDEX idx1 ON table1 (id) "
+            "WHERE ((name)::text = 'test'::text)",
+        )
+
     @testing.fails_if("postgresql < 8.3", "index ordering not supported")
     @testing.provide_metadata
     def test_index_reflection_with_sorting(self):