]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement relationship AND criteria; global loader criteria
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Aug 2020 01:47:43 +0000 (21:47 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 6 Aug 2020 02:13:11 +0000 (22:13 -0400)
Added the ability to add arbitrary criteria to the ON clause generated
by a relationship attribute in a query, which applies to methods such
as :meth:`_query.Query.join` as well as loader options like
:func:`_orm.joinedload`.   Additionally, a "global" version of the option
allows limiting criteria to be applied to particular entities in
a query globally.

Documentation is minimal at this point, new examples will
be coming in a subsequent commit.

Some adjustments to execution options in how they are represented
in the ORMExecuteState as well as well as a few ORM tests that
forgot to get merged in a preceding commit.

Fixes: #4472
Change-Id: I2b8fc57092dedf35ebd16f6343ad0f0d7d332beb

23 files changed:
doc/build/changelog/unreleased_14/4472.rst [new file with mode: 0644]
doc/build/orm/loading_relationships.rst
doc/build/orm/query.rst
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/compiler.py
test/ext/test_baked.py
test/orm/inheritance/test_polymorphic_rel.py
test/orm/test_bundle.py
test/orm/test_cache_key.py
test/orm/test_events.py
test/orm/test_options.py
test/orm/test_relationship_criteria.py [new file with mode: 0644]
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_14/4472.rst b/doc/build/changelog/unreleased_14/4472.rst
new file mode 100644 (file)
index 0000000..6de5058
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+    :tags: feature, orm
+    :tickets: 4472
+
+    Added the ability to add arbitrary criteria to the ON clause generated
+    by a relationship attribute in a query, which applies to methods such
+    as :meth:`_query.Query.join` as well as loader options like
+    :func:`_orm.joinedload`.   Additionally, a "global" version of the option
+    allows limiting criteria to be applied to particular entities in
+    a query globally.
+
+    .. seealso::
+
+        :ref:`loader_option_criteria`
+
+        :func:`_orm.with_loader_criteria`
+
+    .. TODO: add links to new examples section and session-related
+       documentation involving do_orm_execute event when merged
\ No newline at end of file
index 50d3cc51a79603f962edf2ebe4f0796d10720c61..8909d9a6eb1c6da996a099144c85cbb91e0f2a08 100644 (file)
@@ -112,13 +112,10 @@ the string name of an attribute against a parent, or for greater specificity
 can accommodate a class-bound attribute directly::
 
     # set children to load lazily
-    session.query(Parent).options(lazyload('children')).all()
-
-    # same, using class-bound attribute
     session.query(Parent).options(lazyload(Parent.children)).all()
 
     # set children to load eagerly with a join
-    session.query(Parent).options(joinedload('children')).all()
+    session.query(Parent).options(joinedload(Parent.children)).all()
 
 The loader options can also be "chained" using **method chaining**
 to specify how loading should occur further levels deep::
@@ -141,6 +138,48 @@ collections loaded.  When the ``children`` collection on a particular
 objects, but additionally apply eager loading to the ``subelements``
 collection on each member of ``children``.
 
+The above examples, using :class:`_orm.Query`, are now referred to as
+:term:`1.x style` queries.   The options system is available as well for
+:term:`2.0 style` queries using the :meth:`_sql.Select.options` method::
+
+  stmt = select(Parent).options(
+        lazyload(Parent.children).
+        subqueryload(Child.subelements))
+
+  result = session.execute(stmt)
+
+Under the hood, :class:`_orm.Query` is ultimately using the above
+:class:`_sql.select` based mechanism.
+
+
+.. _loader_option_criteria:
+
+Adding Criteria to loader options
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The relationship attributes used to indicate loader options include the
+ability to add additional filtering criteria to the ON clause of the join
+that's created, or to the WHERE criteria involved, depending on the loader
+strategy.  This can be achieved using the :meth:`.PropComparator.and_`
+method which will pass through an option such that loaded results are limited
+to the given filter criteria::
+
+    session.query(A).options(lazyload(A.bs.and_(B.id > 5)))
+
+When using limiting criteria, if a particular collection is already loaded
+it won't be refreshed; to ensure the new criteria takes place, apply
+the :meth:`_orm.Query.populate_existing` option::
+
+    session.query(A).options(lazyload(A.bs.and_(B.id > 5))).populate_existing()
+
+In order to add filtering criteria to all occurrences of an entity throughout
+a query, regardless of loader strategy or where it occurs in the loading
+process, see the :func:`_orm.with_loader_criteria` function.
+
+.. versionadded:: 1.4
+
+Specifying Sub-Options with Load.options()
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 Using method chaining, the loader style of each link in the path is explicitly
 stated.  To navigate along a path without changing the existing loader style
 of a particular attribute, the :func:`.defaultload` method/function may be used::
@@ -1263,6 +1302,7 @@ Relationship Loader API
 .. autofunction:: lazyload
 
 .. autoclass:: Load
+    :members:
 
 .. autofunction:: noload
 
index 3fddd6c341f6b3129a58a9466227f7aab495c84a..ed45a65e7870441952009aae56a50ddc6d5dd323 100644 (file)
@@ -44,6 +44,8 @@ ORM-Specific Query Constructs
 .. autoclass:: sqlalchemy.orm.strategy_options.Load
     :members:
 
+.. autofunction:: sqlalchemy.orm.with_loader_criteria
+
 .. autofunction:: join
 
 .. autofunction:: outerjoin
index 53545826bce9d84755fa432ed9de9817b1a2e717..fe9bbaf02dcf97482e0453785a61db43d316cfb7 100644 (file)
@@ -207,7 +207,6 @@ class ShardedSession(Session):
 
 
 def execute_and_instances(orm_context):
-
     if orm_context.is_select:
         load_options = active_options = orm_context.load_options
         update_options = None
@@ -237,8 +236,8 @@ def execute_and_instances(orm_context):
 
     if active_options._refresh_identity_token is not None:
         shard_id = active_options._refresh_identity_token
-    elif "_sa_shard_id" in orm_context.merged_execution_options:
-        shard_id = orm_context.merged_execution_options["_sa_shard_id"]
+    elif "_sa_shard_id" in orm_context.execution_options:
+        shard_id = orm_context.execution_options["_sa_shard_id"]
     elif "shard_id" in orm_context.bind_arguments:
         shard_id = orm_context.bind_arguments["shard_id"]
     else:
index 32ec60322a57c3a0bc66ddbad422e6e044e6c85a..4581038389f3792c1f626306beb014f3671bfb17 100644 (file)
@@ -48,6 +48,7 @@ from .strategy_options import Load  # noqa
 from .util import aliased  # noqa
 from .util import Bundle  # noqa
 from .util import join  # noqa
+from .util import LoaderCriteriaOption  # noqa
 from .util import object_mapper  # noqa
 from .util import outerjoin  # noqa
 from .util import polymorphic_union  # noqa
@@ -101,6 +102,8 @@ def create_session(bind=None, **kwargs):
     return Session(bind=bind, **kwargs)
 
 
+with_loader_criteria = public_factory(LoaderCriteriaOption, ".orm")
+
 relationship = public_factory(RelationshipProperty, ".orm.relationship")
 
 
index 6dd95a5a90f0b9d379b18e019390e452c5a66869..2e1b9dc75ee716207b1268f4bd29983b205b69a6 100644 (file)
@@ -50,6 +50,7 @@ from .. import inspection
 from .. import util
 from ..sql import base as sql_base
 from ..sql import roles
+from ..sql import traversals
 from ..sql import visitors
 
 
@@ -58,6 +59,7 @@ class QueryableAttribute(
     interfaces._MappedAttribute,
     interfaces.InspectionAttr,
     interfaces.PropComparator,
+    traversals.HasCopyInternals,
     roles.JoinTargetRole,
     roles.OnClauseRole,
     sql_base.Immutable,
@@ -91,6 +93,7 @@ class QueryableAttribute(
         impl=None,
         comparator=None,
         of_type=None,
+        extra_criteria=(),
     ):
         self.class_ = class_
         self.key = key
@@ -98,6 +101,7 @@ class QueryableAttribute(
         self.impl = impl
         self.comparator = comparator
         self._of_type = of_type
+        self._extra_criteria = extra_criteria
 
         manager = manager_of_class(class_)
         # manager is None in the case of AliasedClass
@@ -114,6 +118,7 @@ class QueryableAttribute(
         ("key", visitors.ExtendedInternalTraversal.dp_string),
         ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
         ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+        ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
     ]
 
     def __reduce__(self):
@@ -240,6 +245,29 @@ class QueryableAttribute(
             impl=self.impl,
             comparator=self.comparator.of_type(entity),
             of_type=inspection.inspect(entity),
+            extra_criteria=self._extra_criteria,
+        )
+
+    def and_(self, *other):
+        return QueryableAttribute(
+            self.class_,
+            self.key,
+            self._parententity,
+            impl=self.impl,
+            comparator=self.comparator.and_(*other),
+            of_type=self._of_type,
+            extra_criteria=self._extra_criteria + other,
+        )
+
+    def _clone(self, **kw):
+        return QueryableAttribute(
+            self.class_,
+            self.key,
+            self._parententity,
+            impl=self.impl,
+            comparator=self.comparator,
+            of_type=self._of_type,
+            extra_criteria=self._extra_criteria,
         )
 
     def label(self, name):
index 96725e55b1ac5ceb9ba0b02e2fe27e4d4dd4469e..a35b2f9fdf7f79b942f5f355568a08404f66b587 100644 (file)
@@ -11,9 +11,9 @@ from .base import _is_aliased_class
 from .interfaces import ORMColumnsClauseRole
 from .path_registry import PathRegistry
 from .util import _entity_corresponds_to
+from .util import _ORMJoin
 from .util import aliased
 from .util import Bundle
-from .util import join as orm_join
 from .util import ORMAdapter
 from .. import exc as sa_exc
 from .. import future
@@ -78,7 +78,6 @@ class QueryContext(object):
         _yield_per = None
         _refresh_state = None
         _lazy_loaded_from = None
-        _params = _EMPTY_DICT
 
     def __init__(
         self,
@@ -308,6 +307,9 @@ class ORMFromStatementCompileState(ORMCompileState):
     multi_row_eager_loaders = False
     compound_eager_adapter = None
 
+    extra_criteria_entities = _EMPTY_DICT
+    eager_joins = _EMPTY_DICT
+
     @classmethod
     def create_for_statement(cls, statement_container, compiler, **kw):
 
@@ -338,6 +340,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         if toplevel and statement_container._with_options:
             self.attributes = {"_unbound_load_dedupes": set()}
+            self.global_attributes = compiler._global_attributes
 
             for opt in statement_container._with_options:
                 if opt._is_compile_state:
@@ -345,6 +348,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         else:
             self.attributes = {}
+            self.global_attributes = compiler._global_attributes
 
         if statement_container._with_context_options:
             for fn, key in statement_container._with_context_options:
@@ -352,8 +356,6 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self.primary_columns = []
         self.secondary_columns = []
-        self.eager_joins = {}
-        self.single_inh_entities = {}
         self.create_eager_joins = []
         self._fallback_from_clauses = []
 
@@ -423,11 +425,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
     def create_for_statement(cls, statement, compiler, **kw):
         """compiler hook, we arrive here from compiler.visit_select() only."""
 
+        self = cls.__new__(cls)
+
         if compiler is not None:
             toplevel = not compiler.stack
             compiler._rewrites_selected_columns = True
+            self.global_attributes = compiler._global_attributes
         else:
             toplevel = True
+            self.global_attributes = {}
 
         select_statement = statement
 
@@ -437,8 +443,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             statement._compile_options
         )
 
-        self = cls.__new__(cls)
-
         if select_statement._execution_options:
             # execution options should not impact the compilation of a
             # query, and at the moment subqueryloader is putting some things
@@ -516,7 +520,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         self.primary_columns = []
         self.secondary_columns = []
         self.eager_joins = {}
-        self.single_inh_entities = {}
+        self.extra_criteria_entities = {}
         self.create_eager_joins = []
         self._fallback_from_clauses = []
 
@@ -634,7 +638,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         if self.compile_options._enable_single_crit:
 
-            self._adjust_for_single_inheritance()
+            self._adjust_for_extra_criteria()
 
         if not self.primary_columns:
             if self.compile_options._only_load_props:
@@ -1408,6 +1412,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             left, right, onclause, prop, create_aliases, aliased_generation
         )
 
+        if not r_info.is_selectable:
+            extra_criteria = self._get_extra_criteria(r_info)
+        else:
+            extra_criteria = ()
+
         if replace_from_obj_index is not None:
             # splice into an existing element in the
             # self._from_obj list
@@ -1416,12 +1425,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             self.from_clauses = (
                 self.from_clauses[:replace_from_obj_index]
                 + [
-                    orm_join(
+                    _ORMJoin(
                         left_clause,
                         right,
                         onclause,
                         isouter=outerjoin,
                         full=full,
+                        _extra_criteria=extra_criteria,
                     )
                 ]
                 + self.from_clauses[replace_from_obj_index + 1 :]
@@ -1440,8 +1450,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 left_clause = left
 
             self.from_clauses = self.from_clauses + [
-                orm_join(
-                    left_clause, r_info, onclause, isouter=outerjoin, full=full
+                _ORMJoin(
+                    left_clause,
+                    r_info,
+                    onclause,
+                    isouter=outerjoin,
+                    full=full,
+                    _extra_criteria=extra_criteria,
                 )
             ]
 
@@ -1848,8 +1863,23 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
             or kwargs.get("group_by", False)
         )
 
-    def _adjust_for_single_inheritance(self):
-        """Apply single-table-inheritance filtering.
+    def _get_extra_criteria(self, ext_info):
+        if (
+            "additional_entity_criteria",
+            ext_info.mapper,
+        ) in self.global_attributes:
+            return tuple(
+                ae._resolve_where_criteria(ext_info)
+                for ae in self.global_attributes[
+                    ("additional_entity_criteria", ext_info.mapper)
+                ]
+                if ae.include_aliases or ae.entity is ext_info
+            )
+        else:
+            return ()
+
+    def _adjust_for_extra_criteria(self):
+        """Apply extra criteria filtering.
 
         For all distinct single-table-inheritance mappers represented in
         the columns clause of this query, as well as the "select from entity",
@@ -1857,38 +1887,50 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         clause of the given QueryContext such that only the appropriate
         subtypes are selected from the total results.
 
+        Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+        associated with the global context.
+
         """
 
         for fromclause in self.from_clauses:
             ext_info = fromclause._annotations.get("parententity", None)
             if (
                 ext_info
-                and ext_info.mapper._single_table_criterion is not None
-                and ext_info not in self.single_inh_entities
+                and (
+                    ext_info.mapper._single_table_criterion is not None
+                    or ("additional_entity_criteria", ext_info.mapper)
+                    in self.global_attributes
+                )
+                and ext_info not in self.extra_criteria_entities
             ):
 
-                self.single_inh_entities[ext_info] = (
+                self.extra_criteria_entities[ext_info] = (
                     ext_info,
                     ext_info._adapter if ext_info.is_aliased_class else None,
                 )
 
-        search = set(self.single_inh_entities.values())
+        search = set(self.extra_criteria_entities.values())
 
         for (ext_info, adapter) in search:
             if ext_info in self._join_entities:
                 continue
+
             single_crit = ext_info.mapper._single_table_criterion
+
+            additional_entity_criteria = self._get_extra_criteria(ext_info)
+
             if single_crit is not None:
+                additional_entity_criteria += (single_crit,)
+
+            current_adapter = self._get_current_adapter()
+            for crit in additional_entity_criteria:
                 if adapter:
-                    single_crit = adapter.traverse(single_crit)
+                    crit = adapter.traverse(crit)
 
-                current_adapter = self._get_current_adapter()
                 if current_adapter:
-                    single_crit = sql_util._deep_annotate(
-                        single_crit, {"_orm_adapt": True}
-                    )
-                    single_crit = current_adapter(single_crit, False)
-                self._where_criteria += (single_crit,)
+                    crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
+                    crit = current_adapter(crit, False)
+                self._where_criteria += (crit,)
 
 
 def _column_descriptions(query_or_select_stmt, compile_state=None):
@@ -2205,9 +2247,13 @@ class _MapperEntity(_QueryEntity):
         adapter = self._get_entity_clauses(compile_state)
 
         single_table_crit = self.mapper._single_table_criterion
-        if single_table_crit is not None:
+        if (
+            single_table_crit is not None
+            or ("additional_entity_criteria", self.mapper)
+            in compile_state.global_attributes
+        ):
             ext_info = self.entity_zero
-            compile_state.single_inh_entities[ext_info] = (
+            compile_state.extra_criteria_entities[ext_info] = (
                 ext_info,
                 ext_info._adapter if ext_info.is_aliased_class else None,
             )
@@ -2528,8 +2574,13 @@ class _ORMColumnEntity(_ColumnEntity):
         ezero = self.entity_zero
 
         single_table_crit = self.mapper._single_table_criterion
-        if single_table_crit is not None:
-            compile_state.single_inh_entities[ezero] = (
+        if (
+            single_table_crit is not None
+            or ("additional_entity_criteria", self.mapper)
+            in compile_state.global_attributes
+        ):
+
+            compile_state.extra_criteria_entities[ezero] = (
                 ezero,
                 ezero._adapter if ezero.is_aliased_class else None,
             )
index 4cf820ae3250c37f3c33eff45d63545d32aec52a..068c8507322205387bc78ce9b36fba47688fe707 100644 (file)
@@ -480,6 +480,32 @@ class PropComparator(operators.ColumnOperators):
 
         return self.operate(PropComparator.of_type_op, class_)
 
+    def and_(self, *criteria):
+        """Add additional criteria to the ON clause that's represented by this
+        relationship attribute.
+
+        E.g.::
+
+
+            stmt = select(User).join(
+                User.addresses.and_(Address.email_address != 'foo')
+            )
+
+            stmt = select(User).options(
+                joinedload(User.addresses.and_(Address.email_address != 'foo'))
+            )
+
+        .. versionadded:: 1.4
+
+        .. seealso::
+
+            :ref:`loader_option_criteria`
+
+            :func:`.with_loader_criteria`
+
+        """
+        return self.operate(operators.and_, *criteria)
+
     def any(self, criterion=None, **kwargs):
         r"""Return true if this collection contains any member that meets the
         given criterion.
index d60c03bdcb1640225ddef50c3f95667c7e41c106..68ca0365b7892b94075e31189544d9ca775e2da3 100644 (file)
@@ -1997,6 +1997,20 @@ class Query(
                     filter(a1.email_address == 'ed@foo.com').\
                     filter(a2.email_address == 'ed@bar.com')
 
+        **Augmenting Built-in ON Clauses**
+
+        As a substitute for providing a full custom ON condition for an
+        existing relationship, the :meth:`_orm.PropComparator.and_` function
+        may be applied to a relationship attribute to augment additional
+        criteria into the ON clause; the additional criteria will be combined
+        with the default criteria using AND::
+
+            q = session.query(User).join(
+                User.addresses.and_(Address.email_address != 'foo@bar.com')
+            )
+
+        .. versionadded:: 1.4
+
         **Joining to Tables and Subqueries**
 
 
index cb490b7d7a8c67c57af2208b6d64305ff3a67621..794b9422c4dc01b77440bc89e2eb00144cf7e62a 100644 (file)
@@ -1115,9 +1115,15 @@ class RelationshipProperty(StrategizedProperty):
         """
 
         _of_type = None
+        _extra_criteria = ()
 
         def __init__(
-            self, prop, parentmapper, adapt_to_entity=None, of_type=None
+            self,
+            prop,
+            parentmapper,
+            adapt_to_entity=None,
+            of_type=None,
+            extra_criteria=(),
         ):
             """Construction of :class:`.RelationshipProperty.Comparator`
             is internal to the ORM's attribute mechanics.
@@ -1128,6 +1134,7 @@ class RelationshipProperty(StrategizedProperty):
             self._adapt_to_entity = adapt_to_entity
             if of_type:
                 self._of_type = of_type
+            self._extra_criteria = extra_criteria
 
         def adapt_to_entity(self, adapt_to_entity):
             return self.__class__(
@@ -1191,6 +1198,7 @@ class RelationshipProperty(StrategizedProperty):
                 source_polymorphic=True,
                 of_type_entity=of_type_entity,
                 alias_secondary=True,
+                extra_criteria=self._extra_criteria,
             )
             if sj is not None:
                 return pj & sj
@@ -1202,12 +1210,30 @@ class RelationshipProperty(StrategizedProperty):
 
             See :meth:`.PropComparator.of_type` for an example.
 
+
             """
             return RelationshipProperty.Comparator(
                 self.property,
                 self._parententity,
                 adapt_to_entity=self._adapt_to_entity,
                 of_type=cls,
+                extra_criteria=self._extra_criteria,
+            )
+
+        def and_(self, *other):
+            """Add AND criteria.
+
+            See :meth:`.PropComparator.and_` for an example.
+
+            .. versionadded:: 1.4
+
+            """
+            return RelationshipProperty.Comparator(
+                self.property,
+                self._parententity,
+                adapt_to_entity=self._adapt_to_entity,
+                of_type=self._of_type,
+                extra_criteria=self._extra_criteria + other,
             )
 
         def in_(self, other):
@@ -2439,6 +2465,7 @@ class RelationshipProperty(StrategizedProperty):
         dest_selectable=None,
         of_type_entity=None,
         alias_secondary=False,
+        extra_criteria=(),
     ):
 
         aliased = False
@@ -2489,7 +2516,11 @@ class RelationshipProperty(StrategizedProperty):
             target_adapter,
             dest_selectable,
         ) = self._join_condition.join_targets(
-            source_selectable, dest_selectable, aliased, single_crit
+            source_selectable,
+            dest_selectable,
+            aliased,
+            single_crit,
+            extra_criteria,
         )
         if source_selectable is None:
             source_selectable = self.parent.local_table
@@ -3427,7 +3458,12 @@ class JoinCondition(object):
         )
 
     def join_targets(
-        self, source_selectable, dest_selectable, aliased, single_crit=None
+        self,
+        source_selectable,
+        dest_selectable,
+        aliased,
+        single_crit=None,
+        extra_criteria=(),
     ):
         """Given a source and destination selectable, create a
         join between them.
@@ -3463,6 +3499,12 @@ class JoinCondition(object):
             else:
                 primaryjoin = primaryjoin & single_crit
 
+        if extra_criteria:
+            if secondaryjoin is not None:
+                secondaryjoin = secondaryjoin & sql.and_(*extra_criteria)
+            else:
+                primaryjoin = primaryjoin & sql.and_(*extra_criteria)
+
         if aliased:
             if secondary is not None:
                 secondary = secondary._anonymous_fromclause(flat=True)
index 339c57bdcede1960b5ee0917b5a111b597870f90..e9d4ac2c67a2eef7d3d6f188dd03b9f6958f31ec 100644 (file)
@@ -102,23 +102,26 @@ CLOSED = util.symbol("CLOSED")
 
 
 class ORMExecuteState(util.MemoizedSlots):
-    """Stateful object used for the :meth:`.SessionEvents.do_orm_execute`
+    """Represents a call to the :meth:`_orm.Session.execute` method, as passed
+    to the :meth:`.SessionEvents.do_orm_execute` event hook.
 
     .. versionadded:: 1.4
 
+
     """
 
     __slots__ = (
         "session",
         "statement",
         "parameters",
-        "_execution_options",
-        "_merged_execution_options",
+        "execution_options",
+        "local_execution_options",
         "bind_arguments",
         "_compile_state_cls",
         "_starting_event_idx",
         "_events_todo",
         "_future",
+        "_update_execution_options",
     )
 
     def __init__(
@@ -135,7 +138,10 @@ class ORMExecuteState(util.MemoizedSlots):
         self.session = session
         self.statement = statement
         self.parameters = parameters
-        self._execution_options = execution_options
+        self.local_execution_options = execution_options
+        self.execution_options = statement._execution_options.union(
+            execution_options
+        )
         self.bind_arguments = bind_arguments
         self._compile_state_cls = compile_state_cls
         self._events_todo = list(events_todo)
@@ -182,9 +188,8 @@ class ORMExecuteState(util.MemoizedSlots):
 
         .. seealso::
 
-            :ref:`examples_caching` - includes example use of the
-            :meth:`.SessionEvents.do_orm_execute` hook as well as the
-            :meth:`.ORMExecuteState.invoke_query` method.
+            :ref:`do_orm_execute_re_executing` - background and examples on the
+            appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`.
 
 
         """
@@ -203,11 +208,9 @@ class ORMExecuteState(util.MemoizedSlots):
         else:
             _params = self.parameters
 
+        _execution_options = self.local_execution_options
         if execution_options:
-            _execution_options = dict(self._execution_options)
-            _execution_options.update(execution_options)
-        else:
-            _execution_options = self._execution_options
+            _execution_options = _execution_options.union(execution_options)
 
         return self.session.execute(
             statement,
@@ -255,42 +258,9 @@ class ORMExecuteState(util.MemoizedSlots):
     def _is_crud(self):
         return isinstance(self.statement, (dml.Update, dml.Delete))
 
-    @property
-    def execution_options(self):
-        """Placeholder for execution options.
-
-        Raises an informative message, as there are local options
-        vs. merged options that can be viewed, via the
-        :attr:`.ORMExecuteState.local_execution_options` and
-        :attr:`.ORMExecuteState.merged_execution_options` methods.
-
-
-        """
-        raise AttributeError(
-            "Please use .local_execution_options or "
-            ".merged_execution_options"
-        )
-
-    @property
-    def local_execution_options(self):
-        """Dictionary view of the execution options passed to the
-        :meth:`.Session.execute` method.  This does not include options
-        that may be associated with the statement being invoked.
-
-        """
-        return util.immutabledict(self._execution_options)
-
-    @property
-    def merged_execution_options(self):
-        """Dictionary view of all execution options merged together;
-        this includes those of the statement as well as those passed to
-        :meth:`.Session.execute`, with the local options taking precedence.
-
-        """
-        return self._merged_execution_options
-
-    def _memoized_attr__merged_execution_options(self):
-        return self.statement._execution_options.union(self._execution_options)
+    def update_execution_options(self, **opts):
+        # TODO: no coverage
+        self.local_execution_options = self.local_execution_options.union(opts)
 
     def _orm_compile_options(self):
         opts = self.statement._compile_options
@@ -328,6 +298,20 @@ class ORMExecuteState(util.MemoizedSlots):
         else:
             return None
 
+    @property
+    def is_relationship_load(self):
+        """Return True if this load is loading objects on behalf of a
+        relationship.
+
+        This means, the loader in effect is either a LazyLoader,
+        SelectInLoader, SubqueryLoader, or similar, and the entire
+        SELECT statement being emitted is on behalf of a relationship
+        load.
+
+        """
+        path = self.loader_strategy_path
+        return path is not None and not path.is_root
+
     @property
     def load_options(self):
         """Return the load_options that will be used for this execution."""
@@ -337,7 +321,7 @@ class ORMExecuteState(util.MemoizedSlots):
                 "This ORM execution is not against a SELECT statement "
                 "so there are no load options."
             )
-        return self._execution_options.get(
+        return self.execution_options.get(
             "_sa_orm_load_options", context.QueryContext.default_load_options
         )
 
@@ -351,7 +335,7 @@ class ORMExecuteState(util.MemoizedSlots):
                 "This ORM execution is not against an UPDATE or DELETE "
                 "statement so there are no update options."
             )
-        return self._execution_options.get(
+        return self.execution_options.get(
             "_sa_orm_update_options",
             persistence.BulkUDCompileState.default_update_options,
         )
@@ -1003,8 +987,6 @@ class Session(_SessionClassMethods):
 
             :ref:`migration_20_toplevel`
 
-            :ref:`migration_20_result_rows`
-
         :param info: optional dictionary of arbitrary data to be associated
            with this :class:`.Session`.  Is available via the
            :attr:`.Session.info` attribute.  Note the dictionary is copied at
@@ -1282,7 +1264,7 @@ class Session(_SessionClassMethods):
         the operation will release the current SAVEPOINT but not commit
         the outermost database transaction.
 
-        If :term:`2.x-style` use is in effect via the
+        If :term:`2.0-style` use is in effect via the
         :paramref:`_orm.Session.future` flag, the outermost database
         transaction is committed unconditionally, automatically releasing any
         SAVEPOINTs in effect.
@@ -1416,7 +1398,7 @@ class Session(_SessionClassMethods):
         self,
         statement,
         params=None,
-        execution_options=util.immutabledict(),
+        execution_options=util.EMPTY_DICT,
         bind_arguments=None,
         future=False,
         _parent_execute_state=None,
@@ -1576,6 +1558,8 @@ class Session(_SessionClassMethods):
         else:
             compile_state_cls = None
 
+        execution_options = util.coerce_to_immutabledict(execution_options)
+
         if compile_state_cls is not None:
             (
                 statement,
@@ -1591,8 +1575,11 @@ class Session(_SessionClassMethods):
         else:
             bind_arguments.setdefault("clause", statement)
             if future:
-                execution_options = util.immutabledict().merge_with(
-                    execution_options, {"future_result": True}
+                # not sure if immutabledict is working w/ this syntax
+                # execution_options =
+                # execution_options.union(future_result=True)
+                execution_options = execution_options.union(
+                    {"future_result": True}
                 )
 
         if _parent_execute_state:
@@ -1619,6 +1606,10 @@ class Session(_SessionClassMethods):
                 if result:
                     return result
 
+            # TODO: coverage for this pattern
+            statement = orm_exec_state.statement
+            execution_options = orm_exec_state.local_execution_options
+
         bind = self.get_bind(**bind_arguments)
 
         conn = self._connection_for_bind(bind, close_with_result=True)
index 44f303feed970add24c5cc93753ff9da7d6ab38c..53166bd9184992b1df09edbd63e79696259fd775 100644 (file)
@@ -1975,6 +1975,7 @@ class JoinedLoader(AbstractRelationshipLoader):
                 clauses,
                 innerjoin,
                 chained_from_outerjoin,
+                loadopt._extra_criteria if loadopt else (),
             )
         )
 
@@ -1993,6 +1994,7 @@ class JoinedLoader(AbstractRelationshipLoader):
         clauses,
         innerjoin,
         chained_from_outerjoin,
+        extra_criteria,
     ):
         if parentmapper is None:
             localparent = query_entity.mapper
@@ -2081,6 +2083,17 @@ class JoinedLoader(AbstractRelationshipLoader):
             or query_entity.entity_zero.represents_outer_join
         )
 
+        extra_join_criteria = extra_criteria
+        additional_entity_criteria = compile_state.global_attributes.get(
+            ("additional_entity_criteria", self.mapper), ()
+        )
+        if additional_entity_criteria:
+            extra_join_criteria += tuple(
+                ae._resolve_where_criteria(self.mapper)
+                for ae in additional_entity_criteria
+                if ae.propagate_to_loaders
+            )
+
         if attach_on_outside:
             # this is the "classic" eager join case.
             eagerjoin = orm_util._ORMJoin(
@@ -2092,11 +2105,12 @@ class JoinedLoader(AbstractRelationshipLoader):
                 or (chained_from_outerjoin and isinstance(towrap, sql.Join)),
                 _left_memo=self.parent,
                 _right_memo=self.mapper,
+                _extra_criteria=extra_join_criteria,
             )
         else:
             # all other cases are innerjoin=='nested' approach
             eagerjoin = self._splice_nested_inner_join(
-                path, towrap, clauses, onclause,
+                path, towrap, clauses, onclause, extra_join_criteria
             )
 
         compile_state.eager_joins[query_entity_key] = eagerjoin
@@ -2128,7 +2142,7 @@ class JoinedLoader(AbstractRelationshipLoader):
             )
 
     def _splice_nested_inner_join(
-        self, path, join_obj, clauses, onclause, splicing=False
+        self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
     ):
 
         if splicing is False:
@@ -2137,7 +2151,12 @@ class JoinedLoader(AbstractRelationshipLoader):
             assert isinstance(join_obj, orm_util._ORMJoin)
         elif isinstance(join_obj, sql.selectable.FromGrouping):
             return self._splice_nested_inner_join(
-                path, join_obj.element, clauses, onclause, splicing,
+                path,
+                join_obj.element,
+                clauses,
+                onclause,
+                extra_criteria,
+                splicing,
             )
         elif not isinstance(join_obj, orm_util._ORMJoin):
             if path[-2] is splicing:
@@ -2148,18 +2167,29 @@ class JoinedLoader(AbstractRelationshipLoader):
                     isouter=False,
                     _left_memo=splicing,
                     _right_memo=path[-1].mapper,
+                    _extra_criteria=extra_criteria,
                 )
             else:
                 # only here if splicing == True
                 return None
 
         target_join = self._splice_nested_inner_join(
-            path, join_obj.right, clauses, onclause, join_obj._right_memo,
+            path,
+            join_obj.right,
+            clauses,
+            onclause,
+            extra_criteria,
+            join_obj._right_memo,
         )
         if target_join is None:
             right_splice = False
             target_join = self._splice_nested_inner_join(
-                path, join_obj.left, clauses, onclause, join_obj._left_memo,
+                path,
+                join_obj.left,
+                clauses,
+                onclause,
+                extra_criteria,
+                join_obj._left_memo,
             )
             if target_join is None:
                 # should only return None when recursively called,
index b405153b92d067f60787ad3cf9fba23f9ded3e79..b3913ec5bf05910a7ad857bb4959ee35519e8d26 100644 (file)
@@ -78,6 +78,7 @@ class Load(Generative, LoaderOption):
         ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key),
         ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
         ("_of_type", visitors.ExtendedInternalTraversal.dp_multi),
+        ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
         (
             "_context_cache_key",
             visitors.ExtendedInternalTraversal.dp_has_cache_key_tuples,
@@ -101,6 +102,7 @@ class Load(Generative, LoaderOption):
         load.context = {}
         load.local_opts = {}
         load._of_type = None
+        load._extra_criteria = ()
         return load
 
     @property
@@ -124,6 +126,7 @@ class Load(Generative, LoaderOption):
     strategy = None
     propagate_to_loaders = False
     _of_type = None
+    _extra_criteria = ()
 
     def process_compile_state(self, compile_state):
         if not compile_state.compile_options._enable_eagerloads:
@@ -248,6 +251,9 @@ class Load(Generative, LoaderOption):
                 else:
                     return None
 
+            if attr._extra_criteria:
+                self._extra_criteria = attr._extra_criteria
+
             if getattr(attr, "_of_type", None):
                 ac = attr._of_type
                 ext_info = of_type_info = inspect(ac)
@@ -356,6 +362,7 @@ class Load(Generative, LoaderOption):
         cloned = self._clone_for_bind_strategy(attr, strategy, "relationship")
         self.path = cloned.path
         self._of_type = cloned._of_type
+        self._extra_criteria = cloned._extra_criteria
         cloned.is_class_strategy = self.is_class_strategy = False
         self.propagate_to_loaders = cloned.propagate_to_loaders
 
@@ -413,6 +420,7 @@ class Load(Generative, LoaderOption):
             if existing:
                 if merge_opts:
                     existing.local_opts.update(self.local_opts)
+                    existing._extra_criteria += self._extra_criteria
             else:
                 path.set(context, "loader", self)
         else:
@@ -420,6 +428,7 @@ class Load(Generative, LoaderOption):
             path.set(context, "loader", self)
             if existing and existing.is_opts_only:
                 self.local_opts.update(existing.local_opts)
+                existing._extra_criteria += self._extra_criteria
 
     def _set_path_strategy(self):
         if not self.is_class_strategy and self.path.has_entity:
@@ -507,11 +516,13 @@ class _UnboundLoad(Load):
         self.path = ()
         self._to_bind = []
         self.local_opts = {}
+        self._extra_criteria = ()
 
     _cache_key_traversal = [
         ("path", visitors.ExtendedInternalTraversal.dp_multi_list),
         ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj),
         ("_to_bind", visitors.ExtendedInternalTraversal.dp_has_cache_key_list),
+        ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list),
         ("local_opts", visitors.ExtendedInternalTraversal.dp_plain_dict),
     ]
 
