from sqlalchemy import delete
from sqlalchemy import Enum
from sqlalchemy import exc
+from sqlalchemy import Float
from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy import Index
from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy import text
+from sqlalchemy import tuple_
from sqlalchemy import types as sqltypes
from sqlalchemy import update
from sqlalchemy.dialects import postgresql
schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
)
+ @testing.fixture
+ def update_tables(self):
+ self.weather = table(
+ "weather",
+ column("temp_lo", Integer),
+ column("temp_hi", Integer),
+ column("prcp", Integer),
+ column("city", String),
+ column("date", Date),
+ )
+ self.accounts = table(
+ "accounts",
+ column("sales_id", Integer),
+ column("sales_person", Integer),
+ column("contact_first_name", String),
+ column("contact_last_name", String),
+ column("name", String),
+ )
+ self.salesmen = table(
+ "salesmen",
+ column("id", Integer),
+ column("first_name", String),
+ column("last_name", String),
+ )
+ self.employees = table(
+ "employees",
+ column("id", Integer),
+ column("sales_count", String),
+ )
+
+ # from examples at https://www.postgresql.org/docs/current/sql-update.html
+ def test_difficult_update_1(self, update_tables):
+ update = (
+ self.weather.update()
+ .where(self.weather.c.city == "San Francisco")
+ .where(self.weather.c.date == "2003-07-03")
+ .values(
+ {
+ tuple_(
+ self.weather.c.temp_lo,
+ self.weather.c.temp_hi,
+ self.weather.c.prcp,
+ ): tuple_(
+ self.weather.c.temp_lo + 1,
+ self.weather.c.temp_lo + 15,
+ literal_column("DEFAULT"),
+ )
+ }
+ )
+ )
+
+ self.assert_compile(
+ update,
+ "UPDATE weather SET (temp_lo, temp_hi, prcp)=(weather.temp_lo + "
+ "%(temp_lo_1)s, weather.temp_lo + %(temp_lo_2)s, DEFAULT) "
+ "WHERE weather.city = %(city_1)s AND weather.date = %(date_1)s",
+ {
+ "city_1": "San Francisco",
+ "date_1": "2003-07-03",
+ "temp_lo_1": 1,
+ "temp_lo_2": 15,
+ },
+ )
+
+ def test_difficult_update_2(self, update_tables):
+ update = self.accounts.update().values(
+ {
+ tuple_(
+ self.accounts.c.contact_first_name,
+ self.accounts.c.contact_last_name,
+ ): select(
+ self.salesmen.c.first_name, self.salesmen.c.last_name
+ )
+ .where(self.salesmen.c.id == self.accounts.c.sales_id)
+ .scalar_subquery()
+ }
+ )
+
+ self.assert_compile(
+ update,
+ "UPDATE accounts SET (contact_first_name, contact_last_name)="
+ "(SELECT salesmen.first_name, salesmen.last_name FROM "
+ "salesmen WHERE salesmen.id = accounts.sales_id)",
+ )
+
+ def test_difficult_update_3(self, update_tables):
+ update = (
+ self.employees.update()
+ .values(
+ {
+ self.employees.c.sales_count: self.employees.c.sales_count
+ + 1
+ }
+ )
+ .where(
+ self.employees.c.id
+ == select(self.accounts.c.sales_person)
+ .where(self.accounts.c.name == "Acme Corporation")
+ .scalar_subquery()
+ )
+ )
+
+ self.assert_compile(
+ update,
+ "UPDATE employees SET sales_count=(employees.sales_count "
+ "+ %(sales_count_1)s) WHERE employees.id = (SELECT "
+ "accounts.sales_person FROM accounts WHERE "
+ "accounts.name = %(name_1)s)",
+ {"sales_count_1": 1, "name_1": "Acme Corporation"},
+ )
+
+ def test_difficult_update_4(self):
+ summary = table(
+ "summary",
+ column("group_id", Integer),
+ column("sum_y", Float),
+ column("sum_x", Float),
+ column("avg_x", Float),
+ column("avg_y", Float),
+ )
+ data = table(
+ "data",
+ column("group_id", Integer),
+ column("x", Float),
+ column("y", Float),
+ )
+
+ update = summary.update().values(
+ {
+ tuple_(
+ summary.c.sum_x,
+ summary.c.sum_y,
+ summary.c.avg_x,
+ summary.c.avg_y,
+ ): select(
+ func.sum(data.c.x),
+ func.sum(data.c.y),
+ func.avg(data.c.x),
+ func.avg(data.c.y),
+ )
+ .where(data.c.group_id == summary.c.group_id)
+ .scalar_subquery()
+ }
+ )
+ self.assert_compile(
+ update,
+ "UPDATE summary SET (sum_x, sum_y, avg_x, avg_y)="
+ "(SELECT sum(data.x) AS sum_1, sum(data.y) AS sum_2, "
+ "avg(data.x) AS avg_1, avg(data.y) AS avg_2 FROM data "
+ "WHERE data.group_id = summary.group_id)",
+ )
+
class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()