]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Raise if unique() not applied to 2.0 joined eager load results
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Sep 2020 21:28:03 +0000 (17:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Sep 2020 23:55:59 +0000 (19:55 -0400)
The automatic uniquing of rows on the client side is turned off for the new
:term:`2.0 style` of ORM querying.  This improves both clarity and
performance.  However, uniquing of rows on the client side is generally
necessary when using joined eager loading for collections, as there
will be duplicates of the primary entity for each element in the
collection because a join was used.  This uniquing must now be manually
enabled and can be achieved using the new
:meth:`_engine.Result.unique` modifier.   To avoid silent failure, the ORM
explicitly requires the method be called when the result of an ORM
query in 2.0 style makes use of joined load collections.    The newer
:func:`_orm.selectinload` strategy is likely preferable for eager loading
of collections in any case.

This changeset also fixes an issue where ORM-style "single entity"
results would not apply unique() correctly if results were returned
as tuples.

Fixes: #4395
Change-Id: Ie62e0cb68ef2a6d2120e968b79575a70d057212e

doc/build/changelog/migration_20.rst
doc/build/changelog/unreleased_14/4395.rst [new file with mode: 0644]
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/loading.py
test/aaa_profiling/test_orm.py
test/base/test_result.py
test/orm/test_eager_relations.py
test/orm/test_query.py
test/orm/test_relationship_criteria.py

index 04a60ead19ef94bd91d33ba2281bc93a01567631..bf54e64e80142f20ba25b881bf9e9a059f48cd0d 100644 (file)
@@ -1092,6 +1092,8 @@ it will be fully transparent.   Applications that wish to reduce statement
 building latency even further to the levels currently offered by the "baked"
 system can opt to use the "lambda" constructs.
 
+.. _joinedload_not_uniqued:
+
 ORM Rows not uniquified by default
 ===================================
 
diff --git a/doc/build/changelog/unreleased_14/4395.rst b/doc/build/changelog/unreleased_14/4395.rst
new file mode 100644 (file)
index 0000000..7d1ebfa
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: orm, change
+    :tickets: 4395
+
+    The automatic uniquing of rows on the client side is turned off for the new
+    :term:`2.0 style` of ORM querying.  This improves both clarity and
+    performance.  However, uniquing of rows on the client side is generally
+    necessary when using joined eager loading for collections, as there
+    will be duplicates of the primary entity for each element in the
+    collection because a join was used.  This uniquing must now be manually
+    enabled and can be achieved using the new
+    :meth:`_engine.Result.unique` modifier.   To avoid silent failure, the ORM
+    explicitly requires the method be called when the result of an ORM
+    query in 2.0 style makes use of joined load collections.    The newer
+    :func:`_orm.selectinload` strategy is likely preferable for eager loading
+    of collections in any case.
+
+    .. seealso::
+
+        :ref:`joinedload_not_uniqued`
index 9b0bdf9a3e8f9bbdc7a420574336bf1e1805c80c..56abca9a9f7be0c9dec809c8228b041c6f627887 100644 (file)
@@ -654,7 +654,10 @@ class ResultInternal(InPlaceGenerative):
         )
 
         if not strategy and self._metadata._unique_filters:
-            if real_result._source_supports_scalars:
+            if (
+                real_result._source_supports_scalars
+                and not self._generate_rows
+            ):
                 strategy = self._metadata._unique_filters[0]
             else:
                 filters = self._metadata._unique_filters
index d3971414719d32d61ec87015dff26e552bb07a16..a7dd1c5478659456a675c25d5407d96b4c262f2f 100644 (file)
@@ -142,10 +142,25 @@ def instances(cursor, context):
         dynamic_yield_per=cursor.context._is_server_side,
     )
 
+    # filtered and single_entity are used to indicate to legacy Query that the
+    # query has ORM entities, so legacy deduping and scalars should be called
+    # on the result.
     result._attributes = result._attributes.union(
         dict(filtered=filtered, is_single_entity=single_entity)
     )
 
+    # multi_row_eager_loaders OTOH is specific to joinedload.
+    if context.compile_state.multi_row_eager_loaders:
+
+        def require_unique(obj):
+            raise sa_exc.InvalidRequestError(
+                "The unique() method must be invoked on this Result, "
+                "as it contains results that include joined eager loads "
+                "against collections"
+            )
+
+        result._unique_filter_state = (None, require_unique)
+
     if context.yield_per:
         result.yield_per(context.yield_per)
 
