]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fixes for lambda expressions and relationship loaders
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Dec 2020 00:01:12 +0000 (19:01 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 12 Dec 2020 03:01:57 +0000 (22:01 -0500)
Fixed bug in lambda SQL feature, used by ORM
:meth:`_orm.with_loader_criteria` as well as available generally in the SQL
expression language, where assigning a boolean value True/False to a
variable would cause the query-time expression calculation to fail, as it
would produce a SQL expression not compatible with a bound value.

Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load`
parameter would not be set correctly for many lazy loads, all
selectinloads, etc.  The flag is essential in order to test if options
should be added to statements or if they would already have been propagated
via relationship loads.

Fixes: #5763
Fixes: #5764
Change-Id: I66aafbef193f892ff75ede0670698647b7475482

doc/build/changelog/unreleased_14/5763.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/5764.rst [new file with mode: 0644]
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/lambdas.py
test/orm/test_events.py
test/sql/test_lambdas.py
test/sql/test_utils.py

diff --git a/doc/build/changelog/unreleased_14/5763.rst b/doc/build/changelog/unreleased_14/5763.rst
new file mode 100644 (file)
index 0000000..e395b6f
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+       :tags: bug, orm
+       :tickets: 5763
+
+       Fixed bug in lambda SQL feature, used by ORM
+       :meth:`_orm.with_loader_criteria` as well as available generally in the SQL
+       expression language, where assigning a boolean value True/False to a
+       variable would cause the query-time expression calculation to fail, as it
+       would produce a SQL expression not compatible with a bound value.
\ No newline at end of file
diff --git a/doc/build/changelog/unreleased_14/5764.rst b/doc/build/changelog/unreleased_14/5764.rst
new file mode 100644 (file)
index 0000000..29753fa
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+       :tags: orm, bug
+       :tickets: 5764
+
+       Fixed issue where the :attr:`_orm.ORMExecuteState.is_relationship_load`
+       attribute would not be set correctly for many lazy loads, all
+       selectinloads, etc.  The flag is essential in order to test if options
+       should be added to statements or if they would already have been propagated
+       via relationship loads.
\ No newline at end of file
index d7a2cb4092eac02aebd6188ffe10d9a79a24aacc..334283bb965c31d8088a4bae33da2a6634f30ef4 100644 (file)
@@ -1296,7 +1296,6 @@ class Query(
         self._set_select_from([fromclause], set_entity_from)
         self._compile_options += {
             "_enable_single_crit": False,
-            "_statement": None,
         }
 
         # this enables clause adaptation for non-ORM
@@ -2620,7 +2619,6 @@ class Query(
             roles.SelectStatementRole, statement, apply_propagate_attrs=self
         )
         self._statement = statement
-        self._compile_options += {"_statement": statement}
 
     def first(self):
         """Return the first result of this ``Query`` or
index f6943cc5f1c6bd5ca289dba6b43d16e0ae300a6b..7b5fa2c733c7ba98852292eb29b0228e7fd9184a 100644 (file)
@@ -275,7 +275,7 @@ class ORMExecuteState(util.MemoizedSlots):
         if not self.is_select:
             return None
         opts = self.statement._compile_options
-        if isinstance(opts, context.ORMCompileState.default_compile_options):
+        if opts.isinstance(context.ORMCompileState.default_compile_options):
             return opts
         else:
             return None
index 7f7bab68255d7260d497c27c787b206ab6586052..98c57149d341dd2cfbb5c225f7e671475ae0ac43 100644 (file)
@@ -939,9 +939,14 @@ class LazyLoader(AbstractRelationshipLoader, util.MemoizedSlots):
                 )
 
             stmt += lambda stmt: stmt.options(*opts)
-            stmt += lambda stmt: stmt._update_compile_options(
-                {"_current_path": effective_path}
-            )
+        else:
+            # this path is used if there are not already any options
+            # in the query, but an event may want to add them
+            effective_path = state.mapper._path_registry[self.parent_property]
+
+        stmt += lambda stmt: stmt._update_compile_options(
+            {"_current_path": effective_path}
+        )
 
         if use_get:
             if self._raise_on_sql:
@@ -2732,6 +2737,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
                 orm_util.Bundle("pk", *pk_cols), effective_entity
             )
             .apply_labels()
+            ._set_compile_options(ORMCompileState.default_compile_options)
             ._set_propagate_attrs(
                 {
                     "compile_state_plugin": "orm",
@@ -2769,7 +2775,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
         q = q.add_criteria(
             lambda q: q.filter(in_expr.in_(sql.bindparam("primary_keys")))
         )
-
         # a test which exercises what these comments talk about is
         # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic
         #
index ff44ab27c908519408da9b8a3866fca79f652b6c..5178a7ab1393aada9a233f73060c98ce15780788 100644 (file)
@@ -553,6 +553,14 @@ class _MetaOptions(type):
 
     def __add__(self, other):
         o1 = self()
+
+        if set(other).difference(self._cache_attrs):
+            raise TypeError(
+                "dictionary contains attributes not covered by "
+                "Options class %s: %r"
+                % (self, set(other).difference(self._cache_attrs))
+            )
+
         o1.__dict__.update(other)
         return o1
 
@@ -566,6 +574,14 @@ class Options(util.with_metaclass(_MetaOptions)):
     def __add__(self, other):
         o1 = self.__class__.__new__(self.__class__)
         o1.__dict__.update(self.__dict__)
+
+        if set(other).difference(self._cache_attrs):
+            raise TypeError(
+                "dictionary contains attributes not covered by "
+                "Options class %s: %r"
+                % (self, set(other).difference(self._cache_attrs))
+            )
+
         o1.__dict__.update(other)
         return o1
 
@@ -589,6 +605,10 @@ class Options(util.with_metaclass(_MetaOptions)):
             ),
         )
 
+    @classmethod
+    def isinstance(cls, klass):
+        return issubclass(cls, klass)
+
     @hybridmethod
     def add_to_element(self, name, value):
         return self + {name: getattr(self, name) + value}
index 676152781daecb504e3942f922e54e89b06fe994..aafdda4ce14fe27cf8292399abb680b7e362d870 100644 (file)
@@ -1021,7 +1021,12 @@ class PyWrapper(ColumnOperators):
     def __getattribute__(self, key):
         if key.startswith("_sa_"):
             return object.__getattribute__(self, key[4:])
-        elif key in ("__clause_element__", "operate", "reverse_operate"):
+        elif key in (
+            "__clause_element__",
+            "operate",
+            "reverse_operate",
+            "__class__",
+        ):
             return object.__getattribute__(self, key)
 
         if key.startswith("__"):
index bc72d2f2131b4a658d6af59a94da02ffc940bc8e..a046ba34c7629972ba44ddb61d32d9b3ce3ec2bb 100644 (file)
@@ -21,8 +21,10 @@ from sqlalchemy.orm import Mapper
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import query
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
+from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm.mapper import _mapper_registry
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -168,14 +170,10 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
             },
         )
 
-    def test_flags(self):
-        User, Address = self.classes("User", "Address")
-
-        sess = Session(testing.db, future=True)
-
+    def _flag_fixture(self, session):
         canary = Mock()
 
-        @event.listens_for(sess, "do_orm_execute")
+        @event.listens_for(session, "do_orm_execute")
         def do_orm_execute(ctx):
 
             if not ctx.is_select:
@@ -197,17 +195,21 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
                 else None,
             )
 
-        u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
+        return canary
 
-        u1.addresses
+    def test_select_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
 
         sess.expire(u1)
 
         eq_(u1.name, "jack")
 
-        sess.execute(delete(User).filter_by(id=18))
-        sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
-
         eq_(
             canary.mock_calls,
             [
@@ -220,6 +222,32 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
                     is_column_load=False,
                     lazy_loaded_from=None,
                 ),
+                call.options(
+                    is_select=True,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=False,
+                    is_column_load=True,
+                    lazy_loaded_from=None,
+                ),
+            ],
+        )
+
+    def test_lazyload_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        u1 = sess.execute(select(User).filter_by(id=7)).scalar_one()
+
+        u1.addresses
+
+        eq_(
+            canary.mock_calls,
+            [
                 call.options(
                     is_select=True,
                     is_update=False,
@@ -227,17 +255,107 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
                     is_orm_statement=True,
                     is_relationship_load=False,
                     is_column_load=False,
+                    lazy_loaded_from=None,
+                ),
+                call.options(
+                    is_select=True,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=True,
+                    is_column_load=False,
                     lazy_loaded_from=u1._sa_instance_state,
                 ),
+            ],
+        )
+
+    def test_selectinload_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        u1 = sess.execute(
+            select(User).filter_by(id=7).options(selectinload(User.addresses))
+        ).scalar_one()
+
+        assert "addresses" in u1.__dict__
+
+        eq_(
+            canary.mock_calls,
+            [
+                call.options(
+                    is_select=True,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=False,
+                    is_column_load=False,
+                    lazy_loaded_from=None,
+                ),
+                call.options(
+                    is_select=True,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=True,
+                    is_column_load=False,
+                    lazy_loaded_from=None,
+                ),
+            ],
+        )
+
+    def test_subqueryload_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        u1 = sess.execute(
+            select(User).filter_by(id=7).options(subqueryload(User.addresses))
+        ).scalar_one()
+
+        assert "addresses" in u1.__dict__
+
+        eq_(
+            canary.mock_calls,
+            [
                 call.options(
                     is_select=True,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
                     is_relationship_load=False,
-                    is_column_load=True,
+                    is_column_load=False,
+                    lazy_loaded_from=None,
+                ),
+                call.options(
+                    is_select=True,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=True,
+                    is_column_load=False,
                     lazy_loaded_from=None,
                 ),
+            ],
+        )
+
+    def test_update_delete_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        sess.execute(delete(User).filter_by(id=18))
+        sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
+
+        eq_(
+            canary.mock_calls,
+            [
                 call.options(
                     is_select=False,
                     is_update=False,
index c283e804e50a9f820fc60b3d550f1453207e5e56..a70dc051165a8217525760dc1c4a13ea795960db 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.types import Boolean
 from sqlalchemy.types import Integer
 from sqlalchemy.types import String
 
@@ -77,6 +78,41 @@ class DeferredLambdaTest(
             checkparams={"global_x_1": 10, "global_y_1": 9},
         )
 
+    def test_boolean_constants(self):
+        t1 = table("t1", column("q"), column("p"))
+
+        def go():
+            xy = True
+            stmt = select(t1).where(lambda: t1.c.q == xy)
+            return stmt
+
+        self.assert_compile(
+            go(), "SELECT t1.q, t1.p FROM t1 WHERE t1.q = :xy_1"
+        )
+
+    def test_execute_boolean(self, boolean_table_fixture, connection):
+        boolean_data = boolean_table_fixture
+
+        connection.execute(
+            boolean_data.insert(),
+            [{"id": 1, "data": True}, {"id": 2, "data": False}],
+        )
+
+        xy = True
+
+        def go():
+            stmt = select(lambda: boolean_data.c.id).where(
+                lambda: boolean_data.c.data == xy
+            )
+            return connection.execute(stmt)
+
+        result = go()
+        eq_(result.all(), [(1,)])
+
+        xy = False
+        result = go()
+        eq_(result.all(), [(2,)])
+
     def test_stale_checker_embedded(self):
         def go(x):
 
@@ -761,6 +797,15 @@ class DeferredLambdaTest(
         )
         return users, addresses
 
+    @testing.metadata_fixture()
+    def boolean_table_fixture(self, metadata):
+        return Table(
+            "boolean_data",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", Boolean),
+        )
+
     def test_adapt_select(self, user_address_fixture):
         users, addresses = user_address_fixture
 
index a4b76f35d0378311916a961ab366f8b44817aeb9..24a149ece6ae664d66371c317b04293b63f66784 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 
 
@@ -57,6 +58,34 @@ class MiscTest(fixtures.TestBase):
             {common, calias, subset_select},
         )
 
+    def test_incompatible_options_add_clslevel(self):
+        class opt1(sql_base.CacheableOptions):
+            _cache_key_traversal = []
+            foo = "bar"
+
+        with expect_raises_message(
+            TypeError,
+            "dictionary contains attributes not covered by "
+            "Options class .*opt1.* .*'bar'.*",
+        ):
+            o1 = opt1
+
+            o1 += {"foo": "f", "bar": "b"}
+
+    def test_incompatible_options_add_instancelevel(self):
+        class opt1(sql_base.CacheableOptions):
+            _cache_key_traversal = []
+            foo = "bar"
+
+        o1 = opt1(foo="bat")
+
+        with expect_raises_message(
+            TypeError,
+            "dictionary contains attributes not covered by "
+            "Options class .*opt1.* .*'bar'.*",
+        ):
+            o1 += {"foo": "f", "bar": "b"}
+
     def test_options_merge(self):
         class opt1(sql_base.CacheableOptions):
             _cache_key_traversal = []