]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add partial index predicate to SQLiteDialect.get_indexes() result
authorTobias Pfeiffer <tgp@preferred.jp>
Mon, 28 Nov 2022 12:52:31 +0000 (07:52 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2022 01:53:20 +0000 (20:53 -0500)
Added support for reflection of expression-oriented WHERE criteria included
in indexes on the SQLite dialect, in a manner similar to that of the
PostgreSQL dialect. Pull request courtesy Tobias Pfeiffer.

Fixes: #8804
Closes: #8806
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8806
Pull-request-sha: 539dfcb372360911b69aed2a804698bb1a2220b1

Change-Id: I0e34d47dbe2b9c1da6fce531363084843e5127a3
(cherry picked from commit ed39e846cd8ae2714c47fc3d563582f72483df0c)

doc/build/changelog/unreleased_14/8804.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/8804.rst b/doc/build/changelog/unreleased_14/8804.rst
new file mode 100644 (file)
index 0000000..c3f91a1
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, sqlite
+    :tickets: 8804
+
+    Added support for reflection of expression-oriented WHERE criteria included
+    in indexes on the SQLite dialect, in a manner similar to that of the
+    PostgreSQL dialect. Pull request courtesy Tobias Pfeiffer.
index 612d8f9063257722f622d692f772499eb7fb6af6..24166717a410a9e3b7c18da858fd2051d1eced79 100644 (file)
@@ -821,6 +821,7 @@ from ... import exc
 from ... import processors
 from ... import schema as sa_schema
 from ... import sql
+from ... import text
 from ... import types as sqltypes
 from ... import util
 from ...engine import default
@@ -2474,6 +2475,21 @@ class SQLiteDialect(default.DefaultDialect):
         )
         indexes = []
 
+        # regular expression to extract the filter predicate of a partial
+        # index. this could fail to extract the predicate correctly on
+        # indexes created like
+        #   CREATE INDEX i ON t (col || ') where') WHERE col <> ''
+        # but as this function does not support expression-based indexes
+        # this case does not occur.
+        partial_pred_re = re.compile(r"\)\s+where\s+(.+)", re.IGNORECASE)
+
+        if schema:
+            schema_expr = "%s." % self.identifier_preparer.quote_identifier(
+                schema
+            )
+        else:
+            schema_expr = ""
+
         include_auto_indexes = kw.pop("include_auto_indexes", False)
         for row in pragma_indexes:
             # ignore implicit primary key index.
@@ -2482,7 +2498,38 @@ class SQLiteDialect(default.DefaultDialect):
                 "sqlite_autoindex"
             ):
                 continue
-            indexes.append(dict(name=row[1], column_names=[], unique=row[2]))
+            indexes.append(
+                dict(
+                    name=row[1],
+                    column_names=[],
+                    unique=row[2],
+                    dialect_options={},
+                )
+            )
+
+            # check partial indexes
+            if row[4]:
+                s = (
+                    "SELECT sql FROM %(schema)ssqlite_master "
+                    "WHERE name = ? "
+                    "AND type = 'index'" % {"schema": schema_expr}
+                )
+                rs = connection.exec_driver_sql(s, (row[1],))
+                index_sql = rs.scalar()
+                predicate_match = partial_pred_re.search(index_sql)
+                if predicate_match is None:
+                    # unless the regex is broken this case shouldn't happen
+                    # because we know this is a partial index, so the
+                    # definition sql should match the regex
+                    util.warn(
+                        "Failed to look up filter predicate of "
+                        "partial index %s" % row[1]
+                    )
+                else:
+                    predicate = predicate_match.group(1)
+                    indexes[-1]["dialect_options"]["sqlite_where"] = text(
+                        predicate
+                    )
 
         # loop thru unique indexes to get the column names.
         for idx in list(indexes):
@@ -2500,6 +2547,8 @@ class SQLiteDialect(default.DefaultDialect):
                     break
                 else:
                     idx["column_names"].append(row[2])
+
+        indexes.sort(key=lambda d: d["name"] or "~")  # sort None as last
         return indexes
 
     @reflection.cache
index 12949fe02bb16414216df341696ac398025d0d41..4e575046d37031d569b7dd87121d66a3c9fe2ecd 100644 (file)
@@ -1300,8 +1300,14 @@ class ComponentReflectionTestExtra(fixtures.TestBase):
         insp = inspect(connection)
 
         expected = [
-            {"name": "t_idx_2", "column_names": ["x"], "unique": False}
+            {
+                "name": "t_idx_2",
+                "column_names": ["x"],
+                "unique": False,
+                "dialect_options": {},
+            }
         ]
+
         if testing.requires.index_reflects_included_columns.enabled:
             expected[0]["include_columns"] = []
             expected[0]["dialect_options"] = {
@@ -1311,10 +1317,7 @@ class ComponentReflectionTestExtra(fixtures.TestBase):
         with expect_warnings(
             "Skipped unsupported reflection of expression-based index t_idx"
         ):
-            eq_(
-                insp.get_indexes("t"),
-                expected,
-            )
+            eq_(insp.get_indexes("t"), expected)
 
     @testing.requires.index_reflects_included_columns
     def test_reflect_covering_index(self, metadata, connection):
index 1f7a06dffb5c7995f66ee6667920dceccedec4df..3da4d6574b675161b97e95b77c383e47da71ac56 100644 (file)
@@ -2312,6 +2312,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
                     "unique": 1,
                     "name": "sqlite_autoindex_o_1",
                     "column_names": ["foo"],
+                    "dialect_options": {},
                 }
             ],
         )
@@ -2327,10 +2328,60 @@ class ConstraintReflectionTest(fixtures.TestBase):
                     "unique": 0,
                     "name": u"ix_main_l_bar",
                     "column_names": [u"bar"],
+                    "dialect_options": {},
                 }
             ],
         )
 
+    def test_reflect_partial_indexes(self, connection):
+        connection.exec_driver_sql(
+            "create table foo_with_partial_index (x integer, y integer)"
+        )
+        connection.exec_driver_sql(
+            "create unique index ix_partial on "
+            "foo_with_partial_index (x) where y > 10"
+        )
+        connection.exec_driver_sql(
+            "create unique index ix_no_partial on "
+            "foo_with_partial_index (x)"
+        )
+        connection.exec_driver_sql(
+            "create unique index ix_partial2 on "
+            "foo_with_partial_index (x, y) where "
+            "y = 10 or abs(x) < 5"
+        )
+
+        inspector = inspect(connection)
+        indexes = inspector.get_indexes("foo_with_partial_index")
+        eq_(
+            indexes,
+            [
+                {
+                    "unique": 1,
+                    "name": "ix_no_partial",
+                    "column_names": ["x"],
+                    "dialect_options": {},
+                },
+                {
+                    "unique": 1,
+                    "name": "ix_partial",
+                    "column_names": ["x"],
+                    "dialect_options": {"sqlite_where": mock.ANY},
+                },
+                {
+                    "unique": 1,
+                    "name": "ix_partial2",
+                    "column_names": ["x", "y"],
+                    "dialect_options": {"sqlite_where": mock.ANY},
+                },
+            ],
+        )
+        eq_(indexes[1]["dialect_options"]["sqlite_where"].text, "y > 10")
+        eq_(
+            indexes[2]["dialect_options"]["sqlite_where"].text,
+            "y = 10 or abs(x) < 5",
+        )
+
     def test_unique_constraint_named(self):
         inspector = inspect(testing.db)
         eq_(