]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Major revisals to lambdas
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Dec 2020 23:56:58 +0000 (18:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 16 Dec 2020 23:50:47 +0000 (18:50 -0500)
1. Improve coercions._deep_is_literal to check sequences
for clause elements, thus allowing a phrase like
lambda: col.in_([literal("x"), literal("y")]) to be handled

2. revise closure variable caching completely.   All variables
entering must be part of a closure cache key or rejected.
only objects that can be resolved to HasCacheKey or FunctionType
are accepted; all other types are rejected.  This adds a high
degree of strictness to lambdas and will make them a little more
awkward to use in some cases, however prevents several classes
of critical issues:

a. previously, a lambda that had an expression derived from
some kind of state, like "self.x", or "execution_context.session.foo"
would produce a closure cache key from "self" or "execution_context",
objects that can very well be per-execution and would therefore
cause a AnalyzedFunction objects to overflow. (memory won't leak
as it looks like an LRUCache is already used for these)

b. a lambda, such as one used within DeferredLamdaElement, that
produces different SQL expressions based on the arguments
(which is in fact what it's supposed to do), however it would
through the use of conditionals produce different bound parameter
combinations, leading to literal parameters not tracked properly.
These are now rejected as uncacheable whereas previously they would
again be part of the closure cache key, causing an overflow of
AnalyizedFunction objects.

3. Ensure non-mapped mixins are handled correctly by
with_loader_criteria().

4. Fixed bug in lambda SQL system where we are not supposed to allow a Python
function to be embedded in the lambda, since we can't predict a bound value
from it.   While there was an error condition added for this, it was not
tested and wasn't working; an informative error is now raised.

5. new docs for lambdas

6. consolidated changelog for all of these

Fixes: #5760
Fixes: #5765
Fixes: #5766
Fixes: #5768
Fixes: #5770
Change-Id: Iedaa636c3225fad496df23b612c516c8ab247ab7

19 files changed:
doc/build/changelog/unreleased_14/5760.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/5761.rst [deleted file]
doc/build/changelog/unreleased_14/5762.rst [deleted file]
doc/build/changelog/unreleased_14/5763.rst [deleted file]
doc/build/changelog/unreleased_14/5764.rst [deleted file]
doc/build/core/connections.rst
doc/build/errors.rst
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py
test/aaa_profiling/test_orm.py
test/orm/test_lambdas.py
test/orm/test_relationship_criteria.py
test/orm/test_selectin_relations.py
test/profiles.txt
test/sql/test_compare.py
test/sql/test_lambdas.py

diff --git a/doc/build/changelog/unreleased_14/5760.rst b/doc/build/changelog/unreleased_14/5760.rst
new file mode 100644 (file)
index 0000000..053eb5a
--- /dev/null
@@ -0,0 +1,69 @@
+.. change::
+    :tags: bug, sql, orm
+    :tickets: 5760, 5763, 5765, 5768, 5770
+
+    A wide variety of fixes to the "lambda SQL" feature introduced at
+    :ref:`engine_lambda_caching` have been implemented based on user feedback,
+    with an emphasis on its use within the :func:`_orm.with_loader_criteria`
+    feature where it is most prominently used [ticket:5760]:
+
+    * fixed issue where boolean True/False values referred towards in the
+      closure variables of the lambda would cause failures [ticket:5763]
+
+    * Repaired a non-working detection for Python functions embedded in the
+      lambda that produce bound values; this case is likely not supportable
+      so raises an informative error, where the function should be invoked
+      outside the lambda itself.  New documentation has been added to
+      further detail this behavior. [ticket:5770]
+
+    * The lambda system by default now rejects the use of non-SQL elements
+      within the closure variables of the lambda entirely, where the error
+      suggests the two options of either explicitly ignoring closure variables
+      that are not SQL parameters, or specifying a specific set of values to be
+      considered as part of the cache key based on hash value.   This critically
+      prevents the lambda system from assuming that arbitrary objects within
+      the lambda's closure are appropriate for caching while also refusing to
+      ignore them by default, preventing the case where their state might
+      not be constant and have an impact on the SQL construct produced.
+      The error message is comprehensive and new documentation has been
+      added to further detail this behavior. [ticket:5765]
+
+    * Fixed support for the edge case where an ``in_()`` expression
+      against a list of SQL elements, such as :func:`_sql.literal` objects,
+      would fail to be accommodated correctly. [ticket:5768]
+
+
+.. change::
+    :tags: bug, orm
+    :tickets: 5760, 5766, 5762, 5761, 5764
+
+    Related to the fixes for the lambda criteria system within Core, within the
+    ORM implemented a variety of fixes for the
+    :func:`_orm.with_loader_criteria` feature as well as the
+    :meth:`_orm.SessionEvents.do_orm_execute` event handler that is often
+    used in conjunction [ticket:5760]:
+
+
+    * fixed issue where :func:`_orm.with_loader_criteria` function would fail
+      if the given entity or base included non-mapped mixins in its descending
+      class hierarchy [ticket:5766]
+
+    * The :func:`_orm.with_loader_criteria` feature is now unconditionally
+      disabled for the case of ORM "refresh" operations, including loads
+      of deferred or expired column attributes as well as for explicit
+      operations like :meth:`_orm.Session.refresh`.  These loads are necessarily
+      based on primary key identity where addiional WHERE criteria is
+      never appropriate.  [ticket:5762]
+
+    * Added new attribute :attr:`_orm.ORMExecuteState.is_column_load` to indicate
+      that a :meth:`_orm.SessionEvents.do_orm_execute` handler that a particular
+      operation is a primary-key-directed column attribute load, where additional
+      criteria should not be added.  The :func:`_orm.with_loader_criteria`
+      function as above ignores these in any case now.  [ticket:5761]
+
+    * Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load`
+      attribute would not be set correctly for many lazy loads as well as all
+      selectinloads.  The flag is essential in order to test if options should
+      be added to statements or if they would already have been propagated via
+      relationship loads.  [ticket:5764]
+
diff --git a/doc/build/changelog/unreleased_14/5761.rst b/doc/build/changelog/unreleased_14/5761.rst
deleted file mode 100644 (file)
index 9e36f2a..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-.. change::
-       :tags: bug, orm
-       :tickets: 5761
-
-       Added new attribute :attr:`_orm.ORMExecuteState.is_column_load` to indicate
-       that a :meth:`_orm.SessionEvents.do_orm_execute` handler that a particular
-       operation is a primary-key-directed column attribute load, such as from an
-       expiration or a deferred attribute, and that WHERE criteria or additional
-       loader options should not be added to the query.   This has been added to
-       the examples which illustrate the :func:`_orm.with_loader_criteria` option.
\ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/5762.rst b/doc/build/changelog/unreleased_14/5762.rst
deleted file mode 100644 (file)
index 7b5a90c..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-.. change::
-       :tags: bug, orm
-       :tickets: 5762
-
-       The :func:`_orm.with_loader_criteria` option has been modified so that it
-       will never apply its criteria to the SELECT statement for an ORM refresh
-       operation, such as that invoked by :meth:`_orm.Session.refresh` or whenever
-       an expired attribute is loaded.   These queries are only against the
-       primary key row of the object that is already present in memory so there
-       should not be additional criteria added.
\ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/5763.rst b/doc/build/changelog/unreleased_14/5763.rst
deleted file mode 100644 (file)
index e395b6f..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-.. change::
-       :tags: bug, orm
-       :tickets: 5763
-
-       Fixed bug in lambda SQL feature, used by ORM
-       :meth:`_orm.with_loader_criteria` as well as available generally in the SQL
-       expression language, where assigning a boolean value True/False to a
-       variable would cause the query-time expression calculation to fail, as it
-       would produce a SQL expression not compatible with a bound value.
\ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/5764.rst b/doc/build/changelog/unreleased_14/5764.rst
deleted file mode 100644 (file)
index 29753fa..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-.. change::
-       :tags: orm, bug
-       :tickets: 5764
-
-       Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load`
-       attribute would not be set correctly for many lazy loads, all
-       selectinloads, etc.  The flag is essential in order to test if options
-       should be added to statements or if they would already have been propagated
-       via relationship loads.
\ No newline at end of file
index aeefd27407bc0d7121a05cf71eb8363f53aa7004..629bec333abafdc70b6c89648c7474475615132d 100644 (file)
@@ -1114,124 +1114,414 @@ Using Lambdas to add significant speed gains to statement production
    not appropriate for novice Python developers.  The lambda approach can be
    applied to at a later time to existing code with a minimal amount of effort.
 
-The caching system has in its roots the SQLAlchemy :ref:`"baked query"
-<baked_toplevel>` extension, which made novel use of Python lambdas in order to
-produce SQL statements that were intrinsically cacheable, while at the same
-time decreasing not just the overhead involved to compile the statement into
-SQL, but also the overhead in constructing the statement object from a Python
-perspective. The new caching in SQLAlchemy by default does not substantially
-optimize the construction of SQL constructs.  This refers to the Python
-overhead taken up to construct the statement object itself before it is
-compiled or executed, such as the :class:`_sql.Select` object used in the
-example below::
+Python functions, typically expressed as lambdas, may be used to generate
+SQL expressions which are cacheable based on the Python code location of
+the lambda function itself as well as the closure variables within the
+lambda.   The rationale is to allow caching of not only the SQL string-compiled
+form of a SQL expression construct as is SQLAlchemy's normal behavior when
+the lambda system isn't used, but also the in-Python composition
+of the SQL expression construct itself, which also has some degree of
+Python overhead.
+
+The lambda SQL expression feature is available as a performance enhancing
+feature, and is also optionally used in the :func:`_orm.with_loader_criteria`
+ORM option in order to provide a generic SQL fragment.
+
+Synopsis
+^^^^^^^^
+
+Lambda statements are constructed using the :func:`_sql.lambda_stmt` function,
+which returns an instance of :class:`_sql.StatementLambdaElement`, which is
+itself an executable statement construct.    Additional modifiers and criteria
+are added to the object using the Python addition operator ``+``, or
+alternatively the :meth:`_sql.StatementLambdaElement.add_criteria` method which
+allows for more options.
+
+It is assumed that the :func:`_sql.lambda_stmt` construct is being invoked
+within an enclosing function or method that expects to be used many times
+within an application, so that subsequent executions beyond the first one
+can take advantage of the compiled SQL being cached.  When the lambda is
+constructed inside of an enclosing function in Python it is then subject
+to also having closure variables, which are significant to the whole
+approach::
+
+    from sqlalchemy import lambda_stmt
 
     def run_my_statement(connection, parameter):
-        stmt = select(table)
-        stmt = stmt.where(table.c.col == parameter)
-        stmt = stmt.order_by(table.c.id)
+        stmt = lambda_stmt(lambda: select(table))
+        stmt += lambda s: s.where(table.c.col == parameter)
+        stmt += lambda s: s.order_by(table.c.id)
 
         return connection.execute(stmt)
 
-Above, in order to construct ``stmt``, we see three Python functions or methods
-``select()``, ``.where()`` and ``.order_by()`` being invoked directly, and
-additionally there is a Python method invoked when we construct ``table.c.col
-== 'foo'``, as the expression language overrides the ``__eq__()`` method to
-produce a SQL construct.   Within each of these calls is a series of argument
-checking and internal construction logic that makes use of many more Python
-function calls.   With intense production of thousands of statement objects,
-these function calls can add up.   Using the recipe for profiling at
-:ref:`faq_code_profiling`, the above Python code within the scope of the
-``select()`` call down to the ``.order_by()`` call uses 73 Python function
-calls to produce.
-
-Additionally, statement caching requires that a cache key be generated against
-the above statement, which must be composed of all elements within the
-statement that uniquely identify the SQL that it would produce.  Measuring
-this process for the above statement takes another 40 Python function calls.
-
-In order to ensure the full performance gains of the prior "baked query"
-extension are still available, the "lambda:" system used by baked queries has
-been adapted into a more capable and easier to use system as an intrinsic part
-of the SQLAlchemy Core expression language (which by extension then includes
-ORM queries, which as of SQLAlchemy 1.4 using 2.0-style APIs may also be
-invoked directly from SQLAlchemy Core expression objects).   We can
-adapt our statement above to be built using "lambdas" by making use of the
-:func:`_sql.lambda_stmt` element.  Using this approach, we indicate that the
-:func:`_sql.select` should be returned by a lambda.  We can then add new
-criteria to the statement by composing further lambdas onto the object in a
-similar manner as how "baked queries" worked::
-
-  from sqlalchemy import lambda_stmt
-
-  def run_my_statement(connection, parameter):
-      stmt = lambda_stmt(lambda: select(table))
-      stmt += lambda s: s.where(table.c.col == parameter)
-      stmt += lambda s: s.order_by(table.c.id)
-
-      return connection.execute(stmt)
-
-  result = run_my_statement(some_connection, "some parameter")
-
-The above code produces a :class:`.StatementLambdaElement`, which behaves like a
-Core SQL construct but defers the construction of the statement in most
-cases until it is needed by the compiler.  If the statement is already cached,
-the lambdas will not be called.
-
-The cache key is based on the **Python source code location of each lambda
-itself**, which in the Python interpreter is essentially the ``__code__``
-element of the Python function. This means that the lambda approach should only
-be used inside of a function where the lambdas themselves will be the **same
-lambdas each time, from a Python source code perspective**.
-
-The execution process for the above lambda will **extract literal parameters**
-from the statement each time, without needing to actually run the lambdas. In
-the above example, each time the variable ``parameter`` is used within the
-lambda to generate the WHERE clause of the statement, while the actual lambda
-present will not actually be run, the value of ``parameter`` will be tracked
-and the current value of the variable will be used within the statement
-parameters at execution time.  This is a feature that was not possible with the
-"baked query" extension and involves the use of up-front analysis of the
-incoming ``__code__`` object to determine how parameters can be extracted from
-future lambdas against that same code object.
-
-More simply, this means it's safe for the lambda statement
-to use arbitrary literal parameters, which don't modify the structure
-of the statement, on each invocation::
-
-  def run_my_statement(connection, parameter):
-      stmt = lambda_stmt(lambda: select(table))
-      stmt += lambda s: s.where(table.c.col == parameter)
-      stmt += lambda s: s.order_by(table.c.id)
-
-      return connection.execute(stmt)
-
-However, it's not safe for an individual lambda so modify the SQL structure
-of the statement across calls::
-
-  # incorrect example
-  def run_my_statement(connection, parameter, add_criteria=False):
-      stmt = lambda_stmt(lambda: select(table))
-
-      # will not be cached correctly as add_criteria changes
-      stmt += lambda s: s.where(
-          and_(add_criteria, table.c.col == parameter)
-          if add_criteria
-          else s.where(table.c.col == parameter)
-      )
-
-      stmt += lambda s: s.order_by(table.c.id)
-
-      return connection.execute(stmt)
-
-The lambda statements indicated above will invoke all of the lambdas the first
-time they are constructed; subsequent to that, the lambdas will not be invoked.
-On these subsequent runs, a lambda construct will use far fewer Python function
-calls in order to construct the un-cached object as well as to generate the
-cache key after the first call.  The above statement using lambdas takes only
-41 Python function calls to generate the whole structure as well as to produce
-the cache key, including the extraction of the bound parameters. This is
-compared to a total of about 115 Python function calls for the non-lambda
-version.
+    with engine.connect() as conn:
+        result = run_my_statement(some_connection, "some parameter")
+
+Above, the three ``lambda`` callables that are used to define the structure
+of a SELECT statement are invoked exactly once, and the resulting SQL
+string cached in the compilation cache of the engine.   From that point
+forward, the ``run_my_statement()`` function may be invoked any number
+of times and the ``lambda`` callables within it will not be called, only
+used as cache keys to retrieve the already-compiled SQL.
+
+.. note::  It is important to note that there is already SQL caching in place
+   when the lambda system is not used.   The lambda system only adds an
+   additional layer of work reduction per SQL statement invoked by caching
+   the building up of the SQL construct itself and also using a simpler
+   cache key.
+
+
+Quick Guidelines for Lambdas
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Above all, the emphasis within the lambda SQL system is ensuring that there
+is never a mismatch between the cache key generated for a lambda and the
+SQL string it will produce.   The :class:`_sql.LamdaElement` and related
+objects will run and analyze the given lambda in order to calculate how
+it should be cached on each run, trying to detect any potential problems.
+Basic guidelines include:
+
+* **Any kind of statement is supported** - while it's expected that
+  :func:`_sql.select` constructs are the prime use case for :func:`_sql.lambda_stmt`,
+  DML statements such as :func:`_sql.insert` and :func:`_sql.update` are
+  equally usable::
+
+    def upd(id_, newname):
+        stmt = lambda_stmt(lambda: users.update())
+        stmt += lambda s: s.values(name=newname)
+        stmt += lambda s: s.where(users.c.id==id_)
+        return stmt
+
+    with engine.begin() as conn:
+        conn.execute(upd(7, "foo"))
+
+  ..
+
+* **ORM use cases directly supported as well** - the :func:`_sql.lambda_stmt`
+  can accommodate ORM functionality completely and used directly with
+  :meth:`_orm.Session.execute`::
+
+    def select_user(session, name):
+        stmt = lambda_stmt(lambda: select(User))
+        stmt += lambda s: s.where(User.name == name)
+
+        row = session.execute(stmt).first()
+        return row
+
+  ..
+
+* **Bound parameters are automatically accommodated** - in contrast to SQLAlchemy's
+  previous "baked query" system, the lambda SQL system accommodates for
+  Python literal values which become SQL bound parameters automatically.
+  This means that even though a given lambda runs only once, the values that
+  become bound parameters are extracted from the **closure** of the lambda
+  on every run:
+
+  .. sourcecode:: pycon+sql
+
+        >>> def my_stmt(x, y):
+        ...     stmt = lambda_stmt(lambda: select(func.max(x, y)))
+        ...     return stmt
+        ...
+        >>> engine = create_engine("sqlite://", echo=True)
+        >>> with engine.connect() as conn:
+        ...     print(conn.scalar(my_stmt(5, 10)))
+        ...     print(conn.scalar(my_stmt(12, 8)))
+        ...
+        {opensql}SELECT max(?, ?) AS max_1
+        [generated in 0.00057s] (5, 10){stop}
+        10
+        {opensql}SELECT max(?, ?) AS max_1
+        [cached since 0.002059s ago] (12, 8){stop}
+        12
+
+  Above, :class:`_sql.StatementLambdaElement` extracted the values of ``x``
+  and ``y`` from the **closure** of the lambda that is generated each time
+  ``my_stmt()`` is invoked; these were substituted into the cached SQL
+  construct as the values of the parameters.
+
+* **The lambda should ideally produce an identical SQL structure in all cases** -
+  Avoid using conditionals or custom callables inside of lambdas that might make
+  it produce different SQL based on inputs; if a function might conditionally
+  use two different SQL fragments, use two separate lambdas::
+
+        # **Don't** do this:
+
+        def my_stmt(parameter, thing=False):
+            stmt = lambda_stmt(lambda: select(table))
+            stmt += (
+                lambda s: s.where(table.c.x > parameter) if thing
+                else s.where(table.c.y == parameter)
+            return stmt
+
+        # **Do** do this:
+
+        def my_stmt(parameter, thing=False):
+            stmt = lambda_stmt(lambda: select(table))
+            if thing:
+                stmt += s.where(table.c.x > parameter)
+            else:
+                stmt += s.where(table.c.y == parameter)
+            return stmt
+
+  There are a variety of failures which can occur if the lambda does not
+  produce a consistent SQL construct and some are not trivially detectable
+  right now.
+
+* **Don't use functions inside the lambda to produce bound values** - the
+  bound value tracking approach requires that the actual value to be used in
+  the SQL statement be locally present in the closure of the lambda.  This is
+  not possible if values are generated from other functions, and the
+  :class:`_sql.LambdaElement` should normally raise an error if this is
+  attempted::
+
+    >>> def my_stmt(x, y):
+    ...     def get_x():
+    ...         return x
+    ...     def get_y():
+    ...         return y
+    ...
+    ...     stmt = lambda_stmt(lambda: select(func.max(get_x(), get_y())))
+    ...     return stmt
+    ...
+    >>> with engine.connect() as conn:
+    ...     print(conn.scalar(my_stmt(5, 10)))
+    ...
+    Traceback (most recent call last):
+      # ...
+    sqlalchemy.exc.InvalidRequestError: Can't invoke Python callable get_x()
+    inside of lambda expression argument at
+    <code object <lambda> at 0x7fed15f350e0, file "<stdin>", line 6>;
+    lambda SQL constructs should not invoke functions from closure variables
+    to produce literal values since the lambda SQL system normally extracts
+    bound values without actually invoking the lambda or any functions within it.
+
+  Above, the use of ``get_x()`` and ``get_y()``, if they are necessary, should
+  occur **outside** of the lambda and assigned to a local closure variable::
+
+    >>> def my_stmt(x, y):
+    ...     def get_x():
+    ...         return x
+    ...     def get_y():
+    ...         return y
+    ...
+    ...     x_param, y_param = get_x(), get_y()
+    ...     stmt = lambda_stmt(lambda: select(func.max(x_param, y_param)))
+    ...     return stmt
+
+  ..
+
+* **Avoid referring to non-SQL constructs inside of lambdas as they are not
+  cacheable by default** - this issue refers to how the :class:`_sql.LambdaElement`
+  creates a cache key from other closure variables within the statement.  In order
+  to provide the best guarantee of an accurate cache key, all objects located
+  in the closure of the lambda are considered to be significant, and none
+  will none will be assumed to be appropriate for a cache key by default.
+  So the following example will also raise a rather detailed error message::
+
+    >>> class Foo:
+    ...     def __init__(self, x, y):
+    ...         self.x = x
+    ...         self.y = y
+    ...
+    >>> def my_stmt(foo):
+    ...     stmt = lambda_stmt(lambda: select(func.max(foo.x, foo.y)))
+    ...     return stmt
+    ...
+    >>> with engine.connect() as conn:
+    ...    print(conn.scalar(my_stmt(Foo(5, 10))))
+    ...
+    Traceback (most recent call last):
+      # ...
+    sqlalchemy.exc.InvalidRequestError: Closure variable named 'foo' inside of
+    lambda callable <code object <lambda> at 0x7fed15f35450, file
+    "<stdin>", line 2> does not refer to a cachable SQL element, and also
+    does not appear to be serving as a SQL literal bound value based on the
+    default SQL expression returned by the function.  This variable needs to
+    remain outside the scope of a SQL-generating lambda so that a proper cache
+    key may be generated from the lambda's state.  Evaluate this variable
+    outside of the lambda, set track_on=[<elements>] to explicitly select
+    closure elements to track, or set track_closure_variables=False to exclude
+    closure variables from being part of the cache key.
+
+  The above error indicates that :class:`_sql.LambdaElement` will not assume
+  that the ``Foo`` object passed in will contine to behave the same in all
+  cases.    It also won't assume it can use ``Foo`` as part of the cache key
+  by default; if it were to use the ``Foo`` object as part of the cache key,
+  if there were many different ``Foo`` objects this would fill up the cache
+  with duplicate information, and would also hold long-lasting references to
+  all of these objects.
+
+  The best way to resolve the above situation is to not refer to ``foo``
+  inside of the lambda, and refer to it **outside** instead::
+
+    >>> def my_stmt(foo):
+    ...     x_param, y_param = foo.x, foo.y
+    ...     stmt = lambda_stmt(lambda: select(func.max(x_param, y_param)))
+    ...     return stmt
+
+  In some situations, if the SQL structure of the lambda is guaranteed to
+  never change based on input, to pass ``track_closure_variables=False``
+  which will disable any tracking of closure variables other than those
+  used for bound parameters::
+
+    >>> def my_stmt(foo):
+    ...     stmt = lambda_stmt(
+    ...         lambda: select(func.max(foo.x, foo.y)),
+    ...         track_closure_variables=False
+    ...     )
+    ...     return stmt
+
+  There is also the option to add objects to the element to explicitly form
+  part of the cache key, using the ``track_on`` parameter; using this parameter
+  allows specific values to serve as the cache key and will also prevent other
+  closure variables from being considered.  This is useful for cases where part
+  of the SQL being constructed originates from a contextual object of some sort
+  that may have many different values.  In the example below, the first
+  segment of the SELECT statement will disable tracking of the ``foo`` variable,
+  whereas the second segment will explicitly track ``self`` as part of the
+  cache key::
+
+    >>> def my_stmt(self, foo):
+    ...     stmt = lambda_stmt(
+    ...         lambda: select(*self.column_expressions),
+    ...         track_closure_variables=False
+    ...     )
+    ...     stmt = stmt.add_criteria(
+    ...         lambda: self.where_criteria,
+    ...         track_on=[self]
+    ...     )
+    ...     return stmt
+
+  Using ``track_on`` means the given objects will be stored long term in the
+  lambda's internal cache and will have strong references for as long as the
+  cache doesn't clear out those objects (an LRU scheme of 1000 entries is used
+  by default).
+
+  ..
+
+
+Cache Key Generation
+^^^^^^^^^^^^^^^^^^^^
+
+In order to understand some of the options and behaviors which occur
+with lambda SQL constructs, an understanding of the caching system
+is helpful.
+
+SQLAlchemy's caching system normally generates a cache key from a given
+SQL expression construct by producing a structure that represents all the
+state within the construct::
+
+    >>> from sqlalchemy import select, column
+    >>> stmt = select(column('q'))
+    >>> cache_key = stmt._generate_cache_key()
+    >>> print(cache_key)  # somewhat paraphrased
+    CacheKey(key=(
+      '0',
+      <class 'sqlalchemy.sql.selectable.Select'>,
+      '_raw_columns',
+      (
+        (
+          '1',
+          <class 'sqlalchemy.sql.elements.ColumnClause'>,
+          'name',
+          'q',
+          'type',
+          (
+            <class 'sqlalchemy.sql.sqltypes.NullType'>,
+          ),
+        ),
+      ),
+      # a few more elements are here, and many more for a more
+      # complicated SELECT statement
+    ),)
+
+
+The above key is stored in the cache which is essentially a dictionary, and the
+value is a construct that among other things stores the string form of the SQL
+statement, in this case the phrase "SELECT q".  We can observe that even for an
+extremely short query the cache key is pretty verbose as it has to represent
+everything that may vary about what's being rendered and potentially executed.
+
+The lambda construction system by contrast creates a different kind of cache
+key::
+
+    >>> from sqlalchemy import lambda_stmt
+    >>> stmt = lambda_stmt(lambda: select(column("q")))
+    >>> cache_key = stmt._generate_cache_key()
+    >>> print(cache_key)
+    CacheKey(key=(
+      <code object <lambda> at 0x7fed1617c710, file "<stdin>", line 1>,
+      <class 'sqlalchemy.sql.lambdas.StatementLambdaElement'>,
+    ),)
+
+Above, we see a cache key that is vastly shorter than that of the non-lambda
+statement, and additionally that production of the ``select(column("q"))``
+construct itself was not even necessary; the Python lambda itself contains
+an attribute called ``__code__`` which refers to a Python code object that
+within the runtime of the application is immutable and permanent.
+
+When the lambda also includes closure variables, in the normal case that these
+variables refer to SQL constructs such as column objects, they become
+part of the cache key, or if they refer to literal values that will be bound
+parameters, they are placed in a separate element of the cache key::
+
+    >>> def my_stmt(parameter):
+    ...     col = column("q")
+    ...     stmt = lambda_stmt(lambda: select(col))
+    ...     stmt += lambda s: s.where(col == parameter)
+    ...     return stmt
+
+The above :class:`_sql.StatementLambdaElement` includes two lambdas, both
+of which refer to the ``col`` closure variable, so the cache key will
+represent both of these segments as well as the ``column()`` object::
+
+    >>> stmt = my_stmt(5)
+    >>> key = stmt._generate_cache_key()
+    >>> print(key)
+    CacheKey(key=(
+      <code object <lambda> at 0x7f07323c50e0, file "<stdin>", line 3>,
+      (
+        '0',
+        <class 'sqlalchemy.sql.elements.ColumnClause'>,
+        'name',
+        'q',
+        'type',
+        (
+          <class 'sqlalchemy.sql.sqltypes.NullType'>,
+        ),
+      ),
+      <code object <lambda> at 0x7f07323c5190, file "<stdin>", line 4>,
+      <class 'sqlalchemy.sql.lambdas.LinkedLambdaElement'>,
+      (
+        '0',
+        <class 'sqlalchemy.sql.elements.ColumnClause'>,
+        'name',
+        'q',
+        'type',
+        (
+          <class 'sqlalchemy.sql.sqltypes.NullType'>,
+        ),
+      ),
+      (
+        '0',
+        <class 'sqlalchemy.sql.elements.ColumnClause'>,
+        'name',
+        'q',
+        'type',
+        (
+          <class 'sqlalchemy.sql.sqltypes.NullType'>,
+        ),
+      ),
+    ),)
+
+
+The second part of the cache key has retrieved the bound parameters that will
+be used when the statement is invoked::
+
+    >>> key.bindparams
+    [BindParameter('%(139668884281280 parameter)s', 5, type_=Integer())]
+
 
 For a series of examples of "lambda" caching with performance comparisons,
 see the "short_selects" test suite within the :ref:`examples_performance`
index a52444766db6fb9e2ebe9d62499fc431713737d8..67a8a29b0f5aafb33ac84b5cab0ce67dac708ccd 100644 (file)
@@ -802,6 +802,7 @@ therefore requires that :meth:`_expression.SelectBase.subquery` is used::
 
   :ref:`change_4617`
 
+
 Object Relational Mapping
 =========================
 
index 98c57149d341dd2cfbb5c225f7e671475ae0ac43..6838011b14ad08d47ddf1082affd1089e55391a6 100644 (file)
@@ -540,6 +540,9 @@ class AbstractRelationshipLoader(LoaderStrategy):
         self.target = self.parent_property.target
         self.uselist = self.parent_property.uselist
 
+    def _size_alert(self, lru_cache):
+        util.warn("LRU cache size alert for loader strategy: %s" % self)
+
 
 @log.class_logger
 @relationships.RelationshipProperty.strategy_for(do_nothing=True)
@@ -884,7 +887,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         ]
 
     def _memoized_attr__query_cache(self):
-        return util.LRUCache(30)
+        # cache is per lazy loader; stores not only cached SQL but also
+        # sqlalchemy.sql.lambdas.AnalyzedCode and
+        # sqlalchemy.sql.lambdas.AnalyzedFunction objects which are generated
+        # from the StatementLambda used.
+        return util.LRUCache(30, size_alert=self._size_alert)
 
     @util.preload_module("sqlalchemy.orm.strategy_options")
     def _emit_lazyload(
@@ -912,8 +919,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
         }
 
         if self.parent_property.secondary is not None:
-            stmt += lambda stmt: stmt.select_from(
-                self.mapper, self.parent_property.secondary
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.select_from(
+                    self.mapper, self.parent_property.secondary
+                ),
+                track_on=[self.parent_property],
             )
 
         pending = not state.key
@@ -961,7 +971,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
             )
 
         if self._order_by:
