]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Check for column expr in Oracle RETURNING check
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jan 2021 22:05:46 +0000 (17:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Jan 2021 22:21:04 +0000 (17:21 -0500)
Fixed regression in Oracle dialect introduced by :ticket:`4894` in
SQLAlchemy 1.3.11 where use of a SQL expression in RETURNING for an UPDATE
would fail to compile, due to a check for "server_default" when an
arbitrary SQL expression is not a column.

Fixes: #5813
Change-Id: I1977bb49bc971399195015ae45e761f774f4008d
(cherry picked from commit ea467fccbe4337929b91e0daec12b8672fa7907c)

doc/build/changelog/unreleased_13/5813.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
test/sql/test_returning.py

diff --git a/doc/build/changelog/unreleased_13/5813.rst b/doc/build/changelog/unreleased_13/5813.rst
new file mode 100644 (file)
index 0000000..d6483a2
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, oracle
+    :tickets: 5813
+
+    Fixed regression in Oracle dialect introduced by :ticket:`4894` in
+    SQLAlchemy 1.3.11 where use of a SQL expression in RETURNING for an UPDATE
+    would fail to compile, due to a check for "server_default" when an
+    arbitrary SQL expression is not a column.
+
index a072376a38ae6a45115565125200066dedb1d60c..c476554bd175266446008c33c9190abe6522c66e 100644 (file)
@@ -976,6 +976,7 @@ class OracleCompiler(compiler.SQLCompiler):
         ):
             if (
                 self.isupdate
+                and isinstance(column, sa_schema.Column)
                 and isinstance(column.server_default, Computed)
                 and not self.dialect._supports_update_returning_computed_cols
             ):
index 8a3cc64933561adb7a4e455fcd916f2ba87452d7..2dbea604f2ba72621532e5ffd7898b577b2fa376 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy import select
 from sqlalchemy import Sequence
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import type_coerce
 from sqlalchemy import update
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -96,14 +97,14 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
+class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
     __requires__ = ("returning",)
     __backend__ = True
 
-    def setup(self):
-        meta = MetaData(testing.db)
-        global table, GoofyType
+    run_create_tables = "each"
 
+    @classmethod
+    def define_tables(cls, metadata):
         class GoofyType(TypeDecorator):
             impl = String
 
@@ -117,26 +118,25 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
                     return None
                 return value + "BAR"
 
-        table = Table(
+        cls.GoofyType = GoofyType
+
+        Table(
             "tables",
-            meta,
+            metadata,
             Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
             ),
             Column("persons", Integer),
             Column("full", Boolean),
             Column("goofy", GoofyType(50)),
+            Column("strval", String(50)),
         )
-        table.create(checkfirst=True)
 
