]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Optimize MySQL foreign key reflection
authorFederico Caselli <cfederico87@gmail.com>
Sat, 12 Oct 2024 12:58:26 +0000 (14:58 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 12 Oct 2024 12:58:26 +0000 (14:58 +0200)
Improved foreign keys reflection logic in MySQL 8+ to use a better
optimized query. The previous query could be quite slow in databases
with a large number of columns.

Fixes: #11975
Change-Id: Ie8bcd810d4b37abf7fd5e497596e0ade52c3f82e

doc/build/changelog/unreleased_20/11975.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
test/dialect/mysql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/11975.rst b/doc/build/changelog/unreleased_20/11975.rst
new file mode 100644 (file)
index 0000000..708a23a
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: mysql, performance
+    :tickets: 11975
+
+    Improved foreign keys reflection logic in MySQL 8+ to use a better
+    optimized query. The previous query could be quite slow in databases
+    with a large number of columns.
index b2b8c6536a73a048ac5b8ee3d35fb726e7087345..c834495759e9c3f5c04f878df243fe8456dfaf78 100644 (file)
@@ -3070,29 +3070,47 @@ class MySQLDialect(default.DefaultDialect):
                 return s
 
         default_schema_name = connection.dialect.default_schema_name
-        col_tuples = [
-            (
-                lower(rec["referred_schema"] or default_schema_name),
-                lower(rec["referred_table"]),
-                col_name,
-            )
-            for rec in fkeys
-            for col_name in rec["referred_columns"]
-        ]
 
-        if col_tuples:
-            correct_for_wrong_fk_case = connection.execute(
-                sql.text(
-                    """
-                    select table_schema, table_name, column_name
-                    from information_schema.columns
-                    where (table_schema, table_name, lower(column_name)) in
-                    :table_data;
-                """
-                ).bindparams(sql.bindparam("table_data", expanding=True)),
-                dict(table_data=col_tuples),
+        # NOTE: using (table_schema, table_name, lower(column_name)) in (...)
+        # is very slow since mysql does not seem able to properly use indexse.
+        # Unpack the where condition instead.
+        schema_by_table_by_column = defaultdict(lambda: defaultdict(list))
+        for rec in fkeys:
+            sch = lower(rec["referred_schema"] or default_schema_name)
+            tbl = lower(rec["referred_table"])
+            for col_name in rec["referred_columns"]:
+                schema_by_table_by_column[sch][tbl].append(col_name)
+
+        if schema_by_table_by_column:
+
+            condition = sql.or_(
+                *(
+                    sql.and_(
+                        _info_columns.c.table_schema == schema,
+                        sql.or_(
+                            *(
+                                sql.and_(
+                                    _info_columns.c.table_name == table,
+                                    sql.func.lower(
+                                        _info_columns.c.column_name
+                                    ).in_(columns),
+                                )
+                                for table, columns in tables.items()
+                            )
+                        ),
+                    )
+                    for schema, tables in schema_by_table_by_column.items()
+                )
             )
 
+            select = sql.select(
+                _info_columns.c.table_schema,
+                _info_columns.c.table_name,
+                _info_columns.c.column_name,
+            ).where(condition)
+
+            correct_for_wrong_fk_case = connection.execute(select)
+
             # in casing=0, table name and schema name come back in their
             # exact case.
             # in casing=1, table name and schema name come back in lower
@@ -3465,3 +3483,12 @@ class _DecodingRow:
             return item.decode(self.charset)
         else:
             return item
+
+
+_info_columns = sql.table(
+    "columns",
+    sql.column("table_schema", VARCHAR(64)),
+    sql.column("table_name", VARCHAR(64)),
+    sql.column("column_name", VARCHAR(64)),
+    schema="information_schema",
+)
index 4fa472ce1ae7cba0faa280ca9c64726cd5cfdfbb..92cf3818e2477f5ecf2347ba74fb09cacfeb2741 100644 (file)
@@ -1197,7 +1197,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect._casing = casing
             dialect.default_schema_name = "Test"
             connection = mock.Mock(
-                dialect=dialect, execute=lambda stmt, params: ischema
+                dialect=dialect, execute=lambda stmt: ischema
             )
             dialect._correct_for_mysql_bugs_88718_96365(fkeys, connection)
             eq_(