-            stmt += lambda stmt: stmt.order_by(*self._order_by)
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.order_by(*self._order_by), track_on=[self]
+            )
 
         def _lazyload_reverse(compile_context):
             for rev in self.parent_property._reverse_property:
@@ -978,8 +990,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                         ]
                     ).lazyload(rev.key).process_compile_state(compile_context)
 
-        stmt += lambda stmt: stmt._add_context_option(
-            _lazyload_reverse, self.parent_property
+        stmt = stmt.add_criteria(
+            lambda stmt: stmt._add_context_option(
+                _lazyload_reverse, self.parent_property
+            ),
+            track_on=[self],
         )
 
         lazy_clause, params = self._generate_lazy_clause(state, passive)
@@ -2587,7 +2602,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
         ).init_class_attribute(mapper)
 
     def _memoized_attr__query_cache(self):
-        return util.LRUCache(30)
+        return util.LRUCache(30, size_alert=self._size_alert)
 
     def create_row_processor(
         self,
@@ -2763,13 +2778,13 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
             # in the non-omit_join case, the Bundle is against the annotated/
             # mapped column of the parent entity, but the #4347 issue does not
             # occur in this case.
-            pa = self._parent_alias
             q = q.add_criteria(
-                lambda q: q.select_from(pa).join(
-                    getattr(pa, self.parent_property.key).of_type(
-                        effective_entity
-                    )
-                )
+                lambda q: q.select_from(self._parent_alias).join(
+                    getattr(
+                        self._parent_alias, self.parent_property.key
+                    ).of_type(effective_entity)
+                ),
+                track_on=[self],
             )
 
         q = q.add_criteria(
@@ -2849,7 +2864,8 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
                 q = q.add_criteria(
                     lambda q: q._add_context_option(
                         _setup_outermost_orderby, self.parent_property
-                    )
+                    ),
+                    track_on=[self],
                 )
 
         if query_info.load_only_child:
index c9437d1b2eacef428065f55b4b9aad7761c9896a..88f9a34d056e4a723537b82d1f6c09dc37de12d9 100644 (file)
@@ -885,6 +885,7 @@ class LoaderCriteriaOption(CriteriaOption):
         loader_only=False,
         include_aliases=False,
         propagate_to_loaders=True,
+        track_closure_variables=True,
     ):
         """Add additional WHERE criteria to the load for all occurrences of
         a particular entity.
@@ -993,6 +994,14 @@ class LoaderCriteriaOption(CriteriaOption):
             combine :func:`_orm.with_loader_criteria` with the
             :meth:`_orm.SessionEvents.do_orm_execute` event.
 
+        :param track_closure_variables: when False, closure variables inside
+         of a lambda expression will not be validated used as part of
+         any cache key.    This allows more complex expressions to be used
+         inside of a lambda expression but requires that the lambda ensures
+         it returns the identical SQL every time given a particular class.
+
+         .. versionadded:: 1.4.0b2
+
         """
         entity = inspection.inspect(entity_or_base, False)
         if entity is None:
