From: Mike Bayer Date: Sat, 12 Dec 2020 00:01:12 +0000 (-0500) Subject: Fixes for lambda expressions and relationship loaders X-Git-Tag: rel_1_4_0b2~104^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ed20e2f95f52a072d0c6b09af095b4cda0436d38;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixes for lambda expressions and relationship loaders Fixed bug in lambda SQL feature, used by ORM :meth:`_orm.with_loader_criteria` as well as available generally in the SQL expression language, where assigning a boolean value True/False to a variable would cause the query-time expression calculation to fail, as it would produce a SQL expression not compatible with a bound value. Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load` parameter would not be set correctly for many lazy loads, all selectinloads, etc. The flag is essential in order to test if options should be added to statements or if they would already have been propagated via relationship loads. Fixes: #5763 Fixes: #5764 Change-Id: I66aafbef193f892ff75ede0670698647b7475482 --- diff --git a/doc/build/changelog/unreleased_14/5763.rst b/doc/build/changelog/unreleased_14/5763.rst new file mode 100644 index 0000000000..e395b6fcfe --- /dev/null +++ b/doc/build/changelog/unreleased_14/5763.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 5763 + + Fixed bug in lambda SQL feature, used by ORM + :meth:`_orm.with_loader_criteria` as well as available generally in the SQL + expression language, where assigning a boolean value True/False to a + variable would cause the query-time expression calculation to fail, as it + would produce a SQL expression not compatible with a bound value. \ No newline at end of file diff --git a/doc/build/changelog/unreleased_14/5764.rst b/doc/build/changelog/unreleased_14/5764.rst new file mode 100644 index 0000000000..29753fafe6 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5764.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: orm, bug + :tickets: 5764 + + Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load` + attribute would not be set correctly for many lazy loads, all + selectinloads, etc. The flag is essential in order to test if options + should be added to statements or if they would already have been propagated + via relationship loads. \ No newline at end of file diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index d7a2cb4092..334283bb96 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1296,7 +1296,6 @@ class Query( self._set_select_from([fromclause], set_entity_from) self._compile_options += { "_enable_single_crit": False, - "_statement": None, } # this enables clause adaptation for non-ORM @@ -2620,7 +2619,6 @@ class Query( roles.SelectStatementRole, statement, apply_propagate_attrs=self ) self._statement = statement - self._compile_options += {"_statement": statement} def first(self): """Return the first result of this ``Query`` or diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index f6943cc5f1..7b5fa2c733 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -275,7 +275,7 @@ class ORMExecuteState(util.MemoizedSlots): if not self.is_select: return None opts = self.statement._compile_options - if isinstance(opts, context.ORMCompileState.default_compile_options): + if opts.isinstance(context.ORMCompileState.default_compile_options): return opts else: return None diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 7f7bab6825..98c57149d3 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -939,9 +939,14 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots): ) stmt += lambda stmt: stmt.options(*opts) - stmt += lambda stmt: stmt._update_compile_options( - {"_current_path": effective_path} - ) + else: + # this path is used if there are not already any options + # in the query, but an event may want to add them + effective_path = state.mapper._path_registry[self.parent_property] + + stmt += lambda stmt: stmt._update_compile_options( + {"_current_path": effective_path} + ) if use_get: if self._raise_on_sql: @@ -2732,6 +2737,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): orm_util.Bundle("pk", *pk_cols), effective_entity ) .apply_labels() + ._set_compile_options(ORMCompileState.default_compile_options) ._set_propagate_attrs( { "compile_state_plugin": "orm", @@ -2769,7 +2775,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): q = q.add_criteria( lambda q: q.filter(in_expr.in_(sql.bindparam("primary_keys"))) ) - # a test which exercises what these comments talk about is # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic # diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index ff44ab27c9..5178a7ab13 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -553,6 +553,14 @@ class _MetaOptions(type): def __add__(self, other): o1 = self() + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + o1.__dict__.update(other) return o1 @@ -566,6 +574,14 @@ class Options(util.with_metaclass(_MetaOptions)): def __add__(self, other): o1 = self.__class__.__new__(self.__class__) o1.__dict__.update(self.__dict__) + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + o1.__dict__.update(other) return o1 @@ -589,6 +605,10 @@ class Options(util.with_metaclass(_MetaOptions)): ), ) + @classmethod + def isinstance(cls, klass): + return issubclass(cls, klass) + @hybridmethod def add_to_element(self, name, value): return self + {name: getattr(self, name) + value} diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 676152781d..aafdda4ce1 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -1021,7 +1021,12 @@ class PyWrapper(ColumnOperators): def __getattribute__(self, key): if key.startswith("_sa_"): return object.__getattribute__(self, key[4:]) - elif key in ("__clause_element__", "operate", "reverse_operate"): + elif key in ( + "__clause_element__", + "operate", + "reverse_operate", + "__class__", + ): return object.__getattribute__(self, key) if key.startswith("__"): diff --git a/test/orm/test_events.py b/test/orm/test_events.py index bc72d2f213..a046ba34c7 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -21,8 +21,10 @@ from sqlalchemy.orm import Mapper from sqlalchemy.orm import mapper from sqlalchemy.orm import query from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import subqueryload from sqlalchemy.orm.mapper import _mapper_registry from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message @@ -168,14 +170,10 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): }, ) - def test_flags(self): - User, Address = self.classes("User", "Address") - - sess = Session(testing.db, future=True) - + def _flag_fixture(self, session): canary = Mock() - @event.listens_for(sess, "do_orm_execute") + @event.listens_for(session, "do_orm_execute") def do_orm_execute(ctx): if not ctx.is_select: @@ -197,17 +195,21 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): else None, ) - u1 = sess.execute(select(User).filter_by(id=7)).scalar_one() + return canary - u1.addresses + def test_select_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + u1 = sess.execute(select(User).filter_by(id=7)).scalar_one() sess.expire(u1) eq_(u1.name, "jack") - sess.execute(delete(User).filter_by(id=18)) - sess.execute(update(User).filter_by(id=18).values(name="eighteen")) - eq_( canary.mock_calls, [ @@ -220,6 +222,32 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): is_column_load=False, lazy_loaded_from=None, ), + call.options( + is_select=True, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=True, + lazy_loaded_from=None, + ), + ], + ) + + def test_lazyload_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + u1 = sess.execute(select(User).filter_by(id=7)).scalar_one() + + u1.addresses + + eq_( + canary.mock_calls, + [ call.options( is_select=True, is_update=False, @@ -227,17 +255,107 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): is_orm_statement=True, is_relationship_load=False, is_column_load=False, + lazy_loaded_from=None, + ), + call.options( + is_select=True, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=True, + is_column_load=False, lazy_loaded_from=u1._sa_instance_state, ), + ], + ) + + def test_selectinload_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + u1 = sess.execute( + select(User).filter_by(id=7).options(selectinload(User.addresses)) + ).scalar_one() + + assert "addresses" in u1.__dict__ + + eq_( + canary.mock_calls, + [ + call.options( + is_select=True, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=False, + lazy_loaded_from=None, + ), + call.options( + is_select=True, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=True, + is_column_load=False, + lazy_loaded_from=None, + ), + ], + ) + + def test_subqueryload_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + u1 = sess.execute( + select(User).filter_by(id=7).options(subqueryload(User.addresses)) + ).scalar_one() + + assert "addresses" in u1.__dict__ + + eq_( + canary.mock_calls, + [ call.options( is_select=True, is_update=False, is_delete=False, is_orm_statement=True, is_relationship_load=False, - is_column_load=True, + is_column_load=False, + lazy_loaded_from=None, + ), + call.options( + is_select=True, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=True, + is_column_load=False, lazy_loaded_from=None, ), + ], + ) + + def test_update_delete_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + sess.execute(delete(User).filter_by(id=18)) + sess.execute(update(User).filter_by(id=18).values(name="eighteen")) + + eq_( + canary.mock_calls, + [ call.options( is_select=False, is_update=False, diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index c283e804e5..a70dc05116 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -22,6 +22,7 @@ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import ne_ from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.types import Boolean from sqlalchemy.types import Integer from sqlalchemy.types import String @@ -77,6 +78,41 @@ class DeferredLambdaTest( checkparams={"global_x_1": 10, "global_y_1": 9}, ) + def test_boolean_constants(self): + t1 = table("t1", column("q"), column("p")) + + def go(): + xy = True + stmt = select(t1).where(lambda: t1.c.q == xy) + return stmt + + self.assert_compile( + go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :xy_1" + ) + + def test_execute_boolean(self, boolean_table_fixture, connection): + boolean_data = boolean_table_fixture + + connection.execute( + boolean_data.insert(), + [{"id": 1, "data": True}, {"id": 2, "data": False}], + ) + + xy = True + + def go(): + stmt = select(lambda: boolean_data.c.id).where( + lambda: boolean_data.c.data == xy + ) + return connection.execute(stmt) + + result = go() + eq_(result.all(), [(1,)]) + + xy = False + result = go() + eq_(result.all(), [(2,)]) + def test_stale_checker_embedded(self): def go(x): @@ -761,6 +797,15 @@ class DeferredLambdaTest( ) return users, addresses + @testing.metadata_fixture() + def boolean_table_fixture(self, metadata): + return Table( + "boolean_data", + metadata, + Column("id", Integer, primary_key=True), + Column("data", Boolean), + ) + def test_adapt_select(self, user_address_fixture): users, addresses = user_address_fixture diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index a4b76f35d0..24a149ece6 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -15,6 +15,7 @@ from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures @@ -57,6 +58,34 @@ class MiscTest(fixtures.TestBase): {common, calias, subset_select}, ) + def test_incompatible_options_add_clslevel(self): + class opt1(sql_base.CacheableOptions): + _cache_key_traversal = [] + foo = "bar" + + with expect_raises_message( + TypeError, + "dictionary contains attributes not covered by " + "Options class .*opt1.* .*'bar'.*", + ): + o1 = opt1 + + o1 += {"foo": "f", "bar": "b"} + + def test_incompatible_options_add_instancelevel(self): + class opt1(sql_base.CacheableOptions): + _cache_key_traversal = [] + foo = "bar" + + o1 = opt1(foo="bat") + + with expect_raises_message( + TypeError, + "dictionary contains attributes not covered by " + "Options class .*opt1.* .*'bar'.*", + ): + o1 += {"foo": "f", "bar": "b"} + def test_options_merge(self): class opt1(sql_base.CacheableOptions): _cache_key_traversal = []