]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure _ORMJoin transfers parententity from left side
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Oct 2022 02:59:51 +0000 (22:59 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Fri, 28 Oct 2022 16:09:20 +0000 (16:09 +0000)
Fixed bug involving :class:`.Select` constructs which used a combination of
:meth:`.Select.select_from` with an ORM entity followed by
:meth:`.Select.join` against the entity sent in
:meth:`.Select.select_from`, as well as using plain
:meth:`.Select.join_from`, which when combined with a columns clause that
didn't explicitly include that entity would then cause "automatic WHERE
criteria" features such as the IN expression required for a single-table
inheritance subclass, as well as the criteria set up by the
:func:`_orm.with_loader_criteria` option, to not be rendered for that
entity. The correct entity is now transferred to the :class:`.Join` object
that's generated internally, so that the criteria against the left
side entity is correctly added.

Fixes: #8721
Change-Id: I8266430063e2c72071b7262fdd5ec5079fbcba3e

doc/build/changelog/unreleased_14/8721.rst [new file with mode: 0644]
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/util.py
test/orm/inheritance/test_single.py
test/orm/test_relationship_criteria.py

diff --git a/doc/build/changelog/unreleased_14/8721.rst b/doc/build/changelog/unreleased_14/8721.rst
new file mode 100644 (file)
index 0000000..e6d7f4b
--- /dev/null
@@ -0,0 +1,17 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8721
+
+    Fixed bug involving :class:`.Select` constructs which used a combination of
+    :meth:`.Select.select_from` with an ORM entity followed by
+    :meth:`.Select.join` against the entity sent in
+    :meth:`.Select.select_from`, as well as using plain
+    :meth:`.Select.join_from`, which when combined with a columns clause that
+    didn't explicitly include that entity would then cause "automatic WHERE
+    criteria" features such as the IN expression required for a single-table
+    inheritance subclass, as well as the criteria set up by the
+    :func:`_orm.with_loader_criteria` option, to not be rendered for that
+    entity. The correct entity is now transferred to the :class:`.Join` object
+    that's generated internally, so that the criteria against the left
+    side entity is correctly added.
+
index f0013259339bc529e1336cfd0894d06009d63d44..30119d9d79e4813ed83d90e648cd342b42fea807 100644 (file)
@@ -2282,8 +2282,16 @@ def join(
                 join(User.addresses).\
                 filter(Address.email_address=='foo@bar.com')
 
-    See :ref:`orm_queryguide_joins` for information on modern usage
-    of ORM level joins.
+    .. warning:: using :func:`_orm.join` directly may not work properly
+       with modern ORM options such as :func:`_orm.with_loader_criteria`.
+       It is strongly recommended to use the idiomatic join patterns
+       provided by methods such as :meth:`.Select.join` and
+       :meth:`.Select.join_from` when creating ORM joins.
+
+    .. seealso::
+
+        :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel` for
+        background on idiomatic ORM join patterns
 
     """
     return _ORMJoin(left, right, onclause, isouter, full)
index f8c7ba7143efdb8dd16166e8d911767fd505c9f9..8dca8375690de71ae16b8d9f3120b612e867d117 100644 (file)
@@ -2246,6 +2246,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         for fromclause in self.from_clauses:
             ext_info = fromclause._annotations.get("parententity", None)
+
             if (
                 ext_info
                 and (
index 481a71f8e908e78b0f5bd0846fba5c97b015e5a3..3302feb70a807bfb1744bc784986f85f80306993 100644 (file)
@@ -68,6 +68,7 @@ from ..sql import lambdas
 from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
+from ..sql._typing import is_selectable
 from ..sql.annotation import SupportsCloneAnnotations
 from ..sql.base import ColumnCollection
 from ..sql.cache_key import HasCacheKey
@@ -1704,6 +1705,24 @@ class _ORMJoin(expression.Join):
 
             self._target_adapter = target_adapter
 
+            # we don't use the normal coercions logic for _ORMJoin
+            # (probably should), so do some gymnastics to get the entity.
+            # logic here is for #8721, which was a major bug in 1.4
+            # for almost two years, not reported/fixed until 1.4.43 (!)
+            if is_selectable(left_info):
+                parententity = left_selectable._annotations.get(
+                    "parententity", None
+                )
+            elif insp_is_mapper(left_info) or insp_is_aliased_class(left_info):
+                parententity = left_info
+            else:
+                parententity = None
+
+            if parententity is not None:
+                self._annotations = self._annotations.union(
+                    {"parententity": parententity}
+                )
+
         augment_onclause = onclause is None and _extra_criteria
         expression.Join.__init__(self, left, right, onclause, isouter, full)
 
index 86422000f7f014bb5b083c5e8c0e4941dcd4d01e..2384d7e2da2ddacc02f05c0218b9487829b7da9a 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy import true
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import column_property
+from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
@@ -366,6 +367,124 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
             "WHERE employees_1.type IN (__[POSTCOMPILE_type_1])",
         )
 
+    @testing.combinations(
+        (
+            lambda Engineer, Report: select(Report)
+            .select_from(Engineer)
+            .join(Engineer.reports),
+        ),
+        (
+            lambda Engineer, Report: select(Report).select_from(
+                orm_join(Engineer, Report, Engineer.reports)
+            ),
+        ),
+        (
+            lambda Engineer, Report: select(Report).join_from(
+                Engineer, Report, Engineer.reports
+            ),
+        ),
+        argnames="stmt_fn",
+    )
+    @testing.combinations(True, False, argnames="alias_engineer")
+    def test_select_from_w_join_left(self, stmt_fn, alias_engineer):
+        """test #8721"""
+
+        Engineer = self.classes.Engineer
+        Report = self.classes.Report
+
+        if alias_engineer:
+            Engineer = aliased(Engineer)
+        stmt = testing.resolve_lambda(
+            stmt_fn, Engineer=Engineer, Report=Report
+        )
+
+        if alias_engineer:
+            self.assert_compile(
+                stmt,
+                "SELECT reports.report_id, reports.employee_id, reports.name "
+                "FROM employees AS employees_1 JOIN reports "
+                "ON employees_1.employee_id = reports.employee_id "
+                "WHERE employees_1.type IN (__[POSTCOMPILE_type_1])",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT reports.report_id, reports.employee_id, reports.name "
+                "FROM employees JOIN reports ON employees.employee_id = "
+                "reports.employee_id "
+                "WHERE employees.type IN (__[POSTCOMPILE_type_1])",
+            )
+
+    @testing.combinations(
+        (
+            lambda Engineer, Report: select(
+                Report.report_id, Engineer.employee_id
+            )
+            .select_from(Engineer)
+            .join(Engineer.reports),
+        ),
+        (
+            lambda Engineer, Report: select(
+                Report.report_id, Engineer.employee_id
+            ).select_from(orm_join(Engineer, Report, Engineer.reports)),
+        ),
+        (
+            lambda Engineer, Report: select(
+                Report.report_id, Engineer.employee_id
+            ).join_from(Engineer, Report, Engineer.reports),
+        ),
+    )
+    def test_select_from_w_join_left_including_entity(self, stmt_fn):
+        """test #8721"""
+
+        Engineer = self.classes.Engineer
+        Report = self.classes.Report
+        stmt = testing.resolve_lambda(
+            stmt_fn, Engineer=Engineer, Report=Report
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT reports.report_id, employees.employee_id "
+            "FROM employees JOIN reports ON employees.employee_id = "
+            "reports.employee_id "
+            "WHERE employees.type IN (__[POSTCOMPILE_type_1])",
+        )
+
+    @testing.combinations(
+        (
+            lambda Engineer, Report: select(Report).join(
+                Report.employee.of_type(Engineer)
+            ),
+        ),
+        (
+            lambda Engineer, Report: select(Report).select_from(
+                orm_join(Report, Engineer, Report.employee.of_type(Engineer))
+            )
+        ),
+        (
+            lambda Engineer, Report: select(Report).join_from(
+                Report, Engineer, Report.employee.of_type(Engineer)
+            ),
+        ),
+    )
+    def test_select_from_w_join_right(self, stmt_fn):
+        """test #8721"""
+
+        Engineer = self.classes.Engineer
+        Report = self.classes.Report
+        stmt = testing.resolve_lambda(
+            stmt_fn, Engineer=Engineer, Report=Report
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT reports.report_id, reports.employee_id, reports.name "
+            "FROM reports JOIN employees ON employees.employee_id = "
+            "reports.employee_id AND employees.type "
+            "IN (__[POSTCOMPILE_type_1])",
+        )
+
     def test_from_statement_select(self):
         Engineer = self.classes.Engineer
 
index baa19e31ee0243ba028893c1018395eac4ce9fbc..f50be2e564cf40ae9b79d4a24e9c87f57413e957 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import defer
+from sqlalchemy.orm import join as orm_join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import registry
@@ -264,6 +265,144 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "WHERE users.name != :name_1",
         )
 
+    @testing.combinations(
+        (
+            lambda User, Address: select(Address)
+            .select_from(User)
+            .join(User.addresses)
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        (
+            lambda User, Address: select(Address)
+            .select_from(orm_join(User, Address, User.addresses))
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        (
+            lambda User, Address: select(Address)
+            .join_from(User, Address, User.addresses)
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        argnames="stmt_fn",
+    )
+    @testing.combinations(True, False, argnames="alias_user")
+    def test_criteria_select_from_w_join_left(
+        self, user_address_fixture, stmt_fn, alias_user
+    ):
+        """test #8721"""
+        User, Address = user_address_fixture
+
+        if alias_user:
+            User = aliased(User)
+
+        stmt = testing.resolve_lambda(stmt_fn, User=User, Address=Address)
+
+        if alias_user:
+            self.assert_compile(
+                stmt,
+                "SELECT addresses.id, addresses.user_id, "
+                "addresses.email_address FROM users AS users_1 "
+                "JOIN addresses ON users_1.id = addresses.user_id "
+                "WHERE users_1.name != :name_1",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT addresses.id, addresses.user_id, "
+                "addresses.email_address "
+                "FROM users JOIN addresses ON users.id = addresses.user_id "
+                "WHERE users.name != :name_1",
+            )
+
+    @testing.combinations(
+        (
+            lambda User, Address: select(Address.id, User.id)
+            .select_from(User)
+            .join(User.addresses)
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        (
+            lambda User, Address: select(Address.id, User.id)
+            .select_from(orm_join(User, Address, User.addresses))
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        (
+            lambda User, Address: select(Address.id, User.id)
+            .join_from(User, Address, User.addresses)
+            .options(with_loader_criteria(User, User.name != "name")),
+        ),
+        argnames="stmt_fn",
+    )
+    @testing.combinations(True, False, argnames="alias_user")
+    def test_criteria_select_from_w_join_left_including_entity(
+        self, user_address_fixture, stmt_fn, alias_user
+    ):
+        """test #8721"""
+        User, Address = user_address_fixture
+
+        if alias_user:
+            User = aliased(User)
+
+        stmt = testing.resolve_lambda(stmt_fn, User=User, Address=Address)
+
+        if alias_user:
+            self.assert_compile(
+                stmt,
+                "SELECT addresses.id, users_1.id AS id_1 "
+                "FROM users AS users_1 JOIN addresses "
+                "ON users_1.id = addresses.user_id "
+                "WHERE users_1.name != :name_1",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT addresses.id, users.id AS id_1 "
+                "FROM users JOIN addresses ON users.id = addresses.user_id "
+                "WHERE users.name != :name_1",
+            )
+
+    @testing.combinations(
+        (
+            lambda User, Address: select(Address)
+            .select_from(User)
+            .join(User.addresses)
+            .options(
+                with_loader_criteria(Address, Address.email_address != "email")
+            ),
+        ),
+        (
+            # for orm_join(), this is set up before we have the context
+            # available that allows with_loader_criteria to be set up
+            # correctly
+            lambda User, Address: select(Address)
+            .select_from(orm_join(User, Address, User.addresses))
+            .options(
+                with_loader_criteria(Address, Address.email_address != "email")
+            ),
+            testing.fails("not implemented right now"),
+        ),
+        (
+            lambda User, Address: select(Address)
+            .join_from(User, Address, User.addresses)
+            .options(
+                with_loader_criteria(Address, Address.email_address != "email")
+            ),
+        ),
+        argnames="stmt_fn",
+    )
+    def test_criteria_select_from_w_join_right(
+        self, user_address_fixture, stmt_fn
+    ):
+        """test #8721"""
+        User, Address = user_address_fixture
+
+        stmt = testing.resolve_lambda(stmt_fn, User=User, Address=Address)
+        self.assert_compile(
+            stmt,
+            "SELECT addresses.id, addresses.user_id, addresses.email_address "
+            "FROM users JOIN addresses ON users.id = addresses.user_id "
+            "AND addresses.email_address != :email_address_1",
+        )
+
     @testing.combinations(
         "select",
         "joined",