@@ -1012,6 +1021,9 @@ class LoaderCriteriaOption(CriteriaOption):
                     if self.root_entity is not None
                     else self.entity.entity,
                 ),
+                opts=lambdas.LambdaOptions(
+                    track_closure_variables=track_closure_variables
+                ),
             )
         else:
             self.deferred_where_criteria = False
@@ -1030,7 +1042,7 @@ class LoaderCriteriaOption(CriteriaOption):
             stack = list(self.root_entity.__subclasses__())
             while stack:
                 subclass = stack.pop(0)
-                ent = inspection.inspect(subclass)
+                ent = inspection.inspect(subclass, raiseerr=False)
                 if ent:
                     for mp in ent.mapper.self_and_descendants:
                         yield mp
index bdd807438ce81a70100ab1d188bc396395b54f3e..43c89ee8237ecdc4664a6449ca9f6da61e59c317 100644 (file)
@@ -7,7 +7,6 @@
 
 import numbers
 import re
-import types
 
 from . import operators
 from . import roles
@@ -56,6 +55,15 @@ def _deep_is_literal(element):
 
     """
 
+    if isinstance(element, collections_abc.Sequence) and not isinstance(
+        element, str
+    ):
+        for elem in element:
+            if not _deep_is_literal(elem):
+                return False
+        else:
+            return True
+
     return (
         not isinstance(
             element,
@@ -66,7 +74,6 @@ def _deep_is_literal(element):
             not isinstance(element, type)
             or not issubclass(element, HasCacheKey)
         )
-        and not isinstance(element, types.FunctionType)
     )
 
 
@@ -109,9 +116,8 @@ def expect(role, element, apply_propagate_attrs=None, argname=None, **kw):
         return lambdas.LambdaElement(
             element,
             role,
+            lambdas.LambdaOptions(**kw),
             apply_propagate_attrs=apply_propagate_attrs,
-            argname=argname,
-            **kw
         )
 
     # major case is that we are given a ClauseElement already, skip more
index 86611baeb155fd70a7a387c3f63e113814b986b4..ab8701dd6566689540218df9daa3536ef7397041 100644 (file)
@@ -1383,7 +1383,6 @@ class BindParameter(roles.InElementRole, ColumnElement):
         """Return a copy of this :class:`.BindParameter` with the given value
         set.
         """
-
         cloned = self._clone(maintain_key=maintain_key)
         cloned.value = value
         cloned.callable = None
index aafdda4ce14fe27cf8292399abb680b7e362d870..3f0ca477e653af17503736caae427fb040794fde 100644 (file)
@@ -19,6 +19,7 @@ from . import traversals
 from . import type_api
 from . import visitors
 from .base import _clone
+from .base import Options
 from .operators import ColumnOperators
 from .. import exc
 from .. import inspection
@@ -28,7 +29,24 @@ from ..util import collections_abc
 _closure_per_cache_key = util.LRUCache(1000)
 
 
-def lambda_stmt(lmb, **opts):
+class LambdaOptions(Options):
+    enable_tracking = True
+    track_closure_variables = True
+    track_on = None
+    global_track_bound_values = True
+    track_bound_values = True
+    lambda_cache = None
+
+
+def lambda_stmt(
+    lmb,
+    enable_tracking=True,
+    track_closure_variables=True,
+    track_on=None,
+    global_track_bound_values=True,
+    track_bound_values=True,
+    lambda_cache=None,
+):
     """Produce a SQL statement that is cached as a lambda.
 
     The Python code object within the lambda is scanned for both Python
