From: Federico Caselli Date: Tue, 25 May 2021 20:48:54 +0000 (+0200) Subject: Test some complex update cases in the pg dialect X-Git-Tag: rel_1_4_16~2^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=85d66c5adf32abbb8fe5035d810c11e3f23e752d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Test some complex update cases in the pg dialect Change-Id: I2323e155e78aa8e1e00359b103974fb8d27d80eb --- diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index a517ad1ac0..2f91580a93 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -9,6 +9,7 @@ from sqlalchemy import Date from sqlalchemy import delete from sqlalchemy import Enum from sqlalchemy import exc +from sqlalchemy import Float from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import Index @@ -25,6 +26,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import text +from sqlalchemy import tuple_ from sqlalchemy import types as sqltypes from sqlalchemy import update from sqlalchemy.dialects import postgresql @@ -1950,6 +1952,158 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)" ) + @testing.fixture + def update_tables(self): + self.weather = table( + "weather", + column("temp_lo", Integer), + column("temp_hi", Integer), + column("prcp", Integer), + column("city", String), + column("date", Date), + ) + self.accounts = table( + "accounts", + column("sales_id", Integer), + column("sales_person", Integer), + column("contact_first_name", String), + column("contact_last_name", String), + column("name", String), + ) + self.salesmen = table( + "salesmen", + column("id", Integer), + column("first_name", String), + column("last_name", String), + ) + self.employees = table( + "employees", + column("id", Integer), + column("sales_count", String), + ) + + # from examples at https://www.postgresql.org/docs/current/sql-update.html + def test_difficult_update_1(self, update_tables): + update = ( + self.weather.update() + .where(self.weather.c.city == "San Francisco") + .where(self.weather.c.date == "2003-07-03") + .values( + { + tuple_( + self.weather.c.temp_lo, + self.weather.c.temp_hi, + self.weather.c.prcp, + ): tuple_( + self.weather.c.temp_lo + 1, + self.weather.c.temp_lo + 15, + literal_column("DEFAULT"), + ) + } + ) + ) + + self.assert_compile( + update, + "UPDATE weather SET (temp_lo, temp_hi, prcp)=(weather.temp_lo + " + "%(temp_lo_1)s, weather.temp_lo + %(temp_lo_2)s, DEFAULT) " + "WHERE weather.city = %(city_1)s AND weather.date = %(date_1)s", + { + "city_1": "San Francisco", + "date_1": "2003-07-03", + "temp_lo_1": 1, + "temp_lo_2": 15, + }, + ) + + def test_difficult_update_2(self, update_tables): + update = self.accounts.update().values( + { + tuple_( + self.accounts.c.contact_first_name, + self.accounts.c.contact_last_name, + ): select( + self.salesmen.c.first_name, self.salesmen.c.last_name + ) + .where(self.salesmen.c.id == self.accounts.c.sales_id) + .scalar_subquery() + } + ) + + self.assert_compile( + update, + "UPDATE accounts SET (contact_first_name, contact_last_name)=" + "(SELECT salesmen.first_name, salesmen.last_name FROM " + "salesmen WHERE salesmen.id = accounts.sales_id)", + ) + + def test_difficult_update_3(self, update_tables): + update = ( + self.employees.update() + .values( + { + self.employees.c.sales_count: self.employees.c.sales_count + + 1 + } + ) + .where( + self.employees.c.id + == select(self.accounts.c.sales_person) + .where(self.accounts.c.name == "Acme Corporation") + .scalar_subquery() + ) + ) + + self.assert_compile( + update, + "UPDATE employees SET sales_count=(employees.sales_count " + "+ %(sales_count_1)s) WHERE employees.id = (SELECT " + "accounts.sales_person FROM accounts WHERE " + "accounts.name = %(name_1)s)", + {"sales_count_1": 1, "name_1": "Acme Corporation"}, + ) + + def test_difficult_update_4(self): + summary = table( + "summary", + column("group_id", Integer), + column("sum_y", Float), + column("sum_x", Float), + column("avg_x", Float), + column("avg_y", Float), + ) + data = table( + "data", + column("group_id", Integer), + column("x", Float), + column("y", Float), + ) + + update = summary.update().values( + { + tuple_( + summary.c.sum_x, + summary.c.sum_y, + summary.c.avg_x, + summary.c.avg_y, + ): select( + func.sum(data.c.x), + func.sum(data.c.y), + func.avg(data.c.x), + func.avg(data.c.y), + ) + .where(data.c.group_id == summary.c.group_id) + .scalar_subquery() + } + ) + self.assert_compile( + update, + "UPDATE summary SET (sum_x, sum_y, avg_x, avg_y)=" + "(SELECT sum(data.x) AS sum_1, sum(data.y) AS sum_2, " + "avg(data.x) AS avg_1, avg(data.y) AS avg_2 FROM data " + "WHERE data.group_id = summary.group_id)", + ) + class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect()