]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support selectin_polymorphic w/ no fixed polymorphic_on
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Mar 2022 20:11:29 +0000 (15:11 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Mar 2022 14:39:38 +0000 (09:39 -0500)
Fixed issue where the :func:`_orm.polymorphic_selectin` loader option would
not work with joined inheritance mappers that don't have a fixed
"polymorphic_on" column.   Additionally added test support for a wider
variety of usage patterns with this construct.

Fixed bug where :func:`_orm.composite` attributes would not work in
conjunction with the :func:`_orm.selectin_polymorphic` loader strategy for
joined table inheritance.

Fixes: #7799
Fixes: #7801
Change-Id: I7cfe32dfe844b188403b39545930c0aee71d0119

doc/build/changelog/unreleased_14/7799.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/7801.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
test/orm/inheritance/test_poly_loading.py

diff --git a/doc/build/changelog/unreleased_14/7799.rst b/doc/build/changelog/unreleased_14/7799.rst
new file mode 100644 (file)
index 0000000..0025473
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7799
+
+    Fixed issue where the :func:`_orm.polymorphic_selectin` loader option would
+    not work with joined inheritance mappers that don't have a fixed
+    "polymorphic_on" column.   Additionally added test support for a wider
+    variety of usage patterns with this construct.
\ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/7801.rst b/doc/build/changelog/unreleased_14/7801.rst
new file mode 100644 (file)
index 0000000..4df3bdf
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7801
+
+    Fixed bug where :func:`_orm.composite` attributes would not work in
+    conjunction with the :func:`_orm.selectin_polymorphic` loader strategy for
+    joined table inheritance.
+
index 5a34188a9cf28441ccd26fe40cb8397ba3ffaa90..011e7d2efc2b641f56e49081d13e872df9c965e4 100644 (file)
@@ -43,6 +43,7 @@ from .interfaces import InspectionAttr
 from .interfaces import MapperProperty
 from .interfaces import ORMEntityColumnsClauseRole
 from .interfaces import ORMFromClauseRole
+from .interfaces import StrategizedProperty
 from .path_registry import PathRegistry
 from .. import event
 from .. import exc as sa_exc
@@ -3077,8 +3078,11 @@ class Mapper(
 
         assert self.inherits
 
-        polymorphic_prop = self._columntoproperty[self.polymorphic_on]
-        keep_props = set([polymorphic_prop] + self._identity_key_props)
+        if self.polymorphic_on is not None:
+            polymorphic_prop = self._columntoproperty[self.polymorphic_on]
+            keep_props = set([polymorphic_prop] + self._identity_key_props)
+        else:
+            keep_props = set(self._identity_key_props)
 
         disable_opt = strategy_options.Load(entity)
         enable_opt = strategy_options.Load(entity)
@@ -3087,6 +3091,9 @@ class Mapper(
             if prop.parent is self or prop in keep_props:
                 # "enable" options, to turn on the properties that we want to
                 # load by default (subject to options from the query)
+                if not isinstance(prop, StrategizedProperty):
+                    continue
+
                 enable_opt = enable_opt._set_generic_strategy(
                     # convert string name to an attribute before passing
                     # to loader strategy
index d9d4a9a2214e8c9f0a03b89e10afaea0b2a76435..f03f15bd25110b129ddcc329e23c163802c32a7b 100644 (file)
@@ -2,10 +2,13 @@ from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import literal
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import union
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import composite
 from sqlalchemy.orm import defaultload
 from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import joinedload
@@ -26,6 +29,7 @@ from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import EachOf
 from sqlalchemy.testing.assertsql import Or
+from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from ._poly_fixtures import _Polymorphic
@@ -424,7 +428,10 @@ class TestGeometries(GeometryFixtureBase):
         with self.assert_statement_count(testing.db, 0):
             eq_(result, [d(d_data="d1"), e(e_data="e1")])
 
-    def test_threelevel_selectin_to_inline_awkward_alias_options(self):
+    @testing.combinations((True,), (False,))
+    def test_threelevel_selectin_to_inline_awkward_alias_options(
+        self, use_aliased_class
+    ):
         self._fixture_from_geometry(
             {
                 "a": {
@@ -455,57 +462,96 @@ class TestGeometries(GeometryFixtureBase):
         )
 
         c_alias = with_polymorphic(c, (d, e), poly)
-        q = (
-            sess.query(a)
-            .options(selectin_polymorphic(a, [b, c_alias]))
-            .order_by(a.id)
-        )
 
-        result = self.assert_sql_execution(
-            testing.db,
-            q.all,
-            CompiledSQL(
-                "SELECT a.id AS a_id, a.type AS a_type, "
-                "a.a_data AS a_a_data FROM a ORDER BY a.id",
-                {},
-            ),
-            Or(
-                # here, the test is that the adaptation of "a" takes place
+        if use_aliased_class:
+            opt = selectin_polymorphic(a, [b, c_alias])
+        else:
+            opt = selectin_polymorphic(
+                a,
+                [b, c_alias, d, e],
+            )
+        q = sess.query(a).options(opt).order_by(a.id)
+
+        if use_aliased_class:
+            result = self.assert_sql_execution(
+                testing.db,
+                q.all,
                 CompiledSQL(
-                    "SELECT poly.c_id AS poly_c_id, "
-                    "poly.a_type AS poly_a_type, "
-                    "poly.a_id AS poly_a_id, poly.c_c_data AS poly_c_c_data, "
-                    "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data, "
-                    "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data "
-                    "FROM (SELECT a.id AS a_id, a.type AS a_type, "
-                    "c.id AS c_id, "
-                    "c.c_data AS c_c_data, d.id AS d_id, "
-                    "d.d_data AS d_d_data, "
-                    "e.id AS e_id, e.e_data AS e_e_data FROM a JOIN c "
-                    "ON a.id = c.id LEFT OUTER JOIN d ON c.id = d.id "
-                    "LEFT OUTER JOIN e ON c.id = e.id) AS poly "
-                    "WHERE poly.a_id IN (__[POSTCOMPILE_primary_keys]) "
-                    "ORDER BY poly.a_id",
-                    [{"primary_keys": [1, 2]}],
+                    "SELECT a.id AS a_id, a.type AS a_type, "
+                    "a.a_data AS a_a_data FROM a ORDER BY a.id",
+                    {},
+                ),
+                Or(
+                    # here, the test is that the adaptation of "a" takes place
+                    CompiledSQL(
+                        "SELECT poly.c_id AS poly_c_id, "
+                        "poly.a_type AS poly_a_type, "
+                        "poly.a_id AS poly_a_id, poly.c_c_data "
+                        "AS poly_c_c_data, "
+                        "poly.e_id AS poly_e_id, poly.e_e_data "
+                        "AS poly_e_e_data, "
+                        "poly.d_id AS poly_d_id, poly.d_d_data "
+                        "AS poly_d_d_data "
+                        "FROM (SELECT a.id AS a_id, a.type AS a_type, "
+                        "c.id AS c_id, "
+                        "c.c_data AS c_c_data, d.id AS d_id, "
+                        "d.d_data AS d_d_data, "
+                        "e.id AS e_id, e.e_data AS e_e_data FROM a JOIN c "
+                        "ON a.id = c.id LEFT OUTER JOIN d ON c.id = d.id "
+                        "LEFT OUTER JOIN e ON c.id = e.id) AS poly "
+                        "WHERE poly.a_id IN (__[POSTCOMPILE_primary_keys]) "
+                        "ORDER BY poly.a_id",
+                        [{"primary_keys": [1, 2]}],
+                    ),
+                    CompiledSQL(
+                        "SELECT poly.c_id AS poly_c_id, "
+                        "poly.a_id AS poly_a_id, poly.a_type AS poly_a_type, "
+                        "poly.c_c_data AS poly_c_c_data, "
+                        "poly.d_id AS poly_d_id, poly.d_d_data "
+                        "AS poly_d_d_data, "
+                        "poly.e_id AS poly_e_id, poly.e_e_data "
+                        "AS poly_e_e_data "
+                        "FROM (SELECT a.id AS a_id, a.type AS a_type, "
+                        "c.id AS c_id, c.c_data AS c_c_data, d.id AS d_id, "
+                        "d.d_data AS d_d_data, e.id AS e_id, "
+                        "e.e_data AS e_e_data FROM a JOIN c ON a.id = c.id "
+                        "LEFT OUTER JOIN d ON c.id = d.id "
+                        "LEFT OUTER JOIN e ON c.id = e.id) AS poly "
+                        "WHERE poly.a_id IN (__[POSTCOMPILE_primary_keys]) "
+                        "ORDER BY poly.a_id",
+                        [{"primary_keys": [1, 2]}],
+                    ),
                 ),
+            )
+        else:
+            result = self.assert_sql_execution(
+                testing.db,
+                q.all,
                 CompiledSQL(
-                    "SELECT poly.c_id AS poly_c_id, "
-                    "poly.a_id AS poly_a_id, poly.a_type AS poly_a_type, "
-                    "poly.c_c_data AS poly_c_c_data, "
-                    "poly.d_id AS poly_d_id, poly.d_d_data AS poly_d_d_data, "
-                    "poly.e_id AS poly_e_id, poly.e_e_data AS poly_e_e_data "
-                    "FROM (SELECT a.id AS a_id, a.type AS a_type, "
-                    "c.id AS c_id, c.c_data AS c_c_data, d.id AS d_id, "
-                    "d.d_data AS d_d_data, e.id AS e_id, "
-                    "e.e_data AS e_e_data FROM a JOIN c ON a.id = c.id "
-                    "LEFT OUTER JOIN d ON c.id = d.id "
-                    "LEFT OUTER JOIN e ON c.id = e.id) AS poly "
-                    "WHERE poly.a_id IN (__[POSTCOMPILE_primary_keys]) "
-                    "ORDER BY poly.a_id",
-                    [{"primary_keys": [1, 2]}],
+                    "SELECT a.id AS a_id, a.type AS a_type, "
+                    "a.a_data AS a_a_data FROM a ORDER BY a.id",
+                    {},
                 ),
-            ),
-        )
+                AllOf(
+                    CompiledSQL(
+                        "SELECT d.id AS d_id, c.id AS c_id, a.id AS a_id, "
+                        "a.type AS a_type, d.d_data AS d_d_data FROM a "
+                        "JOIN c ON a.id = c.id JOIN d ON c.id = d.id "
+                        "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                        "ORDER BY a.id",
+                        [{"primary_keys": [1]}],
+                    ),
+                    CompiledSQL(
+                        "SELECT e.id AS e_id, c.id AS c_id, a.id AS a_id, "
+                        "a.type AS a_type, e.e_data AS e_e_data FROM a "
+                        "JOIN c ON a.id = c.id JOIN e ON c.id = e.id "
+                        "WHERE a.id IN (__[POSTCOMPILE_primary_keys]) "
+                        "ORDER BY a.id",
+                        [{"primary_keys": [2]}],
+                    ),
+                ),
+            )
+
         with self.assert_statement_count(testing.db, 0):
             eq_(result, [d(d_data="d1"), e(e_data="e1")])
 
@@ -929,3 +975,202 @@ class LazyLoaderTransfersOptsTest(fixtures.DeclarativeMappedTest):
         u = sess.execute(select(User).options(*opts)).scalars().one()
         address = u.address
         eq_(inspect(address).load_options, opts)
+
+
+class NoBaseWPPlusAliasedTest(
+    testing.AssertsExecutionResults, fixtures.TestBase
+):
+    """test for #7799"""
+
+    @testing.fixture
+    def mapping_fixture(self, registry, connection):
+        Base = registry.generate_base()
+
+        class BaseClass(Base):
+            __tablename__ = "baseclass"
+            id = Column(
+                Integer,
+                primary_key=True,
+                unique=True,
+            )
+
+        class A(BaseClass):
+            __tablename__ = "a"
+
+            id = Column(ForeignKey(BaseClass.id), primary_key=True)
+            thing1 = Column(String(50))
+
+            __mapper_args__ = {"polymorphic_identity": "a"}
+
+        class B(BaseClass):
+            __tablename__ = "b"
+
+            id = Column(ForeignKey(BaseClass.id), primary_key=True)
+            thing2 = Column(String(50))
+
+            __mapper_args__ = {"polymorphic_identity": "b"}
+
+        registry.metadata.create_all(connection)
+        with Session(connection) as sess:
+
+            sess.add_all(
+                [
+                    A(thing1="thing1_1"),
+                    A(thing1="thing1_2"),
+                    B(thing2="thing2_2"),
+                    B(thing2="thing2_3"),
+                    A(thing1="thing1_3"),
+                    A(thing1="thing1_4"),
+                    B(thing2="thing2_1"),
+                    B(thing2="thing2_4"),
+                ]
+            )
+
+            sess.commit()
+
+        return BaseClass, A, B
+
+    def test_wp(self, mapping_fixture, connection):
+        BaseClass, A, B = mapping_fixture
+
+        stmt = union(
+            select(A.id, literal("a").label("type")),
+            select(B.id, literal("b").label("type")),
+        ).subquery()
+
+        wp = with_polymorphic(
+            BaseClass,
+            [A, B],
+            selectable=stmt,
+            polymorphic_on=stmt.c.type,
+        )
+
+        session = Session(connection)
+
+        with self.sql_execution_asserter() as asserter:
+            result = session.scalars(
+                select(wp)
+                .options(selectin_polymorphic(wp, [A, B]))
+                .order_by(wp.id)
+            )
+            for obj in result:
+                if isinstance(obj, A):
+                    obj.thing1
+                else:
+                    obj.thing2
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT anon_1.id, anon_1.type FROM "
+                "(SELECT a.id AS id, :param_1 AS type FROM baseclass "
+                "JOIN a ON baseclass.id = a.id "
+                "UNION SELECT b.id AS id, :param_2 AS type "
+                "FROM baseclass JOIN b ON baseclass.id = b.id) AS anon_1 "
+                "ORDER BY anon_1.id",
+                [{"param_1": "a", "param_2": "b"}],
+            ),
+            AllOf(
+                CompiledSQL(
+                    "SELECT a.id AS a_id, baseclass.id AS baseclass_id, "
+                    "a.thing1 AS a_thing1 FROM baseclass "
+                    "JOIN a ON baseclass.id = a.id "
+                    "WHERE baseclass.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY baseclass.id",
+                    {"primary_keys": [1, 2, 5, 6]},
+                ),
+                CompiledSQL(
+                    "SELECT b.id AS b_id, baseclass.id AS baseclass_id, "
+                    "b.thing2 AS b_thing2 FROM baseclass "
+                    "JOIN b ON baseclass.id = b.id "
+                    "WHERE baseclass.id IN (__[POSTCOMPILE_primary_keys]) "
+                    "ORDER BY baseclass.id",
+                    {"primary_keys": [3, 4, 7, 8]},
+                ),
+            ),
+        )
+
+
+class CompositeAttributesTest(fixtures.TestBase):
+    @testing.fixture
+    def mapping_fixture(self, registry, connection):
+        Base = registry.generate_base()
+
+        class BaseCls(Base):
+            __tablename__ = "base"
+            id = Column(
+                Integer, primary_key=True, test_needs_autoincrement=True
+            )
+            type = Column(String(50))
+
+            __mapper_args__ = {"polymorphic_on": type}
+
+        class XYThing:
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+            def __composite_values__(self):
+                return (self.x, self.y)
+
+            def __eq__(self, other):
+                return (
+                    isinstance(other, XYThing)
+                    and other.x == self.x
+                    and other.y == self.y
+                )
+
+            def __ne__(self, other):
+                return not self.__eq__(other)
+
+        class A(ComparableEntity, BaseCls):
+            __tablename__ = "a"
+            id = Column(ForeignKey(BaseCls.id), primary_key=True)
+            thing1 = Column(String(50))
+            comp1 = composite(
+                XYThing, Column("x1", Integer), Column("y1", Integer)
+            )
+
+            __mapper_args__ = {
+                "polymorphic_identity": "a",
+                "polymorphic_load": "selectin",
+            }
+
+        class B(ComparableEntity, BaseCls):
+            __tablename__ = "b"
+            id = Column(ForeignKey(BaseCls.id), primary_key=True)
+            thing2 = Column(String(50))
+            comp2 = composite(
+                XYThing, Column("x2", Integer), Column("y2", Integer)
+            )
+
+            __mapper_args__ = {
+                "polymorphic_identity": "b",
+                "polymorphic_load": "selectin",
+            }
+
+        registry.metadata.create_all(connection)
+
+        with Session(connection) as sess:
+            sess.add_all(
+                [
+                    A(id=1, thing1="thing1", comp1=XYThing(1, 2)),
+                    B(id=2, thing2="thing2", comp2=XYThing(3, 4)),
+                ]
+            )
+            sess.commit()
+
+        return BaseCls, A, B, XYThing
+
+    def test_load_composite(self, mapping_fixture, connection):
+        BaseCls, A, B, XYThing = mapping_fixture
+
+        with Session(connection) as sess:
+            rows = sess.scalars(select(BaseCls).order_by(BaseCls.id)).all()
+
+            eq_(
+                rows,
+                [
+                    A(id=1, thing1="thing1", comp1=XYThing(1, 2)),
+                    B(id=2, thing2="thing2", comp2=XYThing(3, 4)),
+                ],
+            )