index fb1edd38fcbe98800eed3ea4bc7c292f2202eedc..30a02472cfb686d8c7a2776eaf339af1781d0134 100644 (file)
@@ -893,7 +893,7 @@ class JoinedEagerLoadTest(NoCache, fixtures.MappedTest):
                 obj = ORMCompileState.orm_setup_cursor_result(
                     sess, compile_state.statement, {}, exec_opts, {}, r,
                 )
-                list(obj)
+                list(obj.unique())
                 sess.close()
 
         go()
index 7281a66945e3a43f520bf9f78eb24044b55e345b..0136b6e2966cade51c810468b849b6c49e106be9 100644 (file)
@@ -1031,6 +1031,45 @@ class OnlyScalarsTest(fixtures.TestBase):
 
         eq_(r.all(), [1, 2, 1, 1, 4])
 
+    def test_scalar_mode_mfiltered_unique_rows_all(self, no_tuple_fixture):
+        metadata = result.SimpleResultMetaData(
+            ["a", "b", "c"], _unique_filters=[int]
+        )
+
+        r = result.ChunkedIteratorResult(
+            metadata, no_tuple_fixture, source_supports_scalars=True,
+        )
+
+        r = r.unique()
+
+        eq_(r.all(), [(1,), (2,), (4,)])
+
+    def test_scalar_mode_mfiltered_unique_mappings_all(self, no_tuple_fixture):
+        metadata = result.SimpleResultMetaData(
+            ["a", "b", "c"], _unique_filters=[int]
+        )
+
+        r = result.ChunkedIteratorResult(
+            metadata, no_tuple_fixture, source_supports_scalars=True,
+        )
+
+        r = r.unique()
+
+        eq_(r.mappings().all(), [{"a": 1}, {"a": 2}, {"a": 4}])
+
+    def test_scalar_mode_mfiltered_unique_scalars_all(self, no_tuple_fixture):
+        metadata = result.SimpleResultMetaData(
+            ["a", "b", "c"], _unique_filters=[int]
+        )
+
+        r = result.ChunkedIteratorResult(
+            metadata, no_tuple_fixture, source_supports_scalars=True,
+        )
+
+        r = r.scalars().unique()
+
+        eq_(r.all(), [1, 2, 4])
+
     def test_scalar_mode_unique_scalars_all(self, no_tuple_fixture):
         metadata = result.SimpleResultMetaData(["a", "b", "c"])
 
index 00d98d2b5927b288493554537cb23f0c1de4790a..a699cfa63449f3329a2a327ce057a5731c28bd5c 100644 (file)
@@ -32,6 +32,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
@@ -2813,6 +2814,107 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
         )
 
 
+class SelectUniqueTest(_fixtures.FixtureTest):
+    run_inserts = "once"
+    run_deletes = None
+
+    @classmethod
+    def setup_mappers(cls):
+        cls._setup_stock_mapping()
+
+    def test_many_to_one(self):
+        Address = self.classes.Address
+
+        stmt = (
+            select(Address)
+            .options(joinedload(Address.user))
+            .order_by(Address.id)
+        )
+
+        s = create_session()
+        result = s.execute(stmt)
+
+        eq_(result.scalars().all(), self.static.address_user_result)
+
+    def test_unique_error(self):
+        User = self.classes.User
+
+        stmt = select(User).options(joinedload(User.addresses))
+        s = create_session()
+        result = s.execute(stmt)
+
+        with expect_raises_message(
+            sa.exc.InvalidRequestError,
+            r"The unique\(\) method must be invoked on this Result",
+        ):
+            result.all()
+
+    def test_unique_tuples_single_entity(self):
+        User = self.classes.User
+
+        stmt = (
+            select(User).options(joinedload(User.addresses)).order_by(User.id)
+        )
+        s = create_session()
+        result = s.execute(stmt)
+
+        eq_(
+            result.unique().all(),
+            [(u,) for u in self.static.user_address_result],
+        )
+
+    def test_unique_scalars_single_entity(self):
+        User = self.classes.User
+
+        stmt = (
+            select(User).options(joinedload(User.addresses)).order_by(User.id)
+        )
+        s = create_session()
+        result = s.execute(stmt)
+
+        eq_(result.scalars().unique().all(), self.static.user_address_result)
+
+    def test_unique_tuples_multiple_entity(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        stmt = (
+            select(User, Address)
+            .join(User.addresses)
+            .options(joinedload(User.addresses))
+            .order_by(User.id, Address.id)
+        )
+        s = create_session()
+        result = s.execute(stmt)
+
+        eq_(
+            result.unique().all(),
+            [
+                (u, a)
+                for u in self.static.user_address_result
+                for a in u.addresses
+            ],
+        )
+
+    def test_unique_scalars_multiple_entity(self):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        stmt = (
+            select(User, Address)
+            .join(User.addresses)
+            .options(joinedload(User.addresses))
+            .order_by(User.id)
+        )
+        s = create_session()
+        result = s.execute(stmt)
+
+        eq_(
+            result.scalars().unique().all(),
+            [u for u in self.static.user_address_result if u.addresses],
+        )
+
+
 class InnerJoinSplicingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     __dialect__ = "default"
     __backend__ = True  # exercise hardcore join nesting on backends
index 4dbaa91668ce43c3fab492bd901d62efa8bbf6da..cee583ffbefc32bdcb8d533ac9a30df522ed5a7b 100644 (file)
@@ -909,7 +909,7 @@ class GetTest(QueryTest):
             .execution_options(populate_existing=True)
         )
 