@@ -576,6 +587,7 @@ class _UnboundLoad(Load):
         if attr:
             path = path + (attr,)
         self.path = path
+        self._extra_criteria = getattr(attr, "_extra_criteria", ())
 
         return path
 
index 71ee295974797033dc5a2e0c0f511fd58575dcb3..82fad0815dbcac5c2dbd0ee8727429e51d7aaede 100644 (file)
@@ -23,6 +23,7 @@ from .base import object_state  # noqa
 from .base import state_attribute_str  # noqa
 from .base import state_class_str  # noqa
 from .base import state_str  # noqa
+from .interfaces import LoaderOption
 from .interfaces import MapperProperty  # noqa
 from .interfaces import ORMColumnsClauseRole
 from .interfaces import ORMEntityColumnsClauseRole
@@ -38,6 +39,7 @@ from ..engine.result import result_tuple
 from ..sql import base as sql_base
 from ..sql import coercions
 from ..sql import expression
+from ..sql import lambdas
 from ..sql import roles
 from ..sql import util as sql_util
 from ..sql import visitors
@@ -854,6 +856,184 @@ class AliasedInsp(
             return "aliased(%s)" % (self._target.__name__,)
 
 
+class LoaderCriteriaOption(LoaderOption):
+    """Add additional WHERE criteria to the load for all occurrences of
+    a particular entity.
+
+    :class:`_orm.LoaderCriteriaOption` is invoked using the
+    :func:`_orm.with_loader_criteria` function; see that function for
+    details.
+
+    .. versionadded:: 1.4
+
+    """
+
+    _traverse_internals = [
+        ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj),
+        ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key),
+        ("where_criteria", visitors.InternalTraversal.dp_clauseelement),
+        ("include_aliases", visitors.InternalTraversal.dp_boolean),
+        ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean),
+    ]
+
+    def __init__(
+        self,
+        entity_or_base,
+        where_criteria,
+        loader_only=False,
+        include_aliases=False,
+        propagate_to_loaders=True,
+    ):
+        """Add additional WHERE criteria to the load for all occurrences of
+        a particular entity.
+
+        .. versionadded:: 1.4
+
+        The :func:`_orm.with_loader_criteria` option is intended to add
+        limiting criteria to a particular kind of entity in a query,
+        **globally**, meaning it will apply to the entity as it appears
+        in the SELECT query as well as within any subqueries, join
+        conditions, and relationship loads, including both eager and lazy
+        loaders, without the need for it to be specified in any particular
+        part of the query.    The rendering logic uses the same system used by
+        single table inheritance to ensure a certain discriminator is applied
+        to a table.
+
+        E.g., using :term:`2.0-style` queries, we can limit the way the
+        ``User.addresses`` collection is loaded, regardless of the kind
+        of loading used::
+
+            from sqlalchemy.orm import with_loader_criteria
+
+            stmt = select(User).options(
+                selectinload(User.addresses),
+                with_loader_criteria(Address, Address.email_address != 'foo'))
+            )
+
+        Above, the "selectinload" for ``User.addresses`` will apply the
+        given filtering criteria to the WHERE clause.
+
+        Another example, where the filtering will be applied to the
+        ON clause of the join, in this example using :term:`1.x style`
+        queries::
+
+            q = session.query(User).outerjoin(User.addresses).options(
+                with_loader_criteria(Address, Address.email_address != 'foo'))
+            )
+
+        The primary purpose of :func:`_orm.with_loader_criteria` is to use
+        it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler
+        to ensure that all occurrences of a particular entity are filtered
+        in a certain way, such as filtering for access control roles.    It
+        also can be used to apply criteria to relationship loads.  In the
+        example below, we can apply a certain set of rules to all queries
+        emitted by a particular :class:`_orm.Session`::
+
+            session = Session(bind=engine)
+
+            @event.listens_for("do_orm_execute", session)
+            def _add_filtering_criteria(execute_state):
+                execute_state.statement = execute_state.statement.options(
+                    with_loader_criteria(
+                        SecurityRole,
+                        lambda cls: cls.role.in_(['some_role']),
+                        include_aliases=True
+                    )
+                )
+
+        The given class will expand to include all mapped subclass and
+        need not itself be a mapped class.
+
+
+        :param entity_or_base: a mapped class, or a class that is a super
+         class of a particular set of mapped classes, to which the rule
+         will apply.
+
+        :param where_criteria: a Core SQL expression that applies limiting
+         criteria.   This may also be a "lambda:" or Python function that
+         accepts a target class as an argument, when the given class is
+         a base with many different mapped subclasses.
+
+        :param include_aliases: if True, apply the rule to :func:`_orm.aliased`
+         constructs as well.
+
+        :param propagate_to_loaders: defaults to True, apply to relationship
+         loaders such as lazy loaders.
+
+
+        .. seealso::
+
+            :ref:`examples_session_orm_events` - includes examples of using
+            :func:`_orm.with_loader_criteria`.
+
+            :ref:`do_orm_execute_global_criteria` - basic example on how to
+            combine :func:`_orm.with_loader_criteria` with the
+            :meth:`_orm.SessionEvents.do_orm_execute` event.
+
+        """
+        entity = inspection.inspect(entity_or_base, False)
+        if entity is None:
+            self.root_entity = entity_or_base
+            self.entity = None
+        else:
+            self.root_entity = None
+            self.entity = entity
+
+        if callable(where_criteria):
+            self.deferred_where_criteria = True
+            self.where_criteria = lambdas.DeferredLambdaElement(
+                where_criteria,
+                roles.WhereHavingRole,
+                lambda_args=(
+                    self.root_entity
+                    if self.root_entity is not None
+                    else self.entity.entity,
+                ),
+            )
+        else:
+            self.deferred_where_criteria = False
+            self.where_criteria = coercions.expect(
+                roles.WhereHavingRole, where_criteria
+            )
+
+        self.include_aliases = include_aliases
+        self.propagate_to_loaders = propagate_to_loaders
+
+    def _all_mappers(self):
+        if self.entity:
+            for ent in self.entity.mapper.self_and_descendants:
+                yield ent
+        else:
+            stack = list(self.root_entity.__subclasses__())
+            while stack:
+                subclass = stack.pop(0)
+                ent = inspection.inspect(subclass)
+                if ent:
+                    for mp in ent.mapper.self_and_descendants:
+                        yield mp
+                else:
+                    stack.extend(subclass.__subclasses__())
+
+    def _resolve_where_criteria(self, ext_info):
+        if self.deferred_where_criteria:
+            return self.where_criteria._resolve_with_args(ext_info.entity)
+        else:
+            return self.where_criteria
+
+    def process_compile_state(self, compile_state):
+        """Apply a modification to a given :class:`.CompileState`."""
+
+        # if options to limit the criteria to immediate query only,
+        # use compile_state.attributes instead
+
+        for mp in self._all_mappers():
+            load_criteria = compile_state.global_attributes.setdefault(
+                ("additional_entity_criteria", mp), []
+            )
+
+            load_criteria.append(self)
+
+
 inspection._inspects(AliasedClass)(lambda target: target._aliased_insp)
 inspection._inspects(AliasedInsp)(lambda target: target)
 
