]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement PropComparator.and_() for remaining options
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Oct 2020 18:29:57 +0000 (14:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Oct 2020 18:34:22 +0000 (14:34 -0400)
In c7b489b25802f7a25ef78d0731411295c611cc1c we implemented
with_loader_criteria() for everyone as well as PropComparator.and_()
for joinedload() and join(), but forgot to do anything for
lazyload(), selectinload(), or subqueryload().  Even though
I actually documented it in terms of lazyload().

Fixes: #4472
Change-Id: I0ef410a83c34e63b9c9c9c3277c0063d8971ec14

lib/sqlalchemy/orm/strategies.py
test/orm/test_relationship_criteria.py

index 371f923eef711d1b769132d93b442c523a7bd010..1d4709726689900e4767d62eb868c21333224f42 100644 (file)
@@ -774,7 +774,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             "'%s' is not available due to lazy='%s'" % (self, lazy)
         )
 
-    def _load_for_state(self, state, passive):
+    def _load_for_state(self, state, passive, loadopt=None):
 
         if not state.key and (
             (
@@ -788,7 +788,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         pending = not state.key
         primary_key_identity = None
 
-        if (not passive & attributes.SQL_OK and not self.use_get) or (
+        use_get = self.use_get and (not loadopt or not loadopt._extra_criteria)
+
+        if (not passive & attributes.SQL_OK and not use_get) or (
             not passive & attributes.NON_PERSISTENT_OK and pending
         ):
             return attributes.PASSIVE_NO_RESULT
@@ -804,7 +806,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                 # for history purposes or otherwise returning
                 # PASSIVE_NO_RESULT, don't raise.  This is also a
                 # history-related flag
-                not self.use_get
+                not use_get
                 or passive & attributes.RELATED_OBJECT_OK
             )
         ):
@@ -824,7 +826,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
 
         # if we have a simple primary key load, check the
         # identity map without generating a Query at all
-        if self.use_get:
+        if use_get:
             primary_key_identity = self._get_ident_for_use_get(
                 session, state, passive
             )
@@ -863,7 +865,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                 return attributes.PASSIVE_NO_RESULT
 
         return self._emit_lazyload(
-            session, state, primary_key_identity, passive
+            session, state, primary_key_identity, passive, loadopt
         )
 
     def _get_ident_for_use_get(self, session, state, passive):
@@ -885,7 +887,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         return util.LRUCache(30)
 
     @util.preload_module("sqlalchemy.orm.strategy_options")
-    def _emit_lazyload(self, session, state, primary_key_identity, passive):
+    def _emit_lazyload(
+        self, session, state, primary_key_identity, passive, loadopt
+    ):
         strategy_options = util.preloaded.orm_strategy_options
 
         stmt = sql.lambda_stmt(
@@ -918,18 +922,28 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         if pending or passive & attributes.NO_AUTOFLUSH:
             stmt += lambda stmt: stmt.execution_options(autoflush=False)
 
-        if state.load_options:
+        use_get = self.use_get
+
+        if state.load_options or (loadopt and loadopt._extra_criteria):
 
             effective_path = state.load_path[self.parent_property]
 
             opts = list(state.load_options)
 
+            if loadopt and loadopt._extra_criteria:
+                use_get = False
+                opts += (
+                    orm_util.LoaderCriteriaOption(
+                        self.entity, sql.and_(*loadopt._extra_criteria)
+                    ),
+                )
+
             stmt += lambda stmt: stmt.options(*opts)
             stmt += lambda stmt: stmt._update_compile_options(
                 {"_current_path": effective_path}
             )
 
-        if self.use_get:
+        if use_get:
             if self._raise_on_sql:
                 self._invoke_raise_load(state, passive, "raise_on_sql")
 
@@ -1023,7 +1037,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
     ):
         key = self.key
 
-        if not self.is_class_level:
+        if not self.is_class_level or (loadopt and loadopt._extra_criteria):
             # we are not the primary manager for this attribute
             # on this class - set up a
             # per-instance lazyloader, which will override the
@@ -1034,7 +1048,7 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             # class-level lazyloader installed.
             set_lazy_callable = (
                 InstanceState._instance_level_callable_processor
-            )(mapper.class_manager, LoadLazyAttribute(key, self), key)
+            )(mapper.class_manager, LoadLazyAttribute(key, self, loadopt), key)
 
             populators["new"].append((self.key, set_lazy_callable))
         elif context.populate_existing or mapper.always_refresh:
