From 07d7c4905d65b7f28c1ffcbd33f81ee52c9fd847 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 24 Oct 2012 15:37:06 -0400 Subject: [PATCH] Fixed bug where keyword arguments passed to :meth:`.Compiler.process` wouldn't get propagated to the column expressions present in the columns clause of a SELECT statement. In particular this would come up when used by custom compilation schemes that relied upon special flags. [ticket:2593] --- doc/build/changelog/changelog_08.rst | 11 +++++++ lib/sqlalchemy/engine/interfaces.py | 8 +++-- lib/sqlalchemy/sql/compiler.py | 14 ++++++-- test/sql/test_compiler.py | 49 +++++++++++++++++++++++++++- 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/doc/build/changelog/changelog_08.rst b/doc/build/changelog/changelog_08.rst index c741cdda99..2efcee98ea 100644 --- a/doc/build/changelog/changelog_08.rst +++ b/doc/build/changelog/changelog_08.rst @@ -8,6 +8,17 @@ :version: 0.8.0b1 :released: + .. change:: + :tags: sql, bug + :tickets: 2593 + + Fixed bug where keyword arguments passed to + :meth:`.Compiler.process` wouldn't get propagated + to the column expressions present in the columns + clause of a SELECT statement. In particular this would + come up when used by custom compilation schemes that + relied upon special flags. + .. change:: :tags: sql, feature diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e9e0da4360..c601201660 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -673,7 +673,8 @@ class Compiled(object): defaults. """ - def __init__(self, dialect, statement, bind=None): + def __init__(self, dialect, statement, bind=None, + compile_kwargs=util.immutabledict()): """Construct a new ``Compiled`` object. :param dialect: ``Dialect`` to compile against. @@ -682,6 +683,9 @@ class Compiled(object): :param bind: Optional Engine or Connection to compile this statement against. + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. """ self.dialect = dialect @@ -689,7 +693,7 @@ class Compiled(object): if statement is not None: self.statement = statement self.can_execute = statement.supports_execution - self.string = self.process(self.statement) + self.string = self.process(self.statement, **compile_kwargs) @util.deprecated("0.7", ":class:`.Compiled` objects now compile " "within the constructor.") diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6da51c31ce..0847335c26 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1044,9 +1044,12 @@ class SQLCompiler(engine.Compiled): else: result_expr = col_expr + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map + ) return result_expr._compiler_dispatch( - self, within_columns_clause=within_columns_clause, - add_to_result_map=add_to_result_map, + self, **column_clause_args ) @@ -1098,7 +1101,12 @@ class SQLCompiler(engine.Compiled): self.stack.append({'from': correlate_froms, 'iswrapper': iswrapper}) - column_clause_args = {'positional_names': positional_names} + column_clause_args = kwargs.copy() + column_clause_args.update({ + 'positional_names': positional_names, + 'within_label_clause': False, + 'within_columns_clause': False + }) # the actual list of columns to print in the SELECT column list. inner_columns = [ diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index bb819472ac..50b425a019 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -18,7 +18,7 @@ from sqlalchemy import Integer, String, MetaData, Table, Column, select, \ insert, literal, and_, null, type_coerce, alias, or_, literal_column,\ Float, TIMESTAMP, Numeric, Date, Text, collate, union, except_,\ intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\ - over, subquery + over, subquery, case import decimal from sqlalchemy import exc, sql, util, types, schema from sqlalchemy.sql import table, column, label @@ -2437,6 +2437,53 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): ) +class KwargPropagationTest(fixtures.TestBase): + + @classmethod + def setup_class(cls): + from sqlalchemy.sql.expression import ColumnClause, TableClause + class CatchCol(ColumnClause): + pass + + class CatchTable(TableClause): + pass + + cls.column = CatchCol("x") + cls.table = CatchTable("y") + cls.criterion = cls.column == CatchCol('y') + + @compiles(CatchCol) + def compile_col(element, compiler, **kw): + assert "canary" in kw + return compiler.visit_column(element) + + @compiles(CatchTable) + def compile_table(element, compiler, **kw): + assert "canary" in kw + return compiler.visit_table(element) + + def _do_test(self, element): + d = default.DefaultDialect() + d.statement_compiler(d, element, + compile_kwargs={"canary": True}) + + def test_binary(self): + self._do_test(self.column == 5) + + def test_select(self): + s = select([self.column]).select_from(self.table).\ + where(self.column == self.criterion).\ + order_by(self.column) + self._do_test(s) + + def test_case(self): + c = case([(self.criterion, self.column)], else_=self.column) + self._do_test(c) + + def test_cast(self): + c = cast(self.column, Integer) + self._do_test(c) + class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' -- 2.47.3