@@ -1270,6 +1450,7 @@ class _ORMJoin(expression.Join):
         full=False,
         _left_memo=None,
         _right_memo=None,
+        _extra_criteria=(),
     ):
         left_info = inspection.inspect(left)
 
@@ -1291,6 +1472,7 @@ class _ORMJoin(expression.Join):
         if isinstance(onclause, attributes.QueryableAttribute):
             on_selectable = onclause.comparator._source_selectable()
             prop = onclause.property
+            _extra_criteria += onclause._extra_criteria
         elif isinstance(onclause, MapperProperty):
             # used internally by joined eager loader...possibly not ideal
             prop = onclause
@@ -1319,6 +1501,7 @@ class _ORMJoin(expression.Join):
                 source_polymorphic=True,
                 of_type_entity=right_info,
                 alias_secondary=True,
+                extra_criteria=_extra_criteria,
             )
 
             if sj is not None:
@@ -1331,6 +1514,7 @@ class _ORMJoin(expression.Join):
                     onclause = sj
             else:
                 onclause = pj
+
             self._target_adapter = target_adapter
 
         expression.Join.__init__(self, left, right, onclause, isouter, full)
index ac4055bdf17ecc8c082d47fb486ae6fbacb670a0..b8984316c65bb92e34ee220988ec34e3982fe336 100644 (file)
@@ -792,6 +792,10 @@ class SQLCompiler(Compiled):
     def prefetch(self):
         return list(self.insert_prefetch + self.update_prefetch)
 
