From f9cbf22a3873005f946d30653146e87e4184f0e8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 25 Jun 2012 12:51:23 -0400 Subject: [PATCH] - move cte tests into their own test/sql/test_cte.py - rework bindtemplate system of "numbered" params by applying the numbers last, as we now need to generate these out of order in some cases - add positional assertion to assert_compile - add new cte_positional collection to track bindparams generated within cte visits; splice this onto the beginning of self.positiontup at cte render time, [ticket:2521] --- CHANGES | 8 + lib/sqlalchemy/dialects/oracle/cx_oracle.py | 6 +- lib/sqlalchemy/dialects/sybase/pysybase.py | 2 +- lib/sqlalchemy/sql/compiler.py | 50 +++-- test/lib/testing.py | 4 + test/sql/test_compiler.py | 164 --------------- test/sql/test_cte.py | 213 ++++++++++++++++++++ 7 files changed, 263 insertions(+), 184 deletions(-) create mode 100644 test/sql/test_cte.py diff --git a/CHANGES b/CHANGES index 6ab25ac748..3bb42a3361 100644 --- a/CHANGES +++ b/CHANGES @@ -6,6 +6,14 @@ CHANGES 0.7.9 ===== - sql + - [bug] Fixed CTE bug whereby positional + bound parameters present in the CTEs themselves + would corrupt the overall ordering of + bound parameters. This primarily + affected SQL Server as the platform with + positional binds + CTE support. + [ticket:2521] + - [bug] quoting is applied to the column names inside the WITH RECURSIVE clause of a common table expression according to the diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 828ea87fcf..06b27b7104 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -300,13 +300,13 @@ class _OracleRowid(oracle.ROWID): return dbapi.ROWID class OracleCompiler_cx_oracle(OracleCompiler): - def bindparam_string(self, name): + def bindparam_string(self, name, **kw): if self.preparer._bindparam_requires_quotes(name): quoted_name = '"%s"' % name self._quoted_bind_names[name] = quoted_name - return OracleCompiler.bindparam_string(self, quoted_name) + return OracleCompiler.bindparam_string(self, quoted_name, **kw) else: - return OracleCompiler.bindparam_string(self, name) + return OracleCompiler.bindparam_string(self, name, **kw) class OracleExecutionContext_cx_oracle(OracleExecutionContext): diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index dd7b272513..e3bfae06cd 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -52,7 +52,7 @@ class SybaseExecutionContext_pysybase(SybaseExecutionContext): class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): - def bindparam_string(self, name): + def bindparam_string(self, name, **kw): return "@" + name class SybaseDialect_pysybase(SybaseDialect): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0ed4be1ba5..bc81da22c5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -29,6 +29,7 @@ from sqlalchemy.sql import operators, functions, util as sql_util, \ visitors from sqlalchemy.sql import expression as sql import decimal +import itertools RESERVED_WORDS = set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', @@ -59,7 +60,7 @@ BIND_TEMPLATES = { 'pyformat':"%%(%(name)s)s", 'qmark':"?", 'format':"%%s", - 'numeric':":%(position)s", + 'numeric':":[_POSITION]", 'named':":%(name)s" } @@ -252,16 +253,18 @@ class SQLCompiler(engine.Compiled): # column targeting self.result_map = {} - # collect CTEs to tack on top of a SELECT - self.ctes = util.OrderedDict() - self.ctes_recursive = False - # true if the paramstyle is positional self.positional = dialect.positional if self.positional: self.positiontup = [] self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + # collect CTEs to tack on top of a SELECT + self.ctes = util.OrderedDict() + self.ctes_recursive = False + if self.positional: + self.cte_positional = [] + # an IdentifierPreparer that formats the quoting of identifiers self.preparer = dialect.identifier_preparer self.label_length = dialect.label_length \ @@ -276,7 +279,15 @@ class SQLCompiler(engine.Compiled): self.truncated_names = {} engine.Compiled.__init__(self, dialect, statement, **kwargs) + if self.positional and dialect.paramstyle == 'numeric': + self._apply_numbered_params() + def _apply_numbered_params(self): + poscount = itertools.count(1) + self.string = re.sub( + r'\[_POSITION\]', + lambda m:str(next(poscount)), + self.string) @util.memoized_property def _bind_processors(self): @@ -448,7 +459,7 @@ class SQLCompiler(engine.Compiled): if name in textclause.bindparams: return self.process(textclause.bindparams[name]) else: - return self.bindparam_string(name) + return self.bindparam_string(name, **kwargs) # un-escape any \:params return BIND_PARAMS_ESC.sub(lambda m: m.group(1), @@ -680,7 +691,7 @@ class SQLCompiler(engine.Compiled): self.binds[bindparam.key] = self.binds[name] = bindparam - return self.bindparam_string(name) + return self.bindparam_string(name, **kwargs) def render_literal_bindparam(self, bindparam, **kw): value = bindparam.value @@ -750,16 +761,19 @@ class SQLCompiler(engine.Compiled): self.anon_map[derived] = anonymous_counter + 1 return derived + "_" + str(anonymous_counter) - def bindparam_string(self, name): + def bindparam_string(self, name, positional_names=None, **kw): if self.positional: - self.positiontup.append(name) - return self.bindtemplate % { - 'name':name, 'position':len(self.positiontup)} - else: - return self.bindtemplate % {'name':name} + if positional_names is not None: + positional_names.append(name) + else: + self.positiontup.append(name) + return self.bindtemplate % {'name':name} def visit_cte(self, cte, asfrom=False, ashint=False, fromhints=None, **kwargs): + if self.positional: + kwargs['positional_names'] = self.cte_positional + if isinstance(cte.name, sql._truncated_label): cte_name = self._truncated_identifier("alias", cte.name) else: @@ -867,7 +881,8 @@ class SQLCompiler(engine.Compiled): def visit_select(self, select, asfrom=False, parens=True, iswrapper=False, fromhints=None, - compound_index=1, **kwargs): + compound_index=1, + positional_names=None, **kwargs): entry = self.stack and self.stack[-1] or {} @@ -886,9 +901,10 @@ class SQLCompiler(engine.Compiled): : iswrapper}) if compound_index==1 and not entry or entry.get('iswrapper', False): - column_clause_args = {'result_map':self.result_map} + column_clause_args = {'result_map':self.result_map, + 'positional_names':positional_names} else: - column_clause_args = {} + column_clause_args = {'positional_names':positional_names} # the actual list of columns to print in the SELECT column list. inner_columns = [ @@ -975,6 +991,8 @@ class SQLCompiler(engine.Compiled): return text def _render_cte_clause(self): + if self.positional: + self.positiontup = self.cte_positional + self.positiontup cte_text = self.get_cte_preamble(self.ctes_recursive) + " " cte_text += ", \n".join( [txt for txt in self.ctes.values()] diff --git a/test/lib/testing.py b/test/lib/testing.py index cea11095b3..4dae400701 100644 --- a/test/lib/testing.py +++ b/test/lib/testing.py @@ -578,6 +578,7 @@ class adict(dict): class AssertsCompiledSQL(object): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, + checkpositional=None, use_default_dialect=False, allow_dialect_select=False): @@ -613,6 +614,9 @@ class AssertsCompiledSQL(object): if checkparams is not None: eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) class ComparesTables(object): def assert_tables_equal(self, table, reflected_table, strict_types=False): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index f58b006dff..6980c7974c 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2339,170 +2339,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT x + foo() OVER () AS anon_1" ) - def test_cte_nonrecursive(self): - orders = table('orders', - column('region'), - column('amount'), - column('product'), - column('quantity') - ) - - regional_sales = select([ - orders.c.region, - func.sum(orders.c.amount).label('total_sales') - ]).group_by(orders.c.region).cte("regional_sales") - - top_regions = select([regional_sales.c.region]).\ - where( - regional_sales.c.total_sales > - select([ - func.sum(regional_sales.c.total_sales)/10 - ]) - ).cte("top_regions") - - s = select([ - orders.c.region, - orders.c.product, - func.sum(orders.c.quantity).label("product_units"), - func.sum(orders.c.amount).label("product_sales") - ]).where(orders.c.region.in_( - select([top_regions.c.region]) - )).group_by(orders.c.region, orders.c.product) - - # needs to render regional_sales first as top_regions - # refers to it - self.assert_compile( - s, - "WITH regional_sales AS (SELECT orders.region AS region, " - "sum(orders.amount) AS total_sales FROM orders " - "GROUP BY orders.region), " - "top_regions AS (SELECT " - "regional_sales.region AS region FROM regional_sales " - "WHERE regional_sales.total_sales > " - "(SELECT sum(regional_sales.total_sales) / :sum_1 AS " - "anon_1 FROM regional_sales)) " - "SELECT orders.region, orders.product, " - "sum(orders.quantity) AS product_units, " - "sum(orders.amount) AS product_sales " - "FROM orders WHERE orders.region " - "IN (SELECT top_regions.region FROM top_regions) " - "GROUP BY orders.region, orders.product" - ) - - def test_cte_recursive(self): - parts = table('parts', - column('part'), - column('sub_part'), - column('quantity'), - ) - - included_parts = select([ - parts.c.sub_part, - parts.c.part, - parts.c.quantity]).\ - where(parts.c.part=='our part').\ - cte(recursive=True) - - incl_alias = included_parts.alias() - parts_alias = parts.alias() - included_parts = included_parts.union( - select([ - parts_alias.c.part, - parts_alias.c.sub_part, - parts_alias.c.quantity]).\ - where(parts_alias.c.part==incl_alias.c.sub_part) - ) - - s = select([ - included_parts.c.sub_part, - func.sum(included_parts.c.quantity).label('total_quantity')]).\ - select_from(included_parts.join( - parts,included_parts.c.part==parts.c.part)).\ - group_by(included_parts.c.sub_part) - self.assert_compile(s, - "WITH RECURSIVE anon_1(sub_part, part, quantity) " - "AS (SELECT parts.sub_part AS sub_part, parts.part " - "AS part, parts.quantity AS quantity FROM parts " - "WHERE parts.part = :part_1 UNION SELECT parts_1.part " - "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " - "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " - "WHERE parts_1.part = anon_2.sub_part) " - "SELECT anon_1.sub_part, " - "sum(anon_1.quantity) AS total_quantity FROM anon_1 " - "JOIN parts ON anon_1.part = parts.part " - "GROUP BY anon_1.sub_part" - ) - - # quick check that the "WITH RECURSIVE" varies per - # dialect - self.assert_compile(s, - "WITH anon_1(sub_part, part, quantity) " - "AS (SELECT parts.sub_part AS sub_part, parts.part " - "AS part, parts.quantity AS quantity FROM parts " - "WHERE parts.part = :part_1 UNION SELECT parts_1.part " - "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " - "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " - "WHERE parts_1.part = anon_2.sub_part) " - "SELECT anon_1.sub_part, " - "sum(anon_1.quantity) AS total_quantity FROM anon_1 " - "JOIN parts ON anon_1.part = parts.part " - "GROUP BY anon_1.sub_part", - dialect=mssql.dialect() - ) - - def test_cte_union(self): - orders = table('orders', - column('region'), - column('amount'), - ) - - regional_sales = select([ - orders.c.region, - orders.c.amount - ]).cte("regional_sales") - - s = select([regional_sales.c.region]).\ - where( - regional_sales.c.amount > 500 - ) - - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT regional_sales.region " - "FROM regional_sales WHERE " - "regional_sales.amount > :amount_1") - - s = s.union_all( - select([regional_sales.c.region]).\ - where( - regional_sales.c.amount < 300 - ) - ) - self.assert_compile(s, - "WITH regional_sales AS " - "(SELECT orders.region AS region, " - "orders.amount AS amount FROM orders) " - "SELECT regional_sales.region FROM regional_sales " - "WHERE regional_sales.amount > :amount_1 " - "UNION ALL SELECT regional_sales.region " - "FROM regional_sales WHERE " - "regional_sales.amount < :amount_2") - - def test_cte_reserved_quote(self): - orders = table('orders', - column('order'), - ) - s = select([orders.c.order]).cte("regional_sales", recursive=True) - s = select([s.c.order]) - self.assert_compile(s, - 'WITH RECURSIVE regional_sales("order") AS ' - '(SELECT orders."order" AS "order" ' - "FROM orders)" - ' SELECT regional_sales."order" ' - "FROM regional_sales" - ) def test_date_between(self): import datetime diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py new file mode 100644 index 0000000000..36f992a86c --- /dev/null +++ b/test/sql/test_cte.py @@ -0,0 +1,213 @@ +from test.lib import fixtures +from test.lib.testing import AssertsCompiledSQL +from sqlalchemy.sql import table, column, select, func, literal +from sqlalchemy.dialects import mssql +from sqlalchemy.engine import default + +class CTETest(fixtures.TestBase, AssertsCompiledSQL): + + __dialect__ = 'default' + + def test_nonrecursive(self): + orders = table('orders', + column('region'), + column('amount'), + column('product'), + column('quantity') + ) + + regional_sales = select([ + orders.c.region, + func.sum(orders.c.amount).label('total_sales') + ]).group_by(orders.c.region).cte("regional_sales") + + top_regions = select([regional_sales.c.region]).\ + where( + regional_sales.c.total_sales > + select([ + func.sum(regional_sales.c.total_sales)/10 + ]) + ).cte("top_regions") + + s = select([ + orders.c.region, + orders.c.product, + func.sum(orders.c.quantity).label("product_units"), + func.sum(orders.c.amount).label("product_sales") + ]).where(orders.c.region.in_( + select([top_regions.c.region]) + )).group_by(orders.c.region, orders.c.product) + + # needs to render regional_sales first as top_regions + # refers to it + self.assert_compile( + s, + "WITH regional_sales AS (SELECT orders.region AS region, " + "sum(orders.amount) AS total_sales FROM orders " + "GROUP BY orders.region), " + "top_regions AS (SELECT " + "regional_sales.region AS region FROM regional_sales " + "WHERE regional_sales.total_sales > " + "(SELECT sum(regional_sales.total_sales) / :sum_1 AS " + "anon_1 FROM regional_sales)) " + "SELECT orders.region, orders.product, " + "sum(orders.quantity) AS product_units, " + "sum(orders.amount) AS product_sales " + "FROM orders WHERE orders.region " + "IN (SELECT top_regions.region FROM top_regions) " + "GROUP BY orders.region, orders.product" + ) + + def test_recursive(self): + parts = table('parts', + column('part'), + column('sub_part'), + column('quantity'), + ) + + included_parts = select([ + parts.c.sub_part, + parts.c.part, + parts.c.quantity]).\ + where(parts.c.part=='our part').\ + cte(recursive=True) + + incl_alias = included_parts.alias() + parts_alias = parts.alias() + included_parts = included_parts.union( + select([ + parts_alias.c.part, + parts_alias.c.sub_part, + parts_alias.c.quantity]).\ + where(parts_alias.c.part==incl_alias.c.sub_part) + ) + + s = select([ + included_parts.c.sub_part, + func.sum(included_parts.c.quantity).label('total_quantity')]).\ + select_from(included_parts.join( + parts,included_parts.c.part==parts.c.part)).\ + group_by(included_parts.c.sub_part) + self.assert_compile(s, + "WITH RECURSIVE anon_1(sub_part, part, quantity) " + "AS (SELECT parts.sub_part AS sub_part, parts.part " + "AS part, parts.quantity AS quantity FROM parts " + "WHERE parts.part = :part_1 UNION SELECT parts_1.part " + "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " + "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " + "WHERE parts_1.part = anon_2.sub_part) " + "SELECT anon_1.sub_part, " + "sum(anon_1.quantity) AS total_quantity FROM anon_1 " + "JOIN parts ON anon_1.part = parts.part " + "GROUP BY anon_1.sub_part" + ) + + # quick check that the "WITH RECURSIVE" varies per + # dialect + self.assert_compile(s, + "WITH anon_1(sub_part, part, quantity) " + "AS (SELECT parts.sub_part AS sub_part, parts.part " + "AS part, parts.quantity AS quantity FROM parts " + "WHERE parts.part = :part_1 UNION SELECT parts_1.part " + "AS part, parts_1.sub_part AS sub_part, parts_1.quantity " + "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 " + "WHERE parts_1.part = anon_2.sub_part) " + "SELECT anon_1.sub_part, " + "sum(anon_1.quantity) AS total_quantity FROM anon_1 " + "JOIN parts ON anon_1.part = parts.part " + "GROUP BY anon_1.sub_part", + dialect=mssql.dialect() + ) + + def test_union(self): + orders = table('orders', + column('region'), + column('amount'), + ) + + regional_sales = select([ + orders.c.region, + orders.c.amount + ]).cte("regional_sales") + + s = select([regional_sales.c.region]).\ + where( + regional_sales.c.amount > 500 + ) + + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount > :amount_1") + + s = s.union_all( + select([regional_sales.c.region]).\ + where( + regional_sales.c.amount < 300 + ) + ) + self.assert_compile(s, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT regional_sales.region FROM regional_sales " + "WHERE regional_sales.amount > :amount_1 " + "UNION ALL SELECT regional_sales.region " + "FROM regional_sales WHERE " + "regional_sales.amount < :amount_2") + + def test_reserved_quote(self): + orders = table('orders', + column('order'), + ) + s = select([orders.c.order]).cte("regional_sales", recursive=True) + s = select([s.c.order]) + self.assert_compile(s, + 'WITH RECURSIVE regional_sales("order") AS ' + '(SELECT orders."order" AS "order" ' + "FROM orders)" + ' SELECT regional_sales."order" ' + "FROM regional_sales" + ) + + def test_positional_binds(self): + orders = table('orders', + column('order'), + ) + s = select([orders.c.order, literal("x")]).cte("regional_sales") + s = select([s.c.order, literal("y")]) + dialect = default.DefaultDialect() + dialect.positional = True + dialect.paramstyle = 'numeric' + self.assert_compile(s, + 'WITH regional_sales AS (SELECT orders."order" ' + 'AS "order", :1 AS anon_2 FROM orders) SELECT ' + 'regional_sales."order", :2 AS anon_1 FROM regional_sales', + checkpositional=('x', 'y'), + dialect=dialect + ) + + self.assert_compile(s.union(s), + 'WITH regional_sales AS (SELECT orders."order" ' + 'AS "order", :1 AS anon_2 FROM orders) SELECT ' + 'regional_sales."order", :2 AS anon_1 FROM regional_sales ' + 'UNION SELECT regional_sales."order", :3 AS anon_1 ' + 'FROM regional_sales', + checkpositional=('x', 'y', 'y'), + dialect=dialect + ) + + s = select([orders.c.order]).\ + where(orders.c.order=='x').cte("regional_sales") + s = select([s.c.order]).where(s.c.order=="y") + self.assert_compile(s, + 'WITH regional_sales AS (SELECT orders."order" AS ' + '"order" FROM orders WHERE orders."order" = :1) ' + 'SELECT regional_sales."order" FROM regional_sales ' + 'WHERE regional_sales."order" = :2', + checkpositional=('x', 'y'), + dialect=dialect + ) -- 2.47.2