]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add tests for issue #8168; slight internal adjustments
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Dec 2022 17:02:37 +0000 (12:02 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Dec 2022 22:22:04 +0000 (17:22 -0500)
The issue in #8168 was improved, but not completely fixed,
by #8456.

This includes some small changes to ORM context that
are a prerequisite for getting ORM adaptation to be
better.   Have these in 2.0.0b4 so that we have at
least a better starting point.

References: #8168
Change-Id: I51dbe333b156048836d074fbba1d850f9eb67fd2

lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/test_assorted_poly.py

index 621b3e5d74d8d905504565bba34690f8b0a19970..b5b326bca2c7f2e79bf8fa5d9759e9d7c970f1ce 100644 (file)
@@ -520,13 +520,16 @@ class ORMCompileState(AbstractORMCompileState):
             for mp in ext_info.mapper.iterate_to_root():
                 self._mapper_loads_polymorphically_with(
                     mp,
-                    sql_util.ColumnAdapter(selectable, mp._equivalent_columns),
+                    ORMAdapter(
+                        mp, mp._equivalent_columns, selectable=selectable
+                    ),
                 )
 
     def _mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers or [mapper]:
             self._polymorphic_adapters[m2] = adapter
-            for m in m2.iterate_to_root():  # TODO: redundant ?
+
+            for m in m2.iterate_to_root():
                 self._polymorphic_adapters[m.local_table] = adapter
 
     @classmethod
@@ -1673,13 +1676,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
                 left = onclause._parententity
 
-                alias = self._polymorphic_adapters.get(left, None)
-
-                # could be None or could be ColumnAdapter also
-                if isinstance(alias, ORMAdapter) and alias.mapper.isa(left):
-                    left = alias.aliased_class
-                    onclause = getattr(left, onclause.key)
-
                 prop = onclause.property
                 if not isinstance(onclause, attributes.QueryableAttribute):
                     onclause = prop
index c0fc11d2330f5617ac998dfa441eb3446336c8c1..6ed8e22727298b28804510298dd6d36efa37d291 100644 (file)
@@ -107,6 +107,7 @@ if typing.TYPE_CHECKING:
     from ..sql.selectable import _ColumnsClauseElement
     from ..sql.selectable import Alias
     from ..sql.selectable import Select
+    from ..sql.selectable import Selectable
     from ..sql.selectable import Subquery
     from ..sql.visitors import anon_map
     from ..util.typing import _AnnotationScanType
@@ -456,10 +457,12 @@ class ORMAdapter(sql_util.ColumnAdapter):
         adapt_required: bool = False,
         allow_label_resolve: bool = True,
         anonymize_labels: bool = False,
+        selectable: Optional[Selectable] = None,
     ):
 
         self.mapper = entity.mapper
-        selectable = entity.selectable
+        if selectable is None:
+            selectable = entity.selectable
         if insp_is_aliased_class(entity):
             self.is_aliased_class = True
             self.aliased_insp = entity
@@ -478,6 +481,10 @@ class ORMAdapter(sql_util.ColumnAdapter):
         )
 
     def _include_fn(self, elem):
+        # TODO: we still have cases where we should return False here
+        # yet we are not able to reliably detect without false positives.
+        # see issue #8168
+
         entity = elem._annotations.get("parentmapper", None)
 
         return not entity or entity.isa(self.mapper) or self.mapper.isa(entity)
index 14cbe24562dceeaf12b35a4fe8cbd9f349442ff8..d162428ec35552bd71d18cb43446a4f1e8fe0d7c 100644 (file)
@@ -1139,7 +1139,6 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal):
         if isinstance(col, FromClause) and not isinstance(
             col, functions.FunctionElement
         ):
-
             if self.selectable.is_derived_from(col):
                 if self.adapt_from_selectables:
                     for adp in self.adapt_from_selectables:
index 71592a22c3a3431ca8d53c77cebfe9531f71d209..8ec36d2993f67d0d5dad0f6683cad87aa438d072 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import skip_test
 from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.provision import normalize_sequence
@@ -2399,3 +2400,150 @@ class CorrelateExceptWPolyAdaptTest(
             "LEFT OUTER JOIN s2 ON s1.id = s2.id "
             "JOIN c ON c.id = s1.common_id WHERE c.id = :id_1",
         )
