]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Extract table names when comparing to nrte error
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Feb 2021 15:43:16 +0000 (10:43 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Feb 2021 15:58:06 +0000 (10:58 -0500)
Fixed issue where the process of joining two tables could fail if one of
the tables had an unrelated, unresolvable foreign key constraint which
would raise :class:`_exc.NoReferenceError` within the join process, which
nonetheless could be bypassed to allow the join to complete. The logic
which tested the exception for signficance within the process would make
assumptions about the construct which would fail.

Fixes: #5952
Change-Id: I492dacd082ddcf8abb1310ed447a6ed734595bb7

doc/build/changelog/unreleased_13/5952.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_13/5952.rst b/doc/build/changelog/unreleased_13/5952.rst
new file mode 100644 (file)
index 0000000..7166e92
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 5952
+
+    Fixed issue where the process of joining two tables could fail if one of
+    the tables had an unrelated, unresolvable foreign key constraint which
+    would raise :class:`_exc.NoReferenceError` within the join process, which
+    nonetheless could be bypassed to allow the join to complete. The logic
+    which tested the exception for signficance within the process would make
+    assumptions about the construct which would fail.
+
index ee2c4dafc59bd40bb8bf46860dddbaeb6b753a00..7c53f437c6b4b503388c0cda6174f2495e6ea137 100644 (file)
@@ -1229,9 +1229,12 @@ class Join(roles.DMLTableRole, FromClause):
         return bool(constraints)
 
     @classmethod
+    @util.preload_module("sqlalchemy.sql.util")
     def _joincond_scan_left_right(
         cls, a, a_subset, b, consider_as_foreign_keys
     ):
+        sql_util = util.preloaded.sql_util
+
         a = coercions.expect(roles.FromClauseRole, a)
         b = coercions.expect(roles.FromClauseRole, b)
 
@@ -1251,7 +1254,8 @@ class Join(roles.DMLTableRole, FromClause):
                 try:
                     col = fk.get_referent(left)
                 except exc.NoReferenceError as nrte:
-                    if nrte.table_name == left.name:
+                    table_names = {t.name for t in sql_util.find_tables(left)}
+                    if nrte.table_name in table_names:
                         raise
                     else:
                         continue
@@ -1270,7 +1274,8 @@ class Join(roles.DMLTableRole, FromClause):
                     try:
                         col = fk.get_referent(b)
                     except exc.NoReferenceError as nrte:
-                        if nrte.table_name == b.name:
+                        table_names = {t.name for t in sql_util.find_tables(b)}
+                        if nrte.table_name in table_names:
                             raise
                         else:
                             continue
index 9f0c72247c0b52b893c4efbb9e6962a3c1a3cc0b..762146a3d30adacff2e23b4f49ba056cb7267647 100644 (file)
@@ -1834,6 +1834,55 @@ class JoinConditionTest(fixtures.TestBase, AssertsCompiledSQL):
         assert sql_util.join_condition(t1, t2).compare(t1.c.x == t2.c.id)
         assert sql_util.join_condition(t2, t1).compare(t1.c.x == t2.c.id)
 
+    def test_join_cond_no_such_unrelated_table_dont_compare_names(self):
+        m = MetaData()
+        t1 = Table(
+            "t1",
+            m,
+            Column("y", Integer, ForeignKey("t22.id")),
+            Column("x", Integer, ForeignKey("t2.id")),
+            Column("q", Integer, ForeignKey("t22.id")),
+        )
+        t2 = Table(
+            "t2",
+            m,
+            Column("id", Integer),
+            Column("t3id", ForeignKey("t3.id")),
+            Column("z", ForeignKey("t33.id")),
+        )
+        t3 = Table(
+            "t3", m, Column("id", Integer), Column("q", ForeignKey("t4.id"))
+        )
+
+        j1 = t1.join(t2)
+
+        assert sql_util.join_condition(j1, t3).compare(t2.c.t3id == t3.c.id)
+
+    def test_join_cond_no_such_unrelated_column_dont_compare_names(self):
+        m = MetaData()
+        t1 = Table(
+            "t1",
+            m,
+            Column("x", Integer, ForeignKey("t2.id")),
+        )
+        t2 = Table(
+            "t2",
+            m,
+            Column("id", Integer),
+            Column("t3id", ForeignKey("t3.id")),
+            Column("q", ForeignKey("t5.q")),
+        )
+        t3 = Table(
+            "t3", m, Column("id", Integer), Column("t4id", ForeignKey("t4.id"))
+        )
+        t4 = Table("t4", m, Column("id", Integer))
+        Table("t5", m, Column("id", Integer))
+        j1 = t1.join(t2)
+
+        j2 = t3.join(t4)
+
+        assert sql_util.join_condition(j1, j2).compare(t2.c.t3id == t3.c.id)
+
     def test_join_cond_no_such_related_table(self):
         m1 = MetaData()
         m2 = MetaData()