From: Mike Bayer Date: Sat, 12 Dec 2020 23:56:58 +0000 (-0500) Subject: Major revisals to lambdas X-Git-Tag: rel_1_4_0b2~95 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=77c9534dcaf3723f7b2baf42442eda3e1d8c3332;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Major revisals to lambdas 1. Improve coercions._deep_is_literal to check sequences for clause elements, thus allowing a phrase like lambda: col.in_([literal("x"), literal("y")]) to be handled 2. revise closure variable caching completely. All variables entering must be part of a closure cache key or rejected. only objects that can be resolved to HasCacheKey or FunctionType are accepted; all other types are rejected. This adds a high degree of strictness to lambdas and will make them a little more awkward to use in some cases, however prevents several classes of critical issues: a. previously, a lambda that had an expression derived from some kind of state, like "self.x", or "execution_context.session.foo" would produce a closure cache key from "self" or "execution_context", objects that can very well be per-execution and would therefore cause a AnalyzedFunction objects to overflow. (memory won't leak as it looks like an LRUCache is already used for these) b. a lambda, such as one used within DeferredLamdaElement, that produces different SQL expressions based on the arguments (which is in fact what it's supposed to do), however it would through the use of conditionals produce different bound parameter combinations, leading to literal parameters not tracked properly. These are now rejected as uncacheable whereas previously they would again be part of the closure cache key, causing an overflow of AnalyizedFunction objects. 3. Ensure non-mapped mixins are handled correctly by with_loader_criteria(). 4. Fixed bug in lambda SQL system where we are not supposed to allow a Python function to be embedded in the lambda, since we can't predict a bound value from it. While there was an error condition added for this, it was not tested and wasn't working; an informative error is now raised. 5. new docs for lambdas 6. consolidated changelog for all of these Fixes: #5760 Fixes: #5765 Fixes: #5766 Fixes: #5768 Fixes: #5770 Change-Id: Iedaa636c3225fad496df23b612c516c8ab247ab7 --- diff --git a/doc/build/changelog/unreleased_14/5760.rst b/doc/build/changelog/unreleased_14/5760.rst new file mode 100644 index 0000000000..053eb5a0e3 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5760.rst @@ -0,0 +1,69 @@ +.. change:: + :tags: bug, sql, orm + :tickets: 5760, 5763, 5765, 5768, 5770 + + A wide variety of fixes to the "lambda SQL" feature introduced at + :ref:`engine_lambda_caching` have been implemented based on user feedback, + with an emphasis on its use within the :func:`_orm.with_loader_criteria` + feature where it is most prominently used [ticket:5760]: + + * fixed issue where boolean True/False values referred towards in the + closure variables of the lambda would cause failures [ticket:5763] + + * Repaired a non-working detection for Python functions embedded in the + lambda that produce bound values; this case is likely not supportable + so raises an informative error, where the function should be invoked + outside the lambda itself. New documentation has been added to + further detail this behavior. [ticket:5770] + + * The lambda system by default now rejects the use of non-SQL elements + within the closure variables of the lambda entirely, where the error + suggests the two options of either explicitly ignoring closure variables + that are not SQL parameters, or specifying a specific set of values to be + considered as part of the cache key based on hash value. This critically + prevents the lambda system from assuming that arbitrary objects within + the lambda's closure are appropriate for caching while also refusing to + ignore them by default, preventing the case where their state might + not be constant and have an impact on the SQL construct produced. + The error message is comprehensive and new documentation has been + added to further detail this behavior. [ticket:5765] + + * Fixed support for the edge case where an ``in_()`` expression + against a list of SQL elements, such as :func:`_sql.literal` objects, + would fail to be accommodated correctly. [ticket:5768] + + +.. change:: + :tags: bug, orm + :tickets: 5760, 5766, 5762, 5761, 5764 + + Related to the fixes for the lambda criteria system within Core, within the + ORM implemented a variety of fixes for the + :func:`_orm.with_loader_criteria` feature as well as the + :meth:`_orm.SessionEvents.do_orm_execute` event handler that is often + used in conjunction [ticket:5760]: + + + * fixed issue where :func:`_orm.with_loader_criteria` function would fail + if the given entity or base included non-mapped mixins in its descending + class hierarchy [ticket:5766] + + * The :func:`_orm.with_loader_criteria` feature is now unconditionally + disabled for the case of ORM "refresh" operations, including loads + of deferred or expired column attributes as well as for explicit + operations like :meth:`_orm.Session.refresh`. These loads are necessarily + based on primary key identity where addiional WHERE criteria is + never appropriate. [ticket:5762] + + * Added new attribute :attr:`_orm.ORMExecuteState.is_column_load` to indicate + that a :meth:`_orm.SessionEvents.do_orm_execute` handler that a particular + operation is a primary-key-directed column attribute load, where additional + criteria should not be added. The :func:`_orm.with_loader_criteria` + function as above ignores these in any case now. [ticket:5761] + + * Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load` + attribute would not be set correctly for many lazy loads as well as all + selectinloads. The flag is essential in order to test if options should + be added to statements or if they would already have been propagated via + relationship loads. [ticket:5764] + diff --git a/doc/build/changelog/unreleased_14/5761.rst b/doc/build/changelog/unreleased_14/5761.rst deleted file mode 100644 index 9e36f2a898..0000000000 --- a/doc/build/changelog/unreleased_14/5761.rst +++ /dev/null @@ -1,10 +0,0 @@ -.. change:: - :tags: bug, orm - :tickets: 5761 - - Added new attribute :attr:`_orm.ORMExecuteState.is_column_load` to indicate - that a :meth:`_orm.SessionEvents.do_orm_execute` handler that a particular - operation is a primary-key-directed column attribute load, such as from an - expiration or a deferred attribute, and that WHERE criteria or additional - loader options should not be added to the query. This has been added to - the examples which illustrate the :func:`_orm.with_loader_criteria` option. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_14/5762.rst b/doc/build/changelog/unreleased_14/5762.rst deleted file mode 100644 index 7b5a90cdf1..0000000000 --- a/doc/build/changelog/unreleased_14/5762.rst +++ /dev/null @@ -1,10 +0,0 @@ -.. change:: - :tags: bug, orm - :tickets: 5762 - - The :func:`_orm.with_loader_criteria` option has been modified so that it - will never apply its criteria to the SELECT statement for an ORM refresh - operation, such as that invoked by :meth:`_orm.Session.refresh` or whenever - an expired attribute is loaded. These queries are only against the - primary key row of the object that is already present in memory so there - should not be additional criteria added. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_14/5763.rst b/doc/build/changelog/unreleased_14/5763.rst deleted file mode 100644 index e395b6fcfe..0000000000 --- a/doc/build/changelog/unreleased_14/5763.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. change:: - :tags: bug, orm - :tickets: 5763 - - Fixed bug in lambda SQL feature, used by ORM - :meth:`_orm.with_loader_criteria` as well as available generally in the SQL - expression language, where assigning a boolean value True/False to a - variable would cause the query-time expression calculation to fail, as it - would produce a SQL expression not compatible with a bound value. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_14/5764.rst b/doc/build/changelog/unreleased_14/5764.rst deleted file mode 100644 index 29753fafe6..0000000000 --- a/doc/build/changelog/unreleased_14/5764.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. change:: - :tags: orm, bug - :tickets: 5764 - - Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load` - attribute would not be set correctly for many lazy loads, all - selectinloads, etc. The flag is essential in order to test if options - should be added to statements or if they would already have been propagated - via relationship loads. \ No newline at end of file diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index aeefd27407..629bec333a 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -1114,124 +1114,414 @@ Using Lambdas to add significant speed gains to statement production not appropriate for novice Python developers. The lambda approach can be applied to at a later time to existing code with a minimal amount of effort. -The caching system has in its roots the SQLAlchemy :ref:`"baked query" -` extension, which made novel use of Python lambdas in order to -produce SQL statements that were intrinsically cacheable, while at the same -time decreasing not just the overhead involved to compile the statement into -SQL, but also the overhead in constructing the statement object from a Python -perspective. The new caching in SQLAlchemy by default does not substantially -optimize the construction of SQL constructs. This refers to the Python -overhead taken up to construct the statement object itself before it is -compiled or executed, such as the :class:`_sql.Select` object used in the -example below:: +Python functions, typically expressed as lambdas, may be used to generate +SQL expressions which are cacheable based on the Python code location of +the lambda function itself as well as the closure variables within the +lambda. The rationale is to allow caching of not only the SQL string-compiled +form of a SQL expression construct as is SQLAlchemy's normal behavior when +the lambda system isn't used, but also the in-Python composition +of the SQL expression construct itself, which also has some degree of +Python overhead. + +The lambda SQL expression feature is available as a performance enhancing +feature, and is also optionally used in the :func:`_orm.with_loader_criteria` +ORM option in order to provide a generic SQL fragment. + +Synopsis +^^^^^^^^ + +Lambda statements are constructed using the :func:`_sql.lambda_stmt` function, +which returns an instance of :class:`_sql.StatementLambdaElement`, which is +itself an executable statement construct. Additional modifiers and criteria +are added to the object using the Python addition operator ``+``, or +alternatively the :meth:`_sql.StatementLambdaElement.add_criteria` method which +allows for more options. + +It is assumed that the :func:`_sql.lambda_stmt` construct is being invoked +within an enclosing function or method that expects to be used many times +within an application, so that subsequent executions beyond the first one +can take advantage of the compiled SQL being cached. When the lambda is +constructed inside of an enclosing function in Python it is then subject +to also having closure variables, which are significant to the whole +approach:: + + from sqlalchemy import lambda_stmt def run_my_statement(connection, parameter): - stmt = select(table) - stmt = stmt.where(table.c.col == parameter) - stmt = stmt.order_by(table.c.id) + stmt = lambda_stmt(lambda: select(table)) + stmt += lambda s: s.where(table.c.col == parameter) + stmt += lambda s: s.order_by(table.c.id) return connection.execute(stmt) -Above, in order to construct ``stmt``, we see three Python functions or methods -``select()``, ``.where()`` and ``.order_by()`` being invoked directly, and -additionally there is a Python method invoked when we construct ``table.c.col -== 'foo'``, as the expression language overrides the ``__eq__()`` method to -produce a SQL construct. Within each of these calls is a series of argument -checking and internal construction logic that makes use of many more Python -function calls. With intense production of thousands of statement objects, -these function calls can add up. Using the recipe for profiling at -:ref:`faq_code_profiling`, the above Python code within the scope of the -``select()`` call down to the ``.order_by()`` call uses 73 Python function -calls to produce. - -Additionally, statement caching requires that a cache key be generated against -the above statement, which must be composed of all elements within the -statement that uniquely identify the SQL that it would produce. Measuring -this process for the above statement takes another 40 Python function calls. - -In order to ensure the full performance gains of the prior "baked query" -extension are still available, the "lambda:" system used by baked queries has -been adapted into a more capable and easier to use system as an intrinsic part -of the SQLAlchemy Core expression language (which by extension then includes -ORM queries, which as of SQLAlchemy 1.4 using 2.0-style APIs may also be -invoked directly from SQLAlchemy Core expression objects). We can -adapt our statement above to be built using "lambdas" by making use of the -:func:`_sql.lambda_stmt` element. Using this approach, we indicate that the -:func:`_sql.select` should be returned by a lambda. We can then add new -criteria to the statement by composing further lambdas onto the object in a -similar manner as how "baked queries" worked:: - - from sqlalchemy import lambda_stmt - - def run_my_statement(connection, parameter): - stmt = lambda_stmt(lambda: select(table)) - stmt += lambda s: s.where(table.c.col == parameter) - stmt += lambda s: s.order_by(table.c.id) - - return connection.execute(stmt) - - result = run_my_statement(some_connection, "some parameter") - -The above code produces a :class:`.StatementLambdaElement`, which behaves like a -Core SQL construct but defers the construction of the statement in most -cases until it is needed by the compiler. If the statement is already cached, -the lambdas will not be called. - -The cache key is based on the **Python source code location of each lambda -itself**, which in the Python interpreter is essentially the ``__code__`` -element of the Python function. This means that the lambda approach should only -be used inside of a function where the lambdas themselves will be the **same -lambdas each time, from a Python source code perspective**. - -The execution process for the above lambda will **extract literal parameters** -from the statement each time, without needing to actually run the lambdas. In -the above example, each time the variable ``parameter`` is used within the -lambda to generate the WHERE clause of the statement, while the actual lambda -present will not actually be run, the value of ``parameter`` will be tracked -and the current value of the variable will be used within the statement -parameters at execution time. This is a feature that was not possible with the -"baked query" extension and involves the use of up-front analysis of the -incoming ``__code__`` object to determine how parameters can be extracted from -future lambdas against that same code object. - -More simply, this means it's safe for the lambda statement -to use arbitrary literal parameters, which don't modify the structure -of the statement, on each invocation:: - - def run_my_statement(connection, parameter): - stmt = lambda_stmt(lambda: select(table)) - stmt += lambda s: s.where(table.c.col == parameter) - stmt += lambda s: s.order_by(table.c.id) - - return connection.execute(stmt) - -However, it's not safe for an individual lambda so modify the SQL structure -of the statement across calls:: - - # incorrect example - def run_my_statement(connection, parameter, add_criteria=False): - stmt = lambda_stmt(lambda: select(table)) - - # will not be cached correctly as add_criteria changes - stmt += lambda s: s.where( - and_(add_criteria, table.c.col == parameter) - if add_criteria - else s.where(table.c.col == parameter) - ) - - stmt += lambda s: s.order_by(table.c.id) - - return connection.execute(stmt) - -The lambda statements indicated above will invoke all of the lambdas the first -time they are constructed; subsequent to that, the lambdas will not be invoked. -On these subsequent runs, a lambda construct will use far fewer Python function -calls in order to construct the un-cached object as well as to generate the -cache key after the first call. The above statement using lambdas takes only -41 Python function calls to generate the whole structure as well as to produce -the cache key, including the extraction of the bound parameters. This is -compared to a total of about 115 Python function calls for the non-lambda -version. + with engine.connect() as conn: + result = run_my_statement(some_connection, "some parameter") + +Above, the three ``lambda`` callables that are used to define the structure +of a SELECT statement are invoked exactly once, and the resulting SQL +string cached in the compilation cache of the engine. From that point +forward, the ``run_my_statement()`` function may be invoked any number +of times and the ``lambda`` callables within it will not be called, only +used as cache keys to retrieve the already-compiled SQL. + +.. note:: It is important to note that there is already SQL caching in place + when the lambda system is not used. The lambda system only adds an + additional layer of work reduction per SQL statement invoked by caching + the building up of the SQL construct itself and also using a simpler + cache key. + + +Quick Guidelines for Lambdas +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Above all, the emphasis within the lambda SQL system is ensuring that there +is never a mismatch between the cache key generated for a lambda and the +SQL string it will produce. The :class:`_sql.LamdaElement` and related +objects will run and analyze the given lambda in order to calculate how +it should be cached on each run, trying to detect any potential problems. +Basic guidelines include: + +* **Any kind of statement is supported** - while it's expected that + :func:`_sql.select` constructs are the prime use case for :func:`_sql.lambda_stmt`, + DML statements such as :func:`_sql.insert` and :func:`_sql.update` are + equally usable:: + + def upd(id_, newname): + stmt = lambda_stmt(lambda: users.update()) + stmt += lambda s: s.values(name=newname) + stmt += lambda s: s.where(users.c.id==id_) + return stmt + + with engine.begin() as conn: + conn.execute(upd(7, "foo")) + + .. + +* **ORM use cases directly supported as well** - the :func:`_sql.lambda_stmt` + can accommodate ORM functionality completely and used directly with + :meth:`_orm.Session.execute`:: + + def select_user(session, name): + stmt = lambda_stmt(lambda: select(User)) + stmt += lambda s: s.where(User.name == name) + + row = session.execute(stmt).first() + return row + + .. + +* **Bound parameters are automatically accommodated** - in contrast to SQLAlchemy's + previous "baked query" system, the lambda SQL system accommodates for + Python literal values which become SQL bound parameters automatically. + This means that even though a given lambda runs only once, the values that + become bound parameters are extracted from the **closure** of the lambda + on every run: + + .. sourcecode:: pycon+sql + + >>> def my_stmt(x, y): + ... stmt = lambda_stmt(lambda: select(func.max(x, y))) + ... return stmt + ... + >>> engine = create_engine("sqlite://", echo=True) + >>> with engine.connect() as conn: + ... print(conn.scalar(my_stmt(5, 10))) + ... print(conn.scalar(my_stmt(12, 8))) + ... + {opensql}SELECT max(?, ?) AS max_1 + [generated in 0.00057s] (5, 10){stop} + 10 + {opensql}SELECT max(?, ?) AS max_1 + [cached since 0.002059s ago] (12, 8){stop} + 12 + + Above, :class:`_sql.StatementLambdaElement` extracted the values of ``x`` + and ``y`` from the **closure** of the lambda that is generated each time + ``my_stmt()`` is invoked; these were substituted into the cached SQL + construct as the values of the parameters. + +* **The lambda should ideally produce an identical SQL structure in all cases** - + Avoid using conditionals or custom callables inside of lambdas that might make + it produce different SQL based on inputs; if a function might conditionally + use two different SQL fragments, use two separate lambdas:: + + # **Don't** do this: + + def my_stmt(parameter, thing=False): + stmt = lambda_stmt(lambda: select(table)) + stmt += ( + lambda s: s.where(table.c.x > parameter) if thing + else s.where(table.c.y == parameter) + return stmt + + # **Do** do this: + + def my_stmt(parameter, thing=False): + stmt = lambda_stmt(lambda: select(table)) + if thing: + stmt += s.where(table.c.x > parameter) + else: + stmt += s.where(table.c.y == parameter) + return stmt + + There are a variety of failures which can occur if the lambda does not + produce a consistent SQL construct and some are not trivially detectable + right now. + +* **Don't use functions inside the lambda to produce bound values** - the + bound value tracking approach requires that the actual value to be used in + the SQL statement be locally present in the closure of the lambda. This is + not possible if values are generated from other functions, and the + :class:`_sql.LambdaElement` should normally raise an error if this is + attempted:: + + >>> def my_stmt(x, y): + ... def get_x(): + ... return x + ... def get_y(): + ... return y + ... + ... stmt = lambda_stmt(lambda: select(func.max(get_x(), get_y()))) + ... return stmt + ... + >>> with engine.connect() as conn: + ... print(conn.scalar(my_stmt(5, 10))) + ... + Traceback (most recent call last): + # ... + sqlalchemy.exc.InvalidRequestError: Can't invoke Python callable get_x() + inside of lambda expression argument at + at 0x7fed15f350e0, file "", line 6>; + lambda SQL constructs should not invoke functions from closure variables + to produce literal values since the lambda SQL system normally extracts + bound values without actually invoking the lambda or any functions within it. + + Above, the use of ``get_x()`` and ``get_y()``, if they are necessary, should + occur **outside** of the lambda and assigned to a local closure variable:: + + >>> def my_stmt(x, y): + ... def get_x(): + ... return x + ... def get_y(): + ... return y + ... + ... x_param, y_param = get_x(), get_y() + ... stmt = lambda_stmt(lambda: select(func.max(x_param, y_param))) + ... return stmt + + .. + +* **Avoid referring to non-SQL constructs inside of lambdas as they are not + cacheable by default** - this issue refers to how the :class:`_sql.LambdaElement` + creates a cache key from other closure variables within the statement. In order + to provide the best guarantee of an accurate cache key, all objects located + in the closure of the lambda are considered to be significant, and none + will none will be assumed to be appropriate for a cache key by default. + So the following example will also raise a rather detailed error message:: + + >>> class Foo: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... + >>> def my_stmt(foo): + ... stmt = lambda_stmt(lambda: select(func.max(foo.x, foo.y))) + ... return stmt + ... + >>> with engine.connect() as conn: + ... print(conn.scalar(my_stmt(Foo(5, 10)))) + ... + Traceback (most recent call last): + # ... + sqlalchemy.exc.InvalidRequestError: Closure variable named 'foo' inside of + lambda callable at 0x7fed15f35450, file + "", line 2> does not refer to a cachable SQL element, and also + does not appear to be serving as a SQL literal bound value based on the + default SQL expression returned by the function. This variable needs to + remain outside the scope of a SQL-generating lambda so that a proper cache + key may be generated from the lambda's state. Evaluate this variable + outside of the lambda, set track_on=[] to explicitly select + closure elements to track, or set track_closure_variables=False to exclude + closure variables from being part of the cache key. + + The above error indicates that :class:`_sql.LambdaElement` will not assume + that the ``Foo`` object passed in will contine to behave the same in all + cases. It also won't assume it can use ``Foo`` as part of the cache key + by default; if it were to use the ``Foo`` object as part of the cache key, + if there were many different ``Foo`` objects this would fill up the cache + with duplicate information, and would also hold long-lasting references to + all of these objects. + + The best way to resolve the above situation is to not refer to ``foo`` + inside of the lambda, and refer to it **outside** instead:: + + >>> def my_stmt(foo): + ... x_param, y_param = foo.x, foo.y + ... stmt = lambda_stmt(lambda: select(func.max(x_param, y_param))) + ... return stmt + + In some situations, if the SQL structure of the lambda is guaranteed to + never change based on input, to pass ``track_closure_variables=False`` + which will disable any tracking of closure variables other than those + used for bound parameters:: + + >>> def my_stmt(foo): + ... stmt = lambda_stmt( + ... lambda: select(func.max(foo.x, foo.y)), + ... track_closure_variables=False + ... ) + ... return stmt + + There is also the option to add objects to the element to explicitly form + part of the cache key, using the ``track_on`` parameter; using this parameter + allows specific values to serve as the cache key and will also prevent other + closure variables from being considered. This is useful for cases where part + of the SQL being constructed originates from a contextual object of some sort + that may have many different values. In the example below, the first + segment of the SELECT statement will disable tracking of the ``foo`` variable, + whereas the second segment will explicitly track ``self`` as part of the + cache key:: + + >>> def my_stmt(self, foo): + ... stmt = lambda_stmt( + ... lambda: select(*self.column_expressions), + ... track_closure_variables=False + ... ) + ... stmt = stmt.add_criteria( + ... lambda: self.where_criteria, + ... track_on=[self] + ... ) + ... return stmt + + Using ``track_on`` means the given objects will be stored long term in the + lambda's internal cache and will have strong references for as long as the + cache doesn't clear out those objects (an LRU scheme of 1000 entries is used + by default). + + .. + + +Cache Key Generation +^^^^^^^^^^^^^^^^^^^^ + +In order to understand some of the options and behaviors which occur +with lambda SQL constructs, an understanding of the caching system +is helpful. + +SQLAlchemy's caching system normally generates a cache key from a given +SQL expression construct by producing a structure that represents all the +state within the construct:: + + >>> from sqlalchemy import select, column + >>> stmt = select(column('q')) + >>> cache_key = stmt._generate_cache_key() + >>> print(cache_key) # somewhat paraphrased + CacheKey(key=( + '0', + , + '_raw_columns', + ( + ( + '1', + , + 'name', + 'q', + 'type', + ( + , + ), + ), + ), + # a few more elements are here, and many more for a more + # complicated SELECT statement + ),) + + +The above key is stored in the cache which is essentially a dictionary, and the +value is a construct that among other things stores the string form of the SQL +statement, in this case the phrase "SELECT q". We can observe that even for an +extremely short query the cache key is pretty verbose as it has to represent +everything that may vary about what's being rendered and potentially executed. + +The lambda construction system by contrast creates a different kind of cache +key:: + + >>> from sqlalchemy import lambda_stmt + >>> stmt = lambda_stmt(lambda: select(column("q"))) + >>> cache_key = stmt._generate_cache_key() + >>> print(cache_key) + CacheKey(key=( + at 0x7fed1617c710, file "", line 1>, + , + ),) + +Above, we see a cache key that is vastly shorter than that of the non-lambda +statement, and additionally that production of the ``select(column("q"))`` +construct itself was not even necessary; the Python lambda itself contains +an attribute called ``__code__`` which refers to a Python code object that +within the runtime of the application is immutable and permanent. + +When the lambda also includes closure variables, in the normal case that these +variables refer to SQL constructs such as column objects, they become +part of the cache key, or if they refer to literal values that will be bound +parameters, they are placed in a separate element of the cache key:: + + >>> def my_stmt(parameter): + ... col = column("q") + ... stmt = lambda_stmt(lambda: select(col)) + ... stmt += lambda s: s.where(col == parameter) + ... return stmt + +The above :class:`_sql.StatementLambdaElement` includes two lambdas, both +of which refer to the ``col`` closure variable, so the cache key will +represent both of these segments as well as the ``column()`` object:: + + >>> stmt = my_stmt(5) + >>> key = stmt._generate_cache_key() + >>> print(key) + CacheKey(key=( + at 0x7f07323c50e0, file "", line 3>, + ( + '0', + , + 'name', + 'q', + 'type', + ( + , + ), + ), + at 0x7f07323c5190, file "", line 4>, + , + ( + '0', + , + 'name', + 'q', + 'type', + ( + , + ), + ), + ( + '0', + , + 'name', + 'q', + 'type', + ( + , + ), + ), + ),) + + +The second part of the cache key has retrieved the bound parameters that will +be used when the statement is invoked:: + + >>> key.bindparams + [BindParameter('%(139668884281280 parameter)s', 5, type_=Integer())] + For a series of examples of "lambda" caching with performance comparisons, see the "short_selects" test suite within the :ref:`examples_performance` diff --git a/doc/build/errors.rst b/doc/build/errors.rst index a52444766d..67a8a29b0f 100644 --- a/doc/build/errors.rst +++ b/doc/build/errors.rst @@ -802,6 +802,7 @@ therefore requires that :meth:`_expression.SelectBase.subquery` is used:: :ref:`change_4617` + Object Relational Mapping ========================= diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 98c57149d3..6838011b14 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -540,6 +540,9 @@ class AbstractRelationshipLoader(LoaderStrategy): self.target = self.parent_property.target self.uselist = self.parent_property.uselist + def _size_alert(self, lru_cache): + util.warn("LRU cache size alert for loader strategy: %s" % self) + @log.class_logger @relationships.RelationshipProperty.strategy_for(do_nothing=True) @@ -884,7 +887,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ] def _memoized_attr__query_cache(self): - return util.LRUCache(30) + # cache is per lazy loader; stores not only cached SQL but also + # sqlalchemy.sql.lambdas.AnalyzedCode and + # sqlalchemy.sql.lambdas.AnalyzedFunction objects which are generated + # from the StatementLambda used. + return util.LRUCache(30, size_alert=self._size_alert) @util.preload_module("sqlalchemy.orm.strategy_options") def _emit_lazyload( @@ -912,8 +919,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): } if self.parent_property.secondary is not None: - stmt += lambda stmt: stmt.select_from( - self.mapper, self.parent_property.secondary + stmt = stmt.add_criteria( + lambda stmt: stmt.select_from( + self.mapper, self.parent_property.secondary + ), + track_on=[self.parent_property], ) pending = not state.key @@ -961,7 +971,9 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) if self._order_by: - stmt += lambda stmt: stmt.order_by(*self._order_by) + stmt = stmt.add_criteria( + lambda stmt: stmt.order_by(*self._order_by), track_on=[self] + ) def _lazyload_reverse(compile_context): for rev in self.parent_property._reverse_property: @@ -978,8 +990,11 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ] ).lazyload(rev.key).process_compile_state(compile_context) - stmt += lambda stmt: stmt._add_context_option( - _lazyload_reverse, self.parent_property + stmt = stmt.add_criteria( + lambda stmt: stmt._add_context_option( + _lazyload_reverse, self.parent_property + ), + track_on=[self], ) lazy_clause, params = self._generate_lazy_clause(state, passive) @@ -2587,7 +2602,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ).init_class_attribute(mapper) def _memoized_attr__query_cache(self): - return util.LRUCache(30) + return util.LRUCache(30, size_alert=self._size_alert) def create_row_processor( self, @@ -2763,13 +2778,13 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): # in the non-omit_join case, the Bundle is against the annotated/ # mapped column of the parent entity, but the #4347 issue does not # occur in this case. - pa = self._parent_alias q = q.add_criteria( - lambda q: q.select_from(pa).join( - getattr(pa, self.parent_property.key).of_type( - effective_entity - ) - ) + lambda q: q.select_from(self._parent_alias).join( + getattr( + self._parent_alias, self.parent_property.key + ).of_type(effective_entity) + ), + track_on=[self], ) q = q.add_criteria( @@ -2849,7 +2864,8 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): q = q.add_criteria( lambda q: q._add_context_option( _setup_outermost_orderby, self.parent_property - ) + ), + track_on=[self], ) if query_info.load_only_child: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c9437d1b2e..88f9a34d05 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -885,6 +885,7 @@ class LoaderCriteriaOption(CriteriaOption): loader_only=False, include_aliases=False, propagate_to_loaders=True, + track_closure_variables=True, ): """Add additional WHERE criteria to the load for all occurrences of a particular entity. @@ -993,6 +994,14 @@ class LoaderCriteriaOption(CriteriaOption): combine :func:`_orm.with_loader_criteria` with the :meth:`_orm.SessionEvents.do_orm_execute` event. + :param track_closure_variables: when False, closure variables inside + of a lambda expression will not be validated used as part of + any cache key. This allows more complex expressions to be used + inside of a lambda expression but requires that the lambda ensures + it returns the identical SQL every time given a particular class. + + .. versionadded:: 1.4.0b2 + """ entity = inspection.inspect(entity_or_base, False) if entity is None: @@ -1012,6 +1021,9 @@ class LoaderCriteriaOption(CriteriaOption): if self.root_entity is not None else self.entity.entity, ), + opts=lambdas.LambdaOptions( + track_closure_variables=track_closure_variables + ), ) else: self.deferred_where_criteria = False @@ -1030,7 +1042,7 @@ class LoaderCriteriaOption(CriteriaOption): stack = list(self.root_entity.__subclasses__()) while stack: subclass = stack.pop(0) - ent = inspection.inspect(subclass) + ent = inspection.inspect(subclass, raiseerr=False) if ent: for mp in ent.mapper.self_and_descendants: yield mp diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index bdd807438c..43c89ee823 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -7,7 +7,6 @@ import numbers import re -import types from . import operators from . import roles @@ -56,6 +55,15 @@ def _deep_is_literal(element): """ + if isinstance(element, collections_abc.Sequence) and not isinstance( + element, str + ): + for elem in element: + if not _deep_is_literal(elem): + return False + else: + return True + return ( not isinstance( element, @@ -66,7 +74,6 @@ def _deep_is_literal(element): not isinstance(element, type) or not issubclass(element, HasCacheKey) ) - and not isinstance(element, types.FunctionType) ) @@ -109,9 +116,8 @@ def expect(role, element, apply_propagate_attrs=None, argname=None, **kw): return lambdas.LambdaElement( element, role, + lambdas.LambdaOptions(**kw), apply_propagate_attrs=apply_propagate_attrs, - argname=argname, - **kw ) # major case is that we are given a ClauseElement already, skip more diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 86611baeb1..ab8701dd65 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1383,7 +1383,6 @@ class BindParameter(roles.InElementRole, ColumnElement): """Return a copy of this :class:`.BindParameter` with the given value set. """ - cloned = self._clone(maintain_key=maintain_key) cloned.value = value cloned.callable = None diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index aafdda4ce1..3f0ca477e6 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -19,6 +19,7 @@ from . import traversals from . import type_api from . import visitors from .base import _clone +from .base import Options from .operators import ColumnOperators from .. import exc from .. import inspection @@ -28,7 +29,24 @@ from ..util import collections_abc _closure_per_cache_key = util.LRUCache(1000) -def lambda_stmt(lmb, **opts): +class LambdaOptions(Options): + enable_tracking = True + track_closure_variables = True + track_on = None + global_track_bound_values = True + track_bound_values = True + lambda_cache = None + + +def lambda_stmt( + lmb, + enable_tracking=True, + track_closure_variables=True, + track_on=None, + global_track_bound_values=True, + track_bound_values=True, + lambda_cache=None, +): """Produce a SQL statement that is cached as a lambda. The Python code object within the lambda is scanned for both Python @@ -49,6 +67,29 @@ def lambda_stmt(lmb, **opts): .. versionadded:: 1.4 + :param lmb: a Python function, typically a lambda, which takes no arguments + and returns a SQL expression construct + :param enable_tracking: when False, all scanning of the given lambda for + changes in closure variables or bound parameters is disabled. Use for + a lambda that produces the identical results in all cases with no + parameterization. + :param track_closure_variables: when False, changes in closure variables + within the lambda will not be scanned. Use for a lambda where the + state of its closure variables will never change the SQL structure + returned by the lambda. + :param track_bound_values: when False, bound parameter tracking will + be disabled for the given lambda. Use for a lambda that either does + not produce any bound values, or where the initial bound values never + change. + :param global_track_bound_values: when False, bound parameter tracking + will be disabled for the entire statement including additional links + added via the :meth:`_sql.StatementLambdaElement.add_criteria` method. + :param lambda_cache: a dictionary or other mapping-like object where + information about the lambda's Python code as well as the tracked closure + variables in the lambda itself will be stored. Defaults + to a global LRU cache. This cache is independent of the "compiled_cache" + used by the :class:`_engine.Connection` object. + .. seealso:: :ref:`engine_lambda_caching` @@ -56,7 +97,18 @@ def lambda_stmt(lmb, **opts): """ - return StatementLambdaElement(lmb, roles.CoerceTextStatementRole, **opts) + return StatementLambdaElement( + lmb, + roles.CoerceTextStatementRole, + LambdaOptions( + enable_tracking=enable_tracking, + track_on=track_on, + track_closure_variables=track_closure_variables, + global_track_bound_values=global_track_bound_values, + track_bound_values=track_bound_values, + lambda_cache=lambda_cache, + ), + ) class LambdaElement(elements.ClauseElement): @@ -94,38 +146,39 @@ class LambdaElement(elements.ClauseElement): def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.fn.__code__) - def __init__(self, fn, role, apply_propagate_attrs=None, **kw): + def __init__( + self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None + ): self.fn = fn self.role = role self.tracker_key = (fn.__code__,) + self.opts = opts if apply_propagate_attrs is None and ( role is roles.CoerceTextStatementRole ): apply_propagate_attrs = self - rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, kw) + rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts) if apply_propagate_attrs is not None: propagate_attrs = rec.propagate_attrs if propagate_attrs: apply_propagate_attrs._propagate_attrs = propagate_attrs - def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, kw): - lambda_cache = kw.get("lambda_cache", _closure_per_cache_key) + def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts): + lambda_cache = opts.lambda_cache + if lambda_cache is None: + lambda_cache = _closure_per_cache_key tracker_key = self.tracker_key fn = self.fn closure = fn.__closure__ - tracker = AnalyzedCode.get( fn, self, - kw, - track_bound_values=kw.get("track_bound_values", True), - enable_tracking=kw.get("enable_tracking", True), - track_on=kw.get("track_on", None), + opts, ) self._resolved_bindparams = bindparams = [] @@ -133,10 +186,11 @@ class LambdaElement(elements.ClauseElement): anon_map = traversals.anon_map() cache_key = tuple( [ - getter(closure, kw, anon_map, bindparams) + getter(closure, opts, anon_map, bindparams) for getter in tracker.closure_trackers ] ) + if self.parent_lambda is not None: cache_key = self.parent_lambda.closure_cache_key + cache_key @@ -148,9 +202,7 @@ class LambdaElement(elements.ClauseElement): rec = None if rec is None: - rec = AnalyzedFunction( - tracker, self, apply_propagate_attrs, kw, fn - ) + rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn) rec.closure_bindparams = bindparams lambda_cache[tracker_key + cache_key] = rec else: @@ -213,14 +265,13 @@ class LambdaElement(elements.ClauseElement): bindparam_lookup = {b.key: b for b in self._resolved_bindparams} def replace(thing): - if ( - isinstance(thing, elements.BindParameter) - and thing.key in bindparam_lookup - ): - bind = bindparam_lookup[thing.key] - if thing.expanding: - bind.expanding = True - return bind + if isinstance(thing, elements.BindParameter): + + if thing.key in bindparam_lookup: + bind = bindparam_lookup[thing.key] + if thing.expanding: + bind.expanding = True + return bind if self._rec.is_sequence: expr = [ @@ -268,7 +319,6 @@ class LambdaElement(elements.ClauseElement): if self._resolved_bindparams: bindparams.extend(self._resolved_bindparams) - return cache_key def _invoke_user_fn(self, fn, *arg): @@ -285,10 +335,9 @@ class DeferredLambdaElement(LambdaElement): """ - def __init__(self, fn, role, lambda_args=(), **kw): + def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()): self.lambda_args = lambda_args - self.coerce_kw = kw - super(DeferredLambdaElement, self).__init__(fn, role, **kw) + super(DeferredLambdaElement, self).__init__(fn, role, opts) def _invoke_user_fn(self, fn, *arg): return fn(*self.lambda_args) @@ -297,10 +346,30 @@ class DeferredLambdaElement(LambdaElement): tracker_fn = self._rec.tracker_instrumented_fn expr = tracker_fn(*lambda_args) - expr = coercions.expect(self.role, expr, **self.coerce_kw) - - if self._resolved_bindparams: - expr = self._setup_binds_for_tracked_expr(expr) + expr = coercions.expect(self.role, expr) + + expr = self._setup_binds_for_tracked_expr(expr) + + # this validation is getting very close, but not quite, to achieving + # #5767. The problem is if the base lambda uses an unnamed column + # as is very common with mixins, the parameter name is different + # and it produces a false positive; that is, for the documented case + # that is exactly what people will be doing, it doesn't work, so + # I'm not really sure how to handle this right now. + # expected_binds = [ + # b._orig_key + # for b in self._rec.expr._generate_cache_key()[1] + # if b.required + # ] + # got_binds = [ + # b._orig_key for b in expr._generate_cache_key()[1] if b.required + # ] + # if expected_binds != got_binds: + # raise exc.InvalidRequestError( + # "Lambda callable at %s produced a different set of bound " + # "parameters than its original run: %s" + # % (self.fn.__code__, ", ".join(got_binds)) + # ) # TODO: TEST TEST TEST, this is very out there for deferred_copy_internals in self._transforms: @@ -312,7 +381,9 @@ class DeferredLambdaElement(LambdaElement): self, clone=_clone, deferred_copy_internals=None, **kw ): super(DeferredLambdaElement, self)._copy_internals( - clone=clone, deferred_copy_internals=deferred_copy_internals, **kw + clone=clone, + deferred_copy_internals=deferred_copy_internals, # **kw + opts=kw, ) # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know @@ -347,33 +418,60 @@ class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement): """ - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = default_kw = {} - global_track_bound_values = kw.pop("global_track_bound_values", None) - if global_track_bound_values is not None: - default_kw["track_bound_values"] = global_track_bound_values - kw["track_bound_values"] = global_track_bound_values + def __add__(self, other): + return self.add_criteria(other) - if "lambda_cache" in kw: - default_kw["lambda_cache"] = kw["lambda_cache"] + def add_criteria( + self, + other, + enable_tracking=True, + track_on=None, + track_closure_variables=True, + track_bound_values=True, + ): + """Add new criteria to this :class:`_sql.StatementLambdaElement`. + + E.g.:: + + >>> def my_stmt(parameter): + ... stmt = lambda_stmt( + ... lambda: select(table.c.x, table.c.y), + ... ) + ... stmt = stmt.add_criteria( + ... lambda: table.c.x > parameter + ... ) + ... return stmt + + The :meth:`_sql.StatementLambdaElement.add_criteria` method is + equivalent to using the Python addition operator to add a new + lambda, except that additional arguments may be added including + ``track_closure_values`` and ``track_on``:: + + >>> def my_stmt(self, foo): + ... stmt = lambda_stmt( + ... lambda: select(func.max(foo.x, foo.y)), + ... track_closure_variables=False + ... ) + ... stmt = stmt.add_criteria( + ... lambda: self.where_criteria, + ... track_on=[self] + ... ) + ... return stmt + + See :func:`_sql.lambda_stmt` for a description of the parameters + accepted. - super(StatementLambdaElement, self).__init__(fn, parent_lambda, **kw) + """ - def __add__(self, other): - return LinkedLambdaElement( - other, parent_lambda=self, **self._default_kw + opts = self.opts + dict( + enable_tracking=enable_tracking, + track_closure_variables=track_closure_variables, + global_track_bound_values=self.opts.global_track_bound_values, + track_on=track_on, + track_bound_values=track_bound_values, ) - def add_criteria(self, other, **kw): - if self._default_kw: - if kw: - default_kw = self._default_kw.copy() - default_kw.update(kw) - kw = default_kw - else: - kw = self._default_kw - - return LinkedLambdaElement(other, parent_lambda=self, **kw) + return LinkedLambdaElement(other, parent_lambda=self, opts=opts) def _execute_on_connection( self, connection, multiparams, params, execution_options @@ -461,14 +559,13 @@ class LinkedLambdaElement(StatementLambdaElement): role = None - def __init__(self, fn, parent_lambda, **kw): - self._default_kw = parent_lambda._default_kw - + def __init__(self, fn, parent_lambda, opts): + self.opts = opts self.fn = fn self.parent_lambda = parent_lambda self.tracker_key = parent_lambda.tracker_key + (fn.__code__,) - self._retrieve_tracker_rec(fn, self, kw) + self._retrieve_tracker_rec(fn, self, opts) self._propagate_attrs = parent_lambda._propagate_attrs def _invoke_user_fn(self, fn, *arg): @@ -497,20 +594,17 @@ class AnalyzedCode(object): ) return analyzed - def __init__( - self, - fn, - lambda_element, - lambda_kw, - track_bound_values=True, - enable_tracking=True, - track_on=None, - ): + def __init__(self, fn, lambda_element, opts): closure = fn.__closure__ - self.track_closure_variables = not track_on + self.track_bound_values = ( + opts.track_bound_values and opts.global_track_bound_values + ) + enable_tracking = opts.enable_tracking + track_on = opts.track_on + track_closure_variables = opts.track_closure_variables - self.track_bound_values = track_bound_values + self.track_closure_variables = track_closure_variables and not track_on # a list of callables generated from _bound_parameter_getter_* # functions. Each of these uses a PyWrapper object to retrieve @@ -533,7 +627,7 @@ class AnalyzedCode(object): if closure: self._init_closure(fn) - self._setup_additional_closure_trackers(fn, lambda_element, lambda_kw) + self._setup_additional_closure_trackers(fn, lambda_element, opts) def _init_track_on(self, track_on): self.closure_trackers.extend( @@ -590,13 +684,11 @@ class AnalyzedCode(object): if track_closure_variables: closure_trackers.append( self._cache_key_getter_closure_variable( - closure_index, cell.cell_contents + fn, fv, closure_index, cell.cell_contents ) ) - def _setup_additional_closure_trackers( - self, fn, lambda_element, lambda_kw - ): + def _setup_additional_closure_trackers(self, fn, lambda_element, opts): # an additional step is to actually run the function, then # go through the PyWrapper objects that were set up to catch a bound # parameter. then if they *didn't* make a param, oh they're another @@ -607,7 +699,6 @@ class AnalyzedCode(object): self, lambda_element, None, - lambda_kw, fn, ) @@ -616,7 +707,7 @@ class AnalyzedCode(object): for pywrapper in analyzed_function.closure_pywrappers: if not pywrapper._sa__has_param: closure_trackers.append( - self._cache_key_getter_tracked_literal(pywrapper) + self._cache_key_getter_tracked_literal(fn, pywrapper) ) @classmethod @@ -625,7 +716,7 @@ class AnalyzedCode(object): if is_clause_element: while not isinstance( - element, (elements.ClauseElement, schema.SchemaItem) + element, (elements.ClauseElement, schema.SchemaItem, type) ): try: element = element.__clause_element__() @@ -688,17 +779,25 @@ class AnalyzedCode(object): """ if isinstance(elem, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx]._gen_cache_key(anon_map, bindparams) + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) else: - def get(closure, kw, anon_map, bindparams): - return kw["track_on"][idx] + def get(closure, opts, anon_map, bindparams): + return opts.track_on[idx] return get - def _cache_key_getter_closure_variable(self, idx, cell_contents): + def _cache_key_getter_closure_variable( + self, + fn, + variable_name, + idx, + cell_contents, + use_clause_element=False, + use_inspect=False, + ): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -706,29 +805,90 @@ class AnalyzedCode(object): if isinstance(cell_contents, traversals.HasCacheKey): - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents._gen_cache_key( - anon_map, bindparams - ) + def get(closure, opts, anon_map, bindparams): + + obj = closure[idx].cell_contents + if use_inspect: + obj = inspection.inspect(obj) + elif use_clause_element: + while hasattr(obj, "__clause_element__"): + if not getattr(obj, "is_clause_element", False): + obj = obj.__clause_element__() + + return obj._gen_cache_key(anon_map, bindparams) elif isinstance(cell_contents, types.FunctionType): - def get(closure, kw, anon_map, bindparams): + def get(closure, opts, anon_map, bindparams): return closure[idx].cell_contents.__code__ - elif cell_contents.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () + elif isinstance(cell_contents, collections_abc.Sequence): + + def get(closure, opts, anon_map, bindparams): + contents = closure[idx].cell_contents + + try: + return tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in contents + ) + except AttributeError as ae: + self._raise_for_uncacheable_closure_variable( + variable_name, fn, from_=ae + ) else: + # if the object is a mapped class or aliased class, or some + # other object in the ORM realm of things like that, imitate + # the logic used in coercions.expect() to roll it down to the + # SQL element + element = cell_contents + is_clause_element = False + while hasattr(element, "__clause_element__"): + is_clause_element = True + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break - def get(closure, kw, anon_map, bindparams): - return closure[idx].cell_contents + if not is_clause_element: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, insp, use_inspect=True + ) + else: + return self._cache_key_getter_closure_variable( + fn, variable_name, idx, element, use_clause_element=True + ) + + self._raise_for_uncacheable_closure_variable(variable_name, fn) return get - def _cache_key_getter_tracked_literal(self, pytracker): + def _raise_for_uncacheable_closure_variable( + self, variable_name, fn, from_=None + ): + util.raise_( + exc.InvalidRequestError( + "Closure variable named '%s' inside of lambda callable %s " + "does not refer to a cachable SQL element, and also does not " + "appear to be serving as a SQL literal bound value based on " + "the default " + "SQL expression returned by the function. This variable " + "needs to remain outside the scope of a SQL-generating lambda " + "so that a proper cache key may be generated from the " + "lambda's state. Evaluate this variable outside of the " + "lambda, set track_on=[] to explicitly select " + "closure elements to track, or set " + "track_closure_variables=False to exclude " + "closure variables from being part of the cache key." + % (variable_name, fn.__code__), + ), + from_=from_, + ) + + def _cache_key_getter_tracked_literal(self, fn, pytracker): """Return a getter that will extend a cache key with new entries from the ``__closure__`` collection of a particular lambda. @@ -741,33 +901,11 @@ class AnalyzedCode(object): elem = pytracker._sa__to_evaluate closure_index = pytracker._sa__closure_index + variable_name = pytracker._sa__name - if isinstance(elem, set): - raise exc.ArgumentError( - "Can't create a cache key for lambda closure variable " - '"%s" because it\'s a set. try using a list' - % pytracker._sa__name - ) - - elif isinstance(elem, list): - - def get(closure, kw, anon_map, bindparams): - return tuple( - elem._gen_cache_key(anon_map, bindparams) - for elem in closure[closure_index].cell_contents - ) - - elif elem.__hash__ is None: - # this covers dict, etc. - def get(closure, kw, anon_map, bindparams): - return () - - else: - - def get(closure, kw, anon_map, bindparams): - return closure[closure_index].cell_contents - - return get + return self._cache_key_getter_closure_variable( + fn, variable_name, closure_index, elem + ) class AnalyzedFunction(object): @@ -789,7 +927,6 @@ class AnalyzedFunction(object): analyzed_code, lambda_element, apply_propagate_attrs, - kw, fn, ): self.analyzed_code = analyzed_code @@ -799,7 +936,7 @@ class AnalyzedFunction(object): self._instrument_and_run_function(lambda_element) - self._coerce_expression(lambda_element, apply_propagate_attrs, kw) + self._coerce_expression(lambda_element, apply_propagate_attrs) def _instrument_and_run_function(self, lambda_element): analyzed_code = self.analyzed_code @@ -832,13 +969,19 @@ class AnalyzedFunction(object): if closure_index is not None: value = closure[closure_index].cell_contents new_closure[name] = bind = PyWrapper( - name, value, closure_index=closure_index + fn, + name, + value, + closure_index=closure_index, + track_bound_values=( + self.analyzed_code.track_bound_values + ), ) if track_closure_variables: closure_pywrappers.append(bind) else: value = fn.__globals__[name] - new_globals[name] = bind = PyWrapper(name, value) + new_globals[name] = bind = PyWrapper(fn, name, value) # rewrite the original fn. things that look like they will # become bound parameters are wrapped in a PyWrapper. @@ -863,7 +1006,7 @@ class AnalyzedFunction(object): # variable. self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn) - def _coerce_expression(self, lambda_element, apply_propagate_attrs, kw): + def _coerce_expression(self, lambda_element, apply_propagate_attrs): """Run the tracker-generated expression through coercion rules. After the user-defined lambda has been invoked to produce a statement @@ -882,7 +1025,6 @@ class AnalyzedFunction(object): lambda_element.role, sub_expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) for sub_expr in expr ] @@ -892,7 +1034,6 @@ class AnalyzedFunction(object): lambda_element.role, expr, apply_propagate_attrs=apply_propagate_attrs, - **kw ) self.is_sequence = False else: @@ -956,7 +1097,16 @@ class PyWrapper(ColumnOperators): """ - def __init__(self, name, to_evaluate, closure_index=None, getter=None): + def __init__( + self, + fn, + name, + to_evaluate, + closure_index=None, + getter=None, + track_bound_values=True, + ): + self.fn = fn self._name = name self._to_evaluate = to_evaluate self._param = None @@ -964,28 +1114,35 @@ class PyWrapper(ColumnOperators): self._bind_paths = {} self._getter = getter self._closure_index = closure_index + self.track_bound_values = track_bound_values def __call__(self, *arg, **kw): elem = object.__getattribute__(self, "_to_evaluate") value = elem(*arg, **kw) - if coercions._deep_is_literal(value) and not isinstance( - # TODO: coverage where an ORM option or similar is here - value, - traversals.HasCacheKey, + if ( + self._sa_track_bound_values + and coercions._deep_is_literal(value) + and not isinstance( + # TODO: coverage where an ORM option or similar is here + value, + traversals.HasCacheKey, + ) ): - # TODO: we can instead scan the arguments and make sure they - # are all Python literals - - # TODO: coverage name = object.__getattribute__(self, "_name") raise exc.InvalidRequestError( "Can't invoke Python callable %s() inside of lambda " - "expression argument; lambda cache keys should not call " - "regular functions since the caching " - "system does not track the values of the arguments passed " - "to the functions. Call the function outside of the lambda " - "and assign to a local variable that is used in the lambda." - % (name) + "expression argument at %s; lambda SQL constructs should " + "not invoke functions from closure variables to produce " + "literal values since the " + "lambda SQL system normally extracts bound values without " + "actually " + "invoking the lambda or any functions within it. Call the " + "function outside of the " + "lambda and assign to a local variable that is used in the " + "lambda as a closure variable, or set " + "track_bound_values=False if the return value of this " + "function is used in some other way other than a SQL bound " + "value." % (name, self._sa_fn.__code__) ) else: return value @@ -1018,6 +1175,14 @@ class PyWrapper(ColumnOperators): param.type = type_api._resolve_value_to_type(to_evaluate) return param._with_value(to_evaluate, maintain_key=True) + def __bool__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + + def __nonzero__(self): + to_evaluate = object.__getattribute__(self, "_to_evaluate") + return bool(to_evaluate) + def __getattribute__(self, key): if key.startswith("_sa_"): return object.__getattribute__(self, key[4:]) @@ -1026,6 +1191,7 @@ class PyWrapper(ColumnOperators): "operate", "reverse_operate", "__class__", + "__dict__", ): return object.__getattribute__(self, key) @@ -1064,8 +1230,10 @@ class PyWrapper(ColumnOperators): elem = object.__getattribute__(self, "_to_evaluate") value = getter(elem) - if coercions._deep_is_literal(value): - wrapper = PyWrapper(key, value, getter=getter) + rolled_down_value = AnalyzedCode._roll_down_to_literal(value) + + if coercions._deep_is_literal(rolled_down_value): + wrapper = PyWrapper(self._sa_fn, key, value, getter=getter) bind_paths[bind_path_key] = wrapper return wrapper else: diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index f87f610745..5476729612 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -703,7 +703,7 @@ class SelectInEagerLoadTest(NoCache, fixtures.MappedTest): # this is because the test was previously making use of the same # loader option state repeatedly without rebuilding it. - @profiling.function_call_count() + @profiling.function_call_count(warmup=1) def go(): for i in range(100): obj = q.all() diff --git a/test/orm/test_lambdas.py b/test/orm/test_lambdas.py index d4fae7f6fe..b190f46d63 100644 --- a/test/orm/test_lambdas.py +++ b/test/orm/test_lambdas.py @@ -8,6 +8,7 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import update from sqlalchemy.future import select +from sqlalchemy.orm import aliased from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload @@ -132,19 +133,62 @@ class LambdaTest(QueryTest, AssertsCompiledSQL): fn = random.choice([go1, go2]) fn() - def test_entity_round_trip(self, plain_fixture): + @testing.combinations( + (True, True), + (True, False), + (False, False), + argnames="use_aliased,use_indirect_access", + ) + def test_entity_round_trip( + self, plain_fixture, use_aliased, use_indirect_access + ): User, Address = plain_fixture s = Session(testing.db, future=True) - def query(names): - stmt = lambda_stmt( - lambda: select(User) - .where(User.name.in_(names)) - .options(selectinload(User.addresses)) - ) + (lambda s: s.order_by(User.id)) + if use_aliased: + if use_indirect_access: - return s.execute(stmt) + def query(names): + class Foo(object): + def __init__(self): + self.u1 = aliased(User) + + f1 = Foo() + + stmt = lambda_stmt( + lambda: select(f1.u1) + .where(f1.u1.name.in_(names)) + .options(selectinload(f1.u1.addresses)), + track_on=[f1.u1], + ).add_criteria( + lambda s: s.order_by(f1.u1.id), track_on=[f1.u1] + ) + + return s.execute(stmt) + + else: + + def query(names): + u1 = aliased(User) + stmt = lambda_stmt( + lambda: select(u1) + .where(u1.name.in_(names)) + .options(selectinload(u1.addresses)) + ) + (lambda s: s.order_by(u1.id)) + + return s.execute(stmt) + + else: + + def query(names): + stmt = lambda_stmt( + lambda: select(User) + .where(User.name.in_(names)) + .options(selectinload(User.addresses)) + ) + (lambda s: s.order_by(User.id)) + + return s.execute(stmt) def go1(): r1 = query(["ed"]) diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index 87589d3be1..2d01410234 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import defer from sqlalchemy.orm import joinedload from sqlalchemy.orm import lazyload from sqlalchemy.orm import mapper +from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session @@ -84,6 +85,32 @@ class _Fixtures(_fixtures.FixtureTest): ) return HasFoob, UserWFoob + @testing.fixture + def multi_mixin_fixture(self): + orders, items = self.tables.orders, self.tables.items + order_items = self.tables.order_items + + class HasFoob(object): + description = Column(String) + + class HasBat(HasFoob): + some_nothing = Column(Integer) + + class Order(HasFoob, self.Comparable): + pass + + class Item(HasBat, self.Comparable): + pass + + base = registry() + base.map_imperatively( + Order, + orders, + properties={"items": relationship("Item", secondary=order_items)}, + ) + base.map_imperatively(Item, items) + return HasFoob, Order, Item + class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): """ @@ -598,6 +625,66 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_(s.execute(stmt).scalars().all(), [UserWFoob(name=name)]) + def test_unnamed_param_dont_fail(self, multi_mixin_fixture): + HasFoob, Order, Item = multi_mixin_fixture + + def go(stmt, value): + return stmt.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description == "order 3", + include_aliases=True, + ) + ) + + with Session(testing.db) as sess: + for i in range(10): + name = random.choice(["order 1", "order 3", "order 5"]) + + statement = select(Order) + stmt = go(statement, name) + + eq_( + sess.execute(stmt).scalars().all(), + [Order(description="order 3")], + ) + + def test_caching_and_binds_lambda_more_mixins(self, multi_mixin_fixture): + # By including non-mapped mixin HasBat in the middle of the + # hierarchy, we test issue #5766 + HasFoob, Order, Item = multi_mixin_fixture + + def go(stmt, value): + return stmt.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description == value, + include_aliases=True, + ) + ) + + with Session(testing.db) as sess: + for i in range(10): + name = random.choice(["order 1", "order 3", "order 5"]) + + statement = select(Order) + stmt = go(statement, name) + + eq_( + sess.execute(stmt).scalars().all(), + [Order(description=name)], + ) + + name = random.choice(["item 1", "item 3", "item 5"]) + + statement = select(Item) + stmt = go(statement, name) + + eq_( + sess.execute(stmt).scalars().all(), + [Item(description=name)], + ) + def test_never_for_refresh(self, user_address_fixture): User, Address = user_address_fixture diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index c759425647..4b486f8717 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -619,6 +619,13 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): "subqueryload": subqueryload, } + # NOTE: make sure this test continues to run many different + # combinations for the *same* mappers above; that is, don't tear the + # mappers down and build them up for every "config". This allows + # testing of the LRUCache that's associated with LazyLoader + # and SelectInLoader and how they interact with the lambda query + # API, which stores AnalyzedFunction objects in this cache. + for o, i, k, count in configs: options = [] if o in callables: @@ -629,7 +636,6 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): options.append( callables[k](User.orders, Order.items, Item.keywords) ) - self._do_query_tests(options, count) def _do_mapper_test(self, configs): @@ -716,6 +722,7 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): ) self.assert_sql_count(testing.db, go, count) + return eq_( sess.query(User) diff --git a/test/profiles.txt b/test/profiles.txt index 71788209aa..d96f802f08 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -370,8 +370,6 @@ test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.8_s # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_cextensions 233595 -test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_2.7_sqlite_pysqlite_dbapiunicode_nocextensions 251018 test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_cextensions 246134 test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results x86_64_linux_cpython_3.8_sqlite_pysqlite_dbapiunicode_nocextensions 264752 diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 6fa961e4d4..70281d4e89 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -57,6 +57,7 @@ from sqlalchemy.sql.functions import GenericFunction from sqlalchemy.sql.functions import ReturnTypeFromArgs from sqlalchemy.sql.lambdas import lambda_stmt from sqlalchemy.sql.lambdas import LambdaElement +from sqlalchemy.sql.lambdas import LambdaOptions from sqlalchemy.sql.selectable import _OffsetLimitParam from sqlalchemy.sql.selectable import AliasedReturnsRows from sqlalchemy.sql.selectable import FromGrouping @@ -859,7 +860,9 @@ class CoreFixtures(object): d = {"g": random.randint(40, 45)} return LambdaElement( - lambda: and_(table_a.c.b == d["g"]), roles.WhereHavingRole + lambda: and_(table_a.c.b == d["g"]), + roles.WhereHavingRole, + opts=LambdaOptions(track_closure_variables=False), ) def seven(): diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index a70dc05116..e8e4a8d2a7 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -11,6 +11,8 @@ from sqlalchemy.sql import column from sqlalchemy.sql import join from sqlalchemy.sql import lambda_stmt from sqlalchemy.sql import lambdas +from sqlalchemy.sql import literal +from sqlalchemy.sql import null from sqlalchemy.sql import roles from sqlalchemy.sql import select from sqlalchemy.sql import table @@ -27,7 +29,7 @@ from sqlalchemy.types import Integer from sqlalchemy.types import String -class DeferredLambdaTest( +class LambdaElementTest( fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL ): __dialect__ = "default" @@ -274,6 +276,75 @@ class DeferredLambdaTest( checkparams={"x_1": 10, "x_2": 15}, ) + def test_conditional_must_be_tracked(self): + tab = table("foo", column("id"), column("col")) + + def run_my_statement(parameter, add_criteria=False): + stmt = lambda_stmt(lambda: select(tab)) + + stmt = stmt.add_criteria( + lambda s: s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter), + ) + + stmt += lambda s: s.order_by(tab.c.id) + + return stmt + + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'add_criteria' inside of lambda callable", + run_my_statement, + 5, + False, + ) + + def test_boolean_conditionals(self): + + tab = table("foo", column("id"), column("col")) + + def run_my_statement(parameter, add_criteria=False): + stmt = lambda_stmt(lambda: select(tab)) + + stmt = stmt.add_criteria( + lambda s: s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter), + track_on=[add_criteria], + ) + + stmt += lambda s: s.order_by(tab.c.id) + + return stmt + + c1 = run_my_statement(5, False) + c2 = run_my_statement(10, True) + c3 = run_my_statement(18, False) + + ck1 = c1._generate_cache_key() + ck2 = c2._generate_cache_key() + ck3 = c3._generate_cache_key() + + eq_(ck1[0], ck3[0]) + ne_(ck1[0], ck2[0]) + + self.assert_compile( + c1, + "SELECT foo.id, foo.col FROM foo WHERE " + "foo.col = :parameter_1 ORDER BY foo.id", + ) + self.assert_compile( + c2, + "SELECT foo.id, foo.col FROM foo " + "WHERE foo.col > :parameter_1 ORDER BY foo.id", + ) + self.assert_compile( + c3, + "SELECT foo.id, foo.col FROM foo WHERE " + "foo.col = :parameter_1 ORDER BY foo.id", + ) + def test_stmt_lambda_plain_customtrack(self): c2 = column("y") @@ -487,10 +558,11 @@ class DeferredLambdaTest( self.assert_compile(s1, "SELECT x WHERE :x") - def test_stmt_lambda_w_additional_hashable_variants(self): - # note a Python 2 old style class would fail here because it - # isn't hashable. right now we do a hard check for __hash__ which - # will raise if the attr isn't present + def test_reject_plain_object(self): + # with #5765 we move to no longer allow closure variables that + # refer to unknown types of objects inside the lambda. these have + # to be resolved outside of the lambda because we otherwise can't + # be sure they can be safely used as cache keys. class Thing(object): def __init__(self, col_expr): self.col_expr = col_expr @@ -501,6 +573,83 @@ class DeferredLambdaTest( return stmt + c1 = Thing(column("x")) + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'thing' inside of lambda callable", + go, + c1, + 5, + ) + + def test_plain_object_ok_w_tracking_disabled(self): + # with #5765 we move to no longer allow closure variables that + # refer to unknown types of objects inside the lambda. these have + # to be resolved outside of the lambda because we otherwise can't + # be sure they can be safely used as cache keys. + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + stmt = lambdas.lambda_stmt( + lambda: select(thing.col_expr), track_closure_variables=False + ) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(thing.col_expr == q), + track_closure_variables=False, + ) + + return stmt + + c1 = Thing(column("x")) + c2 = Thing(column("y")) + + s1 = go(c1, 5) + s2 = go(c2, 10) + s3 = go(c1, 8) + s4 = go(c2, 12) + + self.assert_compile( + s1, "SELECT x WHERE x = :q_1", checkparams={"q_1": 5} + ) + # note this is wrong, because no tracking + self.assert_compile( + s2, "SELECT x WHERE x = :q_1", checkparams={"q_1": 10} + ) + self.assert_compile( + s3, "SELECT x WHERE x = :q_1", checkparams={"q_1": 8} + ) + # also wrong + self.assert_compile( + s4, "SELECT x WHERE x = :q_1", checkparams={"q_1": 12} + ) + + s1key = s1._generate_cache_key() + s2key = s2._generate_cache_key() + s3key = s3._generate_cache_key() + s4key = s4._generate_cache_key() + + # all one cache key + eq_(s1key[0], s3key[0]) + eq_(s2key[0], s4key[0]) + eq_(s1key[0], s2key[0]) + + def test_plain_object_used_outside_lambda(self): + # test the above 'test_reject_plain_object' with the expected + # workaround + + class Thing(object): + def __init__(self, col_expr): + self.col_expr = col_expr + + def go(thing, q): + col_expr = thing.col_expr + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt += lambda stmt: stmt.where(col_expr == q) + + return stmt + c1 = Thing(column("x")) c2 = Thing(column("y")) @@ -538,13 +687,92 @@ class DeferredLambdaTest( opts = {column("x"), column("y")} assert_raises_message( - exc.ArgumentError, - 'Can\'t create a cache key for lambda closure variable "opts" ' - "because it's a set. try using a list", + exc.InvalidRequestError, + "Closure variable named 'opts' inside of lambda callable ", stmt.__add__, lambda stmt: stmt.options(*opts), ) + def test_detect_embedded_callables_one(self): + t1 = table("t1", column("q")) + + x = 1 + + def go(): + def foo(): + return x + + stmt = select(t1).where(lambda: t1.c.q == foo()) + return stmt + + assert_raises_message( + exc.InvalidRequestError, + r"Can't invoke Python callable foo\(\) inside of lambda " + "expression ", + go, + ) + + def test_detect_embedded_callables_two(self): + t1 = table("t1", column("q"), column("y")) + + def go(): + def foo(): + return t1.c.y + + stmt = select(t1).where(lambda: t1.c.q == foo()) + return stmt + + self.assert_compile( + go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y" + ) + + def test_detect_embedded_callables_three(self): + t1 = table("t1", column("q"), column("y")) + + def go(): + def foo(): + t1.c.y + + stmt = select(t1).where(lambda: t1.c.q == getattr(t1.c, "y")) + return stmt + + self.assert_compile( + go(), "SELECT t1.q, t1.y FROM t1 WHERE t1.q = t1.y" + ) + + def test_detect_embedded_callables_four(self): + t1 = table("t1", column("q")) + + x = 1 + + def go(): + def foo(): + return x + + stmt = select(t1).where( + lambdas.LambdaElement( + lambda: t1.c.q == foo(), + roles.WhereHavingRole, + lambdas.LambdaOptions(track_bound_values=False), + ) + ) + return stmt + + self.assert_compile( + go(), + "SELECT t1.q FROM t1 WHERE t1.q = :q_1", + checkparams={"q_1": 1}, + ) + + # we're not tracking it + x = 2 + + self.assert_compile( + go(), + "SELECT t1.q FROM t1 WHERE t1.q = :q_1", + checkparams={"q_1": 1}, + ) + def test_stmt_lambda_w_list_of_opts(self): def go(opts): stmt = lambdas.lambda_stmt(lambda: select(column("x"))) @@ -755,6 +983,23 @@ class DeferredLambdaTest( }, ) + def test_in_columnelement(self): + # test issue #5768 + + def go(): + v = [literal("a"), literal("b")] + expr1 = select(1).where(lambda: column("q").in_(v)) + return expr1 + + self.assert_compile(go(), "SELECT 1 WHERE q IN (:param_1, :param_2)") + + self.assert_compile( + go(), + "SELECT 1 WHERE q IN (:param_1, :param_2)", + render_postcompile=True, + checkparams={"param_1": "a", "param_2": "b"}, + ) + def test_select_columns_clause(self): t1 = table("t1", column("q"), column("p")) @@ -854,14 +1099,28 @@ class DeferredLambdaTest( expr, ) - def test_dict_literal_keys(self, user_address_fixture): + def test_reject_dict_literal_keys(self, user_address_fixture): users, addresses = user_address_fixture names = {"x": "some name"} lmb = lambda: users.c.name == names["x"] # noqa - expr = coercions.expect(roles.WhereHavingRole, lmb) + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'names' inside of lambda callable", + coercions.expect, + roles.WhereHavingRole, + lmb, + ) + + def test_dict_literal_keys_proper_use(self, user_address_fixture): + users, addresses = user_address_fixture + names = {"x": "some name"} + x = names["x"] + lmb = lambda: users.c.name == x # noqa + + expr = coercions.expect(roles.WhereHavingRole, lmb) self.assert_compile( expr, "users.name = :x_1", @@ -1158,7 +1417,7 @@ class DeferredLambdaTest( ), ) - def test_cache_key_thing(self): + def test_cache_key_bindparam_matches(self): t1 = table("t1", column("q"), column("p")) def go(x): @@ -1169,3 +1428,294 @@ class DeferredLambdaTest( is_(expr1._generate_cache_key().bindparams[0], expr1._resolved.right) is_(expr2._generate_cache_key().bindparams[0], expr2._resolved.right) + + def test_cache_key_instance_variable_issue_incorrect(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + return coercions.expect( + roles.WhereHavingRole, lambda: t1.c.q == foo.value + ) + + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'foo' inside of lambda callable", + go, + Foo(5), + ) + + def test_cache_key_instance_variable_issue_correct_one(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + value = foo.value + return coercions.expect( + roles.WhereHavingRole, lambda: t1.c.q == value + ) + + expr1 = go(Foo(5)) + expr2 = go(Foo(10)) + + c1 = expr1._generate_cache_key() + c2 = expr2._generate_cache_key() + eq_(c1, c2) + + def test_cache_key_instance_variable_issue_correct_two(self): + t1 = table("t1", column("q"), column("p")) + + class Foo(object): + def __init__(self, value): + self.value = value + + def go(foo): + return coercions.expect( + roles.WhereHavingRole, + lambda: t1.c.q == foo.value, + track_on=[self], + ) + + expr1 = go(Foo(5)) + expr2 = go(Foo(10)) + + c1 = expr1._generate_cache_key() + c2 = expr2._generate_cache_key() + eq_(c1, c2) + + def test_insert_statement(self, user_address_fixture): + users, addresses = user_address_fixture + + def ins(id_, name): + stmt = lambda_stmt(lambda: users.insert()) + stmt += lambda s: s.values(id=id_, name=name) + return stmt + + with testing.db.begin() as conn: + conn.execute(ins(12, "foo")) + + eq_( + conn.execute(select(users).where(users.c.id == 12)).first(), + (12, "foo"), + ) + + def test_update_statement(self, user_address_fixture): + users, addresses = user_address_fixture + + def upd(id_, newname): + stmt = lambda_stmt(lambda: users.update()) + stmt += lambda s: s.values(name=newname) + stmt += lambda s: s.where(users.c.id == id_) + return stmt + + with testing.db.begin() as conn: + conn.execute(users.insert().values(id=7, name="bar")) + conn.execute(upd(7, "foo")) + + eq_( + conn.execute(select(users).where(users.c.id == 7)).first(), + (7, "foo"), + ) + + +class DeferredLambdaElementTest( + fixtures.TestBase, testing.AssertsExecutionResults, AssertsCompiledSQL +): + __dialect__ = "default" + + @testing.fails("wontfix issue #5767") + def test_detect_change_in_binds_no_tracking(self): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + + vv = [1, 2, 3] + # lambda produces either "t1 IN vv" or "NULL" based on the + # argument. will not produce a consistent cache key + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else null(), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions(track_closure_variables=False), + ) + + self.assert_compile(elem.expr, "NULL") + + assert_raises_message( + exc.InvalidRequestError, + r"Lambda callable at %s produced " + "a different set of bound parameters " + "than its original run: vv" % (elem.fn.__code__), + elem._resolve_with_args, + t2, + ) + + def test_detect_change_in_binds_tracking_positive(self): + t1 = table("t1", column("q"), column("p")) + + vv = [1, 2, 3] + + # lambda produces either "t1 IN vv" or "NULL" based on the + # argument. will not produce a consistent cache key + assert_raises_message( + exc.InvalidRequestError, + "Closure variable named 'vv' inside of lambda callable", + lambdas.DeferredLambdaElement, + lambda tab: tab.c.q.in_(vv) if tab.name == "t2" else None, + roles.WhereHavingRole, + opts=lambdas.LambdaOptions, + lambda_args=(t1,), + ) + + @testing.fails("wontfix issue #5767") + def test_detect_change_in_binds_tracking_negative(self): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + + vv = [1, 2, 3] + qq = [3, 4, 5] + + # lambda produces either "t1 IN vv" or "t2 IN qq" based on the + # argument. will not produce a consistent cache key + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv) + if tab.name == "t1" + else tab.c.q.in_(qq), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions(track_closure_variables=False), + ) + + self.assert_compile(elem.expr, "t1.q IN ([POSTCOMPILE_vv_1])") + + assert_raises_message( + exc.InvalidRequestError, + r"Lambda callable at %s produced " + "a different set of bound parameters " + "than its original run: qq" % (elem.fn.__code__), + elem._resolve_with_args, + t2, + ) + + def _fixture_one(self, t1): + vv = [1, 2, 3] + + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_(vv), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_two(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q == "x", + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_three(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q != "x", + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_four(self, t1): + def go(): + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q.in_([1, 2, 3]), + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_five(self, t1): + def go(): + x = "x" + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q == x, + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + def _fixture_six(self, t1): + def go(): + x = "x" + elem = lambdas.DeferredLambdaElement( + lambda tab: tab.c.q != x, + roles.WhereHavingRole, + lambda_args=(t1,), + opts=lambdas.LambdaOptions, + ) + return elem + + return go + + @testing.combinations( + ("_fixture_one",), + ("_fixture_two",), + ("_fixture_three",), + ("_fixture_four",), + ("_fixture_five",), + ("_fixture_six",), + ) + def test_cache_key_many_different_args(self, fixture_name): + t1 = table("t1", column("q"), column("p")) + t2 = table("t2", column("q"), column("p")) + t3 = table("t3", column("q"), column("p")) + + go = getattr(self, fixture_name)(t1) + + g1 = go() + g2 = go() + + g1key = g1._generate_cache_key() + g2key = g2._generate_cache_key() + eq_(g1key[0], g2key[0]) + + e1 = go()._resolve_with_args(t1) + e2 = go()._resolve_with_args(t2) + e3 = go()._resolve_with_args(t3) + + e1key = e1._generate_cache_key() + e2key = e2._generate_cache_key() + e3key = e3._generate_cache_key() + + e12 = go()._resolve_with_args(t1) + e32 = go()._resolve_with_args(t3) + + e12key = e12._generate_cache_key() + e32key = e32._generate_cache_key() + + ne_(e1key[0], e2key[0]) + ne_(e2key[0], e3key[0]) + + eq_(e12key[0], e1key[0]) + eq_(e32key[0], e3key[0])