-        s.execute(stmt).scalars().all()
+        s.execute(stmt).unique().scalars().all()
 
         assert u.addresses[0].email_address == "jack@bean.com"
         assert u.orders[1].items[2].description == "item 5"
index c4bcf0404634f620ba2f79c51e63d7781b751767..ccee396a32f317bbf98d80264b64999e2d009151 100644 (file)
@@ -697,6 +697,9 @@ class TemporalFixtureTest(testing.fixtures.DeclarativeMappedTest):
         else:
             loader_options = ()
 
+        is_joined = (
+            loader_strategy and loader_strategy.__name__ == "joinedload"
+        )
         p1 = sess.execute(
             select(Parent).filter(
                 Parent.timestamp == datetime.datetime(2009, 10, 15, 12, 00, 00)
@@ -712,42 +715,40 @@ class TemporalFixtureTest(testing.fixtures.DeclarativeMappedTest):
         ).scalar()
         c5 = p2.children[1]
 
-        parents = (
-            sess.execute(
-                select(Parent)
-                .execution_options(populate_existing=True)
-                .options(
-                    temporal_range(
-                        datetime.datetime(2009, 10, 16, 12, 00, 00),
-                        datetime.datetime(2009, 10, 18, 12, 00, 00),
-                    ),
-                    *loader_options
-                )
+        result = sess.execute(
+            select(Parent)
+            .execution_options(populate_existing=True)
+            .options(
+                temporal_range(
+                    datetime.datetime(2009, 10, 16, 12, 00, 00),
+                    datetime.datetime(2009, 10, 18, 12, 00, 00),
+                ),
+                *loader_options
             )
-            .scalars()
-            .all()
         )
+        if is_joined:
+            result = result.unique()
+        parents = result.scalars().all()
 
         assert parents[0] == p2
         assert parents[0].children == [c5]
 
-        parents = (
-            sess.execute(
-                select(Parent)
-                .execution_options(populate_existing=True)
-                .join(Parent.children)
-                .filter(Child.id == c2_id)
-                .options(
-                    temporal_range(
-                        datetime.datetime(2009, 10, 15, 11, 00, 00),
-                        datetime.datetime(2009, 10, 18, 12, 00, 00),
-                    ),
-                    *loader_options
-                )
+        result = sess.execute(
+            select(Parent)
+            .execution_options(populate_existing=True)
+            .join(Parent.children)
+            .filter(Child.id == c2_id)
+            .options(
+                temporal_range(
+                    datetime.datetime(2009, 10, 15, 11, 00, 00),
+                    datetime.datetime(2009, 10, 18, 12, 00, 00),
+                ),
+                *loader_options
             )
-            .scalars()
-            .all()
         )
+        if is_joined:
+            result = result.unique()
+        parents = result.scalars().all()
 
         assert parents[0] == p1
         assert parents[0].children == [c1, c2]