]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
render col name in on conflict set clause, not given key
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 13 May 2022 20:08:34 +0000 (16:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 May 2022 21:29:26 +0000 (17:29 -0400)
Fixed bug where the PostgreSQL :meth:`_postgresql.Insert.on_conflict`
method and the SQLite :meth:`_sqlite.Insert.on_conflict` method would both
fail to correctly accommodate a column with a separate ".key" when
specifying the column using its key name in the dictionary passed to
``set_``, as well as if the :attr:`_sqlite.Insert.excluded` or
:attr:`_postgresql.Insert.excluded` collection were used as the dictionary
directly.

Fixes: #8014
Change-Id: I67226aeedcb2c683e22405af64720cc1f990f274
(cherry picked from commit 927abc3b33f10464ed04db3d7a454faeb6e729f2)

doc/build/changelog/unreleased_14/8014.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
test/dialect/postgresql/test_compiler.py
test/dialect/test_sqlite.py

diff --git a/doc/build/changelog/unreleased_14/8014.rst b/doc/build/changelog/unreleased_14/8014.rst
new file mode 100644 (file)
index 0000000..331a957
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, sql, postgresql, sqlite
+    :tickets: 8014
+
+    Fixed bug where the PostgreSQL :meth:`_postgresql.Insert.on_conflict`
+    method and the SQLite :meth:`_sqlite.Insert.on_conflict` method would both
+    fail to correctly accommodate a column with a separate ".key" when
+    specifying the column using its key name in the dictionary passed to
+    ``set_``, as well as if the :attr:`_sqlite.Insert.excluded` or
+    :attr:`_postgresql.Insert.excluded` collection were used as the dictionary
+    directly.
index 7ba996a4a2dc153c7475b1458826f2e9d8fb443c..ad2bdf187753f65b4399e8b4b76592889cfde787 100644 (file)
@@ -2530,7 +2530,7 @@ class PGCompiler(compiler.SQLCompiler):
                     value.type = c.type
             value_text = self.process(value.self_group(), use_schema=False)
 
-            key_text = self.preparer.quote(col_key)
+            key_text = self.preparer.quote(c.name)
             action_set_ops.append("%s = %s" % (key_text, value_text))
 
         # check for names that don't match columns
index 49e4b5c19552b06a45583d3d7eec9788f6e165ec..0959d0417cf68a6a78207feac26485d832e92ba5 100644 (file)
@@ -1385,7 +1385,7 @@ class SQLiteCompiler(compiler.SQLCompiler):
                     value.type = c.type
             value_text = self.process(value.self_group(), use_schema=False)
 
-            key_text = self.preparer.quote(col_key)
+            key_text = self.preparer.quote(c.name)
             action_set_ops.append("%s = %s" % (key_text, value_text))
 
         # check for names that don't match columns
index 6bd2f2fa2be05901f83c0b96ea59cad02c754b58..d85ae9152fd4ee2be7fb3ae045fa2451f0892409 100644 (file)
@@ -2282,41 +2282,103 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
+class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()
 
-    def setup_test(self):
-        self.table1 = table1 = table(
+    run_create_tables = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        cls.table1 = table1 = table(
             "mytable",
             column("myid", Integer),
             column("name", String(128)),
             column("description", String(128)),
         )
-        md = MetaData()
-        self.table_with_metadata = Table(
+        cls.table_with_metadata = Table(
             "mytable",
-            md,
+            metadata,
             Column("myid", Integer, primary_key=True),
             Column("name", String(128)),
             Column("description", String(128)),
         )
-        self.unique_constr = schema.UniqueConstraint(
+        cls.unique_constr = schema.UniqueConstraint(
             table1.c.name, name="uq_name"
         )
-        self.excl_constr = ExcludeConstraint(
+        cls.excl_constr = ExcludeConstraint(
             (table1.c.name, "="),
             (table1.c.description, "&&"),
             name="excl_thing",
         )
-        self.excl_constr_anon = ExcludeConstraint(
-            (self.table_with_metadata.c.name, "="),
-            (self.table_with_metadata.c.description, "&&"),
-            where=self.table_with_metadata.c.description != "foo",
+        cls.excl_constr_anon = ExcludeConstraint(
+            (cls.table_with_metadata.c.name, "="),
+            (cls.table_with_metadata.c.description, "&&"),
+            where=cls.table_with_metadata.c.description != "foo",
         )
-        self.goofy_index = Index(
+        cls.goofy_index = Index(
             "goofy_index", table1.c.name, postgresql_where=table1.c.name > "m"
         )
 
+        Table(
+            "users",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50)),
+        )
+
+        Table(
+            "users_w_key",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), key="name_keyed"),
+        )
+
+    @testing.combinations("control", "excluded", "dict")
+    def test_set_excluded(self, scenario):
+        """test #8014, sending all of .excluded to set"""
+
+        if scenario == "control":
+            users = self.tables.users
+
+            stmt = insert(users)
+            self.assert_compile(
+                stmt.on_conflict_do_update(
+                    constraint=users.primary_key, set_=stmt.excluded
+                ),
+                "INSERT INTO users (id, name) VALUES (%(id)s, %(name)s) ON "
+                "CONFLICT (id) DO UPDATE "
+                "SET id = excluded.id, name = excluded.name",
+            )
+        else:
+            users_w_key = self.tables.users_w_key
+
+            stmt = insert(users_w_key)
+
+            if scenario == "excluded":
+                self.assert_compile(
+                    stmt.on_conflict_do_update(
+                        constraint=users_w_key.primary_key, set_=stmt.excluded
+                    ),
+                    "INSERT INTO users_w_key (id, name) "
+                    "VALUES (%(id)s, %(name_keyed)s) ON "
+                    "CONFLICT (id) DO UPDATE "
+                    "SET id = excluded.id, name = excluded.name",
+                )
+            else:
+                self.assert_compile(
+                    stmt.on_conflict_do_update(
+                        constraint=users_w_key.primary_key,
+                        set_={
+                            "id": stmt.excluded.id,
+                            "name_keyed": stmt.excluded.name_keyed,
+                        },
+                    ),
+                    "INSERT INTO users_w_key (id, name) "
+                    "VALUES (%(id)s, %(name_keyed)s) ON "
+                    "CONFLICT (id) DO UPDATE "
+                    "SET id = excluded.id, name = excluded.name",
+                )
+
     def test_on_conflict_do_no_call_twice(self):
         users = self.table1
 
index 6230c7f94595f754ad0cb2f5a114048e81fa4d20..ff98fea149b8669ad9bd64c985d62bc9d2d25c87 100644 (file)
@@ -2754,7 +2754,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
 
-class OnConflictTest(fixtures.TablesTest):
+class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
 
     __only_on__ = ("sqlite >= 3.24.0",)
     __backend__ = True
@@ -2768,6 +2768,13 @@ class OnConflictTest(fixtures.TablesTest):
             Column("name", String(50)),
         )
 
+        Table(
+            "users_w_key",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("name", String(50), key="name_keyed"),
+        )
+
         class SpecialType(sqltypes.TypeDecorator):
             impl = String
             cache_ok = True
@@ -2812,6 +2819,44 @@ class OnConflictTest(fixtures.TablesTest):
             ValueError, insert(self.tables.users).on_conflict_do_update
         )
 
+    @testing.combinations("control", "excluded", "dict")
+    def test_set_excluded(self, scenario):
+        """test #8014, sending all of .excluded to set"""
+
+        if scenario == "control":
+            users = self.tables.users
+
+            stmt = insert(users)
+            self.assert_compile(
+                stmt.on_conflict_do_update(set_=stmt.excluded),
+                "INSERT INTO users (id, name) VALUES (?, ?) ON CONFLICT  "
+                "DO UPDATE SET id = excluded.id, name = excluded.name",
+            )
+        else:
+            users_w_key = self.tables.users_w_key
+
+            stmt = insert(users_w_key)
+
+            if scenario == "excluded":
+                self.assert_compile(
+                    stmt.on_conflict_do_update(set_=stmt.excluded),
+                    "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
+                    "ON CONFLICT  "
+                    "DO UPDATE SET id = excluded.id, name = excluded.name",
+                )
+            else:
+                self.assert_compile(
+                    stmt.on_conflict_do_update(
+                        set_={
+                            "id": stmt.excluded.id,
+                            "name_keyed": stmt.excluded.name_keyed,
+                        }
+                    ),
+                    "INSERT INTO users_w_key (id, name) VALUES (?, ?) "
+                    "ON CONFLICT  "
+                    "DO UPDATE SET id = excluded.id, name = excluded.name",
+                )
+
     def test_on_conflict_do_no_call_twice(self):
         users = self.tables.users