From 871e66e058fafae8f54d68359370c022d51059e1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 21 Nov 2025 13:44:31 -0500 Subject: [PATCH] break up test_update_delete_where into several files it's grown to 3K lines and i need to add more tests Change-Id: Ic55c50446d57789e3e3cc58e6a99239c1bfa8328 --- test/orm/dml/test_orm_upd_del_assorted.py | 433 ++++++++ ...ete_where.py => test_orm_upd_del_basic.py} | 923 ------------------ test/orm/dml/test_orm_upd_del_inheritance.py | 522 ++++++++++ 3 files changed, 955 insertions(+), 923 deletions(-) create mode 100644 test/orm/dml/test_orm_upd_del_assorted.py rename test/orm/dml/{test_update_delete_where.py => test_orm_upd_del_basic.py} (73%) create mode 100644 test/orm/dml/test_orm_upd_del_inheritance.py diff --git a/test/orm/dml/test_orm_upd_del_assorted.py b/test/orm/dml/test_orm_upd_del_assorted.py new file mode 100644 index 0000000000..53fe0acc55 --- /dev/null +++ b/test/orm/dml/test_orm_upd_del_assorted.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import uuid + +from sqlalchemy import Computed +from sqlalchemy import delete +from sqlalchemy import FetchedValue +from sqlalchemy import insert +from sqlalchemy import Integer +from sqlalchemy import literal +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.entities import ComparableEntity +from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +class LoadFromReturningTest(fixtures.MappedTest): + __backend__ = True + __requires__ = ("insert_returning",) + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("name", String(32)), + Column("age_int", Integer), + ) + + @classmethod + def setup_classes(cls): + class User(cls.Comparable): + pass + + class Address(cls.Comparable): + pass + + @classmethod + def insert_data(cls, connection): + users = cls.tables.users + + connection.execute( + users.insert(), + [ + dict(id=1, name="john", age_int=25), + dict(id=2, name="jack", age_int=47), + dict(id=3, name="jill", age_int=29), + dict(id=4, name="jane", age_int=37), + ], + ) + + @classmethod + def setup_mappers(cls): + User = cls.classes.User + users = cls.tables.users + + cls.mapper_registry.map_imperatively( + User, + users, + properties={ + "age": users.c.age_int, + }, + ) + + @testing.requires.update_returning + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_update(self, connection, use_from_statement): + User = self.classes.User + + stmt = ( + update(User) + .where(User.name.in_(["jack", "jill"])) + .values(age=User.age + 5) + .returning(User) + ) + + if use_from_statement: + # this is now a legacy-ish case, because as of 2.0 you can just + # use returning() directly to get the objects back. + # + # when from_statement is used, the UPDATE statement is no + # longer interpreted by + # BulkUDCompileState.orm_pre_session_exec or + # BulkUDCompileState.orm_setup_cursor_result. The compilation + # level routines still take place though + stmt = select(User).from_statement(stmt) + + with Session(connection) as sess: + rows = sess.execute(stmt).scalars().all() + + eq_( + rows, + [User(name="jack", age=52), User(name="jill", age=34)], + ) + + @testing.combinations( + ("single",), + ("multiple", testing.requires.multivalues_inserts), + argnames="params", + ) + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_insert(self, connection, params, use_from_statement): + User = self.classes.User + + if params == "multiple": + values = [ + {User.id: 5, User.age: 25, User.name: "spongebob"}, + {User.id: 6, User.age: 30, User.name: "patrick"}, + {User.id: 7, User.age: 35, User.name: "squidward"}, + ] + elif params == "single": + values = {User.id: 5, User.age: 25, User.name: "spongebob"} + else: + assert False + + stmt = insert(User).values(values).returning(User) + + if use_from_statement: + stmt = select(User).from_statement(stmt) + + with Session(connection) as sess: + rows = sess.execute(stmt).scalars().all() + + if params == "multiple": + eq_( + rows, + [ + User(name="spongebob", age=25), + User(name="patrick", age=30), + User(name="squidward", age=35), + ], + ) + elif params == "single": + eq_( + rows, + [User(name="spongebob", age=25)], + ) + else: + assert False + + @testing.requires.delete_returning + @testing.combinations(True, False, argnames="use_from_statement") + def test_load_from_delete(self, connection, use_from_statement): + User = self.classes.User + + stmt = ( + delete(User).where(User.name.in_(["jack", "jill"])).returning(User) + ) + + if use_from_statement: + stmt = select(User).from_statement(stmt) + + with Session(connection) as sess: + rows = sess.execute(stmt).scalars().all() + + eq_( + rows, + [User(name="jack", age=47), User(name="jill", age=29)], + ) + + # TODO: state of above objects should be "deleted" + + +class OnUpdatePopulationTest(fixtures.TestBase): + __backend__ = True + + @testing.variation("populate_existing", [True, False]) + @testing.variation( + "use_onupdate", + [ + "none", + "server", + "callable", + "clientsql", + ("computed", testing.requires.computed_columns), + ], + ) + @testing.variation( + "use_returning", + [ + ("returning", testing.requires.update_returning), + ("defaults", testing.requires.update_returning), + "none", + ], + ) + @testing.variation("synchronize", ["auto", "fetch", "evaluate"]) + @testing.variation("pk_order", ["first", "middle"]) + def test_update_populate_existing( + self, + decl_base, + populate_existing, + use_onupdate, + use_returning, + synchronize, + pk_order, + ): + """test #11912 and #11917""" + + class Employee(ComparableEntity, decl_base): + __tablename__ = "employee" + + if pk_order.first: + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + user_name: Mapped[str] = mapped_column(String(200), nullable=False) + + if pk_order.middle: + uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) + + if use_onupdate.server: + some_server_value: Mapped[str] = mapped_column( + server_onupdate=FetchedValue() + ) + elif use_onupdate.callable: + some_server_value: Mapped[str] = mapped_column( + onupdate=lambda: "value 2" + ) + elif use_onupdate.clientsql: + some_server_value: Mapped[str] = mapped_column( + onupdate=literal("value 2") + ) + elif use_onupdate.computed: + some_server_value: Mapped[str] = mapped_column( + String(255), + Computed(user_name + " computed value"), + nullable=True, + ) + else: + some_server_value: Mapped[str] + + decl_base.metadata.create_all(testing.db) + s = fixture_session() + + uuid1 = uuid.uuid4() + + if use_onupdate.computed: + server_old_value, server_new_value = ( + "e1 old name computed value", + "e1 new name computed value", + ) + e1 = Employee(uuid=uuid1, user_name="e1 old name") + else: + server_old_value, server_new_value = ("value 1", "value 2") + e1 = Employee( + uuid=uuid1, + user_name="e1 old name", + some_server_value="value 1", + ) + s.add(e1) + s.flush() + + stmt = ( + update(Employee) + .values(user_name="e1 new name") + .where(Employee.uuid == uuid1) + ) + + if use_returning.returning: + stmt = stmt.returning(Employee) + elif use_returning.defaults: + # NOTE: the return_defaults case here has not been analyzed for + # #11912 or #11917. future enhancements may change its behavior + stmt = stmt.return_defaults() + + # perform out of band UPDATE on server value to simulate + # a computed col + if use_onupdate.none or use_onupdate.server: + s.connection().execute( + update(Employee.__table__).values(some_server_value="value 2") + ) + + execution_options = {} + + if populate_existing: + execution_options["populate_existing"] = True + + if synchronize.evaluate: + execution_options["synchronize_session"] = "evaluate" + if synchronize.fetch: + execution_options["synchronize_session"] = "fetch" + + if use_returning.returning: + rows = s.scalars(stmt, execution_options=execution_options) + else: + s.execute(stmt, execution_options=execution_options) + + if ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + ): + if not use_returning.defaults: + # if server-side onupdate was generated, the col should have + # been expired + assert "some_server_value" not in e1.__dict__ + + # and refreshes when called. this is even if we have RETURNING + # rows we didn't fetch yet. + eq_(e1.some_server_value, server_new_value) + else: + # using return defaults here is not expiring. have not + # researched why, it may be because the explicit + # return_defaults interferes with the ORMs call + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + elif use_onupdate.callable: + if not use_returning.defaults or not synchronize.fetch: + # for python-side onupdate, col is populated with local value + assert "some_server_value" in e1.__dict__ + + # and is refreshed + eq_(e1.some_server_value, server_new_value) + else: + assert "some_server_value" in e1.__dict__ + + # and is not refreshed + eq_(e1.some_server_value, server_old_value) + + else: + # no onupdate, then the value was not touched yet, + # even if we used RETURNING with populate_existing, because + # we did not fetch the rows yet + assert "some_server_value" in e1.__dict__ + eq_(e1.some_server_value, server_old_value) + + # now see if we can fetch rows + if use_returning.returning: + + if populate_existing or not use_onupdate.none: + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_new_value, + ), + }, + ) + + else: + # if no populate existing and no server default, that column + # is not touched at all + eq_( + set(rows), + { + Employee( + uuid=uuid1, + user_name="e1 new name", + some_server_value=server_old_value, + ), + }, + ) + + if use_returning.defaults: + # as mentioned above, the return_defaults() case here remains + # unanalyzed. + if synchronize.fetch or ( + use_onupdate.clientsql + or use_onupdate.server + or use_onupdate.computed + or use_onupdate.none + ): + eq_(e1.some_server_value, server_old_value) + else: + eq_(e1.some_server_value, server_new_value) + + elif ( + populate_existing and use_returning.returning + ) or not use_onupdate.none: + eq_(e1.some_server_value, server_new_value) + else: + # no onupdate specified, and no populate existing with returning, + # the attribute is not refreshed + eq_(e1.some_server_value, server_old_value) + + # do a full expire, now the new value is definitely there + s.commit() + s.expire_all() + eq_(e1.some_server_value, server_new_value) + + +class PGIssue11849Test(fixtures.DeclarativeMappedTest): + __backend__ = True + __only_on__ = ("postgresql",) + + @classmethod + def setup_classes(cls): + + from sqlalchemy.dialects.postgresql import JSONB + + Base = cls.DeclarativeBasic + + class TestTbl(Base): + __tablename__ = "testtbl" + + test_id = Column(Integer, primary_key=True) + test_field = Column(JSONB) + + def test_issue_11849(self): + TestTbl = self.classes.TestTbl + + session = fixture_session() + + obj = TestTbl( + test_id=1, test_field={"test1": 1, "test2": "2", "test3": [3, "3"]} + ) + session.add(obj) + + query = ( + update(TestTbl) + .where(TestTbl.test_id == 1) + .values(test_field=TestTbl.test_field + {"test3": {"test4": 4}}) + ) + session.execute(query) + + # not loaded + assert "test_field" not in obj.__dict__ + + # synchronizes on load + eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}}) diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_orm_upd_del_basic.py similarity index 73% rename from test/orm/dml/test_update_delete_where.py rename to test/orm/dml/test_orm_upd_del_basic.py index 3463fae907..4df674facd 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_orm_upd_del_basic.py @@ -1,22 +1,17 @@ from __future__ import annotations -import uuid - from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column -from sqlalchemy import Computed from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc -from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import lambda_stmt -from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import or_ @@ -32,12 +27,9 @@ from sqlalchemy.orm import Bundle from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import immediateload from sqlalchemy.orm import joinedload -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session -from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import synonym from sqlalchemy.orm import with_loader_criteria from sqlalchemy.sql.dml import Delete @@ -45,7 +37,6 @@ from sqlalchemy.sql.dml import Update from sqlalchemy.sql.selectable import Select from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures @@ -53,7 +44,6 @@ from sqlalchemy.testing import in_ from sqlalchemy.testing import not_in from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -2728,916 +2718,3 @@ class ExpressionUpdateDeleteTest(fixtures.MappedTest): delete_stmt = m1.mock_calls[0][1][0] eq_(delete_stmt.dialect_kwargs, delete_args) - - -class InheritTest(fixtures.DeclarativeMappedTest): - run_inserts = "each" - - run_deletes = "each" - __backend__ = True - - @classmethod - def setup_classes(cls): - Base = cls.DeclarativeBasic - - class Person(Base): - __tablename__ = "person" - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) - type = Column(String(50)) - name = Column(String(50)) - - class Engineer(Person): - __tablename__ = "engineer" - id = Column(Integer, ForeignKey("person.id"), primary_key=True) - engineer_name = Column(String(50)) - - class Programmer(Engineer): - __tablename__ = "programmer" - id = Column(Integer, ForeignKey("engineer.id"), primary_key=True) - primary_language = Column(String(50)) - - class Manager(Person): - __tablename__ = "manager" - id = Column(Integer, ForeignKey("person.id"), primary_key=True) - manager_name = Column(String(50)) - - @classmethod - def insert_data(cls, connection): - Engineer, Person, Manager, Programmer = ( - cls.classes.Engineer, - cls.classes.Person, - cls.classes.Manager, - cls.classes.Programmer, - ) - s = Session(connection) - s.add_all( - [ - Engineer(name="e1", engineer_name="e1"), - Manager(name="m1", manager_name="m1"), - Engineer(name="e2", engineer_name="e2"), - Person(name="p1"), - Programmer( - name="pp1", engineer_name="pp1", primary_language="python" - ), - ] - ) - s.commit() - - @testing.only_on(["mysql", "mariadb"], "Multi table update") - def test_update_from_join_no_problem(self): - person = self.classes.Person.__table__ - engineer = self.classes.Engineer.__table__ - - sess = fixture_session() - sess.query(person.join(engineer)).filter(person.c.name == "e2").update( - {person.c.name: "updated", engineer.c.engineer_name: "e2a"}, - ) - obj = sess.execute( - select(self.classes.Engineer).filter( - self.classes.Engineer.name == "updated" - ) - ).scalar() - eq_(obj.name, "updated") - eq_(obj.engineer_name, "e2a") - - @testing.combinations(None, "fetch", "evaluate") - def test_update_sub_table_only(self, synchronize_session): - Engineer = self.classes.Engineer - s = Session(testing.db) - s.query(Engineer).update( - {"engineer_name": "e5"}, synchronize_session=synchronize_session - ) - - eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",), ("e5",)]) - - @testing.combinations(None, "fetch", "evaluate") - def test_update_sub_sub_table_only(self, synchronize_session): - Programmer = self.classes.Programmer - s = Session(testing.db) - s.query(Programmer).update( - {"primary_language": "c++"}, - synchronize_session=synchronize_session, - ) - - eq_( - s.query(Programmer.primary_language).all(), - [ - ("c++",), - ], - ) - - @testing.requires.update_from - @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate") - def test_update_from(self, synchronize_session): - """test an UPDATE that uses multiple tables. - - The limitation that MariaDB has with DELETE does not apply here at the - moment as MariaDB doesn't support UPDATE..RETURNING at all. However, - the logic from DELETE is still implemented in persistence.py. If - MariaDB adds UPDATE...RETURNING, then it may be useful. SQLite, - PostgreSQL, MSSQL all support UPDATE..FROM however RETURNING seems to - function correctly for all three. - - """ - Engineer = self.classes.Engineer - Person = self.classes.Person - s = Session(testing.db) - - # we don't have any backends with this combination right now. - db_has_hypothetical_limitation = ( - testing.db.dialect.update_returning - and not testing.db.dialect.update_returning_multifrom - ) - - e2 = s.query(Engineer).filter_by(name="e2").first() - - with self.sql_execution_asserter() as asserter: - eq_(e2.engineer_name, "e2") - q = ( - s.query(Engineer) - .filter(Engineer.id == Person.id) - .filter(Person.name == "e2") - ) - if synchronize_session == "fetch_w_hint": - q.execution_options(is_update_from=True).update( - {"engineer_name": "e5"}, - synchronize_session="fetch", - ) - elif ( - synchronize_session == "fetch" - and db_has_hypothetical_limitation - ): - with expect_raises_message( - exc.CompileError, - 'Dialect ".*" does not support RETURNING with ' - "UPDATE..FROM;", - ): - q.update( - {"engineer_name": "e5"}, - synchronize_session=synchronize_session, - ) - return - else: - q.update( - {"engineer_name": "e5"}, - synchronize_session=synchronize_session, - ) - - if synchronize_session is None: - eq_(e2.engineer_name, "e2") - else: - eq_(e2.engineer_name, "e5") - - if synchronize_session in ("fetch", "fetch_w_hint") and ( - db_has_hypothetical_limitation - or not testing.db.dialect.update_returning - ): - asserter.assert_( - CompiledSQL( - "SELECT person.id FROM person INNER JOIN engineer " - "ON person.id = engineer.id WHERE engineer.id = person.id " - "AND person.name = %s", - [{"name_1": "e2"}], - dialect="mariadb", - ), - CompiledSQL( - "UPDATE engineer, person SET engineer.engineer_name=%s " - "WHERE engineer.id = person.id AND person.name = %s", - [{"engineer_name": "e5", "name_1": "e2"}], - dialect="mariadb", - ), - ) - elif synchronize_session in ("fetch", "fetch_w_hint"): - asserter.assert_( - CompiledSQL( - "UPDATE engineer SET engineer_name=%(engineer_name)s " - "FROM person WHERE engineer.id = person.id " - "AND person.name = %(name_1)s RETURNING engineer.id", - [{"engineer_name": "e5", "name_1": "e2"}], - dialect="postgresql", - ), - ) - else: - asserter.assert_( - CompiledSQL( - "UPDATE engineer SET engineer_name=%(engineer_name)s " - "FROM person WHERE engineer.id = person.id " - "AND person.name = %(name_1)s", - [{"engineer_name": "e5", "name_1": "e2"}], - dialect="postgresql", - ), - ) - - eq_( - set(s.query(Person.name, Engineer.engineer_name)), - {("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")}, - ) - - @testing.requires.delete_using - @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate") - def test_delete_using(self, synchronize_session): - """test a DELETE that uses multiple tables. - - due to a limitation in MariaDB, we have an up front "hint" that needs - to be passed for this backend if DELETE USING is to be used in - conjunction with "fetch" strategy, so that we know before compilation - that we won't be able to use RETURNING. - - """ - - Engineer = self.classes.Engineer - Person = self.classes.Person - s = Session(testing.db) - - db_has_mariadb_limitation = ( - testing.db.dialect.delete_returning - and not testing.db.dialect.delete_returning_multifrom - ) - - e2 = s.query(Engineer).filter_by(name="e2").first() - - with self.sql_execution_asserter() as asserter: - assert e2 in s - - q = ( - s.query(Engineer) - .filter(Engineer.id == Person.id) - .filter(Person.name == "e2") - ) - - if synchronize_session == "fetch_w_hint": - q.execution_options(is_delete_using=True).delete( - synchronize_session="fetch" - ) - elif synchronize_session == "fetch" and db_has_mariadb_limitation: - with expect_raises_message( - exc.CompileError, - 'Dialect ".*" does not support RETURNING with ' - "DELETE..USING;", - ): - q.delete(synchronize_session=synchronize_session) - return - else: - q.delete(synchronize_session=synchronize_session) - - if synchronize_session is None: - assert e2 in s - else: - assert e2 not in s - - if synchronize_session in ("fetch", "fetch_w_hint") and ( - db_has_mariadb_limitation - or not testing.db.dialect.delete_returning - ): - asserter.assert_( - CompiledSQL( - "SELECT person.id FROM person INNER JOIN engineer ON " - "person.id = engineer.id WHERE engineer.id = person.id " - "AND person.name = %s", - [{"name_1": "e2"}], - dialect="mariadb", - ), - CompiledSQL( - "DELETE FROM engineer USING engineer, person WHERE " - "engineer.id = person.id AND person.name = %s", - [{"name_1": "e2"}], - dialect="mariadb", - ), - ) - elif synchronize_session in ("fetch", "fetch_w_hint"): - asserter.assert_( - CompiledSQL( - "DELETE FROM engineer USING person WHERE " - "engineer.id = person.id AND person.name = %(name_1)s " - "RETURNING engineer.id", - [{"name_1": "e2"}], - dialect="postgresql", - ), - ) - else: - asserter.assert_( - CompiledSQL( - "DELETE FROM engineer USING person WHERE " - "engineer.id = person.id AND person.name = %(name_1)s", - [{"name_1": "e2"}], - dialect="postgresql", - ), - ) - - # delete actually worked - eq_( - set(s.query(Person.name, Engineer.engineer_name)), - {("pp1", "pp1"), ("e1", "e1")}, - ) - - @testing.only_on(["mysql", "mariadb"], "Multi table update") - @testing.requires.delete_using - @testing.combinations(None, "fetch", "evaluate") - def test_update_from_multitable(self, synchronize_session): - Engineer = self.classes.Engineer - Person = self.classes.Person - s = Session(testing.db) - s.query(Engineer).filter(Engineer.id == Person.id).filter( - Person.name == "e2" - ).update( - {Person.name: "e22", Engineer.engineer_name: "e55"}, - synchronize_session=synchronize_session, - ) - - eq_( - set(s.query(Person.name, Engineer.engineer_name)), - {("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")}, - ) - - -class InheritWPolyTest(fixtures.TestBase, AssertsCompiledSQL): - __dialect__ = "default" - - @testing.fixture - def inherit_fixture(self, decl_base): - def go(poly_type): - - class Person(decl_base): - __tablename__ = "person" - id = Column(Integer, primary_key=True) - type = Column(String(50)) - name = Column(String(50)) - - if poly_type.wpoly: - __mapper_args__ = {"with_polymorphic": "*"} - - class Engineer(Person): - __tablename__ = "engineer" - id = Column(Integer, ForeignKey("person.id"), primary_key=True) - engineer_name = Column(String(50)) - - if poly_type.inline: - __mapper_args__ = {"polymorphic_load": "inline"} - - return Person, Engineer - - return go - - @testing.variation("poly_type", ["wpoly", "inline", "none"]) - def test_update_base_only(self, poly_type, inherit_fixture): - Person, Engineer = inherit_fixture(poly_type) - - self.assert_compile( - update(Person).values(name="n1"), "UPDATE person SET name=:name" - ) - - @testing.variation("poly_type", ["wpoly", "inline", "none"]) - def test_delete_base_only(self, poly_type, inherit_fixture): - Person, Engineer = inherit_fixture(poly_type) - - self.assert_compile(delete(Person), "DELETE FROM person") - - self.assert_compile( - delete(Person).where(Person.id == 7), - "DELETE FROM person WHERE person.id = :id_1", - ) - - -class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): - __backend__ = True - - @classmethod - def setup_classes(cls): - Base = cls.DeclarativeBasic - - class Staff(Base): - __tablename__ = "staff" - position = Column(String(10), nullable=False) - id = Column( - Integer, primary_key=True, test_needs_autoincrement=True - ) - name = Column(String(5)) - stats = Column(String(5)) - __mapper_args__ = {"polymorphic_on": position} - - class Sales(Staff): - sales_stats = Column(String(5)) - __mapper_args__ = {"polymorphic_identity": "sales"} - - class Support(Staff): - support_stats = Column(String(5)) - __mapper_args__ = {"polymorphic_identity": "support"} - - @classmethod - def insert_data(cls, connection): - with sessionmaker(connection).begin() as session: - Sales, Support = ( - cls.classes.Sales, - cls.classes.Support, - ) - session.add_all( - [ - Sales(name="n1", sales_stats="1", stats="a"), - Sales(name="n2", sales_stats="2", stats="b"), - Support(name="n1", support_stats="3", stats="c"), - Support(name="n2", support_stats="4", stats="d"), - ] - ) - - @testing.combinations( - ("fetch", False), - ("fetch", True), - ("evaluate", False), - ("evaluate", True), - ) - def test_update(self, fetchstyle, newstyle): - Staff, Sales, Support = self.classes("Staff", "Sales", "Support") - - sess = fixture_session() - - en1, en2 = ( - sess.execute(select(Sales).order_by(Sales.sales_stats)) - .scalars() - .all() - ) - mn1, mn2 = ( - sess.execute(select(Support).order_by(Support.support_stats)) - .scalars() - .all() - ) - - if newstyle: - sess.execute( - update(Sales) - .filter_by(name="n1") - .values(stats="p") - .execution_options(synchronize_session=fetchstyle) - ) - else: - sess.query(Sales).filter_by(name="n1").update( - {"stats": "p"}, synchronize_session=fetchstyle - ) - - eq_(en1.stats, "p") - eq_(mn1.stats, "c") - eq_( - sess.execute( - select(Staff.position, Staff.name, Staff.stats).order_by( - Staff.id - ) - ).all(), - [ - ("sales", "n1", "p"), - ("sales", "n2", "b"), - ("support", "n1", "c"), - ("support", "n2", "d"), - ], - ) - - @testing.combinations( - ("fetch", False), - ("fetch", True), - ("evaluate", False), - ("evaluate", True), - ) - def test_delete(self, fetchstyle, newstyle): - Staff, Sales, Support = self.classes("Staff", "Sales", "Support") - - sess = fixture_session() - en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all() - mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all() - - if newstyle: - sess.execute( - delete(Sales) - .filter_by(name="n1") - .execution_options(synchronize_session=fetchstyle) - ) - else: - sess.query(Sales).filter_by(name="n1").delete( - synchronize_session=fetchstyle - ) - assert en1 not in sess - assert en2 in sess - assert mn1 in sess - assert mn2 in sess - - eq_( - sess.execute( - select(Staff.position, Staff.name, Staff.stats).order_by( - Staff.id - ) - ).all(), - [ - ("sales", "n2", "b"), - ("support", "n1", "c"), - ("support", "n2", "d"), - ], - ) - - -class LoadFromReturningTest(fixtures.MappedTest): - __backend__ = True - __requires__ = ("insert_returning",) - - @classmethod - def define_tables(cls, metadata): - Table( - "users", - metadata, - Column( - "id", Integer, primary_key=True, test_needs_autoincrement=True - ), - Column("name", String(32)), - Column("age_int", Integer), - ) - - @classmethod - def setup_classes(cls): - class User(cls.Comparable): - pass - - class Address(cls.Comparable): - pass - - @classmethod - def insert_data(cls, connection): - users = cls.tables.users - - connection.execute( - users.insert(), - [ - dict(id=1, name="john", age_int=25), - dict(id=2, name="jack", age_int=47), - dict(id=3, name="jill", age_int=29), - dict(id=4, name="jane", age_int=37), - ], - ) - - @classmethod - def setup_mappers(cls): - User = cls.classes.User - users = cls.tables.users - - cls.mapper_registry.map_imperatively( - User, - users, - properties={ - "age": users.c.age_int, - }, - ) - - @testing.requires.update_returning - @testing.combinations(True, False, argnames="use_from_statement") - def test_load_from_update(self, connection, use_from_statement): - User = self.classes.User - - stmt = ( - update(User) - .where(User.name.in_(["jack", "jill"])) - .values(age=User.age + 5) - .returning(User) - ) - - if use_from_statement: - # this is now a legacy-ish case, because as of 2.0 you can just - # use returning() directly to get the objects back. - # - # when from_statement is used, the UPDATE statement is no - # longer interpreted by - # BulkUDCompileState.orm_pre_session_exec or - # BulkUDCompileState.orm_setup_cursor_result. The compilation - # level routines still take place though - stmt = select(User).from_statement(stmt) - - with Session(connection) as sess: - rows = sess.execute(stmt).scalars().all() - - eq_( - rows, - [User(name="jack", age=52), User(name="jill", age=34)], - ) - - @testing.combinations( - ("single",), - ("multiple", testing.requires.multivalues_inserts), - argnames="params", - ) - @testing.combinations(True, False, argnames="use_from_statement") - def test_load_from_insert(self, connection, params, use_from_statement): - User = self.classes.User - - if params == "multiple": - values = [ - {User.id: 5, User.age: 25, User.name: "spongebob"}, - {User.id: 6, User.age: 30, User.name: "patrick"}, - {User.id: 7, User.age: 35, User.name: "squidward"}, - ] - elif params == "single": - values = {User.id: 5, User.age: 25, User.name: "spongebob"} - else: - assert False - - stmt = insert(User).values(values).returning(User) - - if use_from_statement: - stmt = select(User).from_statement(stmt) - - with Session(connection) as sess: - rows = sess.execute(stmt).scalars().all() - - if params == "multiple": - eq_( - rows, - [ - User(name="spongebob", age=25), - User(name="patrick", age=30), - User(name="squidward", age=35), - ], - ) - elif params == "single": - eq_( - rows, - [User(name="spongebob", age=25)], - ) - else: - assert False - - @testing.requires.delete_returning - @testing.combinations(True, False, argnames="use_from_statement") - def test_load_from_delete(self, connection, use_from_statement): - User = self.classes.User - - stmt = ( - delete(User).where(User.name.in_(["jack", "jill"])).returning(User) - ) - - if use_from_statement: - stmt = select(User).from_statement(stmt) - - with Session(connection) as sess: - rows = sess.execute(stmt).scalars().all() - - eq_( - rows, - [User(name="jack", age=47), User(name="jill", age=29)], - ) - - # TODO: state of above objects should be "deleted" - - -class OnUpdatePopulationTest(fixtures.TestBase): - __backend__ = True - - @testing.variation("populate_existing", [True, False]) - @testing.variation( - "use_onupdate", - [ - "none", - "server", - "callable", - "clientsql", - ("computed", testing.requires.computed_columns), - ], - ) - @testing.variation( - "use_returning", - [ - ("returning", testing.requires.update_returning), - ("defaults", testing.requires.update_returning), - "none", - ], - ) - @testing.variation("synchronize", ["auto", "fetch", "evaluate"]) - @testing.variation("pk_order", ["first", "middle"]) - def test_update_populate_existing( - self, - decl_base, - populate_existing, - use_onupdate, - use_returning, - synchronize, - pk_order, - ): - """test #11912 and #11917""" - - class Employee(ComparableEntity, decl_base): - __tablename__ = "employee" - - if pk_order.first: - uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) - user_name: Mapped[str] = mapped_column(String(200), nullable=False) - - if pk_order.middle: - uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True) - - if use_onupdate.server: - some_server_value: Mapped[str] = mapped_column( - server_onupdate=FetchedValue() - ) - elif use_onupdate.callable: - some_server_value: Mapped[str] = mapped_column( - onupdate=lambda: "value 2" - ) - elif use_onupdate.clientsql: - some_server_value: Mapped[str] = mapped_column( - onupdate=literal("value 2") - ) - elif use_onupdate.computed: - some_server_value: Mapped[str] = mapped_column( - String(255), - Computed(user_name + " computed value"), - nullable=True, - ) - else: - some_server_value: Mapped[str] - - decl_base.metadata.create_all(testing.db) - s = fixture_session() - - uuid1 = uuid.uuid4() - - if use_onupdate.computed: - server_old_value, server_new_value = ( - "e1 old name computed value", - "e1 new name computed value", - ) - e1 = Employee(uuid=uuid1, user_name="e1 old name") - else: - server_old_value, server_new_value = ("value 1", "value 2") - e1 = Employee( - uuid=uuid1, - user_name="e1 old name", - some_server_value="value 1", - ) - s.add(e1) - s.flush() - - stmt = ( - update(Employee) - .values(user_name="e1 new name") - .where(Employee.uuid == uuid1) - ) - - if use_returning.returning: - stmt = stmt.returning(Employee) - elif use_returning.defaults: - # NOTE: the return_defaults case here has not been analyzed for - # #11912 or #11917. future enhancements may change its behavior - stmt = stmt.return_defaults() - - # perform out of band UPDATE on server value to simulate - # a computed col - if use_onupdate.none or use_onupdate.server: - s.connection().execute( - update(Employee.__table__).values(some_server_value="value 2") - ) - - execution_options = {} - - if populate_existing: - execution_options["populate_existing"] = True - - if synchronize.evaluate: - execution_options["synchronize_session"] = "evaluate" - if synchronize.fetch: - execution_options["synchronize_session"] = "fetch" - - if use_returning.returning: - rows = s.scalars(stmt, execution_options=execution_options) - else: - s.execute(stmt, execution_options=execution_options) - - if ( - use_onupdate.clientsql - or use_onupdate.server - or use_onupdate.computed - ): - if not use_returning.defaults: - # if server-side onupdate was generated, the col should have - # been expired - assert "some_server_value" not in e1.__dict__ - - # and refreshes when called. this is even if we have RETURNING - # rows we didn't fetch yet. - eq_(e1.some_server_value, server_new_value) - else: - # using return defaults here is not expiring. have not - # researched why, it may be because the explicit - # return_defaults interferes with the ORMs call - assert "some_server_value" in e1.__dict__ - eq_(e1.some_server_value, server_old_value) - - elif use_onupdate.callable: - if not use_returning.defaults or not synchronize.fetch: - # for python-side onupdate, col is populated with local value - assert "some_server_value" in e1.__dict__ - - # and is refreshed - eq_(e1.some_server_value, server_new_value) - else: - assert "some_server_value" in e1.__dict__ - - # and is not refreshed - eq_(e1.some_server_value, server_old_value) - - else: - # no onupdate, then the value was not touched yet, - # even if we used RETURNING with populate_existing, because - # we did not fetch the rows yet - assert "some_server_value" in e1.__dict__ - eq_(e1.some_server_value, server_old_value) - - # now see if we can fetch rows - if use_returning.returning: - - if populate_existing or not use_onupdate.none: - eq_( - set(rows), - { - Employee( - uuid=uuid1, - user_name="e1 new name", - some_server_value=server_new_value, - ), - }, - ) - - else: - # if no populate existing and no server default, that column - # is not touched at all - eq_( - set(rows), - { - Employee( - uuid=uuid1, - user_name="e1 new name", - some_server_value=server_old_value, - ), - }, - ) - - if use_returning.defaults: - # as mentioned above, the return_defaults() case here remains - # unanalyzed. - if synchronize.fetch or ( - use_onupdate.clientsql - or use_onupdate.server - or use_onupdate.computed - or use_onupdate.none - ): - eq_(e1.some_server_value, server_old_value) - else: - eq_(e1.some_server_value, server_new_value) - - elif ( - populate_existing and use_returning.returning - ) or not use_onupdate.none: - eq_(e1.some_server_value, server_new_value) - else: - # no onupdate specified, and no populate existing with returning, - # the attribute is not refreshed - eq_(e1.some_server_value, server_old_value) - - # do a full expire, now the new value is definitely there - s.commit() - s.expire_all() - eq_(e1.some_server_value, server_new_value) - - -class PGIssue11849Test(fixtures.DeclarativeMappedTest): - __backend__ = True - __only_on__ = ("postgresql",) - - @classmethod - def setup_classes(cls): - - from sqlalchemy.dialects.postgresql import JSONB - - Base = cls.DeclarativeBasic - - class TestTbl(Base): - __tablename__ = "testtbl" - - test_id = Column(Integer, primary_key=True) - test_field = Column(JSONB) - - def test_issue_11849(self): - TestTbl = self.classes.TestTbl - - session = fixture_session() - - obj = TestTbl( - test_id=1, test_field={"test1": 1, "test2": "2", "test3": [3, "3"]} - ) - session.add(obj) - - query = ( - update(TestTbl) - .where(TestTbl.test_id == 1) - .values(test_field=TestTbl.test_field + {"test3": {"test4": 4}}) - ) - session.execute(query) - - # not loaded - assert "test_field" not in obj.__dict__ - - # synchronizes on load - eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}}) diff --git a/test/orm/dml/test_orm_upd_del_inheritance.py b/test/orm/dml/test_orm_upd_del_inheritance.py new file mode 100644 index 0000000000..8d2c1e4bb0 --- /dev/null +++ b/test/orm/dml/test_orm_upd_del_inheritance.py @@ -0,0 +1,522 @@ +from __future__ import annotations + +from sqlalchemy import delete +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertions import expect_raises_message +from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.testing.schema import Column + + +class InheritTest(fixtures.DeclarativeMappedTest): + run_inserts = "each" + + run_deletes = "each" + __backend__ = True + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Person(Base): + __tablename__ = "person" + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + type = Column(String(50)) + name = Column(String(50)) + + class Engineer(Person): + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) + engineer_name = Column(String(50)) + + class Programmer(Engineer): + __tablename__ = "programmer" + id = Column(Integer, ForeignKey("engineer.id"), primary_key=True) + primary_language = Column(String(50)) + + class Manager(Person): + __tablename__ = "manager" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) + manager_name = Column(String(50)) + + @classmethod + def insert_data(cls, connection): + Engineer, Person, Manager, Programmer = ( + cls.classes.Engineer, + cls.classes.Person, + cls.classes.Manager, + cls.classes.Programmer, + ) + s = Session(connection) + s.add_all( + [ + Engineer(name="e1", engineer_name="e1"), + Manager(name="m1", manager_name="m1"), + Engineer(name="e2", engineer_name="e2"), + Person(name="p1"), + Programmer( + name="pp1", engineer_name="pp1", primary_language="python" + ), + ] + ) + s.commit() + + @testing.only_on(["mysql", "mariadb"], "Multi table update") + def test_update_from_join_no_problem(self): + person = self.classes.Person.__table__ + engineer = self.classes.Engineer.__table__ + + sess = fixture_session() + sess.query(person.join(engineer)).filter(person.c.name == "e2").update( + {person.c.name: "updated", engineer.c.engineer_name: "e2a"}, + ) + obj = sess.execute( + select(self.classes.Engineer).filter( + self.classes.Engineer.name == "updated" + ) + ).scalar() + eq_(obj.name, "updated") + eq_(obj.engineer_name, "e2a") + + @testing.combinations(None, "fetch", "evaluate") + def test_update_sub_table_only(self, synchronize_session): + Engineer = self.classes.Engineer + s = Session(testing.db) + s.query(Engineer).update( + {"engineer_name": "e5"}, synchronize_session=synchronize_session + ) + + eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",), ("e5",)]) + + @testing.combinations(None, "fetch", "evaluate") + def test_update_sub_sub_table_only(self, synchronize_session): + Programmer = self.classes.Programmer + s = Session(testing.db) + s.query(Programmer).update( + {"primary_language": "c++"}, + synchronize_session=synchronize_session, + ) + + eq_( + s.query(Programmer.primary_language).all(), + [ + ("c++",), + ], + ) + + @testing.requires.update_from + @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate") + def test_update_from(self, synchronize_session): + """test an UPDATE that uses multiple tables. + + The limitation that MariaDB has with DELETE does not apply here at the + moment as MariaDB doesn't support UPDATE..RETURNING at all. However, + the logic from DELETE is still implemented in persistence.py. If + MariaDB adds UPDATE...RETURNING, then it may be useful. SQLite, + PostgreSQL, MSSQL all support UPDATE..FROM however RETURNING seems to + function correctly for all three. + + """ + Engineer = self.classes.Engineer + Person = self.classes.Person + s = Session(testing.db) + + # we don't have any backends with this combination right now. + db_has_hypothetical_limitation = ( + testing.db.dialect.update_returning + and not testing.db.dialect.update_returning_multifrom + ) + + e2 = s.query(Engineer).filter_by(name="e2").first() + + with self.sql_execution_asserter() as asserter: + eq_(e2.engineer_name, "e2") + q = ( + s.query(Engineer) + .filter(Engineer.id == Person.id) + .filter(Person.name == "e2") + ) + if synchronize_session == "fetch_w_hint": + q.execution_options(is_update_from=True).update( + {"engineer_name": "e5"}, + synchronize_session="fetch", + ) + elif ( + synchronize_session == "fetch" + and db_has_hypothetical_limitation + ): + with expect_raises_message( + exc.CompileError, + 'Dialect ".*" does not support RETURNING with ' + "UPDATE..FROM;", + ): + q.update( + {"engineer_name": "e5"}, + synchronize_session=synchronize_session, + ) + return + else: + q.update( + {"engineer_name": "e5"}, + synchronize_session=synchronize_session, + ) + + if synchronize_session is None: + eq_(e2.engineer_name, "e2") + else: + eq_(e2.engineer_name, "e5") + + if synchronize_session in ("fetch", "fetch_w_hint") and ( + db_has_hypothetical_limitation + or not testing.db.dialect.update_returning + ): + asserter.assert_( + CompiledSQL( + "SELECT person.id FROM person INNER JOIN engineer " + "ON person.id = engineer.id WHERE engineer.id = person.id " + "AND person.name = %s", + [{"name_1": "e2"}], + dialect="mariadb", + ), + CompiledSQL( + "UPDATE engineer, person SET engineer.engineer_name=%s " + "WHERE engineer.id = person.id AND person.name = %s", + [{"engineer_name": "e5", "name_1": "e2"}], + dialect="mariadb", + ), + ) + elif synchronize_session in ("fetch", "fetch_w_hint"): + asserter.assert_( + CompiledSQL( + "UPDATE engineer SET engineer_name=%(engineer_name)s " + "FROM person WHERE engineer.id = person.id " + "AND person.name = %(name_1)s RETURNING engineer.id", + [{"engineer_name": "e5", "name_1": "e2"}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "UPDATE engineer SET engineer_name=%(engineer_name)s " + "FROM person WHERE engineer.id = person.id " + "AND person.name = %(name_1)s", + [{"engineer_name": "e5", "name_1": "e2"}], + dialect="postgresql", + ), + ) + + eq_( + set(s.query(Person.name, Engineer.engineer_name)), + {("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")}, + ) + + @testing.requires.delete_using + @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate") + def test_delete_using(self, synchronize_session): + """test a DELETE that uses multiple tables. + + due to a limitation in MariaDB, we have an up front "hint" that needs + to be passed for this backend if DELETE USING is to be used in + conjunction with "fetch" strategy, so that we know before compilation + that we won't be able to use RETURNING. + + """ + + Engineer = self.classes.Engineer + Person = self.classes.Person + s = Session(testing.db) + + db_has_mariadb_limitation = ( + testing.db.dialect.delete_returning + and not testing.db.dialect.delete_returning_multifrom + ) + + e2 = s.query(Engineer).filter_by(name="e2").first() + + with self.sql_execution_asserter() as asserter: + assert e2 in s + + q = ( + s.query(Engineer) + .filter(Engineer.id == Person.id) + .filter(Person.name == "e2") + ) + + if synchronize_session == "fetch_w_hint": + q.execution_options(is_delete_using=True).delete( + synchronize_session="fetch" + ) + elif synchronize_session == "fetch" and db_has_mariadb_limitation: + with expect_raises_message( + exc.CompileError, + 'Dialect ".*" does not support RETURNING with ' + "DELETE..USING;", + ): + q.delete(synchronize_session=synchronize_session) + return + else: + q.delete(synchronize_session=synchronize_session) + + if synchronize_session is None: + assert e2 in s + else: + assert e2 not in s + + if synchronize_session in ("fetch", "fetch_w_hint") and ( + db_has_mariadb_limitation + or not testing.db.dialect.delete_returning + ): + asserter.assert_( + CompiledSQL( + "SELECT person.id FROM person INNER JOIN engineer ON " + "person.id = engineer.id WHERE engineer.id = person.id " + "AND person.name = %s", + [{"name_1": "e2"}], + dialect="mariadb", + ), + CompiledSQL( + "DELETE FROM engineer USING engineer, person WHERE " + "engineer.id = person.id AND person.name = %s", + [{"name_1": "e2"}], + dialect="mariadb", + ), + ) + elif synchronize_session in ("fetch", "fetch_w_hint"): + asserter.assert_( + CompiledSQL( + "DELETE FROM engineer USING person WHERE " + "engineer.id = person.id AND person.name = %(name_1)s " + "RETURNING engineer.id", + [{"name_1": "e2"}], + dialect="postgresql", + ), + ) + else: + asserter.assert_( + CompiledSQL( + "DELETE FROM engineer USING person WHERE " + "engineer.id = person.id AND person.name = %(name_1)s", + [{"name_1": "e2"}], + dialect="postgresql", + ), + ) + + # delete actually worked + eq_( + set(s.query(Person.name, Engineer.engineer_name)), + {("pp1", "pp1"), ("e1", "e1")}, + ) + + @testing.only_on(["mysql", "mariadb"], "Multi table update") + @testing.requires.delete_using + @testing.combinations(None, "fetch", "evaluate") + def test_update_from_multitable(self, synchronize_session): + Engineer = self.classes.Engineer + Person = self.classes.Person + s = Session(testing.db) + s.query(Engineer).filter(Engineer.id == Person.id).filter( + Person.name == "e2" + ).update( + {Person.name: "e22", Engineer.engineer_name: "e55"}, + synchronize_session=synchronize_session, + ) + + eq_( + set(s.query(Person.name, Engineer.engineer_name)), + {("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")}, + ) + + +class InheritWPolyTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + @testing.fixture + def inherit_fixture(self, decl_base): + def go(poly_type): + + class Person(decl_base): + __tablename__ = "person" + id = Column(Integer, primary_key=True) + type = Column(String(50)) + name = Column(String(50)) + + if poly_type.wpoly: + __mapper_args__ = {"with_polymorphic": "*"} + + class Engineer(Person): + __tablename__ = "engineer" + id = Column(Integer, ForeignKey("person.id"), primary_key=True) + engineer_name = Column(String(50)) + + if poly_type.inline: + __mapper_args__ = {"polymorphic_load": "inline"} + + return Person, Engineer + + return go + + @testing.variation("poly_type", ["wpoly", "inline", "none"]) + def test_update_base_only(self, poly_type, inherit_fixture): + Person, Engineer = inherit_fixture(poly_type) + + self.assert_compile( + update(Person).values(name="n1"), "UPDATE person SET name=:name" + ) + + @testing.variation("poly_type", ["wpoly", "inline", "none"]) + def test_delete_base_only(self, poly_type, inherit_fixture): + Person, Engineer = inherit_fixture(poly_type) + + self.assert_compile(delete(Person), "DELETE FROM person") + + self.assert_compile( + delete(Person).where(Person.id == 7), + "DELETE FROM person WHERE person.id = :id_1", + ) + + +class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest): + __backend__ = True + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Staff(Base): + __tablename__ = "staff" + position = Column(String(10), nullable=False) + id = Column( + Integer, primary_key=True, test_needs_autoincrement=True + ) + name = Column(String(5)) + stats = Column(String(5)) + __mapper_args__ = {"polymorphic_on": position} + + class Sales(Staff): + sales_stats = Column(String(5)) + __mapper_args__ = {"polymorphic_identity": "sales"} + + class Support(Staff): + support_stats = Column(String(5)) + __mapper_args__ = {"polymorphic_identity": "support"} + + @classmethod + def insert_data(cls, connection): + with sessionmaker(connection).begin() as session: + Sales, Support = ( + cls.classes.Sales, + cls.classes.Support, + ) + session.add_all( + [ + Sales(name="n1", sales_stats="1", stats="a"), + Sales(name="n2", sales_stats="2", stats="b"), + Support(name="n1", support_stats="3", stats="c"), + Support(name="n2", support_stats="4", stats="d"), + ] + ) + + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_update(self, fetchstyle, newstyle): + Staff, Sales, Support = self.classes("Staff", "Sales", "Support") + + sess = fixture_session() + + en1, en2 = ( + sess.execute(select(Sales).order_by(Sales.sales_stats)) + .scalars() + .all() + ) + mn1, mn2 = ( + sess.execute(select(Support).order_by(Support.support_stats)) + .scalars() + .all() + ) + + if newstyle: + sess.execute( + update(Sales) + .filter_by(name="n1") + .values(stats="p") + .execution_options(synchronize_session=fetchstyle) + ) + else: + sess.query(Sales).filter_by(name="n1").update( + {"stats": "p"}, synchronize_session=fetchstyle + ) + + eq_(en1.stats, "p") + eq_(mn1.stats, "c") + eq_( + sess.execute( + select(Staff.position, Staff.name, Staff.stats).order_by( + Staff.id + ) + ).all(), + [ + ("sales", "n1", "p"), + ("sales", "n2", "b"), + ("support", "n1", "c"), + ("support", "n2", "d"), + ], + ) + + @testing.combinations( + ("fetch", False), + ("fetch", True), + ("evaluate", False), + ("evaluate", True), + ) + def test_delete(self, fetchstyle, newstyle): + Staff, Sales, Support = self.classes("Staff", "Sales", "Support") + + sess = fixture_session() + en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all() + mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all() + + if newstyle: + sess.execute( + delete(Sales) + .filter_by(name="n1") + .execution_options(synchronize_session=fetchstyle) + ) + else: + sess.query(Sales).filter_by(name="n1").delete( + synchronize_session=fetchstyle + ) + assert en1 not in sess + assert en2 in sess + assert mn1 in sess + assert mn2 in sess + + eq_( + sess.execute( + select(Staff.position, Staff.name, Staff.stats).order_by( + Staff.id + ) + ).all(), + [ + ("sales", "n2", "b"), + ("support", "n1", "c"), + ("support", "n2", "d"), + ], + ) -- 2.47.3