]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure all visit_sequence accepts **kw args
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Apr 2018 17:36:28 +0000 (13:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 4 Apr 2018 20:46:16 +0000 (16:46 -0400)
Fixed issue where the compilation of an INSERT statement with the
"literal_binds" option that also uses an explicit sequence and "inline"
generation, as on Postgresql and Oracle, would fail to accommodate the
extra keyword argument within the sequence processing routine.

Change-Id: Ibdab7d340aea7429a210c9535ccf1a3e85f074fb
Fixes: #4231
doc/build/changelog/unreleased_12/4231.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/suite/test_sequence.py
test/sql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_12/4231.rst b/doc/build/changelog/unreleased_12/4231.rst
new file mode 100644 (file)
index 0000000..47e70ef
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4231
+    :versions: 1.3.0b1
+
+    Fixed issue where the compilation of an INSERT statement with the
+    "literal_binds" option that also uses an explicit sequence and "inline"
+    generation, as on Postgresql and Oracle, would fail to accommodate the
+    extra keyword argument within the sequence processing routine.
index 335163f150c82c09f61b337c8fdcfe8715331b9e..7b470c1899758ba412112d0b375aceddd6a29082 100644 (file)
@@ -291,7 +291,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
     def default_from(self):
         return " FROM rdb$database"
 
-    def visit_sequence(self, seq):
+    def visit_sequence(self, seq, **kw):
         return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
 
     def get_select_precolumns(self, select, **kw):
index 3970a181c3168d4d616a18436ea54cfde3ab391b..44ab9e3bbd5df73d4a662a6bf7aeb639107b0cad 100644 (file)
@@ -767,7 +767,7 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_outer_join_column(self, vc, **kw):
         return self.process(vc.column, **kw) + "(+)"
 
-    def visit_sequence(self, seq):
+    def visit_sequence(self, seq, **kw):
         return (self.dialect.identifier_preparer.format_sequence(seq) +
                 ".nextval")
 
index c5b0db6ce5606a66182c25ea96c5db6acee558a6..0160239b753d7c08bde15de5daffcfe43662fd7b 100644 (file)
@@ -1489,7 +1489,7 @@ class PGCompiler(compiler.SQLCompiler):
             value = value.replace('\\', '\\\\')
         return value
 
-    def visit_sequence(self, seq):
+    def visit_sequence(self, seq, **kw):
         return "nextval('%s')" % self.preparer.format_sequence(seq)
 
     def limit_clause(self, select, **kw):
index 6c7e6145d0792b9279e041d2a9d51025b12256f5..a442c65fd6d0fffe6eb878c167cdc55b7f502282 100644 (file)
@@ -934,7 +934,7 @@ class SQLCompiler(Compiled):
     def visit_next_value_func(self, next_value, **kw):
         return self.visit_sequence(next_value.sequence)
 
-    def visit_sequence(self, sequence):
+    def visit_sequence(self, sequence, **kw):
         raise NotImplementedError(
             "Dialect '%s' does not support sequence increments." %
             self.dialect.name
index b2d52f27cce4c54b894bc8bc00d194344fc1b31f..f1c00de6b09ef34a0f993fe921f55b39d99794c5 100644 (file)
@@ -3,7 +3,7 @@ from ..config import requirements
 from ..assertions import eq_
 from ... import testing
 
-from ... import Integer, String, Sequence, schema
+from ... import Integer, String, Sequence, schema, MetaData
 
 from ..schema import Table, Column
 
@@ -71,6 +71,28 @@ class SequenceTest(fixtures.TablesTest):
         )
 
 
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+    __requires__ = ('sequences',)
+    __backend__ = True
+
+    def test_literal_binds_inline_compile(self):
+        table = Table(
+            'x', MetaData(),
+            Column('y', Integer, Sequence('y_seq')),
+            Column('q', Integer))
+
+        stmt = table.insert().values(q=5)
+
+        seq_nextval = testing.db.dialect.statement_compiler(
+            statement=None, dialect=testing.db.dialect).visit_sequence(
+            Sequence("y_seq"))
+        self.assert_compile(
+            stmt,
+            "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ),
+            literal_binds=True,
+            dialect=testing.db.dialect)
+
+
 class HasSequenceTest(fixtures.TestBase):
     __requires__ = 'sequences',
     __backend__ = True
index 25eb2b24b6241b428ea64d73aa8f012d2b165536..0ef19e0cb597c7d405b66b724cf0227df250102e 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy import Integer, String, MetaData, Table, Column, select, \
     literal, and_, null, type_coerce, alias, or_, literal_column,\
     Float, TIMESTAMP, Numeric, Date, Text, union, except_,\
     intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\
-    over, subquery, case, true, CheckConstraint
+    over, subquery, case, true, CheckConstraint, Sequence
 import decimal
 from sqlalchemy.util import u
 from sqlalchemy import exc, sql, util, types, schema
@@ -2955,6 +2955,19 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL):
             "INSERT INTO mytable (myid, name) VALUES (3, 'jack')",
             literal_binds=True)
 
+    def test_insert_literal_binds_sequence_notimplemented(self):
+        table = Table('x', MetaData(), Column('y', Integer, Sequence('y_seq')))
+        dialect = default.DefaultDialect()
+        dialect.supports_sequences = True
+
+        stmt = table.insert().values(myid=3, name='jack')
+
+        assert_raises(
+            NotImplementedError,
+            stmt.compile,
+            compile_kwargs=dict(literal_binds=True), dialect=dialect
+        )
+
     def test_update_literal_binds(self):
         stmt = table1.update().values(name='jack').\
             where(table1.c.name == 'jill')