-    def teardown(self):
-        table.drop()
-
-    def test_column_targeting(self):
-        result = (
-            table.insert()
-            .returning(table.c.id, table.c.full)
-            .execute({"persons": 1, "full": False})
+    def test_column_targeting(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(table.c.id, table.c.full),
+            {"persons": 1, "full": False},
         )
 
         row = result.first()
@@ -144,11 +144,10 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         assert row[table.c.full] == row["full"]
         assert row["full"] is False
 
-        result = (
+        result = connection.execute(
             table.insert()
             .values(persons=5, full=True, goofy="somegoofy")
             .returning(table.c.persons, table.c.full, table.c.goofy)
-            .execute()
         )
         row = result.first()
         assert row[table.c.persons] == row["persons"] == 5
@@ -158,12 +157,12 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(row["goofy"], "FOOsomegoofyBAR")
 
     @testing.fails_on("firebird", "fb can't handle returning x AS y")
-    def test_labeling(self):
-        result = (
+    def test_labeling(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
             table.insert()
             .values(persons=6)
             .returning(table.c.persons.label("lala"))
-            .execute()
         )
         row = result.first()
         assert row["lala"] == 6
@@ -171,54 +170,135 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
     @testing.fails_on(
         "firebird", "fb/kintersbasdb can't handle the bind params"
     )
-    @testing.fails_on("oracle+zxjdbc", "JDBC driver bug")
-    def test_anon_expressions(self):
-        result = (
+    def test_anon_expressions(self, connection):
+        table = self.tables.tables
+        GoofyType = self.GoofyType
+        result = connection.execute(
             table.insert()
             .values(goofy="someOTHERgoofy")
             .returning(func.lower(table.c.goofy, type_=GoofyType))
-            .execute()
         )
         row = result.first()
         eq_(row[0], "foosomeothergoofyBAR")
 
-        result = (
-            table.insert()
-            .values(persons=12)
-            .returning(table.c.persons + 18)
-            .execute()
+        result = connection.execute(
+            table.insert().values(persons=12).returning(table.c.persons + 18)
         )
         row = result.first()
         eq_(row[0], 30)
 
-    def test_update_returning(self):
-        table.insert().execute(
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}]
+    def test_update_returning(self, connection):
+        table = self.tables.tables
+        connection.execute(
+            table.insert(),
+            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
         )
 
-        result = (
-            table.update(table.c.persons > 4, dict(full=True))
-            .returning(table.c.id)
-            .execute()
+        result = connection.execute(
+            table.update(table.c.persons > 4, dict(full=True)).returning(
+                table.c.id
+            )
         )
         eq_(result.fetchall(), [(1,)])
 
-        result2 = (
-            select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        result2 = connection.execute(
+            select([table.c.id, table.c.full]).order_by(table.c.id)
         )
         eq_(result2.fetchall(), [(1, True), (2, False)])
 
-    def test_insert_returning(self):
-        result = (
-            table.insert()
-            .returning(table.c.id)
-            .execute({"persons": 1, "full": False})
+    def test_insert_returning(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
 
         eq_(result.fetchall(), [(1,)])
 
+    @testing.fails_on(
+        "mssql",
+        "driver has unknown issue with string concatenation "
+        "in INSERT RETURNING",
+    )
+    def test_insert_returning_w_expression_one(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(table.c.strval + "hi"),
+            {"persons": 5, "full": False, "strval": "str1"},
+        )
+
+        eq_(result.fetchall(), [("str1hi",)])
+
+        result2 = connection.execute(
+            select([table.c.id, table.c.strval]).order_by(table.c.id)
+        )
+        eq_(result2.fetchall(), [(1, "str1")])
+
+    def test_insert_returning_w_type_coerce_expression(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(type_coerce(table.c.goofy, String)),
+            {"persons": 5, "goofy": "somegoofy"},
+        )
+
+        eq_(result.fetchall(), [("FOOsomegoofy",)])
+
+        result2 = connection.execute(
+            select([table.c.id, table.c.goofy]).order_by(table.c.id)
+        )
+        eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")])
+
+    def test_update_returning_w_expression_one(self, connection):
+        table = self.tables.tables
+        connection.execute(
+            table.insert(),
+            [
+                {"persons": 5, "full": False, "strval": "str1"},
+                {"persons": 3, "full": False, "strval": "str2"},
+            ],
+        )
+
+        result = connection.execute(
+            table.update()
+            .where(table.c.persons > 4)
+            .values(full=True)
+            .returning(table.c.strval + "hi")
+        )
+        eq_(result.fetchall(), [("str1hi",)])
+
+        result2 = connection.execute(
+            select([table.c.id, table.c.strval]).order_by(table.c.id)
+        )
+        eq_(result2.fetchall(), [(1, "str1"), (2, "str2")])
+
+    def test_update_returning_w_type_coerce_expression(self, connection):
+        table = self.tables.tables
+        connection.execute(
+            table.insert(),
+            [
+                {"persons": 5, "goofy": "somegoofy1"},
+                {"persons": 3, "goofy": "somegoofy2"},
+            ],
+        )
+
+        result = connection.execute(
+            table.update()
+            .where(table.c.persons > 4)
+            .values(goofy="newgoofy")
+            .returning(type_coerce(table.c.goofy, String))
+        )
+        eq_(result.fetchall(), [("FOOnewgoofy",)])
+
+        result2 = connection.execute(
+            select([table.c.id, table.c.goofy]).order_by(table.c.id)
+        )
+        eq_(
+            result2.fetchall(),
+            [(1, "FOOnewgoofyBAR"), (2, "FOOsomegoofy2BAR")],
+        )
+
     @testing.requires.multivalues_inserts
-    def test_multirow_returning(self):
+    def test_multirow_returning(self, connection):
+        table = self.tables.tables
         ins = (
             table.insert()
             .returning(table.c.id, table.c.persons)
@@ -230,11 +310,12 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
                 ]
             )
         )
-        result = testing.db.execute(ins)
+        result = connection.execute(ins)
         eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
 
-    def test_no_ipk_on_returning(self):
-        result = testing.db.execute(
+    def test_no_ipk_on_returning(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
             table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
         assert_raises_message(
@@ -246,30 +327,32 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
     @testing.fails_on_everything_except("postgresql", "firebird")
-    def test_literal_returning(self):
+    def test_literal_returning(self, connection):
         if testing.against("postgresql"):
             literal_true = "true"
         else:
             literal_true = "1"
 
-        result4 = testing.db.execute(
+        result4 = connection.execute(
             'insert into tables (id, persons, "full") '
             "values (5, 10, %s) returning persons" % literal_true
         )
         eq_([dict(row) for row in result4], [{"persons": 10}])
 
-    def test_delete_returning(self):
-        table.insert().execute(
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}]
+    def test_delete_returning(self, connection):
+        table = self.tables.tables
+        connection.execute(
+            table.insert(),
+            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
         )
 
-        result = (
-            table.delete(table.c.persons > 4).returning(table.c.id).execute()
+        result = connection.execute(
+            table.delete(table.c.persons > 4).returning(table.c.id)
         )
         eq_(result.fetchall(), [(1,)])
 
-        result2 = (
-            select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        result2 = connection.execute(
+            select([table.c.id, table.c.full]).order_by(table.c.id)
         )
         eq_(result2.fetchall(), [(2, False)])