@@ -1056,9 +1070,10 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
 class LoadLazyAttribute(object):
     """serializable loader object used by LazyLoader"""
 
-    def __init__(self, key, initiating_strategy):
+    def __init__(self, key, initiating_strategy, loadopt):
         self.key = key
         self.strategy_key = initiating_strategy.strategy_key
+        self.loadopt = loadopt
 
     def __call__(self, state, passive=attributes.PASSIVE_OFF):
         key = self.key
@@ -1066,7 +1081,7 @@ class LoadLazyAttribute(object):
         prop = instance_mapper._props[key]
         strategy = prop._strategies[self.strategy_key]
 
-        return strategy._load_for_state(state, passive)
+        return strategy._load_for_state(state, passive, loadopt=self.loadopt)
 
 
 class PostLoader(AbstractRelationshipLoader):
@@ -1376,12 +1391,28 @@ class SubqueryLoader(PostLoader):
         return q
 
     def _setup_options(
-        self, q, subq_path, rewritten_path, orig_query, effective_entity
+        self,
+        q,
+        subq_path,
+        rewritten_path,
+        orig_query,
+        effective_entity,
+        loadopt,
     ):
+
+        opts = orig_query._with_options
+
+        if loadopt and loadopt._extra_criteria:
+            opts += (
+                orm_util.LoaderCriteriaOption(
+                    self.entity, sql.and_(*loadopt._extra_criteria)
+                ),
+            )
+
         # propagate loader options etc. to the new query.
         # these will fire relative to subq_path.
         q = q._with_current_path(rewritten_path)
-        q = q.options(*orig_query._with_options)
+        q = q.options(*opts)
 
         return q
 
@@ -1586,7 +1617,7 @@ class SubqueryLoader(PostLoader):
         )
 
         q = self._setup_options(
-            q, subq_path, rewritten_path, orig_query, effective_entity
+            q, subq_path, rewritten_path, orig_query, effective_entity, loadopt
         )
         q = self._setup_outermost_orderby(q)
 
@@ -2627,10 +2658,11 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
             self.parent_property,
             self._load_for_path,
             effective_entity,
+            loadopt,
         )
 
     def _load_for_path(
-        self, context, path, states, load_only, effective_entity
+        self, context, path, states, load_only, effective_entity, loadopt
     ):
         if load_only and self.key not in load_only:
             return
@@ -2768,6 +2800,13 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
         effective_path = path[self.parent_property]
 
         options = orig_query._with_options
