]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support adapt_on_names for with_polymorphic
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Mar 2022 14:34:09 +0000 (09:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Mar 2022 18:25:46 +0000 (13:25 -0500)
Added :paramref:`_orm.with_polymorphic.adapt_on_names` to the
:func:`_orm.with_polymorphic` function, which allows a polymorphic load
(typically with concrete mapping) to be stated against an alternative
selectable that will adapt to the original mapped selectable on column
names alone.

Fixes: #7805
Change-Id: I933e180a489fec8a6f4916d1622d444dd4434f30

doc/build/changelog/unreleased_14/7805.rst [new file with mode: 0644]
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/util.py
test/orm/inheritance/test_concrete.py

diff --git a/doc/build/changelog/unreleased_14/7805.rst b/doc/build/changelog/unreleased_14/7805.rst
new file mode 100644 (file)
index 0000000..2d29402
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 7805
+
+    Added :paramref:`_orm.with_polymorphic.adapt_on_names` to the
+    :func:`_orm.with_polymorphic` function, which allows a polymorphic load
+    (typically with concrete mapping) to be stated against an alternative
+    selectable that will adapt to the original mapped selectable on column
+    names alone.
index 8e05c6ef2898493ad55b6dc57c60eb2a04789c3a..8d5fb91d08f71a7fcfa1a26a0967cd2623f890f8 100644 (file)
@@ -2110,6 +2110,7 @@ def with_polymorphic(
     flat=False,
     polymorphic_on=None,
     aliased=False,
+    adapt_on_names=False,
     innerjoin=False,
     _use_mapper_path=False,
 ):
@@ -2173,6 +2174,15 @@ def with_polymorphic(
 
     :param innerjoin: if True, an INNER JOIN will be used.  This should
        only be specified if querying for one specific subtype only
+
+    :param adapt_on_names: Passes through the
+      :paramref:`_orm.aliased.adapt_on_names`
+      parameter to the aliased object.  This may be useful in situations where
+      the given selectable is not directly related to the existing mapped
+      selectable.
+
+      .. versionadded:: 1.4.33
+
     """
     return AliasedInsp._with_polymorphic_factory(
         base,
@@ -2180,6 +2190,7 @@ def with_polymorphic(
         selectable=selectable,
         flat=flat,
         polymorphic_on=polymorphic_on,
+        adapt_on_names=adapt_on_names,
         aliased=aliased,
         innerjoin=innerjoin,
         _use_mapper_path=_use_mapper_path,
index e00e05954653b17509cb6a07312de538efa6fb78..d4faf10e33b2db825d0c635384b7e536b759e426 100644 (file)
@@ -757,7 +757,9 @@ class AliasedInsp(
             # are not even the thing we are mapping, such as embedded
             # selectables in subqueries or CTEs.  See issue #6060
             adapt_from_selectables=[
-                m.selectable for m in self.with_polymorphic_mappers
+                m.selectable
+                for m in self.with_polymorphic_mappers
+                if not adapt_on_names
             ],
         )
 
@@ -810,6 +812,7 @@ class AliasedInsp(
         polymorphic_on=None,
         aliased=False,
         innerjoin=False,
+        adapt_on_names=False,
         _use_mapper_path=False,
     ):
 
@@ -830,6 +833,7 @@ class AliasedInsp(
             base,
             selectable,
             with_polymorphic_mappers=mappers,
+            adapt_on_names=adapt_on_names,
             with_polymorphic_discriminator=polymorphic_on,
             use_mapper_path=_use_mapper_path,
             represents_outer_join=not innerjoin,
index c5031ed59644a1a34f377fdaa9e164bf06d150c4..ab6d79383c8b0a348ea5c5bac7d300319c08bfdc 100644 (file)
@@ -5,24 +5,34 @@ from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import union
 from sqlalchemy import union_all
+from sqlalchemy.ext.declarative import AbstractConcreteBase
 from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import class_mapper
 from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import composite
 from sqlalchemy.orm import configure_mappers
+from sqlalchemy.orm import contains_eager
+from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
+from test.orm.test_events import _RemoveListeners
 
 
 class ConcreteTest(fixtures.MappedTest):
@@ -1434,3 +1444,235 @@ class ColKeysTest(fixtures.MappedTest):
         eq_(sess.get(Refugee, 2).name, "refugee2")
         eq_(sess.get(Office, 1).name, "office1")
         eq_(sess.get(Office, 2).name, "office2")
+
+
+class AdaptOnNamesTest(_RemoveListeners, fixtures.DeclarativeMappedTest):
+    """test the full integration case for #7805"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+        Basic = cls.Basic
+
+        class Metadata(ComparableEntity, Base):
+            __tablename__ = "metadata"
+            id = Column(
+                Integer,
+                primary_key=True,
+            )
+
+            some_data = Column(String(50))
+
+        class BaseObj(ComparableEntity, AbstractConcreteBase, Base):
+            """abstract concrete base with a custom polymorphic_union.
+
+            Additionally, at query time it needs to use a new version of this
+            union each time in order to add filter criteria.  this is because
+            polymorphic_union() is of course very inefficient in its form
+            and if someone actually has to use this, it's likely better for
+            filter criteria to be within each sub-select.   The current use
+            case here does not really have easy answers as we don't have
+            a built-in widget that does this.  The complexity / little use
+            ratio doesn't justify it unfortunately.
+
+            This use case might be easier if we were mapped to something that
+            can be adapted. however, we are using adapt_on_names here as this
+            is usually what's more accessible to someone trying to get into
+            this, or at least we should make that feature work as well as it
+            can.
+
+            """
+
+            @declared_attr
+            def id(cls):
+                return Column(Integer, primary_key=True)
+
+            @declared_attr
+            def metadata_id(cls):
+                return Column(ForeignKey(Metadata.id), nullable=False)
+
+            @classmethod
+            def _create_polymorphic_union(cls, mappers, discriminator_name):
+                return cls.make_statement().subquery()
+
+            @declared_attr
+            def related_metadata(cls):
+                return relationship(Metadata)
+
+            @classmethod
+            def make_statement(cls, *filter_cond, include_metadata=False):
+
+                a_stmt = (
+                    select(
+                        A.id,
+                        A.metadata_id,
+                        A.thing1,
+                        A.x1,
+                        A.y1,
+                        null().label("thing2"),
+                        null().label("x2"),
+                        null().label("y2"),
+                        literal("a").label("type"),
+                    )
+                    .join(Metadata)
+                    .filter(*filter_cond)
+                )
+                if include_metadata:
+                    a_stmt = a_stmt.add_columns(Metadata.__table__)
+
+                b_stmt = (
+                    select(
+                        B.id,
+                        B.metadata_id,
+                        null().label("thing1"),
+                        null().label("x1"),
+                        null().label("y1"),
+                        B.thing2,
+                        B.x2,
+                        B.y2,
+                        literal("b").label("type"),
+                    )
+                    .join(Metadata)
+                    .filter(*filter_cond)
+                )
+                if include_metadata:
+                    b_stmt = b_stmt.add_columns(Metadata.__table__)
+
+                return union(a_stmt, b_stmt)
+
+        class XYThing(Basic):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+            def __composite_values__(self):
+                return (self.x, self.y)
+
+            def __eq__(self, other):
+                return (
+                    isinstance(other, XYThing)
+                    and other.x == self.x
+                    and other.y == self.y
+                )
+
+            def __ne__(self, other):
+                return not self.__eq__(other)
+
+        class A(BaseObj):
+            __tablename__ = "a"
+            thing1 = Column(String(50))
+            comp1 = composite(
+                XYThing, Column("x1", Integer), Column("y1", Integer)
+            )
+
+            __mapper_args__ = {"polymorphic_identity": "a", "concrete": True}
+
+        class B(BaseObj):
+            __tablename__ = "b"
+            thing2 = Column(String(50))
+            comp2 = composite(
+                XYThing, Column("x2", Integer), Column("y2", Integer)
+            )
+
+            __mapper_args__ = {"polymorphic_identity": "b", "concrete": True}
+
+    @classmethod
+    def insert_data(cls, connection):
+        Metadata, A, B = cls.classes("Metadata", "A", "B")
+        XYThing = cls.classes.XYThing
+
+        with Session(connection) as sess:
+            sess.add_all(
+                [
+                    Metadata(id=1, some_data="m1"),
+                    Metadata(id=2, some_data="m2"),
+                ]
+            )
+            sess.flush()
+
+            sess.add_all(
+                [
+                    A(
+                        id=5,
+                        metadata_id=1,
+                        thing1="thing1",
+                        comp1=XYThing(1, 2),
+                    ),
+                    B(
+                        id=6,
+                        metadata_id=2,
+                        thing2="thing2",
+                        comp2=XYThing(3, 4),
+                    ),
+                ]
+            )
+            sess.commit()
+
+    def test_contains_eager(self):
+        Metadata, A, B = self.classes("Metadata", "A", "B")
+        BaseObj = self.classes.BaseObj
+        XYThing = self.classes.XYThing
+
+        alias = BaseObj.make_statement(
+            Metadata.id < 3, include_metadata=True
+        ).subquery()
+        ac = with_polymorphic(
+            BaseObj,
+            [A, B],
+            selectable=alias,
+            adapt_on_names=True,
+        )
+
+        mt = aliased(Metadata, alias=alias)
+
+        sess = fixture_session()
+
+        with self.sql_execution_asserter() as asserter:
+            objects = sess.scalars(
+                select(ac)
+                .options(
+                    contains_eager(ac.A.related_metadata.of_type(mt)),
+                    contains_eager(ac.B.related_metadata.of_type(mt)),
+                )
+                .order_by(ac.id)
+            ).all()
+
+            eq_(
+                objects,
+                [
+                    A(
+                        id=5,
+                        metadata_id=1,
+                        thing1="thing1",
+                        comp1=XYThing(1, 2),
+                        related_metadata=Metadata(id=1, some_data="m1"),
+                    ),
+                    B(
+                        id=6,
+                        metadata_id=2,
+                        thing2="thing2",
+                        comp2=XYThing(3, 4),
+                        related_metadata=Metadata(id=2, some_data="m2"),
+                    ),
+                ],
+            )
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT anon_1.id, anon_1.metadata_id, anon_1.thing1, "
+                "anon_1.x1, anon_1.y1, anon_1.thing2, anon_1.x2, anon_1.y2, "
+                "anon_1.type, anon_1.id_1, anon_1.some_data FROM "
+                "(SELECT a.id AS id, a.metadata_id AS metadata_id, "
+                "a.thing1 AS thing1, a.x1 AS x1, a.y1 AS y1, "
+                "NULL AS thing2, NULL AS x2, NULL AS y2, :param_1 AS type, "
+                "metadata.id AS id_1, metadata.some_data AS some_data "
+                "FROM a JOIN metadata ON metadata.id = a.metadata_id "
+                "WHERE metadata.id < :id_2 UNION SELECT b.id AS id, "
+                "b.metadata_id AS metadata_id, NULL AS thing1, NULL AS x1, "
+                "NULL AS y1, b.thing2 AS thing2, b.x2 AS x2, b.y2 AS y2, "
+                ":param_2 AS type, metadata.id AS id_1, "
+                "metadata.some_data AS some_data FROM b "
+                "JOIN metadata ON metadata.id = b.metadata_id "
+                "WHERE metadata.id < :id_3) AS anon_1 ORDER BY anon_1.id",
+                [{"param_1": "a", "id_2": 3, "param_2": "b", "id_3": 3}],
+            )
+        )