]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure RETURNING renders in stringify w/ no server version
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Aug 2022 18:08:32 +0000 (14:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Aug 2022 00:47:27 +0000 (20:47 -0400)
just in my own testing, if I say insert().return_defaults()
and stringify, I should see it, so make sure all the dialects
default to "insert_returning" etc. , with downgrade on
server version check.

Change-Id: Id64e78fcb03c48b5dcb0feb21cb9cc495edd15e9

lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/default.py
test/dialect/mssql/test_compiler.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_reflection.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/dialect/test_sqlite.py
test/sql/test_compiler.py
test/sql/test_insert.py

index 596ca34f2c332b01e196562c11bca7c826f25bf8..9e7ba4646b63902f24833765385aa372daf7892d 100644 (file)
@@ -2568,6 +2568,12 @@ class MySQLDialect(default.DefaultDialect):
             # this would have been set by the default dialect already,
             # so set it again
             self.identifier_preparer = self.preparer(self)
+
+            # this will be updated on first connect in initialize()
+            # if using older mariadb version
+            self.delete_returning = True
+            self.insert_returning = True
+
         self.is_mariadb = is_mariadb
 
     def do_begin_twophase(self, connection, xid):
index ce688741f414a24c22128f64d2eff94f6d077062..222f3a13793d8f9ac18c4c635fffd0867469ffa2 100644 (file)
@@ -1930,6 +1930,9 @@ class SQLiteDialect(default.DefaultDialect):
     tuple_in_values = True
     supports_statement_cache = True
     insert_null_pk_still_autoincrements = True
+    insert_returning = True
+    update_returning = True
+    delete_returning = True
 
     default_paramstyle = "qmark"
     execution_ctx_cls = SQLiteExecutionContext
@@ -2037,10 +2040,10 @@ class SQLiteDialect(default.DefaultDialect):
                 14,
             )
 
-            if self.dbapi.sqlite_version_info >= (3, 35):
+            if self.dbapi.sqlite_version_info < (3, 35):
                 self.update_returning = (
                     self.delete_returning
-                ) = self.insert_returning = True
+                ) = self.insert_returning = False
 
     _isolation_lookup = util.immutabledict(
         {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0}
index cab96eac115b564c6176b0cc9355fb175c71f0fb..4b312dcebc483a27418a89410213a1d8dca1627f 100644 (file)
@@ -903,6 +903,10 @@ class StrCompileDialect(DefaultDialect):
     type_compiler_cls = compiler.StrSQLTypeCompiler
     preparer = compiler.IdentifierPreparer
 
+    insert_returning = True
+    update_returning = True
+    delete_returning = True
+
     supports_statement_cache = True
 
     supports_identity_columns = True
index 74722e9496eefb4d24b468a42cd7fe8fe188b556..8605ea9c0526c4e3cb7b3ddae7187854642312d0 100644 (file)
@@ -53,6 +53,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(sql.false(), "0")
         self.assert_compile(sql.true(), "1")
 
+    def test_plain_stringify_returning(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=mssql.dialect())),
+            "INSERT INTO t (description) "
+            "OUTPUT inserted.myid, inserted.name, inserted.description "
+            "VALUES (lower(:lower_1))",
+        )
+
     @testing.combinations(
         ("plain", "sometable", "sometable"),
         ("matched_square_brackets", "colo[u]r", "[colo[u]]r]"),
index 3fb52416ec9a051b9a8d3fdc73b5266418792434..9d2c43bfead32a7f367b42ac221365460f8e628a 100644 (file)
@@ -61,6 +61,7 @@ from sqlalchemy.sql.expression import literal_column
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
@@ -146,6 +147,25 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
             dialect=dialect,
         )
 
+    def test_plain_stringify_returning(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=mysql.dialect(is_mariadb=True))),
+            "INSERT INTO t (description) VALUES (lower(%s)) "
+            "RETURNING t.myid, t.name, t.description",
+        )
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=mysql.dialect())),
+            "INSERT INTO t (description) VALUES (lower(%s))",
+        )
+
     def test_create_index_simple(self):
         m = MetaData()
         tbl = Table("testtbl", m, Column("data", String(255)))
index 846001347c4c264766c1bf7ea6ff96be499f5301..0a23282bf3b2df25a9824826c3dd4eb7f8eb978f 100644 (file)
@@ -369,7 +369,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
 
             assert reflected.comment == comment
             assert reflected.kwargs["mysql_comment"] == comment