+        if loadopt and loadopt._extra_criteria:
+            options += (
+                orm_util.LoaderCriteriaOption(
+                    effective_entity, sql.and_(*loadopt._extra_criteria)
+                ),
+            )
+
         q = q.add_criteria(
             lambda q: q.options(*options)._update_compile_options(
                 {"_current_path": effective_path}
index 1c7eb2e619ea18537b9611b23771bfe45f6d139d..7237dd264f924c30b594caf161986814fcdb4c01 100644 (file)
@@ -13,10 +13,12 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
+from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing.assertsql import CompiledSQL
@@ -783,18 +785,60 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
         )
         return User, Address
 
+    def _user_minus_edwood(self, User, Address):
+        return [
+            User(
+                addresses=[
+                    Address(email_address="jack@bean.com", id=1, user_id=7)
+                ],
+                id=7,
+                name="jack",
+            ),
+            User(
+                addresses=[
+                    Address(
+                        email_address="ed@bettyboop.com",
+                        id=3,
+                        user_id=8,
+                    ),
+                    Address(email_address="ed@lala.com", id=4, user_id=8),
+                ],
+                id=8,
+                name="ed",
+            ),
+            User(
+                addresses=[
+                    Address(email_address="fred@fred.com", id=5, user_id=9)
+                ],
+                id=9,
+                name="fred",
+            ),
+            User(addresses=[], id=10, name="chuck"),
+        ]
+
     def test_joinedload_local_criteria(self, user_address_fixture):
         User, Address = user_address_fixture
 
         s = Session(testing.db, future=True)
 
-        stmt = select(User).options(
-            joinedload(User.addresses.and_(Address.email_address != "email")),
+        stmt = (
+            select(User)
+            .options(
+                joinedload(
+                    User.addresses.and_(Address.email_address != "ed@wood.com")
+                ),
+            )
+            .order_by(User.id)
         )
 
         with self.sql_execution_asserter() as asserter:
 
-            s.execute(stmt)
+            result = s.execute(stmt)
+
+            eq_(
+                result.scalars().unique().all(),
+                self._user_minus_edwood(*user_address_fixture),
+            )
 
         asserter.assert_(
             CompiledSQL(
@@ -803,8 +847,159 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
                 "users LEFT OUTER JOIN addresses AS addresses_1 "
                 "ON users.id = addresses_1.user_id "
                 "AND addresses_1.email_address != :email_address_1 "
-                "ORDER BY addresses_1.id",
-                [{"email_address_1": "email"}],
+                "ORDER BY users.id, addresses_1.id",
+                [{"email_address_1": "ed@wood.com"}],
+            ),
+        )
+
+    def test_selectinload_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        stmt = (
+            select(User)
+            .options(
+                selectinload(
+                    User.addresses.and_(Address.email_address != "ed@wood.com")
+                ),
+            )
+            .order_by(User.id)
+        )
+
+        with self.sql_execution_asserter() as asserter:
+
+            result = s.execute(stmt)
+
+            eq_(
+                result.scalars().unique().all(),
+                self._user_minus_edwood(*user_address_fixture),
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name FROM users ORDER BY users.id"
+            ),
+            CompiledSQL(
+                "SELECT addresses.user_id AS addresses_user_id, "
+                "addresses.id AS addresses_id, addresses.email_address "
+                "AS addresses_email_address FROM addresses "
+                "WHERE addresses.user_id IN ([POSTCOMPILE_primary_keys]) "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [
+                    {
+                        "primary_keys": [7, 8, 9, 10],
+                        "email_address_1": "ed@wood.com",
+                    }
+                ],
+            ),
+        )
+
+    def test_lazyload_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        stmt = (
+            select(User)
+            .options(
+                lazyload(
+                    User.addresses.and_(Address.email_address != "ed@wood.com")
+                ),
+            )
+            .order_by(User.id)
+        )
+
+        with self.sql_execution_asserter() as asserter:
+
+            result = s.execute(stmt)
+
+            eq_(
+                result.scalars().unique().all(),
+                self._user_minus_edwood(*user_address_fixture),
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name FROM users ORDER BY users.id"
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 7, "email_address_1": "ed@wood.com"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 8, "email_address_1": "ed@wood.com"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 9, "email_address_1": "ed@wood.com"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 10, "email_address_1": "ed@wood.com"}],
+            ),
+        )
+
+    def test_subqueryload_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        stmt = (
+            select(User)
+            .options(
+                subqueryload(
+                    User.addresses.and_(Address.email_address != "ed@wood.com")
+                ),
+            )
+            .order_by(User.id)
+        )
+
+        with self.sql_execution_asserter() as asserter:
+
+            result = s.execute(stmt)
+
+            eq_(
+                result.scalars().unique().all(),
+                self._user_minus_edwood(*user_address_fixture),
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name FROM users ORDER BY users.id"
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, addresses.user_id "
+                "AS addresses_user_id, addresses.email_address "
+                "AS addresses_email_address, anon_1.users_id "
+                "AS anon_1_users_id FROM (SELECT users.id AS users_id "
+                "FROM users) AS anon_1 JOIN addresses ON anon_1.users_id = "
+                "addresses.user_id AND "
+                "addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"email_address_1": "ed@wood.com"}],
             ),
         )