+
+
+class Issue8168Test(AssertsCompiledSQL, fixtures.TestBase):
+    """tests for #8168 which was fixed by #8456"""
+
+    __dialect__ = "default"
+
+    @testing.fixture
+    def mapping(self, decl_base):
+        Base = decl_base
+
+        def go(scenario, use_poly):
+            class Customer(Base):
+                __tablename__ = "customer"
+                id = Column(Integer, primary_key=True)
+                type = Column(String(20))
+
+                __mapper_args__ = {
+                    "polymorphic_on": "type",
+                    "polymorphic_identity": "customer",
+                }
+
+            class Store(Customer):
+                __tablename__ = "store"
+                id = Column(
+                    Integer, ForeignKey("customer.id"), primary_key=True
+                )
+                retailer_id = Column(Integer, ForeignKey("retailer.id"))
+                retailer = relationship(
+                    "Retailer",
+                    back_populates="stores",
+                    foreign_keys=[retailer_id],
+                )
+
+                __mapper_args__ = {
+                    "polymorphic_identity": "store",
+                    "polymorphic_load": "inline" if use_poly else None,
+                }
+
+            class Retailer(Customer):
+                __tablename__ = "retailer"
+                id = Column(
+                    Integer, ForeignKey("customer.id"), primary_key=True
+                )
+                stores = relationship(
+                    "Store",
+                    back_populates="retailer",
+                    foreign_keys=[Store.retailer_id],
+                )
+
+                if scenario.mapped_cls:
+                    store_tgt = corr_except = Store
+
+                elif scenario.table:
+                    corr_except = Store.__table__
+                    store_tgt = Store.__table__.c
+                elif scenario.table_alias:
+                    corr_except = Store.__table__.alias()
+                    store_tgt = corr_except.c
+                else:
+                    scenario.fail()
+
+                store_count = column_property(
+                    select(func.count(store_tgt.id))
+                    .where(store_tgt.retailer_id == id)
+                    .correlate_except(corr_except)
+                    .scalar_subquery()
+                )
+
+                __mapper_args__ = {"polymorphic_identity": "retailer"}
+
+            return Customer, Store, Retailer
+
+        yield go
+
+    @testing.variation("scenario", ["mapped_cls", "table", "table_alias"])
+    @testing.variation("use_poly", [True, False])
+    def test_select_attr_only(self, scenario, use_poly, mapping):
+        Customer, Store, Retailer = mapping(scenario, use_poly)
+
+        if scenario.mapped_cls:
+            self.assert_compile(
+                select(Retailer.store_count).select_from(Retailer),
+                "SELECT (SELECT count(store.id) AS count_1 "
+                "FROM customer JOIN store ON customer.id = store.id "
+                "WHERE store.retailer_id = retailer.id) AS anon_1 "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        elif scenario.table:
+            self.assert_compile(
+                select(Retailer.store_count).select_from(Retailer),
+                "SELECT (SELECT count(store.id) AS count_1 "
+                "FROM store "
+                "WHERE store.retailer_id = retailer.id) AS anon_1 "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        elif scenario.table_alias:
+            self.assert_compile(
+                select(Retailer.store_count).select_from(Retailer),
+                "SELECT (SELECT count(store_1.id) AS count_1 FROM store "
+                "AS store_1 "
+                "WHERE store_1.retailer_id = retailer.id) AS anon_1 "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        else:
+            scenario.fail()
+
+    @testing.variation("scenario", ["mapped_cls", "table", "table_alias"])
+    @testing.variation("use_poly", [True, False])
+    def test_select_cls(self, scenario, mapping, use_poly):
+        Customer, Store, Retailer = mapping(scenario, use_poly)
+
+        if scenario.mapped_cls:
+            # breaks for use_poly, but this is not totally unexpected
+            if use_poly:
+                skip_test("Case not working yet")
+            self.assert_compile(
+                select(Retailer),
+                "SELECT (SELECT count(store.id) AS count_1 FROM customer "
+                "JOIN store ON customer.id = store.id "
+                "WHERE store.retailer_id = retailer.id) AS anon_1, "
+                "retailer.id, customer.id AS id_1, customer.type "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        elif scenario.table:
+            # TODO: breaks for use_poly, and this should not happen.
+            # selecting from the Table should be honoring that
+            if use_poly:
+                skip_test("Case not working yet")
+            self.assert_compile(
+                select(Retailer),
+                "SELECT (SELECT count(store.id) AS count_1 FROM store "
+                "WHERE store.retailer_id = retailer.id) AS anon_1, "
+                "retailer.id, customer.id AS id_1, customer.type "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        elif scenario.table_alias:
+            self.assert_compile(
+                select(Retailer),
+                "SELECT (SELECT count(store_1.id) AS count_1 "
+                "FROM store AS store_1 WHERE store_1.retailer_id = "
+                "retailer.id) AS anon_1, retailer.id, customer.id AS id_1, "
+                "customer.type "
+                "FROM customer JOIN retailer ON customer.id = retailer.id",
+            )
+        else:
+            scenario.fail()