]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Correct for CTE correspondence w/ aliased CTE
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2021 23:38:10 +0000 (19:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2021 23:39:23 +0000 (19:39 -0400)
Fixed regression where the :func:`_orm.joinedload` loader strategy would
not successfully joinedload to a mapper that is mapper against a
:class:`.CTE` construct.

Fixes: #6172
Change-Id: I667e46d00d4209dab5a89171118a00a7c30fb542

doc/build/changelog/unreleased_14/6172.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/orm/test_eager_relations.py
test/orm/test_lazy_relations.py
test/orm/test_selectin_relations.py
test/orm/test_subquery_relations.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/6172.rst b/doc/build/changelog/unreleased_14/6172.rst
new file mode 100644 (file)
index 0000000..ac5dc5c
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, regression, orm
+    :tickets: 6172
+
+    Fixed regression where the :func:`_orm.joinedload` loader strategy would
+    not successfully joinedload to a mapper that is mapper against a
+    :class:`.CTE` construct.
index 7c53f437c6b4b503388c0cda6174f2495e6ea137..a2e5780f8a50804a207542ec8dced55f1ee16784 100644 (file)
@@ -2057,6 +2057,12 @@ class CTE(Generative, HasPrefixes, HasSuffixes, AliasedReturnsRows):
             self._suffixes = _suffixes
         super(CTE, self)._init(selectable, name=name)
 
+    def _populate_column_collection(self):
+        if self._cte_alias is not None:
+            self._cte_alias._generate_fromclause_column_proxies(self)
+        else:
+            self.element._generate_fromclause_column_proxies(self)
+
     def alias(self, name=None, flat=False):
         """Return an :class:`_expression.Alias` of this
         :class:`_expression.CTE`.
index 4cb6932cbc86ed0dce7efe700f1931cdb53827c4..c0213cf2271d320481c0764bc8ae66ae4acf11f6 100644 (file)
@@ -1000,6 +1000,36 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
         self.assert_sql_count(testing.db, go, 1)
 
+    @testing.combinations(
+        ("plain",), ("cte", testing.requires.ctes), ("subquery",), id_="s"
+    )
+    def test_map_to_cte_subq(self, type_):
+        User, Address = self.classes("User", "Address")
+        users, addresses = self.tables("users", "addresses")
+
+        if type_ == "plain":
+            target = users
+        elif type_ == "cte":
+            target = select(users).cte()
+        elif type_ == "subquery":
+            target = select(users).subquery()
+
+        mapper(
+            User,
+            target,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(Address, addresses)
+
+        sess = fixture_session()
+
+        q = (
+            sess.query(Address)
+            .options(joinedload(Address.user))
+            .order_by(Address.id)
+        )
+        eq_(q.all(), self.static.address_user_result)
+
     def test_no_false_hits(self):
         """Eager loaders don't interpret main table columns as
         part of their eager load."""
index 47b7a8ff86bd7bc9d4e9547a5da7e1ed869a59f0..a0b92a28a7624dd8c77c4fe1c008b43b63fdca4a 100644 (file)
@@ -664,6 +664,32 @@ class LazyTest(_fixtures.FixtureTest):
             .all(),
         )
 
+    @testing.combinations(
+        ("plain",), ("cte", testing.requires.ctes), ("subquery",), id_="s"
+    )
+    def test_map_to_cte_subq(self, type_):
+        User, Address = self.classes("User", "Address")
+        users, addresses = self.tables("users", "addresses")
+
+        if type_ == "plain":
+            target = users
+        elif type_ == "cte":
+            target = select(users).cte()
+        elif type_ == "subquery":
+            target = select(users).subquery()
+
+        mapper(
+            User,
+            target,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(Address, addresses)
+
+        sess = fixture_session()
+
+        q = sess.query(Address).order_by(Address.id)
+        eq_(q.all(), self.static.address_user_result)
+
     def test_many_to_many(self):
         keywords, items, item_keywords, Keyword, Item = (
             self.tables.keywords,
index 4895c7d3a0780b4d64f5b26768e92e4348792966..21d5e827d97ab00f5f57de6be84c061f3a0fad6c 100644 (file)
@@ -1080,6 +1080,36 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         else:
             self.assert_sql_count(testing.db, go, 6)
 
+    @testing.combinations(
+        ("plain",), ("cte", testing.requires.ctes), ("subquery",), id_="s"
+    )
+    def test_map_to_cte_subq(self, type_):
+        User, Address = self.classes("User", "Address")
+        users, addresses = self.tables("users", "addresses")
+
+        if type_ == "plain":
+            target = users
+        elif type_ == "cte":
+            target = select(users).cte()
+        elif type_ == "subquery":
+            target = select(users).subquery()
+
+        mapper(
+            User,
+            target,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(Address, addresses)
+
+        sess = fixture_session()
+
+        q = (
+            sess.query(Address)
+            .options(selectinload(Address.user))
+            .order_by(Address.id)
+        )
+        eq_(q.all(), self.static.address_user_result)
+
     def test_limit(self):
         """Limit operations combined with lazy-load relationships."""
 
index 150cee2225392411ecf423a4ef85fde352184709..5b1ac7df0b5cf1fd1a7bfa279b9b9d4acc5d8e49 100644 (file)
@@ -1117,6 +1117,36 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         else:
             self.assert_sql_count(testing.db, go, 6)
 
+    @testing.combinations(
+        ("plain",), ("cte", testing.requires.ctes), ("subquery",), id_="s"
+    )
+    def test_map_to_cte_subq(self, type_):
+        User, Address = self.classes("User", "Address")
+        users, addresses = self.tables("users", "addresses")
+
+        if type_ == "plain":
+            target = users
+        elif type_ == "cte":
+            target = select(users).cte()
+        elif type_ == "subquery":
+            target = select(users).subquery()
+
+        mapper(
+            User,
+            target,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(Address, addresses)
+
+        sess = fixture_session()
+
+        q = (
+            sess.query(Address)
+            .options(subqueryload(Address.user))
+            .order_by(Address.id)
+        )
+        eq_(q.all(), self.static.address_user_result)
+
     def test_limit(self):
         """Limit operations combined with lazy-load relationships."""
 
index 458b8f7822fc64d1687b762caf213b92dba4a153..b98487933cd346bdef1cf10075611a6f0556df29 100644 (file)
@@ -161,6 +161,46 @@ class SelectableTest(
             s1, "SELECT (SELECT table1.col1 FROM table1) AS foo"
         )
 
+    @testing.combinations(("cte",), ("subquery",), argnames="type_")
+    @testing.combinations(
+        ("onelevel",), ("twolevel",), ("middle",), argnames="path"
+    )
+    @testing.combinations((True,), (False,), argnames="require_embedded")
+    def test_subquery_cte_correspondence(self, type_, require_embedded, path):
+        stmt = select(table1)
+
+        if type_ == "cte":
+            cte1 = stmt.cte()
+        elif type_ == "subquery":
+            cte1 = stmt.subquery()
+
+        if path == "onelevel":
+            is_(
+                cte1.corresponding_column(
+                    table1.c.col1, require_embedded=require_embedded
+                ),
+                cte1.c.col1,
+            )
+        elif path == "twolevel":
+            cte2 = cte1.alias()
+
+            is_(
+                cte2.corresponding_column(
+                    table1.c.col1, require_embedded=require_embedded
+                ),
+                cte2.c.col1,
+            )
+
+        elif path == "middle":
+            cte2 = cte1.alias()
+
+            is_(
+                cte2.corresponding_column(
+                    cte1.c.col1, require_embedded=require_embedded
+                ),
+                cte2.c.col1,
+            )
+
     def test_labels_anon_w_separate_key(self):
         label = select(table1.c.col1).label(None)
         label.key = "bar"