--- /dev/null
+from __future__ import annotations
+
+import uuid
+
+from sqlalchemy import Computed
+from sqlalchemy import delete
+from sqlalchemy import FetchedValue
+from sqlalchemy import insert
+from sqlalchemy import Integer
+from sqlalchemy import literal
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.entities import ComparableEntity
+from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+
+class LoadFromReturningTest(fixtures.MappedTest):
+ __backend__ = True
+ __requires__ = ("insert_returning",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "users",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("name", String(32)),
+ Column("age_int", Integer),
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class User(cls.Comparable):
+ pass
+
+ class Address(cls.Comparable):
+ pass
+
+ @classmethod
+ def insert_data(cls, connection):
+ users = cls.tables.users
+
+ connection.execute(
+ users.insert(),
+ [
+ dict(id=1, name="john", age_int=25),
+ dict(id=2, name="jack", age_int=47),
+ dict(id=3, name="jill", age_int=29),
+ dict(id=4, name="jane", age_int=37),
+ ],
+ )
+
+ @classmethod
+ def setup_mappers(cls):
+ User = cls.classes.User
+ users = cls.tables.users
+
+ cls.mapper_registry.map_imperatively(
+ User,
+ users,
+ properties={
+ "age": users.c.age_int,
+ },
+ )
+
+ @testing.requires.update_returning
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_update(self, connection, use_from_statement):
+ User = self.classes.User
+
+ stmt = (
+ update(User)
+ .where(User.name.in_(["jack", "jill"]))
+ .values(age=User.age + 5)
+ .returning(User)
+ )
+
+ if use_from_statement:
+ # this is now a legacy-ish case, because as of 2.0 you can just
+ # use returning() directly to get the objects back.
+ #
+ # when from_statement is used, the UPDATE statement is no
+ # longer interpreted by
+ # BulkUDCompileState.orm_pre_session_exec or
+ # BulkUDCompileState.orm_setup_cursor_result. The compilation
+ # level routines still take place though
+ stmt = select(User).from_statement(stmt)
+
+ with Session(connection) as sess:
+ rows = sess.execute(stmt).scalars().all()
+
+ eq_(
+ rows,
+ [User(name="jack", age=52), User(name="jill", age=34)],
+ )
+
+ @testing.combinations(
+ ("single",),
+ ("multiple", testing.requires.multivalues_inserts),
+ argnames="params",
+ )
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_insert(self, connection, params, use_from_statement):
+ User = self.classes.User
+
+ if params == "multiple":
+ values = [
+ {User.id: 5, User.age: 25, User.name: "spongebob"},
+ {User.id: 6, User.age: 30, User.name: "patrick"},
+ {User.id: 7, User.age: 35, User.name: "squidward"},
+ ]
+ elif params == "single":
+ values = {User.id: 5, User.age: 25, User.name: "spongebob"}
+ else:
+ assert False
+
+ stmt = insert(User).values(values).returning(User)
+
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
+
+ with Session(connection) as sess:
+ rows = sess.execute(stmt).scalars().all()
+
+ if params == "multiple":
+ eq_(
+ rows,
+ [
+ User(name="spongebob", age=25),
+ User(name="patrick", age=30),
+ User(name="squidward", age=35),
+ ],
+ )
+ elif params == "single":
+ eq_(
+ rows,
+ [User(name="spongebob", age=25)],
+ )
+ else:
+ assert False
+
+ @testing.requires.delete_returning
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_delete(self, connection, use_from_statement):
+ User = self.classes.User
+
+ stmt = (
+ delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
+ )
+
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
+
+ with Session(connection) as sess:
+ rows = sess.execute(stmt).scalars().all()
+
+ eq_(
+ rows,
+ [User(name="jack", age=47), User(name="jill", age=29)],
+ )
+
+ # TODO: state of above objects should be "deleted"
+
+
+class OnUpdatePopulationTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.variation("populate_existing", [True, False])
+ @testing.variation(
+ "use_onupdate",
+ [
+ "none",
+ "server",
+ "callable",
+ "clientsql",
+ ("computed", testing.requires.computed_columns),
+ ],
+ )
+ @testing.variation(
+ "use_returning",
+ [
+ ("returning", testing.requires.update_returning),
+ ("defaults", testing.requires.update_returning),
+ "none",
+ ],
+ )
+ @testing.variation("synchronize", ["auto", "fetch", "evaluate"])
+ @testing.variation("pk_order", ["first", "middle"])
+ def test_update_populate_existing(
+ self,
+ decl_base,
+ populate_existing,
+ use_onupdate,
+ use_returning,
+ synchronize,
+ pk_order,
+ ):
+ """test #11912 and #11917"""
+
+ class Employee(ComparableEntity, decl_base):
+ __tablename__ = "employee"
+
+ if pk_order.first:
+ uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+ user_name: Mapped[str] = mapped_column(String(200), nullable=False)
+
+ if pk_order.middle:
+ uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+
+ if use_onupdate.server:
+ some_server_value: Mapped[str] = mapped_column(
+ server_onupdate=FetchedValue()
+ )
+ elif use_onupdate.callable:
+ some_server_value: Mapped[str] = mapped_column(
+ onupdate=lambda: "value 2"
+ )
+ elif use_onupdate.clientsql:
+ some_server_value: Mapped[str] = mapped_column(
+ onupdate=literal("value 2")
+ )
+ elif use_onupdate.computed:
+ some_server_value: Mapped[str] = mapped_column(
+ String(255),
+ Computed(user_name + " computed value"),
+ nullable=True,
+ )
+ else:
+ some_server_value: Mapped[str]
+
+ decl_base.metadata.create_all(testing.db)
+ s = fixture_session()
+
+ uuid1 = uuid.uuid4()
+
+ if use_onupdate.computed:
+ server_old_value, server_new_value = (
+ "e1 old name computed value",
+ "e1 new name computed value",
+ )
+ e1 = Employee(uuid=uuid1, user_name="e1 old name")
+ else:
+ server_old_value, server_new_value = ("value 1", "value 2")
+ e1 = Employee(
+ uuid=uuid1,
+ user_name="e1 old name",
+ some_server_value="value 1",
+ )
+ s.add(e1)
+ s.flush()
+
+ stmt = (
+ update(Employee)
+ .values(user_name="e1 new name")
+ .where(Employee.uuid == uuid1)
+ )
+
+ if use_returning.returning:
+ stmt = stmt.returning(Employee)
+ elif use_returning.defaults:
+ # NOTE: the return_defaults case here has not been analyzed for
+ # #11912 or #11917. future enhancements may change its behavior
+ stmt = stmt.return_defaults()
+
+ # perform out of band UPDATE on server value to simulate
+ # a computed col
+ if use_onupdate.none or use_onupdate.server:
+ s.connection().execute(
+ update(Employee.__table__).values(some_server_value="value 2")
+ )
+
+ execution_options = {}
+
+ if populate_existing:
+ execution_options["populate_existing"] = True
+
+ if synchronize.evaluate:
+ execution_options["synchronize_session"] = "evaluate"
+ if synchronize.fetch:
+ execution_options["synchronize_session"] = "fetch"
+
+ if use_returning.returning:
+ rows = s.scalars(stmt, execution_options=execution_options)
+ else:
+ s.execute(stmt, execution_options=execution_options)
+
+ if (
+ use_onupdate.clientsql
+ or use_onupdate.server
+ or use_onupdate.computed
+ ):
+ if not use_returning.defaults:
+ # if server-side onupdate was generated, the col should have
+ # been expired
+ assert "some_server_value" not in e1.__dict__
+
+ # and refreshes when called. this is even if we have RETURNING
+ # rows we didn't fetch yet.
+ eq_(e1.some_server_value, server_new_value)
+ else:
+ # using return defaults here is not expiring. have not
+ # researched why, it may be because the explicit
+ # return_defaults interferes with the ORMs call
+ assert "some_server_value" in e1.__dict__
+ eq_(e1.some_server_value, server_old_value)
+
+ elif use_onupdate.callable:
+ if not use_returning.defaults or not synchronize.fetch:
+ # for python-side onupdate, col is populated with local value
+ assert "some_server_value" in e1.__dict__
+
+ # and is refreshed
+ eq_(e1.some_server_value, server_new_value)
+ else:
+ assert "some_server_value" in e1.__dict__
+
+ # and is not refreshed
+ eq_(e1.some_server_value, server_old_value)
+
+ else:
+ # no onupdate, then the value was not touched yet,
+ # even if we used RETURNING with populate_existing, because
+ # we did not fetch the rows yet
+ assert "some_server_value" in e1.__dict__
+ eq_(e1.some_server_value, server_old_value)
+
+ # now see if we can fetch rows
+ if use_returning.returning:
+
+ if populate_existing or not use_onupdate.none:
+ eq_(
+ set(rows),
+ {
+ Employee(
+ uuid=uuid1,
+ user_name="e1 new name",
+ some_server_value=server_new_value,
+ ),
+ },
+ )
+
+ else:
+ # if no populate existing and no server default, that column
+ # is not touched at all
+ eq_(
+ set(rows),
+ {
+ Employee(
+ uuid=uuid1,
+ user_name="e1 new name",
+ some_server_value=server_old_value,
+ ),
+ },
+ )
+
+ if use_returning.defaults:
+ # as mentioned above, the return_defaults() case here remains
+ # unanalyzed.
+ if synchronize.fetch or (
+ use_onupdate.clientsql
+ or use_onupdate.server
+ or use_onupdate.computed
+ or use_onupdate.none
+ ):
+ eq_(e1.some_server_value, server_old_value)
+ else:
+ eq_(e1.some_server_value, server_new_value)
+
+ elif (
+ populate_existing and use_returning.returning
+ ) or not use_onupdate.none:
+ eq_(e1.some_server_value, server_new_value)
+ else:
+ # no onupdate specified, and no populate existing with returning,
+ # the attribute is not refreshed
+ eq_(e1.some_server_value, server_old_value)
+
+ # do a full expire, now the new value is definitely there
+ s.commit()
+ s.expire_all()
+ eq_(e1.some_server_value, server_new_value)
+
+
+class PGIssue11849Test(fixtures.DeclarativeMappedTest):
+ __backend__ = True
+ __only_on__ = ("postgresql",)
+
+ @classmethod
+ def setup_classes(cls):
+
+ from sqlalchemy.dialects.postgresql import JSONB
+
+ Base = cls.DeclarativeBasic
+
+ class TestTbl(Base):
+ __tablename__ = "testtbl"
+
+ test_id = Column(Integer, primary_key=True)
+ test_field = Column(JSONB)
+
+ def test_issue_11849(self):
+ TestTbl = self.classes.TestTbl
+
+ session = fixture_session()
+
+ obj = TestTbl(
+ test_id=1, test_field={"test1": 1, "test2": "2", "test3": [3, "3"]}
+ )
+ session.add(obj)
+
+ query = (
+ update(TestTbl)
+ .where(TestTbl.test_id == 1)
+ .values(test_field=TestTbl.test_field + {"test3": {"test4": 4}})
+ )
+ session.execute(query)
+
+ # not loaded
+ assert "test_field" not in obj.__dict__
+
+ # synchronizes on load
+ eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}})
from __future__ import annotations
-import uuid
-
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
-from sqlalchemy import Computed
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import exc
-from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import lambda_stmt
-from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import or_
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import immediateload
from sqlalchemy.orm import joinedload
-from sqlalchemy.orm import Mapped
-from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
-from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import synonym
from sqlalchemy.orm import with_loader_criteria
from sqlalchemy.sql.dml import Delete
from sqlalchemy.sql.selectable import Select
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import not_in
from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
delete_stmt = m1.mock_calls[0][1][0]
eq_(delete_stmt.dialect_kwargs, delete_args)
-
-
-class InheritTest(fixtures.DeclarativeMappedTest):
- run_inserts = "each"
-
- run_deletes = "each"
- __backend__ = True
-
- @classmethod
- def setup_classes(cls):
- Base = cls.DeclarativeBasic
-
- class Person(Base):
- __tablename__ = "person"
- id = Column(
- Integer, primary_key=True, test_needs_autoincrement=True
- )
- type = Column(String(50))
- name = Column(String(50))
-
- class Engineer(Person):
- __tablename__ = "engineer"
- id = Column(Integer, ForeignKey("person.id"), primary_key=True)
- engineer_name = Column(String(50))
-
- class Programmer(Engineer):
- __tablename__ = "programmer"
- id = Column(Integer, ForeignKey("engineer.id"), primary_key=True)
- primary_language = Column(String(50))
-
- class Manager(Person):
- __tablename__ = "manager"
- id = Column(Integer, ForeignKey("person.id"), primary_key=True)
- manager_name = Column(String(50))
-
- @classmethod
- def insert_data(cls, connection):
- Engineer, Person, Manager, Programmer = (
- cls.classes.Engineer,
- cls.classes.Person,
- cls.classes.Manager,
- cls.classes.Programmer,
- )
- s = Session(connection)
- s.add_all(
- [
- Engineer(name="e1", engineer_name="e1"),
- Manager(name="m1", manager_name="m1"),
- Engineer(name="e2", engineer_name="e2"),
- Person(name="p1"),
- Programmer(
- name="pp1", engineer_name="pp1", primary_language="python"
- ),
- ]
- )
- s.commit()
-
- @testing.only_on(["mysql", "mariadb"], "Multi table update")
- def test_update_from_join_no_problem(self):
- person = self.classes.Person.__table__
- engineer = self.classes.Engineer.__table__
-
- sess = fixture_session()
- sess.query(person.join(engineer)).filter(person.c.name == "e2").update(
- {person.c.name: "updated", engineer.c.engineer_name: "e2a"},
- )
- obj = sess.execute(
- select(self.classes.Engineer).filter(
- self.classes.Engineer.name == "updated"
- )
- ).scalar()
- eq_(obj.name, "updated")
- eq_(obj.engineer_name, "e2a")
-
- @testing.combinations(None, "fetch", "evaluate")
- def test_update_sub_table_only(self, synchronize_session):
- Engineer = self.classes.Engineer
- s = Session(testing.db)
- s.query(Engineer).update(
- {"engineer_name": "e5"}, synchronize_session=synchronize_session
- )
-
- eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",), ("e5",)])
-
- @testing.combinations(None, "fetch", "evaluate")
- def test_update_sub_sub_table_only(self, synchronize_session):
- Programmer = self.classes.Programmer
- s = Session(testing.db)
- s.query(Programmer).update(
- {"primary_language": "c++"},
- synchronize_session=synchronize_session,
- )
-
- eq_(
- s.query(Programmer.primary_language).all(),
- [
- ("c++",),
- ],
- )
-
- @testing.requires.update_from
- @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate")
- def test_update_from(self, synchronize_session):
- """test an UPDATE that uses multiple tables.
-
- The limitation that MariaDB has with DELETE does not apply here at the
- moment as MariaDB doesn't support UPDATE..RETURNING at all. However,
- the logic from DELETE is still implemented in persistence.py. If
- MariaDB adds UPDATE...RETURNING, then it may be useful. SQLite,
- PostgreSQL, MSSQL all support UPDATE..FROM however RETURNING seems to
- function correctly for all three.
-
- """
- Engineer = self.classes.Engineer
- Person = self.classes.Person
- s = Session(testing.db)
-
- # we don't have any backends with this combination right now.
- db_has_hypothetical_limitation = (
- testing.db.dialect.update_returning
- and not testing.db.dialect.update_returning_multifrom
- )
-
- e2 = s.query(Engineer).filter_by(name="e2").first()
-
- with self.sql_execution_asserter() as asserter:
- eq_(e2.engineer_name, "e2")
- q = (
- s.query(Engineer)
- .filter(Engineer.id == Person.id)
- .filter(Person.name == "e2")
- )
- if synchronize_session == "fetch_w_hint":
- q.execution_options(is_update_from=True).update(
- {"engineer_name": "e5"},
- synchronize_session="fetch",
- )
- elif (
- synchronize_session == "fetch"
- and db_has_hypothetical_limitation
- ):
- with expect_raises_message(
- exc.CompileError,
- 'Dialect ".*" does not support RETURNING with '
- "UPDATE..FROM;",
- ):
- q.update(
- {"engineer_name": "e5"},
- synchronize_session=synchronize_session,
- )
- return
- else:
- q.update(
- {"engineer_name": "e5"},
- synchronize_session=synchronize_session,
- )
-
- if synchronize_session is None:
- eq_(e2.engineer_name, "e2")
- else:
- eq_(e2.engineer_name, "e5")
-
- if synchronize_session in ("fetch", "fetch_w_hint") and (
- db_has_hypothetical_limitation
- or not testing.db.dialect.update_returning
- ):
- asserter.assert_(
- CompiledSQL(
- "SELECT person.id FROM person INNER JOIN engineer "
- "ON person.id = engineer.id WHERE engineer.id = person.id "
- "AND person.name = %s",
- [{"name_1": "e2"}],
- dialect="mariadb",
- ),
- CompiledSQL(
- "UPDATE engineer, person SET engineer.engineer_name=%s "
- "WHERE engineer.id = person.id AND person.name = %s",
- [{"engineer_name": "e5", "name_1": "e2"}],
- dialect="mariadb",
- ),
- )
- elif synchronize_session in ("fetch", "fetch_w_hint"):
- asserter.assert_(
- CompiledSQL(
- "UPDATE engineer SET engineer_name=%(engineer_name)s "
- "FROM person WHERE engineer.id = person.id "
- "AND person.name = %(name_1)s RETURNING engineer.id",
- [{"engineer_name": "e5", "name_1": "e2"}],
- dialect="postgresql",
- ),
- )
- else:
- asserter.assert_(
- CompiledSQL(
- "UPDATE engineer SET engineer_name=%(engineer_name)s "
- "FROM person WHERE engineer.id = person.id "
- "AND person.name = %(name_1)s",
- [{"engineer_name": "e5", "name_1": "e2"}],
- dialect="postgresql",
- ),
- )
-
- eq_(
- set(s.query(Person.name, Engineer.engineer_name)),
- {("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")},
- )
-
- @testing.requires.delete_using
- @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate")
- def test_delete_using(self, synchronize_session):
- """test a DELETE that uses multiple tables.
-
- due to a limitation in MariaDB, we have an up front "hint" that needs
- to be passed for this backend if DELETE USING is to be used in
- conjunction with "fetch" strategy, so that we know before compilation
- that we won't be able to use RETURNING.
-
- """
-
- Engineer = self.classes.Engineer
- Person = self.classes.Person
- s = Session(testing.db)
-
- db_has_mariadb_limitation = (
- testing.db.dialect.delete_returning
- and not testing.db.dialect.delete_returning_multifrom
- )
-
- e2 = s.query(Engineer).filter_by(name="e2").first()
-
- with self.sql_execution_asserter() as asserter:
- assert e2 in s
-
- q = (
- s.query(Engineer)
- .filter(Engineer.id == Person.id)
- .filter(Person.name == "e2")
- )
-
- if synchronize_session == "fetch_w_hint":
- q.execution_options(is_delete_using=True).delete(
- synchronize_session="fetch"
- )
- elif synchronize_session == "fetch" and db_has_mariadb_limitation:
- with expect_raises_message(
- exc.CompileError,
- 'Dialect ".*" does not support RETURNING with '
- "DELETE..USING;",
- ):
- q.delete(synchronize_session=synchronize_session)
- return
- else:
- q.delete(synchronize_session=synchronize_session)
-
- if synchronize_session is None:
- assert e2 in s
- else:
- assert e2 not in s
-
- if synchronize_session in ("fetch", "fetch_w_hint") and (
- db_has_mariadb_limitation
- or not testing.db.dialect.delete_returning
- ):
- asserter.assert_(
- CompiledSQL(
- "SELECT person.id FROM person INNER JOIN engineer ON "
- "person.id = engineer.id WHERE engineer.id = person.id "
- "AND person.name = %s",
- [{"name_1": "e2"}],
- dialect="mariadb",
- ),
- CompiledSQL(
- "DELETE FROM engineer USING engineer, person WHERE "
- "engineer.id = person.id AND person.name = %s",
- [{"name_1": "e2"}],
- dialect="mariadb",
- ),
- )
- elif synchronize_session in ("fetch", "fetch_w_hint"):
- asserter.assert_(
- CompiledSQL(
- "DELETE FROM engineer USING person WHERE "
- "engineer.id = person.id AND person.name = %(name_1)s "
- "RETURNING engineer.id",
- [{"name_1": "e2"}],
- dialect="postgresql",
- ),
- )
- else:
- asserter.assert_(
- CompiledSQL(
- "DELETE FROM engineer USING person WHERE "
- "engineer.id = person.id AND person.name = %(name_1)s",
- [{"name_1": "e2"}],
- dialect="postgresql",
- ),
- )
-
- # delete actually worked
- eq_(
- set(s.query(Person.name, Engineer.engineer_name)),
- {("pp1", "pp1"), ("e1", "e1")},
- )
-
- @testing.only_on(["mysql", "mariadb"], "Multi table update")
- @testing.requires.delete_using
- @testing.combinations(None, "fetch", "evaluate")
- def test_update_from_multitable(self, synchronize_session):
- Engineer = self.classes.Engineer
- Person = self.classes.Person
- s = Session(testing.db)
- s.query(Engineer).filter(Engineer.id == Person.id).filter(
- Person.name == "e2"
- ).update(
- {Person.name: "e22", Engineer.engineer_name: "e55"},
- synchronize_session=synchronize_session,
- )
-
- eq_(
- set(s.query(Person.name, Engineer.engineer_name)),
- {("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")},
- )
-
-
-class InheritWPolyTest(fixtures.TestBase, AssertsCompiledSQL):
- __dialect__ = "default"
-
- @testing.fixture
- def inherit_fixture(self, decl_base):
- def go(poly_type):
-
- class Person(decl_base):
- __tablename__ = "person"
- id = Column(Integer, primary_key=True)
- type = Column(String(50))
- name = Column(String(50))
-
- if poly_type.wpoly:
- __mapper_args__ = {"with_polymorphic": "*"}
-
- class Engineer(Person):
- __tablename__ = "engineer"
- id = Column(Integer, ForeignKey("person.id"), primary_key=True)
- engineer_name = Column(String(50))
-
- if poly_type.inline:
- __mapper_args__ = {"polymorphic_load": "inline"}
-
- return Person, Engineer
-
- return go
-
- @testing.variation("poly_type", ["wpoly", "inline", "none"])
- def test_update_base_only(self, poly_type, inherit_fixture):
- Person, Engineer = inherit_fixture(poly_type)
-
- self.assert_compile(
- update(Person).values(name="n1"), "UPDATE person SET name=:name"
- )
-
- @testing.variation("poly_type", ["wpoly", "inline", "none"])
- def test_delete_base_only(self, poly_type, inherit_fixture):
- Person, Engineer = inherit_fixture(poly_type)
-
- self.assert_compile(delete(Person), "DELETE FROM person")
-
- self.assert_compile(
- delete(Person).where(Person.id == 7),
- "DELETE FROM person WHERE person.id = :id_1",
- )
-
-
-class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
- __backend__ = True
-
- @classmethod
- def setup_classes(cls):
- Base = cls.DeclarativeBasic
-
- class Staff(Base):
- __tablename__ = "staff"
- position = Column(String(10), nullable=False)
- id = Column(
- Integer, primary_key=True, test_needs_autoincrement=True
- )
- name = Column(String(5))
- stats = Column(String(5))
- __mapper_args__ = {"polymorphic_on": position}
-
- class Sales(Staff):
- sales_stats = Column(String(5))
- __mapper_args__ = {"polymorphic_identity": "sales"}
-
- class Support(Staff):
- support_stats = Column(String(5))
- __mapper_args__ = {"polymorphic_identity": "support"}
-
- @classmethod
- def insert_data(cls, connection):
- with sessionmaker(connection).begin() as session:
- Sales, Support = (
- cls.classes.Sales,
- cls.classes.Support,
- )
- session.add_all(
- [
- Sales(name="n1", sales_stats="1", stats="a"),
- Sales(name="n2", sales_stats="2", stats="b"),
- Support(name="n1", support_stats="3", stats="c"),
- Support(name="n2", support_stats="4", stats="d"),
- ]
- )
-
- @testing.combinations(
- ("fetch", False),
- ("fetch", True),
- ("evaluate", False),
- ("evaluate", True),
- )
- def test_update(self, fetchstyle, newstyle):
- Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
-
- sess = fixture_session()
-
- en1, en2 = (
- sess.execute(select(Sales).order_by(Sales.sales_stats))
- .scalars()
- .all()
- )
- mn1, mn2 = (
- sess.execute(select(Support).order_by(Support.support_stats))
- .scalars()
- .all()
- )
-
- if newstyle:
- sess.execute(
- update(Sales)
- .filter_by(name="n1")
- .values(stats="p")
- .execution_options(synchronize_session=fetchstyle)
- )
- else:
- sess.query(Sales).filter_by(name="n1").update(
- {"stats": "p"}, synchronize_session=fetchstyle
- )
-
- eq_(en1.stats, "p")
- eq_(mn1.stats, "c")
- eq_(
- sess.execute(
- select(Staff.position, Staff.name, Staff.stats).order_by(
- Staff.id
- )
- ).all(),
- [
- ("sales", "n1", "p"),
- ("sales", "n2", "b"),
- ("support", "n1", "c"),
- ("support", "n2", "d"),
- ],
- )
-
- @testing.combinations(
- ("fetch", False),
- ("fetch", True),
- ("evaluate", False),
- ("evaluate", True),
- )
- def test_delete(self, fetchstyle, newstyle):
- Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
-
- sess = fixture_session()
- en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all()
- mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all()
-
- if newstyle:
- sess.execute(
- delete(Sales)
- .filter_by(name="n1")
- .execution_options(synchronize_session=fetchstyle)
- )
- else:
- sess.query(Sales).filter_by(name="n1").delete(
- synchronize_session=fetchstyle
- )
- assert en1 not in sess
- assert en2 in sess
- assert mn1 in sess
- assert mn2 in sess
-
- eq_(
- sess.execute(
- select(Staff.position, Staff.name, Staff.stats).order_by(
- Staff.id
- )
- ).all(),
- [
- ("sales", "n2", "b"),
- ("support", "n1", "c"),
- ("support", "n2", "d"),
- ],
- )
-
-
-class LoadFromReturningTest(fixtures.MappedTest):
- __backend__ = True
- __requires__ = ("insert_returning",)
-
- @classmethod
- def define_tables(cls, metadata):
- Table(
- "users",
- metadata,
- Column(
- "id", Integer, primary_key=True, test_needs_autoincrement=True
- ),
- Column("name", String(32)),
- Column("age_int", Integer),
- )
-
- @classmethod
- def setup_classes(cls):
- class User(cls.Comparable):
- pass
-
- class Address(cls.Comparable):
- pass
-
- @classmethod
- def insert_data(cls, connection):
- users = cls.tables.users
-
- connection.execute(
- users.insert(),
- [
- dict(id=1, name="john", age_int=25),
- dict(id=2, name="jack", age_int=47),
- dict(id=3, name="jill", age_int=29),
- dict(id=4, name="jane", age_int=37),
- ],
- )
-
- @classmethod
- def setup_mappers(cls):
- User = cls.classes.User
- users = cls.tables.users
-
- cls.mapper_registry.map_imperatively(
- User,
- users,
- properties={
- "age": users.c.age_int,
- },
- )
-
- @testing.requires.update_returning
- @testing.combinations(True, False, argnames="use_from_statement")
- def test_load_from_update(self, connection, use_from_statement):
- User = self.classes.User
-
- stmt = (
- update(User)
- .where(User.name.in_(["jack", "jill"]))
- .values(age=User.age + 5)
- .returning(User)
- )
-
- if use_from_statement:
- # this is now a legacy-ish case, because as of 2.0 you can just
- # use returning() directly to get the objects back.
- #
- # when from_statement is used, the UPDATE statement is no
- # longer interpreted by
- # BulkUDCompileState.orm_pre_session_exec or
- # BulkUDCompileState.orm_setup_cursor_result. The compilation
- # level routines still take place though
- stmt = select(User).from_statement(stmt)
-
- with Session(connection) as sess:
- rows = sess.execute(stmt).scalars().all()
-
- eq_(
- rows,
- [User(name="jack", age=52), User(name="jill", age=34)],
- )
-
- @testing.combinations(
- ("single",),
- ("multiple", testing.requires.multivalues_inserts),
- argnames="params",
- )
- @testing.combinations(True, False, argnames="use_from_statement")
- def test_load_from_insert(self, connection, params, use_from_statement):
- User = self.classes.User
-
- if params == "multiple":
- values = [
- {User.id: 5, User.age: 25, User.name: "spongebob"},
- {User.id: 6, User.age: 30, User.name: "patrick"},
- {User.id: 7, User.age: 35, User.name: "squidward"},
- ]
- elif params == "single":
- values = {User.id: 5, User.age: 25, User.name: "spongebob"}
- else:
- assert False
-
- stmt = insert(User).values(values).returning(User)
-
- if use_from_statement:
- stmt = select(User).from_statement(stmt)
-
- with Session(connection) as sess:
- rows = sess.execute(stmt).scalars().all()
-
- if params == "multiple":
- eq_(
- rows,
- [
- User(name="spongebob", age=25),
- User(name="patrick", age=30),
- User(name="squidward", age=35),
- ],
- )
- elif params == "single":
- eq_(
- rows,
- [User(name="spongebob", age=25)],
- )
- else:
- assert False
-
- @testing.requires.delete_returning
- @testing.combinations(True, False, argnames="use_from_statement")
- def test_load_from_delete(self, connection, use_from_statement):
- User = self.classes.User
-
- stmt = (
- delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
- )
-
- if use_from_statement:
- stmt = select(User).from_statement(stmt)
-
- with Session(connection) as sess:
- rows = sess.execute(stmt).scalars().all()
-
- eq_(
- rows,
- [User(name="jack", age=47), User(name="jill", age=29)],
- )
-
- # TODO: state of above objects should be "deleted"
-
-
-class OnUpdatePopulationTest(fixtures.TestBase):
- __backend__ = True
-
- @testing.variation("populate_existing", [True, False])
- @testing.variation(
- "use_onupdate",
- [
- "none",
- "server",
- "callable",
- "clientsql",
- ("computed", testing.requires.computed_columns),
- ],
- )
- @testing.variation(
- "use_returning",
- [
- ("returning", testing.requires.update_returning),
- ("defaults", testing.requires.update_returning),
- "none",
- ],
- )
- @testing.variation("synchronize", ["auto", "fetch", "evaluate"])
- @testing.variation("pk_order", ["first", "middle"])
- def test_update_populate_existing(
- self,
- decl_base,
- populate_existing,
- use_onupdate,
- use_returning,
- synchronize,
- pk_order,
- ):
- """test #11912 and #11917"""
-
- class Employee(ComparableEntity, decl_base):
- __tablename__ = "employee"
-
- if pk_order.first:
- uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
- user_name: Mapped[str] = mapped_column(String(200), nullable=False)
-
- if pk_order.middle:
- uuid: Mapped[uuid.UUID] = mapped_column(primary_key=True)
-
- if use_onupdate.server:
- some_server_value: Mapped[str] = mapped_column(
- server_onupdate=FetchedValue()
- )
- elif use_onupdate.callable:
- some_server_value: Mapped[str] = mapped_column(
- onupdate=lambda: "value 2"
- )
- elif use_onupdate.clientsql:
- some_server_value: Mapped[str] = mapped_column(
- onupdate=literal("value 2")
- )
- elif use_onupdate.computed:
- some_server_value: Mapped[str] = mapped_column(
- String(255),
- Computed(user_name + " computed value"),
- nullable=True,
- )
- else:
- some_server_value: Mapped[str]
-
- decl_base.metadata.create_all(testing.db)
- s = fixture_session()
-
- uuid1 = uuid.uuid4()
-
- if use_onupdate.computed:
- server_old_value, server_new_value = (
- "e1 old name computed value",
- "e1 new name computed value",
- )
- e1 = Employee(uuid=uuid1, user_name="e1 old name")
- else:
- server_old_value, server_new_value = ("value 1", "value 2")
- e1 = Employee(
- uuid=uuid1,
- user_name="e1 old name",
- some_server_value="value 1",
- )
- s.add(e1)
- s.flush()
-
- stmt = (
- update(Employee)
- .values(user_name="e1 new name")
- .where(Employee.uuid == uuid1)
- )
-
- if use_returning.returning:
- stmt = stmt.returning(Employee)
- elif use_returning.defaults:
- # NOTE: the return_defaults case here has not been analyzed for
- # #11912 or #11917. future enhancements may change its behavior
- stmt = stmt.return_defaults()
-
- # perform out of band UPDATE on server value to simulate
- # a computed col
- if use_onupdate.none or use_onupdate.server:
- s.connection().execute(
- update(Employee.__table__).values(some_server_value="value 2")
- )
-
- execution_options = {}
-
- if populate_existing:
- execution_options["populate_existing"] = True
-
- if synchronize.evaluate:
- execution_options["synchronize_session"] = "evaluate"
- if synchronize.fetch:
- execution_options["synchronize_session"] = "fetch"
-
- if use_returning.returning:
- rows = s.scalars(stmt, execution_options=execution_options)
- else:
- s.execute(stmt, execution_options=execution_options)
-
- if (
- use_onupdate.clientsql
- or use_onupdate.server
- or use_onupdate.computed
- ):
- if not use_returning.defaults:
- # if server-side onupdate was generated, the col should have
- # been expired
- assert "some_server_value" not in e1.__dict__
-
- # and refreshes when called. this is even if we have RETURNING
- # rows we didn't fetch yet.
- eq_(e1.some_server_value, server_new_value)
- else:
- # using return defaults here is not expiring. have not
- # researched why, it may be because the explicit
- # return_defaults interferes with the ORMs call
- assert "some_server_value" in e1.__dict__
- eq_(e1.some_server_value, server_old_value)
-
- elif use_onupdate.callable:
- if not use_returning.defaults or not synchronize.fetch:
- # for python-side onupdate, col is populated with local value
- assert "some_server_value" in e1.__dict__
-
- # and is refreshed
- eq_(e1.some_server_value, server_new_value)
- else:
- assert "some_server_value" in e1.__dict__
-
- # and is not refreshed
- eq_(e1.some_server_value, server_old_value)
-
- else:
- # no onupdate, then the value was not touched yet,
- # even if we used RETURNING with populate_existing, because
- # we did not fetch the rows yet
- assert "some_server_value" in e1.__dict__
- eq_(e1.some_server_value, server_old_value)
-
- # now see if we can fetch rows
- if use_returning.returning:
-
- if populate_existing or not use_onupdate.none:
- eq_(
- set(rows),
- {
- Employee(
- uuid=uuid1,
- user_name="e1 new name",
- some_server_value=server_new_value,
- ),
- },
- )
-
- else:
- # if no populate existing and no server default, that column
- # is not touched at all
- eq_(
- set(rows),
- {
- Employee(
- uuid=uuid1,
- user_name="e1 new name",
- some_server_value=server_old_value,
- ),
- },
- )
-
- if use_returning.defaults:
- # as mentioned above, the return_defaults() case here remains
- # unanalyzed.
- if synchronize.fetch or (
- use_onupdate.clientsql
- or use_onupdate.server
- or use_onupdate.computed
- or use_onupdate.none
- ):
- eq_(e1.some_server_value, server_old_value)
- else:
- eq_(e1.some_server_value, server_new_value)
-
- elif (
- populate_existing and use_returning.returning
- ) or not use_onupdate.none:
- eq_(e1.some_server_value, server_new_value)
- else:
- # no onupdate specified, and no populate existing with returning,
- # the attribute is not refreshed
- eq_(e1.some_server_value, server_old_value)
-
- # do a full expire, now the new value is definitely there
- s.commit()
- s.expire_all()
- eq_(e1.some_server_value, server_new_value)
-
-
-class PGIssue11849Test(fixtures.DeclarativeMappedTest):
- __backend__ = True
- __only_on__ = ("postgresql",)
-
- @classmethod
- def setup_classes(cls):
-
- from sqlalchemy.dialects.postgresql import JSONB
-
- Base = cls.DeclarativeBasic
-
- class TestTbl(Base):
- __tablename__ = "testtbl"
-
- test_id = Column(Integer, primary_key=True)
- test_field = Column(JSONB)
-
- def test_issue_11849(self):
- TestTbl = self.classes.TestTbl
-
- session = fixture_session()
-
- obj = TestTbl(
- test_id=1, test_field={"test1": 1, "test2": "2", "test3": [3, "3"]}
- )
- session.add(obj)
-
- query = (
- update(TestTbl)
- .where(TestTbl.test_id == 1)
- .values(test_field=TestTbl.test_field + {"test3": {"test4": 4}})
- )
- session.execute(query)
-
- # not loaded
- assert "test_field" not in obj.__dict__
-
- # synchronizes on load
- eq_(obj.test_field, {"test1": 1, "test2": "2", "test3": {"test4": 4}})
--- /dev/null
+from __future__ import annotations
+
+from sqlalchemy import delete
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.testing.schema import Column
+
+
+class InheritTest(fixtures.DeclarativeMappedTest):
+ run_inserts = "each"
+
+ run_deletes = "each"
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Person(Base):
+ __tablename__ = "person"
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
+ type = Column(String(50))
+ name = Column(String(50))
+
+ class Engineer(Person):
+ __tablename__ = "engineer"
+ id = Column(Integer, ForeignKey("person.id"), primary_key=True)
+ engineer_name = Column(String(50))
+
+ class Programmer(Engineer):
+ __tablename__ = "programmer"
+ id = Column(Integer, ForeignKey("engineer.id"), primary_key=True)
+ primary_language = Column(String(50))
+
+ class Manager(Person):
+ __tablename__ = "manager"
+ id = Column(Integer, ForeignKey("person.id"), primary_key=True)
+ manager_name = Column(String(50))
+
+ @classmethod
+ def insert_data(cls, connection):
+ Engineer, Person, Manager, Programmer = (
+ cls.classes.Engineer,
+ cls.classes.Person,
+ cls.classes.Manager,
+ cls.classes.Programmer,
+ )
+ s = Session(connection)
+ s.add_all(
+ [
+ Engineer(name="e1", engineer_name="e1"),
+ Manager(name="m1", manager_name="m1"),
+ Engineer(name="e2", engineer_name="e2"),
+ Person(name="p1"),
+ Programmer(
+ name="pp1", engineer_name="pp1", primary_language="python"
+ ),
+ ]
+ )
+ s.commit()
+
+ @testing.only_on(["mysql", "mariadb"], "Multi table update")
+ def test_update_from_join_no_problem(self):
+ person = self.classes.Person.__table__
+ engineer = self.classes.Engineer.__table__
+
+ sess = fixture_session()
+ sess.query(person.join(engineer)).filter(person.c.name == "e2").update(
+ {person.c.name: "updated", engineer.c.engineer_name: "e2a"},
+ )
+ obj = sess.execute(
+ select(self.classes.Engineer).filter(
+ self.classes.Engineer.name == "updated"
+ )
+ ).scalar()
+ eq_(obj.name, "updated")
+ eq_(obj.engineer_name, "e2a")
+
+ @testing.combinations(None, "fetch", "evaluate")
+ def test_update_sub_table_only(self, synchronize_session):
+ Engineer = self.classes.Engineer
+ s = Session(testing.db)
+ s.query(Engineer).update(
+ {"engineer_name": "e5"}, synchronize_session=synchronize_session
+ )
+
+ eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",), ("e5",)])
+
+ @testing.combinations(None, "fetch", "evaluate")
+ def test_update_sub_sub_table_only(self, synchronize_session):
+ Programmer = self.classes.Programmer
+ s = Session(testing.db)
+ s.query(Programmer).update(
+ {"primary_language": "c++"},
+ synchronize_session=synchronize_session,
+ )
+
+ eq_(
+ s.query(Programmer.primary_language).all(),
+ [
+ ("c++",),
+ ],
+ )
+
+ @testing.requires.update_from
+ @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate")
+ def test_update_from(self, synchronize_session):
+ """test an UPDATE that uses multiple tables.
+
+ The limitation that MariaDB has with DELETE does not apply here at the
+ moment as MariaDB doesn't support UPDATE..RETURNING at all. However,
+ the logic from DELETE is still implemented in persistence.py. If
+ MariaDB adds UPDATE...RETURNING, then it may be useful. SQLite,
+ PostgreSQL, MSSQL all support UPDATE..FROM however RETURNING seems to
+ function correctly for all three.
+
+ """
+ Engineer = self.classes.Engineer
+ Person = self.classes.Person
+ s = Session(testing.db)
+
+ # we don't have any backends with this combination right now.
+ db_has_hypothetical_limitation = (
+ testing.db.dialect.update_returning
+ and not testing.db.dialect.update_returning_multifrom
+ )
+
+ e2 = s.query(Engineer).filter_by(name="e2").first()
+
+ with self.sql_execution_asserter() as asserter:
+ eq_(e2.engineer_name, "e2")
+ q = (
+ s.query(Engineer)
+ .filter(Engineer.id == Person.id)
+ .filter(Person.name == "e2")
+ )
+ if synchronize_session == "fetch_w_hint":
+ q.execution_options(is_update_from=True).update(
+ {"engineer_name": "e5"},
+ synchronize_session="fetch",
+ )
+ elif (
+ synchronize_session == "fetch"
+ and db_has_hypothetical_limitation
+ ):
+ with expect_raises_message(
+ exc.CompileError,
+ 'Dialect ".*" does not support RETURNING with '
+ "UPDATE..FROM;",
+ ):
+ q.update(
+ {"engineer_name": "e5"},
+ synchronize_session=synchronize_session,
+ )
+ return
+ else:
+ q.update(
+ {"engineer_name": "e5"},
+ synchronize_session=synchronize_session,
+ )
+
+ if synchronize_session is None:
+ eq_(e2.engineer_name, "e2")
+ else:
+ eq_(e2.engineer_name, "e5")
+
+ if synchronize_session in ("fetch", "fetch_w_hint") and (
+ db_has_hypothetical_limitation
+ or not testing.db.dialect.update_returning
+ ):
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT person.id FROM person INNER JOIN engineer "
+ "ON person.id = engineer.id WHERE engineer.id = person.id "
+ "AND person.name = %s",
+ [{"name_1": "e2"}],
+ dialect="mariadb",
+ ),
+ CompiledSQL(
+ "UPDATE engineer, person SET engineer.engineer_name=%s "
+ "WHERE engineer.id = person.id AND person.name = %s",
+ [{"engineer_name": "e5", "name_1": "e2"}],
+ dialect="mariadb",
+ ),
+ )
+ elif synchronize_session in ("fetch", "fetch_w_hint"):
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE engineer SET engineer_name=%(engineer_name)s "
+ "FROM person WHERE engineer.id = person.id "
+ "AND person.name = %(name_1)s RETURNING engineer.id",
+ [{"engineer_name": "e5", "name_1": "e2"}],
+ dialect="postgresql",
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE engineer SET engineer_name=%(engineer_name)s "
+ "FROM person WHERE engineer.id = person.id "
+ "AND person.name = %(name_1)s",
+ [{"engineer_name": "e5", "name_1": "e2"}],
+ dialect="postgresql",
+ ),
+ )
+
+ eq_(
+ set(s.query(Person.name, Engineer.engineer_name)),
+ {("e1", "e1"), ("e2", "e5"), ("pp1", "pp1")},
+ )
+
+ @testing.requires.delete_using
+ @testing.combinations(None, "fetch", "fetch_w_hint", "evaluate")
+ def test_delete_using(self, synchronize_session):
+ """test a DELETE that uses multiple tables.
+
+ due to a limitation in MariaDB, we have an up front "hint" that needs
+ to be passed for this backend if DELETE USING is to be used in
+ conjunction with "fetch" strategy, so that we know before compilation
+ that we won't be able to use RETURNING.
+
+ """
+
+ Engineer = self.classes.Engineer
+ Person = self.classes.Person
+ s = Session(testing.db)
+
+ db_has_mariadb_limitation = (
+ testing.db.dialect.delete_returning
+ and not testing.db.dialect.delete_returning_multifrom
+ )
+
+ e2 = s.query(Engineer).filter_by(name="e2").first()
+
+ with self.sql_execution_asserter() as asserter:
+ assert e2 in s
+
+ q = (
+ s.query(Engineer)
+ .filter(Engineer.id == Person.id)
+ .filter(Person.name == "e2")
+ )
+
+ if synchronize_session == "fetch_w_hint":
+ q.execution_options(is_delete_using=True).delete(
+ synchronize_session="fetch"
+ )
+ elif synchronize_session == "fetch" and db_has_mariadb_limitation:
+ with expect_raises_message(
+ exc.CompileError,
+ 'Dialect ".*" does not support RETURNING with '
+ "DELETE..USING;",
+ ):
+ q.delete(synchronize_session=synchronize_session)
+ return
+ else:
+ q.delete(synchronize_session=synchronize_session)
+
+ if synchronize_session is None:
+ assert e2 in s
+ else:
+ assert e2 not in s
+
+ if synchronize_session in ("fetch", "fetch_w_hint") and (
+ db_has_mariadb_limitation
+ or not testing.db.dialect.delete_returning
+ ):
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT person.id FROM person INNER JOIN engineer ON "
+ "person.id = engineer.id WHERE engineer.id = person.id "
+ "AND person.name = %s",
+ [{"name_1": "e2"}],
+ dialect="mariadb",
+ ),
+ CompiledSQL(
+ "DELETE FROM engineer USING engineer, person WHERE "
+ "engineer.id = person.id AND person.name = %s",
+ [{"name_1": "e2"}],
+ dialect="mariadb",
+ ),
+ )
+ elif synchronize_session in ("fetch", "fetch_w_hint"):
+ asserter.assert_(
+ CompiledSQL(
+ "DELETE FROM engineer USING person WHERE "
+ "engineer.id = person.id AND person.name = %(name_1)s "
+ "RETURNING engineer.id",
+ [{"name_1": "e2"}],
+ dialect="postgresql",
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "DELETE FROM engineer USING person WHERE "
+ "engineer.id = person.id AND person.name = %(name_1)s",
+ [{"name_1": "e2"}],
+ dialect="postgresql",
+ ),
+ )
+
+ # delete actually worked
+ eq_(
+ set(s.query(Person.name, Engineer.engineer_name)),
+ {("pp1", "pp1"), ("e1", "e1")},
+ )
+
+ @testing.only_on(["mysql", "mariadb"], "Multi table update")
+ @testing.requires.delete_using
+ @testing.combinations(None, "fetch", "evaluate")
+ def test_update_from_multitable(self, synchronize_session):
+ Engineer = self.classes.Engineer
+ Person = self.classes.Person
+ s = Session(testing.db)
+ s.query(Engineer).filter(Engineer.id == Person.id).filter(
+ Person.name == "e2"
+ ).update(
+ {Person.name: "e22", Engineer.engineer_name: "e55"},
+ synchronize_session=synchronize_session,
+ )
+
+ eq_(
+ set(s.query(Person.name, Engineer.engineer_name)),
+ {("e1", "e1"), ("e22", "e55"), ("pp1", "pp1")},
+ )
+
+
+class InheritWPolyTest(fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ @testing.fixture
+ def inherit_fixture(self, decl_base):
+ def go(poly_type):
+
+ class Person(decl_base):
+ __tablename__ = "person"
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+ name = Column(String(50))
+
+ if poly_type.wpoly:
+ __mapper_args__ = {"with_polymorphic": "*"}
+
+ class Engineer(Person):
+ __tablename__ = "engineer"
+ id = Column(Integer, ForeignKey("person.id"), primary_key=True)
+ engineer_name = Column(String(50))
+
+ if poly_type.inline:
+ __mapper_args__ = {"polymorphic_load": "inline"}
+
+ return Person, Engineer
+
+ return go
+
+ @testing.variation("poly_type", ["wpoly", "inline", "none"])
+ def test_update_base_only(self, poly_type, inherit_fixture):
+ Person, Engineer = inherit_fixture(poly_type)
+
+ self.assert_compile(
+ update(Person).values(name="n1"), "UPDATE person SET name=:name"
+ )
+
+ @testing.variation("poly_type", ["wpoly", "inline", "none"])
+ def test_delete_base_only(self, poly_type, inherit_fixture):
+ Person, Engineer = inherit_fixture(poly_type)
+
+ self.assert_compile(delete(Person), "DELETE FROM person")
+
+ self.assert_compile(
+ delete(Person).where(Person.id == 7),
+ "DELETE FROM person WHERE person.id = :id_1",
+ )
+
+
+class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ Base = cls.DeclarativeBasic
+
+ class Staff(Base):
+ __tablename__ = "staff"
+ position = Column(String(10), nullable=False)
+ id = Column(
+ Integer, primary_key=True, test_needs_autoincrement=True
+ )
+ name = Column(String(5))
+ stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_on": position}
+
+ class Sales(Staff):
+ sales_stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_identity": "sales"}
+
+ class Support(Staff):
+ support_stats = Column(String(5))
+ __mapper_args__ = {"polymorphic_identity": "support"}
+
+ @classmethod
+ def insert_data(cls, connection):
+ with sessionmaker(connection).begin() as session:
+ Sales, Support = (
+ cls.classes.Sales,
+ cls.classes.Support,
+ )
+ session.add_all(
+ [
+ Sales(name="n1", sales_stats="1", stats="a"),
+ Sales(name="n2", sales_stats="2", stats="b"),
+ Support(name="n1", support_stats="3", stats="c"),
+ Support(name="n2", support_stats="4", stats="d"),
+ ]
+ )
+
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_update(self, fetchstyle, newstyle):
+ Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+ sess = fixture_session()
+
+ en1, en2 = (
+ sess.execute(select(Sales).order_by(Sales.sales_stats))
+ .scalars()
+ .all()
+ )
+ mn1, mn2 = (
+ sess.execute(select(Support).order_by(Support.support_stats))
+ .scalars()
+ .all()
+ )
+
+ if newstyle:
+ sess.execute(
+ update(Sales)
+ .filter_by(name="n1")
+ .values(stats="p")
+ .execution_options(synchronize_session=fetchstyle)
+ )
+ else:
+ sess.query(Sales).filter_by(name="n1").update(
+ {"stats": "p"}, synchronize_session=fetchstyle
+ )
+
+ eq_(en1.stats, "p")
+ eq_(mn1.stats, "c")
+ eq_(
+ sess.execute(
+ select(Staff.position, Staff.name, Staff.stats).order_by(
+ Staff.id
+ )
+ ).all(),
+ [
+ ("sales", "n1", "p"),
+ ("sales", "n2", "b"),
+ ("support", "n1", "c"),
+ ("support", "n2", "d"),
+ ],
+ )
+
+ @testing.combinations(
+ ("fetch", False),
+ ("fetch", True),
+ ("evaluate", False),
+ ("evaluate", True),
+ )
+ def test_delete(self, fetchstyle, newstyle):
+ Staff, Sales, Support = self.classes("Staff", "Sales", "Support")
+
+ sess = fixture_session()
+ en1, en2 = sess.query(Sales).order_by(Sales.sales_stats).all()
+ mn1, mn2 = sess.query(Support).order_by(Support.support_stats).all()
+
+ if newstyle:
+ sess.execute(
+ delete(Sales)
+ .filter_by(name="n1")
+ .execution_options(synchronize_session=fetchstyle)
+ )
+ else:
+ sess.query(Sales).filter_by(name="n1").delete(
+ synchronize_session=fetchstyle
+ )
+ assert en1 not in sess
+ assert en2 in sess
+ assert mn1 in sess
+ assert mn2 in sess
+
+ eq_(
+ sess.execute(
+ select(Staff.position, Staff.name, Staff.stats).order_by(
+ Staff.id
+ )
+ ).all(),
+ [
+ ("sales", "n2", "b"),
+ ("support", "n1", "c"),
+ ("support", "n2", "d"),
+ ],
+ )