-            assert reflected.kwargs["mysql_default charset"] == "utf8"
+            assert reflected.kwargs["mysql_default charset"] in (
+                "utf8",
+                "utf8mb3",
+            )
             assert reflected.kwargs["mysql_avg_row_length"] == "3"
             assert reflected.kwargs["mysql_connection"] == "fish"
 
index 45a83ed77a181058c28b1dac8e1d0873f140b168..96969b459442e293ef128f08a29e2898d3594cad 100644 (file)
@@ -57,6 +57,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(sql.false(), "0")
         self.assert_compile(sql.true(), "1")
 
+    def test_plain_stringify_returning(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=oracle.OracleDialect())),
+            "INSERT INTO t (description) VALUES (lower(:lower_1)) "
+            "RETURNING t.myid, t.name, t.description "
+            "INTO :ret_0, :ret_1, :ret_2",
+        )
+
     def test_owner(self):
         meta = MetaData()
         parent = Table(
index 9be76130d5ed7dc6327e2465939aa0c7e999f543..5e5c4f9bdb6b0db7398015a61a56f27eac0b8aa2 100644 (file)
@@ -56,6 +56,7 @@ from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import AssertsCompiledSQL
+from sqlalchemy.testing.assertions import eq_ignore_whitespace
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.assertions import is_
 from sqlalchemy.util import OrderedDict
@@ -101,6 +102,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = postgresql.dialect()
 
+    def test_plain_stringify_returning(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=postgresql.dialect())),
+            "INSERT INTO t (description) VALUES (lower(%(lower_1)s)) "
+            "RETURNING t.myid, t.name, t.description",
+        )
+
     def test_update_returning(self):
         dialect = postgresql.dialect()
         table1 = table(
index 286c6bcf8c455a84eff1311d16cc711e28637782..643a56c1b620cf6a0f8b6690704ea4ff32cc7f6f 100644 (file)
@@ -51,6 +51,7 @@ from sqlalchemy.testing import combinations
 from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
@@ -972,6 +973,21 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
                 "INTEGER) AS anon_1 FROM t" % subst,
             )
 
+    def test_plain_stringify_returning(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt.compile(dialect=sqlite.SQLiteDialect())),
+            "INSERT INTO t (description) VALUES (lower(?)) "
+            "RETURNING myid, name, description",
+        )
+
     def test_true_false(self):
         self.assert_compile(sql.false(), "0")
         self.assert_compile(sql.true(), "1")
index 930f32b7bf43cf8aff6d2b3d3655f2886f0e8342..1d3d173265b17319770779f9159b5a156eb2c2e9 100644 (file)
@@ -4914,6 +4914,21 @@ class StringifySpecialTest(fixtures.TestBase):
         stmt = table1.insert().values()
         eq_ignore_whitespace(str(stmt), "INSERT INTO mytable () VALUES ()")
 
+    def test_insert_return_defaults(self):
+        t = Table(
+            "t",
+            MetaData(),
+            Column("myid", Integer, primary_key=True),
+            Column("name", String, server_default="some str"),
+            Column("description", String, default=func.lower("hi")),
+        )
+        stmt = t.insert().values().return_defaults()
+        eq_ignore_whitespace(
+            str(stmt),
+            "INSERT INTO t (description) VALUES (lower(:lower_1)) "
+            "RETURNING t.myid, t.name, t.description",
+        )
+
     def test_multirow_insert(self):
         stmt = table1.insert().values([{"myid": 1}, {"myid": 2}])
         eq_ignore_whitespace(
index 808b047a2837b20f03c0752370a43eebc9aaf818..071f595f394d5b100a5cc404688101ddabfe68c7 100644 (file)
@@ -1542,7 +1542,9 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
         )
 
         stmt = table.insert().return_defaults().values(id=func.foobar())
-        compiled = stmt.compile(dialect=sqlite.dialect(), column_keys=["data"])
+        dialect = sqlite.dialect()
+        dialect.insert_returning = False
+        compiled = stmt.compile(dialect=dialect, column_keys=["data"])
         eq_(compiled.postfetch, [])
         eq_(compiled.implicit_returning, [])
 
@@ -1551,7 +1553,7 @@ class MultirowTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             "INSERT INTO sometable (id, data) VALUES " "(foobar(), ?)",
             checkparams={"data": "foo"},
             params={"data": "foo"},
-            dialect=sqlite.dialect(),
+            dialect=dialect,
         )
 
     def test_sql_expression_pk_autoinc_returning(self):