]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
set up is_from_statement and others for FromStatement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Apr 2024 21:54:22 +0000 (17:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 1 Apr 2024 23:29:30 +0000 (19:29 -0400)
Added new attribute :attr:`_orm.ORMExecuteState.is_from_statement`, to
detect statements of the form ``select().from_statement()``, and also
enhanced ``FromStatement`` to set :attr:`_orm.ORMExecuteState.is_select`,
:attr:`_orm.ORMExecuteState.is_insert`,
:attr:`_orm.ORMExecuteState.is_update`, and
:attr:`_orm.ORMExecuteState.is_delete` according to the element that is
sent to the :meth:`_sql.Select.from_statement` method itself.

Fixes: #11220
Change-Id: I3bf9e7e22fa2955d772b3b6ad636ed93a60916ae
(cherry picked from commit d3222a31b8df97a454b37a32881dd484a06e5742)

doc/build/changelog/unreleased_20/11220.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/base.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_20/11220.rst b/doc/build/changelog/unreleased_20/11220.rst
new file mode 100644 (file)
index 0000000..4f04cbf
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11220
+
+    Added new attribute :attr:`_orm.ORMExecuteState.is_from_statement`, to
+    detect statements of the form ``select().from_statement()``, and also
+    enhanced ``FromStatement`` to set :attr:`_orm.ORMExecuteState.is_select`,
+    :attr:`_orm.ORMExecuteState.is_insert`,
+    :attr:`_orm.ORMExecuteState.is_update`, and
+    :attr:`_orm.ORMExecuteState.is_delete` according to the element that is
+    sent to the :meth:`_sql.Select.from_statement` method itself.
index 3056016e729b0fe72d9703e8dea7c6ef32424643..fcd01e659161ee016f0eec9ba324293d33874678 100644 (file)
@@ -888,6 +888,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
         ("_compile_options", InternalTraversal.dp_has_cache_key)
     ]
 
+    is_from_statement = True
+
     def __init__(
         self,
         entities: Iterable[_ColumnsClauseArgument[Any]],
@@ -905,6 +907,10 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
         ]
         self.element = element
         self.is_dml = element.is_dml
+        self.is_select = element.is_select
+        self.is_delete = element.is_delete
+        self.is_insert = element.is_insert
+        self.is_update = element.is_update
         self._label_style = (
             element._label_style if is_select_base(element) else None
         )
index 3eba5aaf411b73ba203dd8c1dbfd60ab8d5f3230..acc6895e86f36fad32ab17ea81537a39650127ae 100644 (file)
@@ -575,22 +575,67 @@ class ORMExecuteState(util.MemoizedSlots):
 
     @property
     def is_select(self) -> bool:
-        """return True if this is a SELECT operation."""
+        """return True if this is a SELECT operation.
+
+        .. versionchanged:: 2.0.30 - the attribute is also True for a
+           :meth:`_sql.Select.from_statement` construct that is itself against
+           a :class:`_sql.Select` construct, such as
+           ``select(Entity).from_statement(select(..))``
+
+        """
         return self.statement.is_select
 
+    @property
+    def is_from_statement(self) -> bool:
+        """return True if this operation is a
+        :meth:`_sql.Select.from_statement` operation.
+
+        This is independent from :attr:`_orm.ORMExecuteState.is_select`, as a
+        ``select().from_statement()`` construct can be used with
+        INSERT/UPDATE/DELETE RETURNING types of statements as well.
+        :attr:`_orm.ORMExecuteState.is_select` will only be set if the
+        :meth:`_sql.Select.from_statement` is itself against a
+        :class:`_sql.Select` construct.
+
+        .. versionadded:: 2.0.30
+
+        """
+        return self.statement.is_from_statement
+
     @property
     def is_insert(self) -> bool:
-        """return True if this is an INSERT operation."""
+        """return True if this is an INSERT operation.
+
+        .. versionchanged:: 2.0.30 - the attribute is also True for a
+           :meth:`_sql.Select.from_statement` construct that is itself against
+           a :class:`_sql.Insert` construct, such as
+           ``select(Entity).from_statement(insert(..))``
+
+        """
         return self.statement.is_dml and self.statement.is_insert
 
     @property
     def is_update(self) -> bool:
-        """return True if this is an UPDATE operation."""
+        """return True if this is an UPDATE operation.
+
+        .. versionchanged:: 2.0.30 - the attribute is also True for a
+           :meth:`_sql.Select.from_statement` construct that is itself against
+           a :class:`_sql.Update` construct, such as
+           ``select(Entity).from_statement(update(..))``
+
+        """
         return self.statement.is_dml and self.statement.is_update
 
     @property
     def is_delete(self) -> bool:
-        """return True if this is a DELETE operation."""
+        """return True if this is a DELETE operation.
+
+        .. versionchanged:: 2.0.30 - the attribute is also True for a
+           :meth:`_sql.Select.from_statement` construct that is itself against
+           a :class:`_sql.Delete` construct, such as
+           ``select(Entity).from_statement(delete(..))``
+
+        """
         return self.statement.is_dml and self.statement.is_delete
 
     @property