@@ -49,6 +67,29 @@ def lambda_stmt(lmb, **opts):
 
     .. versionadded:: 1.4
 
+    :param lmb: a Python function, typically a lambda, which takes no arguments
+     and returns a SQL expression construct
+    :param enable_tracking: when False, all scanning of the given lambda for
+     changes in closure variables or bound parameters is disabled.  Use for
+     a lambda that produces the identical results in all cases with no
+     parameterization.
+    :param track_closure_variables: when False, changes in closure variables
+     within the lambda will not be scanned.   Use for a lambda where the
+     state of its closure variables will never change the SQL structure
+     returned by the lambda.
+    :param track_bound_values: when False, bound parameter tracking will
+     be disabled for the given lambda.  Use for a lambda that either does
+     not produce any bound values, or where the initial bound values never
+     change.
+    :param global_track_bound_values: when False, bound parameter tracking
+     will be disabled for the entire statement including additional links
+     added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
+    :param lambda_cache: a dictionary or other mapping-like object where
+     information about the lambda's Python code as well as the tracked closure
+     variables in the lambda itself will be stored.   Defaults
+     to a global LRU cache.  This cache is independent of the "compiled_cache"
+     used by the :class:`_engine.Connection` object.
+
     .. seealso::
 
         :ref:`engine_lambda_caching`
@@ -56,7 +97,18 @@ def lambda_stmt(lmb, **opts):
 
     """
 
-    return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts)
+    return StatementLambdaElement(
+        lmb,
+        roles.CoerceTextStatementRole,
+        LambdaOptions(
+            enable_tracking=enable_tracking,
+            track_on=track_on,
+            track_closure_variables=track_closure_variables,
+            global_track_bound_values=global_track_bound_values,
+            track_bound_values=track_bound_values,
+            lambda_cache=lambda_cache,
+        ),
+    )
 
 
 class LambdaElement(elements.ClauseElement):
@@ -94,38 +146,39 @@ class LambdaElement(elements.ClauseElement):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
 
-    def __init__(self, fn, role, apply_propagate_attrs=None, **kw):
+    def __init__(
+        self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
+    ):
         self.fn = fn
         self.role = role
         self.tracker_key = (fn.__code__,)
+        self.opts = opts
 
         if apply_propagate_attrs is None and (
             role is roles.CoerceTextStatementRole
         ):
             apply_propagate_attrs = self
 
-        rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw)
+        rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
 
         if apply_propagate_attrs is not None:
             propagate_attrs = rec.propagate_attrs
             if propagate_attrs:
                 apply_propagate_attrs._propagate_attrs = propagate_attrs
 
-    def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw):
-        lambda_cache = kw.get("lambda_cache", _closure_per_cache_key)
+    def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
+        lambda_cache = opts.lambda_cache
+        if lambda_cache is None:
+            lambda_cache = _closure_per_cache_key
 
         tracker_key = self.tracker_key
 
         fn = self.fn
         closure = fn.__closure__
-
         tracker = AnalyzedCode.get(
             fn,
             self,
-            kw,
-            track_bound_values=kw.get("track_bound_values", True),
-            enable_tracking=kw.get("enable_tracking", True),
-            track_on=kw.get("track_on", None),
+            opts,
         )
 
         self._resolved_bindparams = bindparams = []
@@ -133,10 +186,11 @@ class LambdaElement(elements.ClauseElement):
         anon_map = traversals.anon_map()
         cache_key = tuple(
             [
-                getter(closure, kw, anon_map, bindparams)
+                getter(closure, opts, anon_map, bindparams)
                 for getter in tracker.closure_trackers
             ]
         )
+
         if self.parent_lambda is not None:
             cache_key = self.parent_lambda.closure_cache_key + cache_key
 
@@ -148,9 +202,7 @@ class LambdaElement(elements.ClauseElement):
             rec = None
 
         if rec is None:
-            rec = AnalyzedFunction(
-                tracker, self, apply_propagate_attrs, kw, fn
-            )
+            rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn)
             rec.closure_bindparams = bindparams
             lambda_cache[tracker_key + cache_key] = rec
         else:
@@ -213,14 +265,13 @@ class LambdaElement(elements.ClauseElement):
         bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
 
         def replace(thing):
-            if (
-                isinstance(thing, elements.BindParameter)
-                and thing.key in bindparam_lookup
-            ):
-                bind = bindparam_lookup[thing.key]
-                if thing.expanding:
-                    bind.expanding = True
-                return bind
+            if isinstance(thing, elements.BindParameter):
+
+                if thing.key in bindparam_lookup:
+                    bind = bindparam_lookup[thing.key]
+                    if thing.expanding:
+                        bind.expanding = True
+                    return bind
 
         if self._rec.is_sequence:
             expr = [
@@ -268,7 +319,6 @@ class LambdaElement(elements.ClauseElement):
 
         if self._resolved_bindparams:
             bindparams.extend(self._resolved_bindparams)
-
         return cache_key
 
     def _invoke_user_fn(self, fn, *arg):
@@ -285,10 +335,9 @@ class DeferredLambdaElement(LambdaElement):
 
     """
 
-    def __init__(self, fn, role, lambda_args=(), **kw):
+    def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
         self.lambda_args = lambda_args
-        self.coerce_kw = kw
-        super(DeferredLambdaElement, self).__init__(fn, role, **kw)
+        super(DeferredLambdaElement, self).__init__(fn, role, opts)
 
     def _invoke_user_fn(self, fn, *arg):
         return fn(*self.lambda_args)
@@ -297,10 +346,30 @@ class DeferredLambdaElement(LambdaElement):
         tracker_fn = self._rec.tracker_instrumented_fn
         expr = tracker_fn(*lambda_args)
 
-        expr = coercions.expect(self.role, expr, **self.coerce_kw)
-
-        if self._resolved_bindparams:
-            expr = self._setup_binds_for_tracked_expr(expr)
+        expr = coercions.expect(self.role, expr)
+
+        expr = self._setup_binds_for_tracked_expr(expr)
+
+        # this validation is getting very close, but not quite, to achieving
+        # #5767.  The problem is if the base lambda uses an unnamed column
+        # as is very common with mixins, the parameter name is different
+        # and it produces a false positive; that is, for the documented case
+        # that is exactly what people will be doing, it doesn't work, so
+        # I'm not really sure how to handle this right now.
+        # expected_binds = [
+        #    b._orig_key
+        #    for b in self._rec.expr._generate_cache_key()[1]
+        #    if b.required
+        # ]
+        # got_binds = [
+        #    b._orig_key for b in expr._generate_cache_key()[1] if b.required
+        # ]
+        # if expected_binds != got_binds:
+        #    raise exc.InvalidRequestError(
+        #        "Lambda callable at %s produced a different set of bound "
+        #        "parameters than its original run: %s"
+        #        % (self.fn.__code__, ", ".join(got_binds))
+        #    )
 
         # TODO: TEST TEST TEST, this is very out there
         for deferred_copy_internals in self._transforms:
@@ -312,7 +381,9 @@ class DeferredLambdaElement(LambdaElement):
         self, clone=_clone, deferred_copy_internals=None, **kw
     ):
         super(DeferredLambdaElement, self)._copy_internals(
-            clone=clone, deferred_copy_internals=deferred_copy_internals, **kw
+            clone=clone,
+            deferred_copy_internals=deferred_copy_internals,  # **kw
+            opts=kw,
         )
 
         # TODO: A LOT A LOT of tests.   for _resolve_with_args, we don't know
@@ -347,33 +418,60 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
 
     """
 
-    def __init__(self, fn, parent_lambda, **kw):
-        self._default_kw = default_kw = {}
-        global_track_bound_values = kw.pop("global_track_bound_values", None)
-        if global_track_bound_values is not None:
-            default_kw["track_bound_values"] = global_track_bound_values
-            kw["track_bound_values"] = global_track_bound_values
+    def __add__(self, other):
+        return self.add_criteria(other)
 
-        if "lambda_cache" in kw:
-            default_kw["lambda_cache"] = kw["lambda_cache"]
+    def add_criteria(
+        self,
+        other,
+        enable_tracking=True,
+        track_on=None,
+        track_closure_variables=True,
+        track_bound_values=True,
+    ):
+        """Add new criteria to this :class:`_sql.StatementLambdaElement`.
+
+        E.g.::
+
+            >>> def my_stmt(parameter):
+            ...     stmt = lambda_stmt(
+            ...         lambda: select(table.c.x, table.c.y),
+            ...     )
+            ...     stmt = stmt.add_criteria(
+            ...         lambda: table.c.x > parameter
+            ...     )
+            ...     return stmt
+
+        The :meth:`_sql.StatementLambdaElement.add_criteria` method is
+        equivalent to using the Python addition operator to add a new
+        lambda, except that additional arguments may be added including
+        ``track_closure_values`` and ``track_on``::
+
+            >>> def my_stmt(self, foo):
+            ...     stmt = lambda_stmt(
+            ...         lambda: select(func.max(foo.x, foo.y)),
+            ...         track_closure_variables=False
+            ...     )
+            ...     stmt = stmt.add_criteria(
+            ...         lambda: self.where_criteria,
+            ...         track_on=[self]
+            ...     )
+            ...     return stmt
+
+        See :func:`_sql.lambda_stmt` for a description of the parameters
+        accepted.
 
-        super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw)
+        """
 
-    def __add__(self, other):
-        return LinkedLambdaElement(
-            other, parent_lambda=self, **self._default_kw
+        opts = self.opts + dict(
+            enable_tracking=enable_tracking,
+            track_closure_variables=track_closure_variables,
+            global_track_bound_values=self.opts.global_track_bound_values,
+            track_on=track_on,
+            track_bound_values=track_bound_values,
         )
 
-    def add_criteria(self, other, **kw):
-        if self._default_kw:
-            if kw:
-                default_kw = self._default_kw.copy()
-                default_kw.update(kw)
-                kw = default_kw
-            else:
-                kw = self._default_kw
-
-        return LinkedLambdaElement(other, parent_lambda=self, **kw)
+        return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
 
     def _execute_on_connection(
         self, connection, multiparams, params, execution_options
@@ -461,14 +559,13 @@ class LinkedLambdaElement(StatementLambdaElement):
 
     role = None
 
-    def __init__(self, fn, parent_lambda, **kw):
-        self._default_kw = parent_lambda._default_kw
-
+    def __init__(self, fn, parent_lambda, opts):
+        self.opts = opts
         self.fn = fn
         self.parent_lambda = parent_lambda
 
         self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
-        self._retrieve_tracker_rec(fn, self, kw)
+        self._retrieve_tracker_rec(fn, self, opts)
         self._propagate_attrs = parent_lambda._propagate_attrs
 
     def _invoke_user_fn(self, fn, *arg):
@@ -497,20 +594,17 @@ class AnalyzedCode(object):
         )
         return analyzed
 
-    def __init__(
-        self,
-        fn,
-        lambda_element,
-        lambda_kw,
-        track_bound_values=True,
-        enable_tracking=True,
-        track_on=None,
-    ):
+    def __init__(self, fn, lambda_element, opts):
         closure = fn.__closure__
 
-        self.track_closure_variables = not track_on
+        self.track_bound_values = (
+            opts.track_bound_values and opts.global_track_bound_values
+        )
+        enable_tracking = opts.enable_tracking
+        track_on = opts.track_on
+        track_closure_variables = opts.track_closure_variables
 
-        self.track_bound_values = track_bound_values
+        self.track_closure_variables = track_closure_variables and not track_on
 
         # a list of callables generated from _bound_parameter_getter_*
         # functions.  Each of these uses a PyWrapper object to retrieve
@@ -533,7 +627,7 @@ class AnalyzedCode(object):
             if closure:
                 self._init_closure(fn)
 
-        self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw)
+        self._setup_additional_closure_trackers(fn, lambda_element, opts)
 
     def _init_track_on(self, track_on):
         self.closure_trackers.extend(
@@ -590,13 +684,11 @@ class AnalyzedCode(object):
                 if track_closure_variables:
                     closure_trackers.append(
                         self._cache_key_getter_closure_variable(
-                            closure_index, cell.cell_contents
+                            fn, fv, closure_index, cell.cell_contents
                         )
                     )
 
-    def _setup_additional_closure_trackers(
-        self, fn, lambda_element, lambda_kw
-    ):
+    def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
         # an additional step is to actually run the function, then
         # go through the PyWrapper objects that were set up to catch a bound
         # parameter.   then if they *didn't* make a param, oh they're another
@@ -607,7 +699,6 @@ class AnalyzedCode(object):
             self,
             lambda_element,
             None,
-            lambda_kw,
             fn,
         )
 
@@ -616,7 +707,7 @@ class AnalyzedCode(object):
         for pywrapper in analyzed_function.closure_pywrappers:
             if not pywrapper._sa__has_param:
                 closure_trackers.append(
-                    self._cache_key_getter_tracked_literal(pywrapper)
+                    self._cache_key_getter_tracked_literal(fn, pywrapper)
                 )
 
     @classmethod
@@ -625,7 +716,7 @@ class AnalyzedCode(object):
 
         if is_clause_element:
             while not isinstance(
-                element, (elements.ClauseElement, schema.SchemaItem)
+                element, (elements.ClauseElement, schema.SchemaItem, type)
             ):
                 try:
                     element = element.__clause_element__()
@@ -688,17 +779,25 @@ class AnalyzedCode(object):
         """
         if isinstance(elem, traversals.HasCacheKey):
 
-            def get(closure, kw, anon_map, bindparams):
-                return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams)
+            def get(closure, opts, anon_map, bindparams):
+                return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
 
         else:
 
-            def get(closure, kw, anon_map, bindparams):
-                return kw["track_on"][idx]
+            def get(closure, opts, anon_map, bindparams):
+                return opts.track_on[idx]
 
         return get
 
-    def _cache_key_getter_closure_variable(self, idx, cell_contents):
+    def _cache_key_getter_closure_variable(
+        self,
+        fn,
+        variable_name,
+        idx,
+        cell_contents,
+        use_clause_element=False,
+        use_inspect=False,
+    ):
         """Return a getter that will extend a cache key with new entries
         from the ``__closure__`` collection of a particular lambda.
 
@@ -706,29 +805,90 @@ class AnalyzedCode(object):
 
         if isinstance(cell_contents, traversals.HasCacheKey):
 
-            def get(closure, kw, anon_map, bindparams):
-                return closure[idx].cell_contents._gen_cache_key(
-                    anon_map, bindparams
-                )
+            def get(closure, opts, anon_map, bindparams):
+
+                obj = closure[idx].cell_contents
+                if use_inspect:
+                    obj = inspection.inspect(obj)
+                elif use_clause_element:
+                    while hasattr(obj, "__clause_element__"):
+                        if not getattr(obj, "is_clause_element", False):
+                            obj = obj.__clause_element__()
+
+                return obj._gen_cache_key(anon_map, bindparams)
 
         elif isinstance(cell_contents, types.FunctionType):
 
-            def get(closure, kw, anon_map, bindparams):
+            def get(closure, opts, anon_map, bindparams):
                 return closure[idx].cell_contents.__code__
 
-        elif cell_contents.__hash__ is None:
-            # this covers dict, etc.
-            def get(closure, kw, anon_map, bindparams):
-                return ()
+        elif isinstance(cell_contents, collections_abc.Sequence):
+
+            def get(closure, opts, anon_map, bindparams):
+                contents = closure[idx].cell_contents
+
+                try:
+                    return tuple(
+                        elem._gen_cache_key(anon_map, bindparams)
+                        for elem in contents
+                    )
+                except AttributeError as ae:
+                    self._raise_for_uncacheable_closure_variable(
+                        variable_name, fn, from_=ae
+                    )
 
         else:
+            # if the object is a mapped class or aliased class, or some
+            # other object in the ORM realm of things like that, imitate
+            # the logic used in coercions.expect() to roll it down to the
+            # SQL element
+            element = cell_contents
+            is_clause_element = False
+            while hasattr(element, "__clause_element__"):
+                is_clause_element = True
+                if not getattr(element, "is_clause_element", False):
+                    element = element.__clause_element__()
+                else:
+                    break
 
-            def get(closure, kw, anon_map, bindparams):
-                return closure[idx].cell_contents
+            if not is_clause_element:
+                insp = inspection.inspect(element, raiseerr=False)
+                if insp is not None:
+                    return self._cache_key_getter_closure_variable(
+                        fn, variable_name, idx, insp, use_inspect=True
+                    )
+            else:
+                return self._cache_key_getter_closure_variable(
+                    fn, variable_name, idx, element, use_clause_element=True
+                )
+
+            self._raise_for_uncacheable_closure_variable(variable_name, fn)
 
         return get
 
-    def _cache_key_getter_tracked_literal(self, pytracker):
+    def _raise_for_uncacheable_closure_variable(
+        self, variable_name, fn, from_=None
+    ):
+        util.raise_(
+            exc.InvalidRequestError(
+                "Closure variable named '%s' inside of lambda callable %s "
+                "does not refer to a cachable SQL element, and also does not "
+                "appear to be serving as a SQL literal bound value based on "
+                "the default "
+                "SQL expression returned by the function.   This variable "
+                "needs to remain outside the scope of a SQL-generating lambda "
+                "so that a proper cache key may be generated from the "
+                "lambda's state.  Evaluate this variable outside of the "
+                "lambda, set track_on=[<elements>] to explicitly select "
+                "closure elements to track, or set "
+                "track_closure_variables=False to exclude "
+                "closure variables from being part of the cache key."
+                % (variable_name, fn.__code__),
+            ),
+            from_=from_,
+        )
+
+    def _cache_key_getter_tracked_literal(self, fn, pytracker):
         """Return a getter that will extend a cache key with new entries
         from the ``__closure__`` collection of a particular lambda.
 
