]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Test some complex update cases in the pg dialect
authorFederico Caselli <cfederico87@gmail.com>
Tue, 25 May 2021 20:48:54 +0000 (22:48 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 27 May 2021 18:49:44 +0000 (20:49 +0200)
Change-Id: I2323e155e78aa8e1e00359b103974fb8d27d80eb

test/dialect/postgresql/test_compiler.py

index a517ad1ac057ea9c375010d74556424ec9c970f9..2f91580a93d19446c007c3b4fff84c43db3430b4 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import Date
 from sqlalchemy import delete
 from sqlalchemy import Enum
 from sqlalchemy import exc
+from sqlalchemy import Float
 from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import Index
@@ -25,6 +26,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy import text
+from sqlalchemy import tuple_
 from sqlalchemy import types as sqltypes
 from sqlalchemy import update
 from sqlalchemy.dialects import postgresql
@@ -1950,6 +1952,158 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             schema.CreateIndex(idx), "CREATE INDEX foo ON test (x) INCLUDE (y)"
         )
 
+    @testing.fixture
+    def update_tables(self):
+        self.weather = table(
+            "weather",
+            column("temp_lo", Integer),
+            column("temp_hi", Integer),
+            column("prcp", Integer),
+            column("city", String),
+            column("date", Date),
+        )
+        self.accounts = table(
+            "accounts",
+            column("sales_id", Integer),
+            column("sales_person", Integer),
+            column("contact_first_name", String),
+            column("contact_last_name", String),
+            column("name", String),
+        )
+        self.salesmen = table(
+            "salesmen",
+            column("id", Integer),
+            column("first_name", String),
+            column("last_name", String),
+        )
+        self.employees = table(
+            "employees",
+            column("id", Integer),
+            column("sales_count", String),
+        )
+
+    # from examples at https://www.postgresql.org/docs/current/sql-update.html
+    def test_difficult_update_1(self, update_tables):
+        update = (
+            self.weather.update()
+            .where(self.weather.c.city == "San Francisco")
+            .where(self.weather.c.date == "2003-07-03")
+            .values(
+                {
+                    tuple_(
+                        self.weather.c.temp_lo,
+                        self.weather.c.temp_hi,
+                        self.weather.c.prcp,
+                    ): tuple_(
+                        self.weather.c.temp_lo + 1,
+                        self.weather.c.temp_lo + 15,
+                        literal_column("DEFAULT"),
+                    )
+                }
+            )
+        )
+
+        self.assert_compile(
+            update,
+            "UPDATE weather SET (temp_lo, temp_hi, prcp)=(weather.temp_lo + "
+            "%(temp_lo_1)s, weather.temp_lo + %(temp_lo_2)s, DEFAULT) "
+            "WHERE weather.city = %(city_1)s AND weather.date = %(date_1)s",
+            {
+                "city_1": "San Francisco",
+                "date_1": "2003-07-03",
+                "temp_lo_1": 1,
+                "temp_lo_2": 15,
+            },
+        )
+
+    def test_difficult_update_2(self, update_tables):
+        update = self.accounts.update().values(
+            {
+                tuple_(
+                    self.accounts.c.contact_first_name,
+                    self.accounts.c.contact_last_name,
+                ): select(
+                    self.salesmen.c.first_name, self.salesmen.c.last_name
+                )
+                .where(self.salesmen.c.id == self.accounts.c.sales_id)
+                .scalar_subquery()
+            }
+        )
+
+        self.assert_compile(
+            update,
+            "UPDATE accounts SET (contact_first_name, contact_last_name)="
+            "(SELECT salesmen.first_name, salesmen.last_name FROM "
+            "salesmen WHERE salesmen.id = accounts.sales_id)",
+        )
+
+    def test_difficult_update_3(self, update_tables):
+        update = (
+            self.employees.update()
+            .values(
+                {
+                    self.employees.c.sales_count: self.employees.c.sales_count
+                    + 1
+                }
+            )
+            .where(
+                self.employees.c.id
+                == select(self.accounts.c.sales_person)
+                .where(self.accounts.c.name == "Acme Corporation")
+                .scalar_subquery()
+            )
+        )
+
+        self.assert_compile(
+            update,
+            "UPDATE employees SET sales_count=(employees.sales_count "
+            "+ %(sales_count_1)s) WHERE employees.id = (SELECT "
+            "accounts.sales_person FROM accounts WHERE "
+            "accounts.name = %(name_1)s)",
+            {"sales_count_1": 1, "name_1": "Acme Corporation"},
+        )
+
+    def test_difficult_update_4(self):
+        summary = table(
+            "summary",
+            column("group_id", Integer),
+            column("sum_y", Float),
+            column("sum_x", Float),
+            column("avg_x", Float),
+            column("avg_y", Float),
+        )
+        data = table(
+            "data",
+            column("group_id", Integer),
+            column("x", Float),
+            column("y", Float),
+        )
+
+        update = summary.update().values(
+            {
+                tuple_(
+                    summary.c.sum_x,
+                    summary.c.sum_y,
+                    summary.c.avg_x,
+                    summary.c.avg_y,
+                ): select(
+                    func.sum(data.c.x),
+                    func.sum(data.c.y),
+                    func.avg(data.c.x),
+                    func.avg(data.c.y),
+                )
+                .where(data.c.group_id == summary.c.group_id)
+                .scalar_subquery()
+            }
+        )
+        self.assert_compile(
+            update,
+            "UPDATE summary SET (sum_x, sum_y, avg_x, avg_y)="
+            "(SELECT sum(data.x) AS sum_1, sum(data.y) AS sum_2, "
+            "avg(data.x) AS avg_1, avg(data.y) AS avg_2 FROM data "
+            "WHERE data.group_id = summary.group_id)",
+        )
+
 
 class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()