From: Mike Bayer Date: Mon, 1 Apr 2024 21:54:22 +0000 (-0400) Subject: set up is_from_statement and others for FromStatement X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d3222a31b8df97a454b37a32881dd484a06e5742;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git set up is_from_statement and others for FromStatement Added new attribute :attr:`_orm.ORMExecuteState.is_from_statement`, to detect statements of the form ``select().from_statement()``, and also enhanced ``FromStatement`` to set :attr:`_orm.ORMExecuteState.is_select`, :attr:`_orm.ORMExecuteState.is_insert`, :attr:`_orm.ORMExecuteState.is_update`, and :attr:`_orm.ORMExecuteState.is_delete` according to the element that is sent to the :meth:`_sql.Select.from_statement` method itself. Fixes: #11220 Change-Id: I3bf9e7e22fa2955d772b3b6ad636ed93a60916ae --- diff --git a/doc/build/changelog/unreleased_20/11220.rst b/doc/build/changelog/unreleased_20/11220.rst new file mode 100644 index 0000000000..4f04cbf23d --- /dev/null +++ b/doc/build/changelog/unreleased_20/11220.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, orm + :tickets: 11220 + + Added new attribute :attr:`_orm.ORMExecuteState.is_from_statement`, to + detect statements of the form ``select().from_statement()``, and also + enhanced ``FromStatement`` to set :attr:`_orm.ORMExecuteState.is_select`, + :attr:`_orm.ORMExecuteState.is_insert`, + :attr:`_orm.ORMExecuteState.is_update`, and + :attr:`_orm.ORMExecuteState.is_delete` according to the element that is + sent to the :meth:`_sql.Select.from_statement` method itself. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index dba3435a26..b62aae7b74 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -897,6 +897,8 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[Unpack[_Ts]]): ("_compile_options", InternalTraversal.dp_has_cache_key) ] + is_from_statement = True + def __init__( self, entities: Iterable[_ColumnsClauseArgument[Any]], @@ -914,6 +916,10 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[Unpack[_Ts]]): ] self.element = element self.is_dml = element.is_dml + self.is_select = element.is_select + self.is_delete = element.is_delete + self.is_insert = element.is_insert + self.is_update = element.is_update self._label_style = ( element._label_style if is_select_base(element) else None ) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 61006ccf0a..13b906fe24 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -580,22 +580,67 @@ class ORMExecuteState(util.MemoizedSlots): @property def is_select(self) -> bool: - """return True if this is a SELECT operation.""" + """return True if this is a SELECT operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Select` construct, such as + ``select(Entity).from_statement(select(..))`` + + """ return self.statement.is_select + @property + def is_from_statement(self) -> bool: + """return True if this operation is a + :meth:`_sql.Select.from_statement` operation. + + This is independent from :attr:`_orm.ORMExecuteState.is_select`, as a + ``select().from_statement()`` construct can be used with + INSERT/UPDATE/DELETE RETURNING types of statements as well. + :attr:`_orm.ORMExecuteState.is_select` will only be set if the + :meth:`_sql.Select.from_statement` is itself against a + :class:`_sql.Select` construct. + + .. versionadded:: 2.0.30 + + """ + return self.statement.is_from_statement + @property def is_insert(self) -> bool: - """return True if this is an INSERT operation.""" + """return True if this is an INSERT operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Insert` construct, such as + ``select(Entity).from_statement(insert(..))`` + + """ return self.statement.is_dml and self.statement.is_insert @property def is_update(self) -> bool: - """return True if this is an UPDATE operation.""" + """return True if this is an UPDATE operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Update` construct, such as + ``select(Entity).from_statement(update(..))`` + + """ return self.statement.is_dml and self.statement.is_update @property def is_delete(self) -> bool: - """return True if this is a DELETE operation.""" + """return True if this is a DELETE operation. + + .. versionchanged:: 2.0.30 - the attribute is also True for a + :meth:`_sql.Select.from_statement` construct that is itself against + a :class:`_sql.Delete` construct, such as + ``select(Entity).from_statement(delete(..))`` + + """ return self.statement.is_dml and self.statement.is_delete @property diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a7bc18c5a4..923e849589 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -1028,6 +1028,7 @@ class Executable(roles.StatementRole): ] is_select = False + is_from_statement = False is_update = False is_insert = False is_text = False diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 3af6aad86a..5e1672b526 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -385,6 +385,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=ctx.bind_mapper, all_mappers=ctx.all_mappers, is_select=ctx.is_select, + is_from_statement=ctx.is_from_statement, + is_insert=ctx.is_insert, is_update=ctx.is_update, is_delete=ctx.is_delete, is_orm_statement=ctx.is_orm_statement, @@ -421,6 +423,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=None, all_mappers=[], is_select=is_select, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=False, @@ -451,6 +455,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User), inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -475,6 +481,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=None, all_mappers=[], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=False, @@ -501,6 +509,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], # Address not in results is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -531,6 +541,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -542,6 +554,54 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=True, + lazy_loaded_from=None, + ), + ], + ) + + def test_select_from_statement_flags(self): + User, Address = self.classes("User", "Address") + + sess = Session(testing.db, future=True) + + canary = self._flag_fixture(sess) + + s1 = select(User).filter_by(id=7) + u1 = sess.execute(select(User).from_statement(s1)).scalar_one() + + sess.expire(u1) + + eq_(u1.name, "jack") + + eq_( + canary.mock_calls, + [ + call.options( + bind_mapper=inspect(User), + all_mappers=[inspect(User)], + is_select=True, + is_from_statement=True, + is_insert=False, + is_update=False, + is_delete=False, + is_orm_statement=True, + is_relationship_load=False, + is_column_load=False, + lazy_loaded_from=None, + ), + call.options( + bind_mapper=inspect(User), + all_mappers=[inspect(User)], + is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -570,6 +630,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -581,6 +643,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(Address), all_mappers=[inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -611,6 +675,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -622,6 +688,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(Address), all_mappers=[inspect(Address)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -652,6 +720,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -663,6 +733,8 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(Address), all_mappers=[inspect(Address), inspect(User)], is_select=True, + is_from_statement=False, + is_insert=False, is_update=False, is_delete=False, is_orm_statement=True, @@ -673,24 +745,44 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): ], ) - def test_update_delete_flags(self): + @testing.variation( + "stmt_type", + [ + ("insert", testing.requires.insert_returning), + ("update", testing.requires.update_returning), + ("delete", testing.requires.delete_returning), + ], + ) + @testing.variation("from_stmt", [True, False]) + def test_update_delete_flags(self, stmt_type, from_stmt): User, Address = self.classes("User", "Address") sess = Session(testing.db, future=True) canary = self._flag_fixture(sess) - sess.execute( - delete(User) - .filter_by(id=18) - .execution_options(synchronize_session="evaluate") - ) - sess.execute( - update(User) - .filter_by(id=18) - .values(name="eighteen") - .execution_options(synchronize_session="evaluate") - ) + if stmt_type.delete: + stmt = ( + delete(User) + .filter_by(id=18) + .execution_options(synchronize_session="evaluate") + ) + elif stmt_type.update: + stmt = ( + update(User) + .filter_by(id=18) + .values(name="eighteen") + .execution_options(synchronize_session="evaluate") + ) + elif stmt_type.insert: + stmt = insert(User).values(name="eighteen") + else: + stmt_type.fail() + + if from_stmt: + stmt = select(User).from_statement(stmt.returning(User)) + + sess.execute(stmt) eq_( canary.mock_calls, @@ -699,19 +791,10 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): bind_mapper=inspect(User), all_mappers=[inspect(User)], is_select=False, - is_update=False, - is_delete=True, - is_orm_statement=True, - is_relationship_load=False, - is_column_load=False, - lazy_loaded_from=None, - ), - call.options( - bind_mapper=inspect(User), - all_mappers=[inspect(User)], - is_select=False, - is_update=True, - is_delete=False, + is_from_statement=bool(from_stmt), + is_insert=stmt_type.insert, + is_update=stmt_type.update, + is_delete=stmt_type.delete, is_orm_statement=True, is_relationship_load=False, is_column_load=False,