@@ -741,33 +901,11 @@ class AnalyzedCode(object):
 
         elem = pytracker._sa__to_evaluate
         closure_index = pytracker._sa__closure_index
+        variable_name = pytracker._sa__name
 
-        if isinstance(elem, set):
-            raise exc.ArgumentError(
-                "Can't create a cache key for lambda closure variable "
-                '"%s" because it\'s a set.  try using a list'
-                % pytracker._sa__name
-            )
-
-        elif isinstance(elem, list):
-
-            def get(closure, kw, anon_map, bindparams):
-                return tuple(
-                    elem._gen_cache_key(anon_map, bindparams)
-                    for elem in closure[closure_index].cell_contents
-                )
-
-        elif elem.__hash__ is None:
-            # this covers dict, etc.
-            def get(closure, kw, anon_map, bindparams):
-                return ()
-
-        else:
-
-            def get(closure, kw, anon_map, bindparams):
-                return closure[closure_index].cell_contents
-
-        return get
+        return self._cache_key_getter_closure_variable(
+            fn, variable_name, closure_index, elem
+        )
 
 
 class AnalyzedFunction(object):
@@ -789,7 +927,6 @@ class AnalyzedFunction(object):
         analyzed_code,
         lambda_element,
         apply_propagate_attrs,
-        kw,
         fn,
     ):
         self.analyzed_code = analyzed_code
@@ -799,7 +936,7 @@ class AnalyzedFunction(object):
 
         self._instrument_and_run_function(lambda_element)
 
-        self._coerce_expression(lambda_element, apply_propagate_attrs, kw)
+        self._coerce_expression(lambda_element, apply_propagate_attrs)
 
     def _instrument_and_run_function(self, lambda_element):
         analyzed_code = self.analyzed_code
@@ -832,13 +969,19 @@ class AnalyzedFunction(object):
                 if closure_index is not None:
                     value = closure[closure_index].cell_contents
                     new_closure[name] = bind = PyWrapper(
-                        name, value, closure_index=closure_index
+                        fn,
+                        name,
+                        value,
+                        closure_index=closure_index,
+                        track_bound_values=(
+                            self.analyzed_code.track_bound_values
+                        ),
                     )
                     if track_closure_variables:
                         closure_pywrappers.append(bind)
                 else:
                     value = fn.__globals__[name]
-                    new_globals[name] = bind = PyWrapper(name, value)
+                    new_globals[name] = bind = PyWrapper(fn, name, value)
 
             # rewrite the original fn.   things that look like they will
             # become bound parameters are wrapped in a PyWrapper.
@@ -863,7 +1006,7 @@ class AnalyzedFunction(object):
             # variable.
             self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
 
-    def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw):
+    def _coerce_expression(self, lambda_element, apply_propagate_attrs):
         """Run the tracker-generated expression through coercion rules.
 
         After the user-defined lambda has been invoked to produce a statement
@@ -882,7 +1025,6 @@ class AnalyzedFunction(object):
                         lambda_element.role,
                         sub_expr,
                         apply_propagate_attrs=apply_propagate_attrs,
-                        **kw
                     )
                     for sub_expr in expr
                 ]
@@ -892,7 +1034,6 @@ class AnalyzedFunction(object):
                     lambda_element.role,
                     expr,
                     apply_propagate_attrs=apply_propagate_attrs,
-                    **kw
                 )
                 self.is_sequence = False
         else:
@@ -956,7 +1097,16 @@ class PyWrapper(ColumnOperators):
 
     """
 
-    def __init__(self, name, to_evaluate, closure_index=None, getter=None):
+    def __init__(
+        self,
+        fn,
+        name,
+        to_evaluate,
+        closure_index=None,
+        getter=None,
+        track_bound_values=True,
+    ):
+        self.fn = fn
         self._name = name
         self._to_evaluate = to_evaluate
         self._param = None
@@ -964,28 +1114,35 @@ class PyWrapper(ColumnOperators):
         self._bind_paths = {}
         self._getter = getter
         self._closure_index = closure_index
+        self.track_bound_values = track_bound_values
 
     def __call__(self, *arg, **kw):
         elem = object.__getattribute__(self, "_to_evaluate")
         value = elem(*arg, **kw)
-        if coercions._deep_is_literal(value) and not isinstance(
-            # TODO: coverage where an ORM option or similar is here
-            value,
-            traversals.HasCacheKey,
+        if (
+            self._sa_track_bound_values
+            and coercions._deep_is_literal(value)
+            and not isinstance(
+                # TODO: coverage where an ORM option or similar is here
+                value,
+                traversals.HasCacheKey,
+            )
         ):
-            # TODO: we can instead scan the arguments and make sure they
-            # are all Python literals
-
-            # TODO: coverage
             name = object.__getattribute__(self, "_name")
             raise exc.InvalidRequestError(
                 "Can't invoke Python callable %s() inside of lambda "
-                "expression argument; lambda cache keys should not call "
-                "regular functions since the caching "
-                "system does not track the values of the arguments passed "
-                "to the functions.  Call the function outside of the lambda "
-                "and assign to a local variable that is used in the lambda."
-                % (name)
+                "expression argument at %s; lambda SQL constructs should "
+                "not invoke functions from closure variables to produce "
+                "literal values since the "
+                "lambda SQL system normally extracts bound values without "
+                "actually "
+                "invoking the lambda or any functions within it.  Call the "
+                "function outside of the "
+                "lambda and assign to a local variable that is used in the "
+                "lambda as a closure variable, or set "
+                "track_bound_values=False if the return value of this "
+                "function is used in some other way other than a SQL bound "
+                "value." % (name, self._sa_fn.__code__)
             )
         else:
             return value
@@ -1018,6 +1175,14 @@ class PyWrapper(ColumnOperators):
             param.type = type_api._resolve_value_to_type(to_evaluate)
         return param._with_value(to_evaluate, maintain_key=True)
 
+    def __bool__(self):
+        to_evaluate = object.__getattribute__(self, "_to_evaluate")
+        return bool(to_evaluate)
+
+    def __nonzero__(self):
+        to_evaluate = object.__getattribute__(self, "_to_evaluate")
+        return bool(to_evaluate)
+
     def __getattribute__(self, key):
         if key.startswith("_sa_"):
             return object.__getattribute__(self, key[4:])
@@ -1026,6 +1191,7 @@ class PyWrapper(ColumnOperators):
             "operate",
             "reverse_operate",
             "__class__",
+            "__dict__",
         ):
             return object.__getattribute__(self, key)
 
@@ -1064,8 +1230,10 @@ class PyWrapper(ColumnOperators):
         elem = object.__getattribute__(self, "_to_evaluate")
         value = getter(elem)
 
-        if coercions._deep_is_literal(value):
-            wrapper = PyWrapper(key, value, getter=getter)
+        rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
+
+        if coercions._deep_is_literal(rolled_down_value):
+            wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
             bind_paths[bind_path_key] = wrapper
             return wrapper
         else:
index f87f61074591ac57f7e92bdc8ac1b676b92e34a3..5476729612bcfff15159a7cf37918bc2015c2f2d 100644 (file)
@@ -703,7 +703,7 @@ class SelectInEagerLoadTest(NoCache, fixtures.MappedTest):
         # this is because the test was previously making use of the same
         # loader option state repeatedly without rebuilding it.
 
-        @profiling.function_call_count()
+        @profiling.function_call_count(warmup=1)
         def go():
             for i in range(100):
                 obj = q.all()
index d4fae7f6feaa307c541dc48830361d869f88fa61..b190f46d63c9f0e5244a0feced514caa1f2019eb 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.future import select
+from sqlalchemy.orm import aliased
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
@@ -132,19 +133,62 @@ class LambdaTest(QueryTest, AssertsCompiledSQL):
             fn = random.choice([go1, go2])
             fn()
 
-    def test_entity_round_trip(self, plain_fixture):
+    @testing.combinations(
+        (True, True),
+        (True, False),
+        (False, False),
+        argnames="use_aliased,use_indirect_access",
+    )
+    def test_entity_round_trip(
+        self, plain_fixture, use_aliased, use_indirect_access
+    ):
         User, Address = plain_fixture
 
         s = Session(testing.db, future=True)
 
-        def query(names):
-            stmt = lambda_stmt(
-                lambda: select(User)
-                .where(User.name.in_(names))
-                .options(selectinload(User.addresses))
-            ) + (lambda s: s.order_by(User.id))
+        if use_aliased:
+            if use_indirect_access:
 
-            return s.execute(stmt)
+                def query(names):
+                    class Foo(object):
+                        def __init__(self):
+                            self.u1 = aliased(User)
+
+                    f1 = Foo()
+
+                    stmt = lambda_stmt(
+                        lambda: select(f1.u1)
+                        .where(f1.u1.name.in_(names))
+                        .options(selectinload(f1.u1.addresses)),
+                        track_on=[f1.u1],
+                    ).add_criteria(
+                        lambda s: s.order_by(f1.u1.id), track_on=[f1.u1]
+                    )
+
+                    return s.execute(stmt)
+
+            else:
+
+                def query(names):
+                    u1 = aliased(User)
+                    stmt = lambda_stmt(
+                        lambda: select(u1)
+                        .where(u1.name.in_(names))
+                        .options(selectinload(u1.addresses))
+                    ) + (lambda s: s.order_by(u1.id))
+
+                    return s.execute(stmt)
+
+        else:
+
+            def query(names):
+                stmt = lambda_stmt(
+                    lambda: select(User)
+                    .where(User.name.in_(names))
+                    .options(selectinload(User.addresses))
+                ) + (lambda s: s.order_by(User.id))
+
+                return s.execute(stmt)
 
         def go1():
             r1 = query(["ed"])
index 87589d3be14a0779f70f8a487955bc8251fe0631..2d01410234b5b989ce26abe7483f7a9573ae55f5 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy.orm import defer
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import mapper
+from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
@@ -84,6 +85,32 @@ class _Fixtures(_fixtures.FixtureTest):
         )
         return HasFoob, UserWFoob
 
