From b4eb29253cb29a069973503f36d1103d4a18311c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 4 Apr 2018 13:36:28 -0400 Subject: [PATCH] Ensure all visit_sequence accepts **kw args Fixed issue where the compilation of an INSERT statement with the "literal_binds" option that also uses an explicit sequence and "inline" generation, as on Postgresql and Oracle, would fail to accommodate the extra keyword argument within the sequence processing routine. Change-Id: Ibdab7d340aea7429a210c9535ccf1a3e85f074fb Fixes: #4231 --- doc/build/changelog/unreleased_12/4231.rst | 9 +++++++ lib/sqlalchemy/dialects/firebird/base.py | 2 +- lib/sqlalchemy/dialects/oracle/base.py | 2 +- lib/sqlalchemy/dialects/postgresql/base.py | 2 +- lib/sqlalchemy/sql/compiler.py | 2 +- lib/sqlalchemy/testing/suite/test_sequence.py | 24 ++++++++++++++++++- test/sql/test_compiler.py | 15 +++++++++++- 7 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 doc/build/changelog/unreleased_12/4231.rst diff --git a/doc/build/changelog/unreleased_12/4231.rst b/doc/build/changelog/unreleased_12/4231.rst new file mode 100644 index 0000000000..47e70ef024 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4231.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 4231 + :versions: 1.3.0b1 + + Fixed issue where the compilation of an INSERT statement with the + "literal_binds" option that also uses an explicit sequence and "inline" + generation, as on Postgresql and Oracle, would fail to accommodate the + extra keyword argument within the sequence processing routine. diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 335163f150..7b470c1899 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -291,7 +291,7 @@ class FBCompiler(sql.compiler.SQLCompiler): def default_from(self): return " FROM rdb$database" - def visit_sequence(self, seq): + def visit_sequence(self, seq, **kw): return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) def get_select_precolumns(self, select, **kw): diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 3970a181c3..44ab9e3bbd 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -767,7 +767,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_outer_join_column(self, vc, **kw): return self.process(vc.column, **kw) + "(+)" - def visit_sequence(self, seq): + def visit_sequence(self, seq, **kw): return (self.dialect.identifier_preparer.format_sequence(seq) + ".nextval") diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c5b0db6ce5..0160239b75 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1489,7 +1489,7 @@ class PGCompiler(compiler.SQLCompiler): value = value.replace('\\', '\\\\') return value - def visit_sequence(self, seq): + def visit_sequence(self, seq, **kw): return "nextval('%s')" % self.preparer.format_sequence(seq) def limit_clause(self, select, **kw): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6c7e6145d0..a442c65fd6 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -934,7 +934,7 @@ class SQLCompiler(Compiled): def visit_next_value_func(self, next_value, **kw): return self.visit_sequence(next_value.sequence) - def visit_sequence(self, sequence): + def visit_sequence(self, sequence, **kw): raise NotImplementedError( "Dialect '%s' does not support sequence increments." % self.dialect.name diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index b2d52f27cc..f1c00de6b0 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -3,7 +3,7 @@ from ..config import requirements from ..assertions import eq_ from ... import testing -from ... import Integer, String, Sequence, schema +from ... import Integer, String, Sequence, schema, MetaData from ..schema import Table, Column @@ -71,6 +71,28 @@ class SequenceTest(fixtures.TablesTest): ) +class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __requires__ = ('sequences',) + __backend__ = True + + def test_literal_binds_inline_compile(self): + table = Table( + 'x', MetaData(), + Column('y', Integer, Sequence('y_seq')), + Column('q', Integer)) + + stmt = table.insert().values(q=5) + + seq_nextval = testing.db.dialect.statement_compiler( + statement=None, dialect=testing.db.dialect).visit_sequence( + Sequence("y_seq")) + self.assert_compile( + stmt, + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ), + literal_binds=True, + dialect=testing.db.dialect) + + class HasSequenceTest(fixtures.TestBase): __requires__ = 'sequences', __backend__ = True diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 25eb2b24b6..0ef19e0cb5 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -19,7 +19,7 @@ from sqlalchemy import Integer, String, MetaData, Table, Column, select, \ literal, and_, null, type_coerce, alias, or_, literal_column,\ Float, TIMESTAMP, Numeric, Date, Text, union, except_,\ intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\ - over, subquery, case, true, CheckConstraint + over, subquery, case, true, CheckConstraint, Sequence import decimal from sqlalchemy.util import u from sqlalchemy import exc, sql, util, types, schema @@ -2955,6 +2955,19 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): "INSERT INTO mytable (myid, name) VALUES (3, 'jack')", literal_binds=True) + def test_insert_literal_binds_sequence_notimplemented(self): + table = Table('x', MetaData(), Column('y', Integer, Sequence('y_seq'))) + dialect = default.DefaultDialect() + dialect.supports_sequences = True + + stmt = table.insert().values(myid=3, name='jack') + + assert_raises( + NotImplementedError, + stmt.compile, + compile_kwargs=dict(literal_binds=True), dialect=dialect + ) + def test_update_literal_binds(self): stmt = table1.update().values(name='jack').\ where(table1.c.name == 'jill') -- 2.47.2