index 5eb32e30dd4d1331845937f5eaf4513764c2c268..1a65b653ea2dc36a6acf5d874ed02c12a1803d8f 100644 (file)
@@ -1029,6 +1029,7 @@ class Executable(roles.StatementRole):
     ]
 
     is_select = False
+    is_from_statement = False
     is_update = False
     is_insert = False
     is_text = False
index 3af6aad86aae9eb5bfacc81e53d5cf970da3a32f..5e1672b526b62cabec4c7856a51968bd95e2de74 100644 (file)
@@ -385,6 +385,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                 bind_mapper=ctx.bind_mapper,
                 all_mappers=ctx.all_mappers,
                 is_select=ctx.is_select,
+                is_from_statement=ctx.is_from_statement,
+                is_insert=ctx.is_insert,
                 is_update=ctx.is_update,
                 is_delete=ctx.is_delete,
                 is_orm_statement=ctx.is_orm_statement,
@@ -421,6 +423,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=None,
                     all_mappers=[],
                     is_select=is_select,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=False,
@@ -451,6 +455,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User), inspect(Address)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -475,6 +481,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=None,
                     all_mappers=[],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=False,
@@ -501,6 +509,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],  # Address not in results
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -531,6 +541,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -542,6 +554,54 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=False,
+                    is_column_load=True,
+                    lazy_loaded_from=None,
+                ),
+            ],
+        )
+
+    def test_select_from_statement_flags(self):
+        User, Address = self.classes("User", "Address")
+
+        sess = Session(testing.db, future=True)
+
+        canary = self._flag_fixture(sess)
+
+        s1 = select(User).filter_by(id=7)
+        u1 = sess.execute(select(User).from_statement(s1)).scalar_one()
+
+        sess.expire(u1)
+
+        eq_(u1.name, "jack")
+
+        eq_(
+            canary.mock_calls,
+            [
+                call.options(
+                    bind_mapper=inspect(User),
+                    all_mappers=[inspect(User)],
+                    is_select=True,
+                    is_from_statement=True,
+                    is_insert=False,
+                    is_update=False,
+                    is_delete=False,
+                    is_orm_statement=True,
+                    is_relationship_load=False,
+                    is_column_load=False,
+                    lazy_loaded_from=None,
+                ),
+                call.options(
+                    bind_mapper=inspect(User),
+                    all_mappers=[inspect(User)],
+                    is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -570,6 +630,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -581,6 +643,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(Address),
                     all_mappers=[inspect(Address)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -611,6 +675,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -622,6 +688,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(Address),
                     all_mappers=[inspect(Address)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -652,6 +720,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -663,6 +733,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(Address),
                     all_mappers=[inspect(Address), inspect(User)],
                     is_select=True,
+                    is_from_statement=False,
+                    is_insert=False,
                     is_update=False,
                     is_delete=False,
                     is_orm_statement=True,
@@ -673,24 +745,44 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
             ],
         )
 
-    def test_update_delete_flags(self):
+    @testing.variation(
+        "stmt_type",
+        [
+            ("insert", testing.requires.insert_returning),
+            ("update", testing.requires.update_returning),
+            ("delete", testing.requires.delete_returning),
+        ],
+    )
+    @testing.variation("from_stmt", [True, False])
+    def test_update_delete_flags(self, stmt_type, from_stmt):
         User, Address = self.classes("User", "Address")
 
         sess = Session(testing.db, future=True)
 
         canary = self._flag_fixture(sess)
 
-        sess.execute(
-            delete(User)
-            .filter_by(id=18)
-            .execution_options(synchronize_session="evaluate")
-        )
-        sess.execute(
-            update(User)
-            .filter_by(id=18)
-            .values(name="eighteen")
-            .execution_options(synchronize_session="evaluate")
-        )
+        if stmt_type.delete:
+            stmt = (
+                delete(User)
+                .filter_by(id=18)
+                .execution_options(synchronize_session="evaluate")
+            )
+        elif stmt_type.update:
+            stmt = (
+                update(User)
+                .filter_by(id=18)
+                .values(name="eighteen")
+                .execution_options(synchronize_session="evaluate")
+            )
+        elif stmt_type.insert:
+            stmt = insert(User).values(name="eighteen")
+        else:
+            stmt_type.fail()
+
+        if from_stmt:
+            stmt = select(User).from_statement(stmt.returning(User))
+
+        sess.execute(stmt)
 
         eq_(
             canary.mock_calls,
@@ -699,19 +791,10 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
                     bind_mapper=inspect(User),
                     all_mappers=[inspect(User)],
                     is_select=False,
-                    is_update=False,
-                    is_delete=True,
-                    is_orm_statement=True,
-                    is_relationship_load=False,
-                    is_column_load=False,
-                    lazy_loaded_from=None,
-                ),
-                call.options(
-                    bind_mapper=inspect(User),
-                    all_mappers=[inspect(User)],
-                    is_select=False,
-                    is_update=True,
-                    is_delete=False,
+                    is_from_statement=bool(from_stmt),
+                    is_insert=stmt_type.insert,
+                    is_update=stmt_type.update,
+                    is_delete=stmt_type.delete,
                     is_orm_statement=True,
                     is_relationship_load=False,
                     is_column_load=False,