+    @testing.fixture
+    def multi_mixin_fixture(self):
+        orders, items = self.tables.orders, self.tables.items
+        order_items = self.tables.order_items
+
+        class HasFoob(object):
+            description = Column(String)
+
+        class HasBat(HasFoob):
+            some_nothing = Column(Integer)
+
+        class Order(HasFoob, self.Comparable):
+            pass
+
+        class Item(HasBat, self.Comparable):
+            pass
+
+        base = registry()
+        base.map_imperatively(
+            Order,
+            orders,
+            properties={"items": relationship("Item", secondary=order_items)},
+        )
+        base.map_imperatively(Item, items)
+        return HasFoob, Order, Item
+
 
 class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
     """
@@ -598,6 +625,66 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
 
             eq_(s.execute(stmt).scalars().all(), [UserWFoob(name=name)])
 
+    def test_unnamed_param_dont_fail(self, multi_mixin_fixture):
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        def go(stmt, value):
+            return stmt.options(
+                with_loader_criteria(
+                    HasFoob,
+                    lambda cls: cls.description == "order 3",
+                    include_aliases=True,
+                )
+            )
+
+        with Session(testing.db) as sess:
+            for i in range(10):
+                name = random.choice(["order 1", "order 3", "order 5"])
+
+                statement = select(Order)
+                stmt = go(statement, name)
+
+                eq_(
+                    sess.execute(stmt).scalars().all(),
+                    [Order(description="order 3")],
+                )
+
+    def test_caching_and_binds_lambda_more_mixins(self, multi_mixin_fixture):
+        # By including non-mapped mixin HasBat in the middle of the
+        # hierarchy, we test issue #5766
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        def go(stmt, value):
+            return stmt.options(
+                with_loader_criteria(
+                    HasFoob,
+                    lambda cls: cls.description == value,
+                    include_aliases=True,
+                )
+            )
+
+        with Session(testing.db) as sess:
+            for i in range(10):
+                name = random.choice(["order 1", "order 3", "order 5"])
+
+                statement = select(Order)
+                stmt = go(statement, name)
+
+                eq_(
+                    sess.execute(stmt).scalars().all(),
+                    [Order(description=name)],
+                )
+
+                name = random.choice(["item 1", "item 3", "item 5"])
+
+                statement = select(Item)
+                stmt = go(statement, name)
+
+                eq_(
+                    sess.execute(stmt).scalars().all(),
+                    [Item(description=name)],
+                )
+
     def test_never_for_refresh(self, user_address_fixture):
         User, Address = user_address_fixture
 
index c75942564720564aff1320b8f3df1a85630fb12b..4b486f87174a4b6f423056d4d8798c2afe4a189b 100644 (file)
@@ -619,6 +619,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             "subqueryload": subqueryload,
         }
 
+        # NOTE: make sure this test continues to run many different
+        # combinations for the *same* mappers above; that is, don't tear the
+        # mappers down and build them up for every "config".  This allows
+        # testing of the LRUCache that's associated with LazyLoader
+        # and SelectInLoader and how they interact with the lambda query
+        # API, which stores AnalyzedFunction objects in this cache.
+
         for o, i, k, count in configs:
             options = []
             if o in callables:
@@ -629,7 +636,6 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
                 options.append(
                     callables[k](User.orders, Order.items, Item.keywords)
                 )
-
             self._do_query_tests(options, count)
 
     def _do_mapper_test(self, configs):
@@ -716,6 +722,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             )
 
         self.assert_sql_count(testing.db, go, count)
+        return
 
         eq_(
             sess.query(User)
index 71788209aaca1c19a66b5ca5f2f82c2545f118b4..d96f802f08b9c26b0a102837edf492b165618380 100644 (file)
@@ -370,8 +370,6 @@ test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.8_s
 
 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results
 
-test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 233595
-test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 251018
 test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 246134
 test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 264752
 
index 6fa961e4d4d0abffa7dd27cc4fb44baef0f57d00..70281d4e89d46f23786c956e6207eb8703194383 100644 (file)
@@ -57,6 +57,7 @@ from sqlalchemy.sql.functions import GenericFunction
 from sqlalchemy.sql.functions import ReturnTypeFromArgs
 from sqlalchemy.sql.lambdas import lambda_stmt
 from sqlalchemy.sql.lambdas import LambdaElement
+from sqlalchemy.sql.lambdas import LambdaOptions
 from sqlalchemy.sql.selectable import _OffsetLimitParam
 from sqlalchemy.sql.selectable import AliasedReturnsRows
 from sqlalchemy.sql.selectable import FromGrouping
@@ -859,7 +860,9 @@ class CoreFixtures(object):
             d = {"g": random.randint(40, 45)}
 
             return LambdaElement(
-                lambda: and_(table_a.c.b == d["g"]), roles.WhereHavingRole
+                lambda: and_(table_a.c.b == d["g"]),
+                roles.WhereHavingRole,
+                opts=LambdaOptions(track_closure_variables=False),
             )
 
         def seven():
index a70dc051165a8217525760dc1c4a13ea795960db..e8e4a8d2a77d99f2e075e6c0bcca00ba7c1971be 100644 (file)
@@ -11,6 +11,8 @@ from sqlalchemy.sql import column
 from sqlalchemy.sql import join
 from sqlalchemy.sql import lambda_stmt
 from sqlalchemy.sql import lambdas
+from sqlalchemy.sql import literal
+from sqlalchemy.sql import null
 from sqlalchemy.sql import roles
 from sqlalchemy.sql import select
 from sqlalchemy.sql import table
@@ -27,7 +29,7 @@ from sqlalchemy.types import Integer
 from sqlalchemy.types import String
 
 
-class DeferredLambdaTest(
+class LambdaElementTest(
     fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL
 ):
     __dialect__ = "default"
@@ -274,6 +276,75 @@ class DeferredLambdaTest(
             checkparams={"x_1": 10, "x_2": 15},
         )
 
+    def test_conditional_must_be_tracked(self):
+        tab = table("foo", column("id"), column("col"))
+
+        def run_my_statement(parameter, add_criteria=False):
+            stmt = lambda_stmt(lambda: select(tab))
+
+            stmt = stmt.add_criteria(
+                lambda s: s.where(tab.c.col > parameter)
+                if add_criteria
+                else s.where(tab.c.col == parameter),
+            )
+
+            stmt += lambda s: s.order_by(tab.c.id)
+
+            return stmt
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Closure variable named 'add_criteria' inside of lambda callable",
+            run_my_statement,
+            5,
+            False,
+        )
+
+    def test_boolean_conditionals(self):
+
+        tab = table("foo", column("id"), column("col"))
+
+        def run_my_statement(parameter, add_criteria=False):
+            stmt = lambda_stmt(lambda: select(tab))
+
+            stmt = stmt.add_criteria(
+                lambda s: s.where(tab.c.col > parameter)
+                if add_criteria
+                else s.where(tab.c.col == parameter),
+                track_on=[add_criteria],
+            )
+
+            stmt += lambda s: s.order_by(tab.c.id)
+
+            return stmt
+
+        c1 = run_my_statement(5, False)
+        c2 = run_my_statement(10, True)
+        c3 = run_my_statement(18, False)
+
+        ck1 = c1._generate_cache_key()
+        ck2 = c2._generate_cache_key()
+        ck3 = c3._generate_cache_key()
+
+        eq_(ck1[0], ck3[0])
+        ne_(ck1[0], ck2[0])
+
+        self.assert_compile(
+            c1,
+            "SELECT foo.id, foo.col FROM foo WHERE "
+            "foo.col = :parameter_1 ORDER BY foo.id",
+        )
+        self.assert_compile(
+            c2,
+            "SELECT foo.id, foo.col FROM foo "
+            "WHERE foo.col > :parameter_1 ORDER BY foo.id",
+        )
+        self.assert_compile(
+            c3,
+            "SELECT foo.id, foo.col FROM foo WHERE "
+            "foo.col = :parameter_1 ORDER BY foo.id",
+        )
+
     def test_stmt_lambda_plain_customtrack(self):
         c2 = column("y")
 
@@ -487,10 +558,11 @@ class DeferredLambdaTest(
 
         self.assert_compile(s1, "SELECT x WHERE :x")
 
-    def test_stmt_lambda_w_additional_hashable_variants(self):
-        # note a Python 2 old style class would fail here because it
-        # isn't hashable.   right now we do a hard check for __hash__ which
-        # will raise if the attr isn't present
+    def test_reject_plain_object(self):
+        # with #5765 we move to no longer allow closure variables that
+        # refer to unknown types of objects inside the lambda.  these have
+        # to be resolved outside of the lambda because we otherwise can't
+        # be sure they can be safely used as cache keys.
         class Thing(object):
             def __init__(self, col_expr):
                 self.col_expr = col_expr
@@ -501,6 +573,83 @@ class DeferredLambdaTest(
 
             return stmt
 
+        c1 = Thing(column("x"))
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Closure variable named 'thing' inside of lambda callable",
+            go,
+            c1,
+            5,
+        )
+
+    def test_plain_object_ok_w_tracking_disabled(self):
+        # with #5765 we move to no longer allow closure variables that
+        # refer to unknown types of objects inside the lambda.  these have
+        # to be resolved outside of the lambda because we otherwise can't
+        # be sure they can be safely used as cache keys.
+        class Thing(object):
+            def __init__(self, col_expr):
+                self.col_expr = col_expr
+
+        def go(thing, q):
+            stmt = lambdas.lambda_stmt(
+                lambda: select(thing.col_expr), track_closure_variables=False
+            )
+            stmt = stmt.add_criteria(
+                lambda stmt: stmt.where(thing.col_expr == q),
+                track_closure_variables=False,
+            )
+
+            return stmt
+
+        c1 = Thing(column("x"))
+        c2 = Thing(column("y"))
+
+        s1 = go(c1, 5)
+        s2 = go(c2, 10)
+        s3 = go(c1, 8)
+        s4 = go(c2, 12)
+
+        self.assert_compile(
+            s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5}
+        )
+        # note this is wrong, because no tracking
+        self.assert_compile(
+            s2, "SELECT x WHERE x = :q_1", checkparams={"q_1": 10}
+        )
+        self.assert_compile(
+            s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8}
+        )
+        # also wrong
+        self.assert_compile(
+            s4, "SELECT x WHERE x = :q_1", checkparams={"q_1": 12}
+        )
+
+        s1key = s1._generate_cache_key()
+        s2key = s2._generate_cache_key()
+        s3key = s3._generate_cache_key()
+        s4key = s4._generate_cache_key()
+
+        # all one cache key
+        eq_(s1key[0], s3key[0])
+        eq_(s2key[0], s4key[0])
+        eq_(s1key[0], s2key[0])
+
+    def test_plain_object_used_outside_lambda(self):
+        # test the above 'test_reject_plain_object' with the expected
+        # workaround
+
+        class Thing(object):
+            def __init__(self, col_expr):
+                self.col_expr = col_expr
+
+        def go(thing, q):
+            col_expr = thing.col_expr
+            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+            stmt += lambda stmt: stmt.where(col_expr == q)
+
+            return stmt
+
         c1 = Thing(column("x"))
         c2 = Thing(column("y"))
 
@@ -538,13 +687,92 @@ class DeferredLambdaTest(
         opts = {column("x"), column("y")}
 
         assert_raises_message(
-            exc.ArgumentError,
-            'Can\'t create a cache key for lambda closure variable "opts" '
-            "because it's a set.  try using a list",
+            exc.InvalidRequestError,
+            "Closure variable named 'opts' inside of lambda callable ",
             stmt.__add__,
             lambda stmt: stmt.options(*opts),
         )
 
+    def test_detect_embedded_callables_one(self):
+        t1 = table("t1", column("q"))
+
+        x = 1
+
+        def go():
+            def foo():
+                return x
+
+            stmt = select(t1).where(lambda: t1.c.q == foo())
+            return stmt
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            r"Can't invoke Python callable foo\(\) inside of lambda "
+            "expression ",
+            go,
+        )
+
+    def test_detect_embedded_callables_two(self):
+        t1 = table("t1", column("q"), column("y"))
+
+        def go():
+            def foo():
+                return t1.c.y
+
+            stmt = select(t1).where(lambda: t1.c.q == foo())
+            return stmt
+
+        self.assert_compile(
+            go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y"
+        )
+
+    def test_detect_embedded_callables_three(self):
+        t1 = table("t1", column("q"), column("y"))
+
+        def go():
+            def foo():
+                t1.c.y
+
+            stmt = select(t1).where(lambda: t1.c.q == getattr(t1.c, "y"))
+            return stmt
+
+        self.assert_compile(
+            go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y"
+        )
+
+    def test_detect_embedded_callables_four(self):
+        t1 = table("t1", column("q"))
+
+        x = 1
+
+        def go():
+            def foo():
+                return x
+
+            stmt = select(t1).where(
+                lambdas.LambdaElement(
+                    lambda: t1.c.q == foo(),
+                    roles.WhereHavingRole,
+                    lambdas.LambdaOptions(track_bound_values=False),
+                )
+            )
+            return stmt
+
+        self.assert_compile(
+            go(),
+            "SELECT t1.q FROM t1 WHERE t1.q = :q_1",
+            checkparams={"q_1": 1},
+        )
+
+        # we're not tracking it
+        x = 2
+
+        self.assert_compile(
+            go(),
+            "SELECT t1.q FROM t1 WHERE t1.q = :q_1",
+            checkparams={"q_1": 1},
+        )
+
     def test_stmt_lambda_w_list_of_opts(self):
         def go(opts):
             stmt = lambdas.lambda_stmt(lambda: select(column("x")))
@@ -755,6 +983,23 @@ class DeferredLambdaTest(
             },
         )
 
+    def test_in_columnelement(self):
+        # test issue #5768
+
+        def go():
+            v = [literal("a"), literal("b")]
+            expr1 = select(1).where(lambda: column("q").in_(v))
+            return expr1
+
+        self.assert_compile(go(), "SELECT 1 WHERE q IN (:param_1, :param_2)")
+
+        self.assert_compile(
+            go(),
+            "SELECT 1 WHERE q IN (:param_1, :param_2)",
+            render_postcompile=True,
+            checkparams={"param_1": "a", "param_2": "b"},
+        )
+
     def test_select_columns_clause(self):
         t1 = table("t1", column("q"), column("p"))
 
@@ -854,14 +1099,28 @@ class DeferredLambdaTest(
             expr,
         )
 
-    def test_dict_literal_keys(self, user_address_fixture):
+    def test_reject_dict_literal_keys(self, user_address_fixture):
         users, addresses = user_address_fixture
 
         names = {"x": "some name"}
         lmb = lambda: users.c.name == names["x"]  # noqa
 
-        expr = coercions.expect(roles.WhereHavingRole, lmb)
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Closure variable named 'names' inside of lambda callable",
+            coercions.expect,
+            roles.WhereHavingRole,
+            lmb,
+        )
+
+    def test_dict_literal_keys_proper_use(self, user_address_fixture):
+        users, addresses = user_address_fixture
 
+        names = {"x": "some name"}
+        x = names["x"]
+        lmb = lambda: users.c.name == x  # noqa
+
+        expr = coercions.expect(roles.WhereHavingRole, lmb)
         self.assert_compile(
             expr,
             "users.name = :x_1",
@@ -1158,7 +1417,7 @@ class DeferredLambdaTest(
             ),
         )
 
-    def test_cache_key_thing(self):
+    def test_cache_key_bindparam_matches(self):
         t1 = table("t1", column("q"), column("p"))
 
         def go(x):
@@ -1169,3 +1428,294 @@ class DeferredLambdaTest(
 
         is_(expr1._generate_cache_key().bindparams[0], expr1._resolved.right)
         is_(expr2._generate_cache_key().bindparams[0], expr2._resolved.right)
+
+    def test_cache_key_instance_variable_issue_incorrect(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        class Foo(object):
+            def __init__(self, value):
+                self.value = value
+
+        def go(foo):
+            return coercions.expect(
+                roles.WhereHavingRole, lambda: t1.c.q == foo.value
+            )
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Closure variable named 'foo' inside of lambda callable",
+            go,
+            Foo(5),
+        )
+
+    def test_cache_key_instance_variable_issue_correct_one(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        class Foo(object):
+            def __init__(self, value):
+                self.value = value
+
+        def go(foo):
+            value = foo.value
+            return coercions.expect(
+                roles.WhereHavingRole, lambda: t1.c.q == value
+            )
+
+        expr1 = go(Foo(5))
+        expr2 = go(Foo(10))
+
+        c1 = expr1._generate_cache_key()
+        c2 = expr2._generate_cache_key()
+        eq_(c1, c2)
+
+    def test_cache_key_instance_variable_issue_correct_two(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        class Foo(object):
+            def __init__(self, value):
+                self.value = value
+
+        def go(foo):
+            return coercions.expect(
+                roles.WhereHavingRole,
+                lambda: t1.c.q == foo.value,
+                track_on=[self],
+            )
+
+        expr1 = go(Foo(5))
+        expr2 = go(Foo(10))
+
+        c1 = expr1._generate_cache_key()
+        c2 = expr2._generate_cache_key()
+        eq_(c1, c2)
+
+    def test_insert_statement(self, user_address_fixture):
+        users, addresses = user_address_fixture
+
+        def ins(id_, name):
+            stmt = lambda_stmt(lambda: users.insert())
+            stmt += lambda s: s.values(id=id_, name=name)
+            return stmt
+
+        with testing.db.begin() as conn:
+            conn.execute(ins(12, "foo"))
+
+            eq_(
+                conn.execute(select(users).where(users.c.id == 12)).first(),
+                (12, "foo"),
+            )
+
+    def test_update_statement(self, user_address_fixture):
+        users, addresses = user_address_fixture
+
+        def upd(id_, newname):
+            stmt = lambda_stmt(lambda: users.update())
+            stmt += lambda s: s.values(name=newname)
+            stmt += lambda s: s.where(users.c.id == id_)
+            return stmt
+
+        with testing.db.begin() as conn:
+            conn.execute(users.insert().values(id=7, name="bar"))
+            conn.execute(upd(7, "foo"))
+
+            eq_(
+                conn.execute(select(users).where(users.c.id == 7)).first(),
+                (7, "foo"),
+            )
+
+
+class DeferredLambdaElementTest(
+    fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL
+):
+    __dialect__ = "default"
+
+    @testing.fails("wontfix issue #5767")
+    def test_detect_change_in_binds_no_tracking(self):
+        t1 = table("t1", column("q"), column("p"))
+        t2 = table("t2", column("q"), column("p"))
+
+        vv = [1, 2, 3]
+        # lambda produces either "t1 IN vv" or "NULL" based on the
+        # argument.  will not produce a consistent cache key
+        elem = lambdas.DeferredLambdaElement(
+            lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else null(),
+            roles.WhereHavingRole,
+            lambda_args=(t1,),
+            opts=lambdas.LambdaOptions(track_closure_variables=False),
+        )
+
+        self.assert_compile(elem.expr, "NULL")
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            r"Lambda callable at %s produced "
+            "a different set of bound parameters "
+            "than its original run: vv" % (elem.fn.__code__),
+            elem._resolve_with_args,
+            t2,
+        )
+
+    def test_detect_change_in_binds_tracking_positive(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        vv = [1, 2, 3]
+
+        # lambda produces either "t1 IN vv" or "NULL" based on the
+        # argument.  will not produce a consistent cache key
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Closure variable named 'vv' inside of lambda callable",
+            lambdas.DeferredLambdaElement,
+            lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else None,
+            roles.WhereHavingRole,
+            opts=lambdas.LambdaOptions,
+            lambda_args=(t1,),
+        )
+
+    @testing.fails("wontfix issue #5767")
+    def test_detect_change_in_binds_tracking_negative(self):
+        t1 = table("t1", column("q"), column("p"))
+        t2 = table("t2", column("q"), column("p"))
+
+        vv = [1, 2, 3]
+        qq = [3, 4, 5]
+
+        # lambda produces either "t1 IN vv" or "t2 IN qq" based on the
+        # argument.  will not produce a consistent cache key
+        elem = lambdas.DeferredLambdaElement(
+            lambda tab: tab.c.q.in_(vv)
+            if tab.name == "t1"
+            else tab.c.q.in_(qq),
+            roles.WhereHavingRole,
+            lambda_args=(t1,),
+            opts=lambdas.LambdaOptions(track_closure_variables=False),
+        )
+
+        self.assert_compile(elem.expr, "t1.q IN ([POSTCOMPILE_vv_1])")
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            r"Lambda callable at %s produced "
+            "a different set of bound parameters "
+            "than its original run: qq" % (elem.fn.__code__),
+            elem._resolve_with_args,
+            t2,
+        )
+
+    def _fixture_one(self, t1):
+        vv = [1, 2, 3]
+
+        def go():
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q.in_(vv),
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    def _fixture_two(self, t1):
+        def go():
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q == "x",
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    def _fixture_three(self, t1):
+        def go():
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q != "x",
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    def _fixture_four(self, t1):
+        def go():
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q.in_([1, 2, 3]),
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    def _fixture_five(self, t1):
+        def go():
+            x = "x"
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q == x,
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    def _fixture_six(self, t1):
+        def go():
+            x = "x"
+            elem = lambdas.DeferredLambdaElement(
+                lambda tab: tab.c.q != x,
+                roles.WhereHavingRole,
+                lambda_args=(t1,),
+                opts=lambdas.LambdaOptions,
+            )
+            return elem
+
+        return go
+
+    @testing.combinations(
+        ("_fixture_one",),
+        ("_fixture_two",),
+        ("_fixture_three",),
+        ("_fixture_four",),
+        ("_fixture_five",),
+        ("_fixture_six",),
+    )
+    def test_cache_key_many_different_args(self, fixture_name):
+        t1 = table("t1", column("q"), column("p"))
+        t2 = table("t2", column("q"), column("p"))
+        t3 = table("t3", column("q"), column("p"))
+
+        go = getattr(self, fixture_name)(t1)
+
+        g1 = go()
+        g2 = go()
+
+        g1key = g1._generate_cache_key()
+        g2key = g2._generate_cache_key()
+        eq_(g1key[0], g2key[0])
+
+        e1 = go()._resolve_with_args(t1)
+        e2 = go()._resolve_with_args(t2)
+        e3 = go()._resolve_with_args(t3)
+
+        e1key = e1._generate_cache_key()
+        e2key = e2._generate_cache_key()
+        e3key = e3._generate_cache_key()
+
+        e12 = go()._resolve_with_args(t1)
+        e32 = go()._resolve_with_args(t3)
+
+        e12key = e12._generate_cache_key()
+        e32key = e32._generate_cache_key()
+
+        ne_(e1key[0], e2key[0])
+        ne_(e2key[0], e3key[0])
+
+        eq_(e12key[0], e1key[0])
+        eq_(e32key[0], e3key[0])