+    @util.memoized_property
+    def _global_attributes(self):
+        return {}
+
     @util.memoized_instancemethod
     def _init_cte_state(self):
         """Initialize collections related to CTEs only if
index 6279dcf55e1e6684a4e4c214828b75adf480f30a..c8e83bbd74ee77380cb54ae3c2b748b0f1bce102 100644 (file)
@@ -1017,8 +1017,8 @@ class CustomIntegrationTest(testing.AssertsCompiledSQL, BakedTest):
                 if ckey:
                     break
             else:
-                if "_cache_key" in orm_context.merged_execution_options:
-                    ckey = orm_context.merged_execution_options["_cache_key"]
+                if "_cache_key" in orm_context.execution_options:
+                    ckey = orm_context.execution_options["_cache_key"]
 
             if ckey is not None:
                 return get_value(
index e33e95cc0339c8552c9ed41fb8e905599902848e..86e0bd360bbff87a33c27cba197c2869a4fd1bba 100644 (file)
@@ -1302,6 +1302,28 @@ class _PolymorphicTestBase(object):
             [e1, e3],
         )
 
+    def test_join_and_thru_polymorphic_nonaliased_one(self):
+        sess = create_session()
+        eq_(
+            sess.query(Company)
+            .join(Company.employees)
+            .join(Person.paperwork.and_(Paperwork.description.like("%#2%")))
+            .all(),
+            [c1],
+        )
+
+    def test_join_and_thru_polymorphic_aliased_one(self):
+        sess = create_session()
+        ea = aliased(Person)
+        pa = aliased(Paperwork)
+        eq_(
+            sess.query(Company)
+            .join(ea, Company.employees)
+            .join(pa, ea.paperwork.and_(pa.description.like("%#2%")))
+            .all(),
+            [c1],
+        )
+
     def test_join_through_polymorphic_nonaliased_one(self):
         sess = create_session()
         eq_(
index f4af840946441647faa01480cf4ff62130cba109..9d1d0b61b7d48a957ca1b2ef6e41647c63c8b950 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
+from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import mapper
@@ -186,6 +187,35 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
             ],
         )
 
+    def test_multi_bundle_future(self):
+        Data = self.classes.Data
+        Other = self.classes.Other
+
+        d1 = aliased(Data)
+
+        b1 = Bundle("b1", d1.d1, d1.d2)
+        b2 = Bundle("b2", Data.d1, Other.o1)
+
+        sess = Session(testing.db, future=True)
+
+        stmt = (
+            select(b1, b2)
+            .join(Data.others)
+            .join(d1, d1.id == Data.id)
+            .filter(b1.c.d1 == "d3d1")
+        )
+
+        eq_(
+            sess.execute(stmt).all(),
+            [
+                (("d3d1", "d3d2"), ("d3d1", "d3o0")),
+                (("d3d1", "d3d2"), ("d3d1", "d3o1")),
+                (("d3d1", "d3d2"), ("d3d1", "d3o2")),
+                (("d3d1", "d3d2"), ("d3d1", "d3o3")),
+                (("d3d1", "d3d2"), ("d3d1", "d3o4")),
+            ],
+        )
+
     def test_single_entity(self):
         Data = self.classes.Data
         sess = Session()
@@ -197,6 +227,18 @@ class BundleTest(fixtures.MappedTest, AssertsCompiledSQL):
             [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
         )
 
+    def test_single_entity_future(self):
+        Data = self.classes.Data
+        sess = Session(testing.db, future=True)
+
+        b1 = Bundle("b1", Data.d1, Data.d2, single_entity=True)
+
+        stmt = select(b1).filter(b1.c.d1.between("d3d1", "d5d1"))
+        eq_(
+            sess.execute(stmt).scalars().all(),
+            [("d3d1", "d3d2"), ("d4d1", "d4d2"), ("d5d1", "d5d2")],
+        )
+
     def test_single_entity_flag_but_multi_entities(self):
         Data = self.classes.Data
         sess = Session()
index 02b1b9fbf8fb2fc47928c2ba583df7cf16be61f3..45a60a5cb92ff6848b69ab66bd475c4e6609e269 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import subqueryload
+from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.sql.base import CacheableOptions
 from sqlalchemy.sql.visitors import InternalTraversal
@@ -65,6 +66,62 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
             compare_values=True,
         )
 
+    def test_loader_criteria(self):
+        User, Address = self.classes("User", "Address")
+
+        from sqlalchemy import Column, Integer, String
+
+        class Foo(object):
+            id = Column(Integer)
+            name = Column(String)
+
+        self._run_cache_key_fixture(
+            lambda: (
+                with_loader_criteria(User, User.name != "somename"),
+                with_loader_criteria(User, User.id != 5),
+                with_loader_criteria(User, lambda cls: cls.id == 10),
+                with_loader_criteria(Address, Address.id != 5),
+                with_loader_criteria(Foo, lambda cls: cls.id == 10),
+            ),
+            compare_values=True,
+        )
+
+    def test_loader_criteria_bound_param_thing(self):
+        from sqlalchemy import Column, Integer
+
+        class Foo(object):
+            id = Column(Integer)
+
+        def go(param):
+            return with_loader_criteria(Foo, lambda cls: cls.id == param)
+
+        g1 = go(10)
+        g2 = go(20)
+
+        ck1 = g1._generate_cache_key()
+        ck2 = g2._generate_cache_key()
+
+        eq_(ck1.key, ck2.key)
+        eq_(ck1.bindparams[0].key, ck2.bindparams[0].key)
+        eq_(ck1.bindparams[0].value, 10)
+        eq_(ck2.bindparams[0].value, 20)
+
+    def test_instrumented_attributes(self):
+        User, Address, Keyword, Order, Item = self.classes(
+            "User", "Address", "Keyword", "Order", "Item"
+        )
+
+        self._run_cache_key_fixture(
+            lambda: (
+                User.addresses,
+                User.addresses.of_type(aliased(Address)),
+                User.orders,
+                User.orders.and_(Order.id != 5),
+                User.orders.and_(Order.description != "somename"),
+            ),
+            compare_values=True,
+        )
+
     def test_unbound_options(self):
         User, Address, Keyword, Order, Item = self.classes(
             "User", "Address", "Keyword", "Order", "Item"
@@ -75,6 +132,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 joinedload(User.addresses),
                 joinedload(User.addresses.of_type(aliased(Address))),
                 joinedload("addresses"),
+                joinedload(User.orders),
+                joinedload(User.orders.and_(Order.id != 5)),
+                joinedload(User.orders.and_(Order.id == 5)),
+                joinedload(User.orders.and_(Order.description != "somename")),
                 joinedload(User.orders).selectinload("items"),
                 joinedload(User.orders).selectinload(Order.items),
                 defer(User.id),
@@ -110,6 +171,10 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                     User.addresses.of_type(aliased(Address))
                 ),
                 Load(User).joinedload(User.orders),
+                Load(User).joinedload(User.orders.and_(Order.id != 5)),
+                Load(User).joinedload(
+                    User.orders.and_(Order.description != "somename")
+                ),
                 Load(User).defer(User.id),
                 Load(User).subqueryload("addresses"),
                 Load(Address).defer("id"),
@@ -169,6 +234,9 @@ class CacheKeyTest(CacheKeyFixture, _fixtures.FixtureTest):
                 select(User).join(Address, User.addresses),
                 select(User).join(a1, User.addresses),
                 select(User).join(User.addresses.of_type(a1)),
+                select(User).join(
+                    User.addresses.and_(Address.email_address == "foo")
+                ),
                 select(User)
                 .join(Address, User.addresses)
                 .join_from(User, Order),
index b68e0d2e652232f638e42bdadcf17e4f96e10b0e..df48cfe63c5552ddeaec8cdddd0e4dafb668c75b 100644 (file)
@@ -2,6 +2,8 @@ import sqlalchemy as sa
 from sqlalchemy import event
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.ext.declarative import declarative_base
@@ -47,6 +49,170 @@ class _RemoveListeners(object):
         super(_RemoveListeners, self).teardown()
 
 
+class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
+    run_setup_mappers = "once"
+    run_inserts = "once"
+    run_deletes = None
+
+    @classmethod
+    def setup_mappers(cls):
+        cls._setup_stock_mapping()
+
+    def _caching_session_fixture(self):
+
+        cache = {}
+
+        maker = sessionmaker(testing.db, future=True)
+
+        def get_value(cache_key, cache, createfunc):
+            if cache_key in cache:
+                return cache[cache_key]()
+            else:
+                cache[cache_key] = retval = createfunc().freeze()
+                return retval()
+
+        @event.listens_for(maker, "do_orm_execute", retval=True)
+        def do_orm_execute(orm_context):
+            ckey = None
+            for opt in orm_context.user_defined_options:
+                ckey = opt.get_cache_key(orm_context)
+                if ckey:
+                    break
+            else:
+                if "cache_key" in orm_context.execution_options:
+                    ckey = orm_context.execution_options["cache_key"]
+
+            if ckey is not None:
+                return get_value(ckey, cache, orm_context.invoke_statement,)
+
+        return maker()
+
+    def test_cache_option(self):
+        User, Address = self.classes("User", "Address")
+
+        with self.sql_execution_asserter(testing.db) as asserter:
+
+            with self._caching_session_fixture() as session:
+                stmt = (
+                    select(User)
+                    .where(User.id == 7)
+                    .execution_options(cache_key="user7")
+                )
+
+                result = session.execute(stmt)
+
+                eq_(
+                    result.scalars().all(),
+                    [User(id=7, addresses=[Address(id=1)])],
+                )
+
+                result = session.execute(stmt)
+
+                eq_(
+                    result.scalars().all(),
+                    [User(id=7, addresses=[Address(id=1)])],
+                )
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name FROM users "
+                "WHERE users.id = :id_1",
+                [{"id_1": 7}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, addresses.user_id AS "
+                "addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "ORDER BY addresses.id",
+                [{"param_1": 7}],
+            ),
+        )
+
+    def test_chained_events_one(self):
+
+        sess = Session(testing.db, future=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def one(ctx):
+            ctx.update_execution_options(one=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def two(ctx):
+            ctx.update_execution_options(two=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def three(ctx):
+            ctx.update_execution_options(three=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def four(ctx):
+            ctx.update_execution_options(four=True)
+
+        result = sess.execute(select(literal_column("1")))
+
+        eq_(
+            result.context.execution_options,
+            {
+                "four": True,
+                "future_result": True,
+                "one": True,
+                "three": True,
+                "two": True,
+            },
+        )
+
+    def test_chained_events_two(self):
+
+        sess = Session(testing.db, future=True)
+
+        def added(ctx):
+            ctx.update_execution_options(added_evt=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def one(ctx):
+            ctx.update_execution_options(one=True)
+
+        @event.listens_for(sess, "do_orm_execute", retval=True)
+        def two(ctx):
+            ctx.update_execution_options(two=True)
+            return ctx.invoke_statement(
+                statement=ctx.statement.execution_options(statement_two=True)
+            )
+
+        @event.listens_for(sess, "do_orm_execute")
+        def three(ctx):
+            ctx.update_execution_options(three=True)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def four(ctx):
+            ctx.update_execution_options(four=True)
+            return ctx.invoke_statement(
+                statement=ctx.statement.execution_options(statement_four=True)
+            )
+
+        @event.listens_for(sess, "do_orm_execute")
+        def five(ctx):
+            ctx.update_execution_options(five=True)
+
+        result = sess.execute(select(literal_column("1")), _add_event=added)
+
+        eq_(
+            result.context.execution_options,
+            {
+                "statement_two": True,
+                "statement_four": True,
+                "future_result": True,
+                "one": True,
+                "two": True,
+                "three": True,
+                "four": True,
+                "five": True,
+                "added_evt": True,
+            },
+        )
+
+
 class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
     run_inserts = None
 
index 208db9d85dba9e10fee36e3dfbd1543dab7540d7..b5a6e3b29103c7c1cf2c288fa97e861eee8e2e61 100644 (file)
@@ -1391,6 +1391,7 @@ class PickleTest(PathTest, QueryTest):
                 "propagate_to_loaders": True,
                 "_of_type": None,
                 "_to_bind": to_bind,
+                "_extra_criteria": (),
             },
         )
 
diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py
new file mode 100644 (file)
index 0000000..c4bcf04
--- /dev/null
@@ -0,0 +1,867 @@
+import datetime
+import random
+
+from sqlalchemy import Column
+from sqlalchemy import DateTime
+from sqlalchemy import event
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import orm
+from sqlalchemy import select
+from sqlalchemy import sql
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import with_loader_criteria
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing.assertsql import CompiledSQL
+from test.orm import _fixtures
+
+
+class _Fixtures(_fixtures.FixtureTest):
+    @testing.fixture
+    def user_address_fixture(self):
+        users, Address, addresses, User = (
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        mapper(
+            User,
+            users,
+            properties={
+                "addresses": relationship(
+                    mapper(Address, addresses), order_by=Address.id
+                )
+            },
+        )
+        return User, Address
+
+    @testing.fixture
+    def order_item_fixture(self):
+        Order, Item = self.classes("Order", "Item")
+        orders, items, order_items = self.tables(
+            "orders", "items", "order_items"
+        )
+
+        mapper(
+            Order,
+            orders,
+            properties={
+                # m2m
+                "items": relationship(
+                    Item, secondary=order_items, order_by=items.c.id
+                ),
+            },
+        )
+        mapper(Item, items)
+
+        return Order, Item
+
+    @testing.fixture
+    def mixin_fixture(self):
+        users = self.tables.users
+
+        class HasFoob(object):
+            name = Column(String)
+
+        class UserWFoob(HasFoob, self.Comparable):
+            pass
+
+        mapper(
+            UserWFoob, users,
+        )
+        return HasFoob, UserWFoob
+
+
+class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
+    """
+    combinations:
+
+
+        with_loader_criteria
+            # for these we have mapper_criteria
+
+            select(mapper)  # select_mapper
+            select(mapper.col, mapper.col)  # select_mapper_col
+            select(func.count()).select_from(mapper)  # select_from_mapper
+            select(a).join(mapper, a.target)  # select_join_mapper
+            select(a).options(joinedload(a.target))  # select_joinedload_mapper
+
+
+            # for these we have aliased_criteria, inclaliased_criteria
+
+            select(aliased)  # select_aliased
+            select(aliased.col, aliased.col)  # select_aliased_col
+            select(func.count()).select_from(aliased) # select_from_aliased
+            select(a).join(aliased, a.target)  # select_join_aliased
+            select(a).options(joinedload(a.target.of_type(aliased))
+            # select_joinedload_aliased
+
+    """
+
+    __dialect__ = "default"
+
+    def test_select_mapper_mapper_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = select(User).options(
+            with_loader_criteria(User, User.name != "name")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name "
+            "FROM users WHERE users.name != :name_1",
+        )
+
+    def test_select_from_mapper_mapper_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = (
+            select(sql.func.count())
+            .select_from(User)
+            .options(with_loader_criteria(User, User.name != "name"))
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM users "
+            "WHERE users.name != :name_1",
+        )
+
+    def test_select_mapper_columns_mapper_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = select(User.id, User.name).options(
+            with_loader_criteria(User, User.name != "name")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name "
+            "FROM users WHERE users.name != :name_1",
+        )
+
+    def test_select_join_mapper_mapper_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = (
+            select(User)
+            .join(User.addresses)
+            .options(
+                with_loader_criteria(Address, Address.email_address != "name")
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users "
+            "JOIN addresses ON users.id = addresses.user_id "
+            "AND addresses.email_address != :email_address_1",
+        )
+
+    def test_select_joinm2m_mapper_mapper_criteria(self, order_item_fixture):
+        Order, Item = order_item_fixture
+
+        stmt = (
+            select(Order)
+            .join(Order.items)
+            .options(
+                with_loader_criteria(Item, Item.description != "description")
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT orders.id, orders.user_id, orders.address_id, "
+            "orders.description, orders.isopen FROM orders "
+            "JOIN order_items AS order_items_1 "
+            "ON orders.id = order_items_1.order_id "
+            "JOIN items ON items.id = order_items_1.item_id "
+            "AND items.description != :description_1",
+        )
+
+    def test_select_joinedload_mapper_mapper_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        stmt = select(User).options(
+            joinedload(User.addresses),
+            with_loader_criteria(Address, Address.email_address != "name"),
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name, addresses_1.id AS id_1, "
+            "addresses_1.user_id, addresses_1.email_address "
+            "FROM users LEFT OUTER JOIN addresses AS addresses_1 "
+            "ON users.id = addresses_1.user_id "
+            "AND addresses_1.email_address != :email_address_1 "
+            "ORDER BY addresses_1.id",
+        )
+
+    def test_select_selectinload_mapper_mapper_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        stmt = select(User).options(
+            selectinload(User.addresses),
+            with_loader_criteria(Address, Address.email_address != "name"),
+        )
+
+        s = Session(testing.db, future=True)
+
+        with self.sql_execution_asserter() as asserter:
+
+            s.execute(stmt).all()
+
+        asserter.assert_(
+            CompiledSQL("SELECT users.id, users.name FROM users", [],),
+            CompiledSQL(
+                "SELECT addresses.user_id AS addresses_user_id, addresses.id "
+                "AS addresses_id, addresses.email_address "
+                "AS addresses_email_address FROM addresses "
+                "WHERE addresses.user_id IN ([POSTCOMPILE_primary_keys]) "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"primary_keys": [7, 8, 9, 10], "email_address_1": "name"}],
+            ),
+        )
+
+    def test_select_lazyload_mapper_mapper_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        stmt = (
+            select(User)
+            .options(
+                with_loader_criteria(Address, Address.email_address != "name"),
+            )
+            .order_by(User.id)
+        )
+
+        s = Session(testing.db, future=True)
+
+        with self.sql_execution_asserter() as asserter:
+            for u in s.execute(stmt).scalars():
+                u.addresses
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name FROM users ORDER BY users.id", [],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 7, "email_address_1": "name"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 8, "email_address_1": "name"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 9, "email_address_1": "name"}],
+            ),
+            CompiledSQL(
+                "SELECT addresses.id AS addresses_id, "
+                "addresses.user_id AS addresses_user_id, "
+                "addresses.email_address AS addresses_email_address "
+                "FROM addresses WHERE :param_1 = addresses.user_id "
+                "AND addresses.email_address != :email_address_1 "
+                "ORDER BY addresses.id",
+                [{"param_1": 10, "email_address_1": "name"}],
+            ),
+        )
+
+    def test_select_aliased_inclaliased_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        u1 = aliased(User)
+        stmt = select(u1).options(
+            with_loader_criteria(
+                User, User.name != "name", include_aliases=True
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users_1.id, users_1.name "
+            "FROM users AS users_1 WHERE users_1.name != :name_1",
+        )
+
+    def test_select_from_aliased_inclaliased_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        u1 = aliased(User)
+        stmt = (
+            select(sql.func.count())
+            .select_from(u1)
+            .options(
+                with_loader_criteria(
+                    User, User.name != "name", include_aliases=True
+                )
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM users AS users_1 "
+            "WHERE users_1.name != :name_1",
+        )
+
+    def test_select_aliased_columns_inclaliased_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        u1 = aliased(User)
+        stmt = select(u1.id, u1.name).options(
+            with_loader_criteria(
+                User, User.name != "name", include_aliases=True
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users_1.id, users_1.name "
+            "FROM users AS users_1 WHERE users_1.name != :name_1",
+        )
+
+    def test_select_join_aliased_inclaliased_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        a1 = aliased(Address)
+        stmt = (
+            select(User)
+            .join(User.addresses.of_type(a1))
+            .options(
+                with_loader_criteria(
+                    Address,
+                    Address.email_address != "name",
+                    include_aliases=True,
+                )
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users "
+            "JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id "
+            "AND addresses_1.email_address != :email_address_1",
+        )
+
+    def test_select_joinm2m_aliased_inclaliased_criteria(
+        self, order_item_fixture
+    ):
+        Order, Item = order_item_fixture
+
+        i1 = aliased(Item)
+
+        stmt = (
+            select(Order)
+            .join(Order.items.of_type(i1))
+            .options(
+                with_loader_criteria(
+                    Item,
+                    Item.description != "description",
+                    include_aliases=True,
+                )
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT orders.id, orders.user_id, orders.address_id, "
+            "orders.description, orders.isopen FROM orders "
+            "JOIN order_items AS order_items_1 "
+            "ON orders.id = order_items_1.order_id "
+            "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
+            "AND items_1.description != :description_1",
+        )
+
+    def test_select_aliased_aliased_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        u1 = aliased(User)
+        stmt = select(u1).options(with_loader_criteria(u1, u1.name != "name"))
+
+        self.assert_compile(
+            stmt,
+            "SELECT users_1.id, users_1.name "
+            "FROM users AS users_1 WHERE users_1.name != :name_1",
+        )
+
+    def test_select_aliased_columns_aliased_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        u1 = aliased(User)
+        stmt = select(u1.id, u1.name).options(
+            with_loader_criteria(u1, u1.name != "name")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users_1.id, users_1.name "
+            "FROM users AS users_1 WHERE users_1.name != :name_1",
+        )
+
+    def test_joinedload_global_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        stmt = select(User).options(
+            joinedload(User.addresses),
+            with_loader_criteria(Address, Address.email_address != "email"),
+        )
+
+        with self.sql_execution_asserter() as asserter:
+
+            s.execute(stmt)
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name, addresses_1.id AS id_1, "
+                "addresses_1.user_id, addresses_1.email_address FROM "
+                "users LEFT OUTER JOIN addresses AS addresses_1 "
+                "ON users.id = addresses_1.user_id "
+                "AND addresses_1.email_address != :email_address_1 "
+                "ORDER BY addresses_1.id",
+                [{"email_address_1": "email"}],
+            ),
+        )
+
+    def test_query_count_global_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db)
+
+        q = s.query(User).options(with_loader_criteria(User, User.id != 8))
+
+        with self.sql_execution_asserter() as asserter:
+            q.count()
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT count(*) AS count_1 FROM (SELECT "
+                "users.id AS users_id, users.name AS users_name "
+                "FROM users WHERE users.id != :id_1) AS anon_1",
+                [{"id_1": 8}],
+            ),
+        )
+
+    def test_query_count_after_the_fact_global_criteria(
+        self, user_address_fixture
+    ):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db)
+
+        # this essentially tests that the query.from_self() which takes
+        # place in count() is one that can still be affected by
+        # the loader criteria, meaning it has to be an ORM query
+
+        q = s.query(User)
+
+        @event.listens_for(s, "do_orm_execute")
+        def add_criteria(orm_context):
+            orm_context.statement = orm_context.statement.options(
+                with_loader_criteria(User, User.id != 8)
+            )
+
+        with self.sql_execution_asserter() as asserter:
+            q.count()
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT count(*) AS count_1 FROM (SELECT "
+                "users.id AS users_id, users.name AS users_name "
+                "FROM users WHERE users.id != :id_1) AS anon_1",
+                [{"id_1": 8}],
+            ),
+        )
+
+    def test_select_count_subquery_global_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = select(User).subquery()
+
+        stmt = (
+            select(sql.func.count())
+            .select_from(stmt)
+            .options(with_loader_criteria(User, User.id != 8))
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT count(*) AS count_1 FROM (SELECT users.id AS id, "
+            "users.name AS name FROM users WHERE users.id != :id_1) AS anon_1",
+        )
+
+    def test_query_outerjoin_global_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db)
+
+        q = (
+            s.query(User, Address)
+            .outerjoin(User.addresses)
+            .options(
+                with_loader_criteria(
+                    Address, ~Address.email_address.like("ed@%"),
+                )
+            )
+            .order_by(User.id)
+        )
+
+        self.assert_compile(
+            q,
+            "SELECT users.id AS users_id, users.name AS users_name, "
+            "addresses.id AS addresses_id, "
+            "addresses.user_id AS addresses_user_id, "
+            "addresses.email_address AS addresses_email_address "
+            "FROM users LEFT OUTER JOIN addresses "
+            "ON users.id = addresses.user_id AND "
+            "addresses.email_address NOT LIKE :email_address_1 "
+            "ORDER BY users.id",
+        )
+        eq_(
+            q.all(),
+            [
+                (User(id=7), Address(id=1)),
+                (User(id=8), None),  # three addresses not here
+                (User(id=9), Address(id=5)),
+                (User(id=10), None),
+            ],
+        )
+
+    def test_caching_and_binds_lambda(self, mixin_fixture):
+        HasFoob, UserWFoob = mixin_fixture
+
+        statement = select(UserWFoob).filter(UserWFoob.id < 10)
+
+        def go(value):
+            return statement.options(
+                with_loader_criteria(
+                    HasFoob,
+                    lambda cls: cls.name == value,
+                    include_aliases=True,
+                )
+            )
+
+        s = Session(testing.db, future=True)
+
+        for i in range(10):
+            name = random.choice(["ed", "fred", "jack"])
+            stmt = go(name)
+
+            eq_(s.execute(stmt).scalars().all(), [UserWFoob(name=name)])
+
+
+class TemporalFixtureTest(testing.fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        class HasTemporal(object):
+            """Mixin that identifies a class as having a timestamp column"""
+
+            timestamp = Column(
+                DateTime, default=datetime.datetime.utcnow, nullable=False
+            )
+
+        cls.HasTemporal = HasTemporal
+
+        def temporal_range(range_lower, range_upper):
+            return with_loader_criteria(
+                HasTemporal,
+                lambda cls: cls.timestamp.between(range_lower, range_upper),
+                include_aliases=True,
+            )
+
+        cls.temporal_range = staticmethod(temporal_range)
+
+        class Parent(HasTemporal, cls.DeclarativeBasic):
+            __tablename__ = "parent"
+            id = Column(Integer, primary_key=True)
+            children = relationship("Child", order_by="Child.id")
+
+        class Child(HasTemporal, cls.DeclarativeBasic):
+            __tablename__ = "child"
+            id = Column(Integer, primary_key=True)
+            parent_id = Column(
+                Integer, ForeignKey("parent.id"), nullable=False
+            )
+
+    @classmethod
+    def insert_data(cls, connection):
+        Parent, Child = cls.classes("Parent", "Child")
+
+        sess = Session(connection)
+        c1, c2, c3, c4, c5 = [
+            Child(timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00)),
+            Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
+            Child(timestamp=datetime.datetime(2009, 10, 20, 12, 00, 00)),
+            Child(timestamp=datetime.datetime(2009, 10, 12, 12, 00, 00)),
+            Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
+        ]
+
+        p1 = Parent(
+            timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00),
+            children=[c1, c2, c3],
+        )
+        p2 = Parent(
+            timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00),
+            children=[c4, c5],
+        )
+
+        sess.add_all([p1, p2])
+        sess.commit()
+
+    @testing.combinations((True,), (False,), argnames="use_caching")
+    @testing.combinations(
+        (None,),
+        (orm.lazyload,),
+        (orm.joinedload,),
+        (orm.subqueryload,),
+        (orm.selectinload,),
+        argnames="loader_strategy",
+    )
+    def test_same_relatinship_load_different_range(
+        self, use_caching, loader_strategy
+    ):
+        """This is the first test that exercises lazy loading, which uses
+        a lambda select, which then needs to transform the select to have
+        different bound parameters if it's not cached (or generate a working
+        list of parameters if it is), which then calls into a
+        with_loader_crieria that itself has another lambda inside of it,
+        which means we have to traverse and replace that lambda's expression,
+        but we can't evaluate it until compile time, so the inner lambda
+        holds onto the "transform" function so it can run it as needed.
+        this makes use of a new feature in visitors that exports a
+        "run this traversal later" function.
+
+        All of these individual features, cloning lambdaelements,
+        running replacement traversals later, are very new and need a lot
+        of tests, most likely in test/sql/test_lambdas.py.
+
+        the test is from the "temporal_range" example which is the whole
+        use case this feature is designed for and it is a whopper.
+
+
+        """
+        Parent, Child = self.classes("Parent", "Child")
+        temporal_range = self.temporal_range
+
+        if use_caching:
+            Parent.children.property.bake_queries = True
+            eng = testing.db
+        else:
+            Parent.children.property.bake_queries = False
+            eng = testing.db.execution_options(compiled_cache=None)
+
+        sess = Session(eng, future=True)
+
+        if loader_strategy:
+            loader_options = (loader_strategy(Parent.children),)
+        else:
+            loader_options = ()
+
+        p1 = sess.execute(
+            select(Parent).filter(
+                Parent.timestamp == datetime.datetime(2009, 10, 15, 12, 00, 00)
+            )
+        ).scalar()
+        c1, c2 = p1.children[0:2]
+        c2_id = c2.id
+
+        p2 = sess.execute(
+            select(Parent).filter(
+                Parent.timestamp == datetime.datetime(2009, 10, 17, 12, 00, 00)
+            )
+        ).scalar()
+        c5 = p2.children[1]
+
+        parents = (
+            sess.execute(
+                select(Parent)
+                .execution_options(populate_existing=True)
+                .options(
+                    temporal_range(
+                        datetime.datetime(2009, 10, 16, 12, 00, 00),
+                        datetime.datetime(2009, 10, 18, 12, 00, 00),
+                    ),
+                    *loader_options
+                )
+            )
+            .scalars()
+            .all()
+        )
+
+        assert parents[0] == p2
+        assert parents[0].children == [c5]
+
+        parents = (
+            sess.execute(
+                select(Parent)
+                .execution_options(populate_existing=True)
+                .join(Parent.children)
+                .filter(Child.id == c2_id)
+                .options(
+                    temporal_range(
+                        datetime.datetime(2009, 10, 15, 11, 00, 00),
+                        datetime.datetime(2009, 10, 18, 12, 00, 00),
+                    ),
+                    *loader_options
+                )
+            )
+            .scalars()
+            .all()
+        )
+
+        assert parents[0] == p1
+        assert parents[0].children == [c1, c2]
+
+
+class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    @testing.fixture
+    def user_address_fixture(self):
+        users, Address, addresses, User = (
+            self.tables.users,
+            self.classes.Address,
+            self.tables.addresses,
+            self.classes.User,
+        )
+
+        mapper(
+            User,
+            users,
+            properties={
+                "addresses": relationship(
+                    mapper(Address, addresses), order_by=Address.id
+                )
+            },
+        )
+        return User, Address
+
+    def test_joinedload_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        stmt = select(User).options(
+            joinedload(User.addresses.and_(Address.email_address != "email")),
+        )
+
+        with self.sql_execution_asserter() as asserter:
+
+            s.execute(stmt)
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT users.id, users.name, addresses_1.id AS id_1, "
+                "addresses_1.user_id, addresses_1.email_address FROM "
+                "users LEFT OUTER JOIN addresses AS addresses_1 "
+                "ON users.id = addresses_1.user_id "
+                "AND addresses_1.email_address != :email_address_1 "
+                "ORDER BY addresses_1.id",
+                [{"email_address_1": "email"}],
+            ),
+        )
+
+    def test_query_join_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        s = Session(testing.db)
+
+        q = s.query(User).join(
+            User.addresses.and_(Address.email_address != "email")
+        )
+
+        self.assert_compile(
+            q,
+            "SELECT users.id AS users_id, users.name AS users_name "
+            "FROM users JOIN addresses ON users.id = addresses.user_id "
+            "AND addresses.email_address != :email_address_1",
+        )
+
+    def test_select_join_local_criteria(self, user_address_fixture):
+        User, Address = user_address_fixture
+
+        stmt = select(User).join(
+            User.addresses.and_(Address.email_address != "email")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users JOIN addresses "
+            "ON users.id = addresses.user_id "
+            "AND addresses.email_address != :email_address_1",
+        )
+
+    def test_select_joinm2m_local_criteria(self, order_item_fixture):
+        Order, Item = order_item_fixture
+
+        stmt = select(Order).join(
+            Order.items.and_(Item.description != "description")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT orders.id, orders.user_id, orders.address_id, "
+            "orders.description, orders.isopen "
+            "FROM orders JOIN order_items AS order_items_1 "
+            "ON orders.id = order_items_1.order_id "
+            "JOIN items ON items.id = order_items_1.item_id "
+            "AND items.description != :description_1",
+        )
+
+    def test_select_joinm2m_aliased_local_criteria(self, order_item_fixture):
+        Order, Item = order_item_fixture
+
+        i1 = aliased(Item)
+        stmt = select(Order).join(
+            Order.items.of_type(i1).and_(i1.description != "description")
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT orders.id, orders.user_id, orders.address_id, "
+            "orders.description, orders.isopen "
+            "FROM orders JOIN order_items AS order_items_1 "
+            "ON orders.id = order_items_1.order_id "
+            "JOIN items AS items_1 ON items_1.id = order_items_1.item_id "
+            "AND items_1.description != :description_1",
+        )
index b573accbd947a2d71382e0668d63acba70c9c4b2..7aad2cab86234dbb0fe0d42989be60756b46a45b 100644 (file)
@@ -1512,3 +1512,23 @@ class CompareClausesTest(fixtures.TestBase):
         is_true(x_p_a.compare(x_p))
         is_true(x_p.compare(x_p_a))
         is_false(x_p_a.compare(x_a))
+
+
+class ExecutableFlagsTest(fixtures.TestBase):
+    @testing.combinations(
+        (select(column("a")),),
+        (table("q", column("a")).insert(),),
+        (table("q", column("a")).update(),),
+        (table("q", column("a")).delete(),),
+        (lambda_stmt(lambda: select(column("a"))),),
+    )
+    def test_is_select(self, case):
+        if isinstance(case, LambdaElement):
+            resolved_case = case._resolved
+        else:
+            resolved_case = case
+
+        if isinstance(resolved_case, Select):
+            is_true(case.is_select)
+        else:
+            is_false(case.is_select)