0.7.9
=====
- sql
+ - [bug] Fixed CTE bug whereby positional
+ bound parameters present in the CTEs themselves
+ would corrupt the overall ordering of
+ bound parameters. This primarily
+ affected SQL Server as the platform with
+ positional binds + CTE support.
+ [ticket:2521]
+
- [bug] quoting is applied to the column names
inside the WITH RECURSIVE clause of a
common table expression according to the
return dbapi.ROWID
class OracleCompiler_cx_oracle(OracleCompiler):
- def bindparam_string(self, name):
+ def bindparam_string(self, name, **kw):
if self.preparer._bindparam_requires_quotes(name):
quoted_name = '"%s"' % name
self._quoted_bind_names[name] = quoted_name
- return OracleCompiler.bindparam_string(self, quoted_name)
+ return OracleCompiler.bindparam_string(self, quoted_name, **kw)
else:
- return OracleCompiler.bindparam_string(self, name)
+ return OracleCompiler.bindparam_string(self, name, **kw)
class OracleExecutionContext_cx_oracle(OracleExecutionContext):
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
- def bindparam_string(self, name):
+ def bindparam_string(self, name, **kw):
return "@" + name
class SybaseDialect_pysybase(SybaseDialect):
visitors
from sqlalchemy.sql import expression as sql
import decimal
+import itertools
RESERVED_WORDS = set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
'pyformat':"%%(%(name)s)s",
'qmark':"?",
'format':"%%s",
- 'numeric':":%(position)s",
+ 'numeric':":[_POSITION]",
'named':":%(name)s"
}
# column targeting
self.result_map = {}
- # collect CTEs to tack on top of a SELECT
- self.ctes = util.OrderedDict()
- self.ctes_recursive = False
-
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
self.positiontup = []
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
+ # collect CTEs to tack on top of a SELECT
+ self.ctes = util.OrderedDict()
+ self.ctes_recursive = False
+ if self.positional:
+ self.cte_positional = []
+
# an IdentifierPreparer that formats the quoting of identifiers
self.preparer = dialect.identifier_preparer
self.label_length = dialect.label_length \
self.truncated_names = {}
engine.Compiled.__init__(self, dialect, statement, **kwargs)
+ if self.positional and dialect.paramstyle == 'numeric':
+ self._apply_numbered_params()
+ def _apply_numbered_params(self):
+ poscount = itertools.count(1)
+ self.string = re.sub(
+ r'\[_POSITION\]',
+ lambda m:str(next(poscount)),
+ self.string)
@util.memoized_property
def _bind_processors(self):
if name in textclause.bindparams:
return self.process(textclause.bindparams[name])
else:
- return self.bindparam_string(name)
+ return self.bindparam_string(name, **kwargs)
# un-escape any \:params
return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
self.binds[bindparam.key] = self.binds[name] = bindparam
- return self.bindparam_string(name)
+ return self.bindparam_string(name, **kwargs)
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.value
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
- def bindparam_string(self, name):
+ def bindparam_string(self, name, positional_names=None, **kw):
if self.positional:
- self.positiontup.append(name)
- return self.bindtemplate % {
- 'name':name, 'position':len(self.positiontup)}
- else:
- return self.bindtemplate % {'name':name}
+ if positional_names is not None:
+ positional_names.append(name)
+ else:
+ self.positiontup.append(name)
+ return self.bindtemplate % {'name':name}
def visit_cte(self, cte, asfrom=False, ashint=False,
fromhints=None, **kwargs):
+ if self.positional:
+ kwargs['positional_names'] = self.cte_positional
+
if isinstance(cte.name, sql._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
def visit_select(self, select, asfrom=False, parens=True,
iswrapper=False, fromhints=None,
- compound_index=1, **kwargs):
+ compound_index=1,
+ positional_names=None, **kwargs):
entry = self.stack and self.stack[-1] or {}
: iswrapper})
if compound_index==1 and not entry or entry.get('iswrapper', False):
- column_clause_args = {'result_map':self.result_map}
+ column_clause_args = {'result_map':self.result_map,
+ 'positional_names':positional_names}
else:
- column_clause_args = {}
+ column_clause_args = {'positional_names':positional_names}
# the actual list of columns to print in the SELECT column list.
inner_columns = [
return text
def _render_cte_clause(self):
+ if self.positional:
+ self.positiontup = self.cte_positional + self.positiontup
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
cte_text += ", \n".join(
[txt for txt in self.ctes.values()]
class AssertsCompiledSQL(object):
def assert_compile(self, clause, result, params=None,
checkparams=None, dialect=None,
+ checkpositional=None,
use_default_dialect=False,
allow_dialect_select=False):
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table, strict_types=False):
"SELECT x + foo() OVER () AS anon_1"
)
- def test_cte_nonrecursive(self):
- orders = table('orders',
- column('region'),
- column('amount'),
- column('product'),
- column('quantity')
- )
-
- regional_sales = select([
- orders.c.region,
- func.sum(orders.c.amount).label('total_sales')
- ]).group_by(orders.c.region).cte("regional_sales")
-
- top_regions = select([regional_sales.c.region]).\
- where(
- regional_sales.c.total_sales >
- select([
- func.sum(regional_sales.c.total_sales)/10
- ])
- ).cte("top_regions")
-
- s = select([
- orders.c.region,
- orders.c.product,
- func.sum(orders.c.quantity).label("product_units"),
- func.sum(orders.c.amount).label("product_sales")
- ]).where(orders.c.region.in_(
- select([top_regions.c.region])
- )).group_by(orders.c.region, orders.c.product)
-
- # needs to render regional_sales first as top_regions
- # refers to it
- self.assert_compile(
- s,
- "WITH regional_sales AS (SELECT orders.region AS region, "
- "sum(orders.amount) AS total_sales FROM orders "
- "GROUP BY orders.region), "
- "top_regions AS (SELECT "
- "regional_sales.region AS region FROM regional_sales "
- "WHERE regional_sales.total_sales > "
- "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
- "anon_1 FROM regional_sales)) "
- "SELECT orders.region, orders.product, "
- "sum(orders.quantity) AS product_units, "
- "sum(orders.amount) AS product_sales "
- "FROM orders WHERE orders.region "
- "IN (SELECT top_regions.region FROM top_regions) "
- "GROUP BY orders.region, orders.product"
- )
-
- def test_cte_recursive(self):
- parts = table('parts',
- column('part'),
- column('sub_part'),
- column('quantity'),
- )
-
- included_parts = select([
- parts.c.sub_part,
- parts.c.part,
- parts.c.quantity]).\
- where(parts.c.part=='our part').\
- cte(recursive=True)
-
- incl_alias = included_parts.alias()
- parts_alias = parts.alias()
- included_parts = included_parts.union(
- select([
- parts_alias.c.part,
- parts_alias.c.sub_part,
- parts_alias.c.quantity]).\
- where(parts_alias.c.part==incl_alias.c.sub_part)
- )
-
- s = select([
- included_parts.c.sub_part,
- func.sum(included_parts.c.quantity).label('total_quantity')]).\
- select_from(included_parts.join(
- parts,included_parts.c.part==parts.c.part)).\
- group_by(included_parts.c.sub_part)
- self.assert_compile(s,
- "WITH RECURSIVE anon_1(sub_part, part, quantity) "
- "AS (SELECT parts.sub_part AS sub_part, parts.part "
- "AS part, parts.quantity AS quantity FROM parts "
- "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
- "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
- "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
- "WHERE parts_1.part = anon_2.sub_part) "
- "SELECT anon_1.sub_part, "
- "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
- "JOIN parts ON anon_1.part = parts.part "
- "GROUP BY anon_1.sub_part"
- )
-
- # quick check that the "WITH RECURSIVE" varies per
- # dialect
- self.assert_compile(s,
- "WITH anon_1(sub_part, part, quantity) "
- "AS (SELECT parts.sub_part AS sub_part, parts.part "
- "AS part, parts.quantity AS quantity FROM parts "
- "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
- "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
- "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
- "WHERE parts_1.part = anon_2.sub_part) "
- "SELECT anon_1.sub_part, "
- "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
- "JOIN parts ON anon_1.part = parts.part "
- "GROUP BY anon_1.sub_part",
- dialect=mssql.dialect()
- )
-
- def test_cte_union(self):
- orders = table('orders',
- column('region'),
- column('amount'),
- )
-
- regional_sales = select([
- orders.c.region,
- orders.c.amount
- ]).cte("regional_sales")
-
- s = select([regional_sales.c.region]).\
- where(
- regional_sales.c.amount > 500
- )
-
- self.assert_compile(s,
- "WITH regional_sales AS "
- "(SELECT orders.region AS region, "
- "orders.amount AS amount FROM orders) "
- "SELECT regional_sales.region "
- "FROM regional_sales WHERE "
- "regional_sales.amount > :amount_1")
-
- s = s.union_all(
- select([regional_sales.c.region]).\
- where(
- regional_sales.c.amount < 300
- )
- )
- self.assert_compile(s,
- "WITH regional_sales AS "
- "(SELECT orders.region AS region, "
- "orders.amount AS amount FROM orders) "
- "SELECT regional_sales.region FROM regional_sales "
- "WHERE regional_sales.amount > :amount_1 "
- "UNION ALL SELECT regional_sales.region "
- "FROM regional_sales WHERE "
- "regional_sales.amount < :amount_2")
-
- def test_cte_reserved_quote(self):
- orders = table('orders',
- column('order'),
- )
- s = select([orders.c.order]).cte("regional_sales", recursive=True)
- s = select([s.c.order])
- self.assert_compile(s,
- 'WITH RECURSIVE regional_sales("order") AS '
- '(SELECT orders."order" AS "order" '
- "FROM orders)"
- ' SELECT regional_sales."order" '
- "FROM regional_sales"
- )
def test_date_between(self):
import datetime
--- /dev/null
+from test.lib import fixtures
+from test.lib.testing import AssertsCompiledSQL
+from sqlalchemy.sql import table, column, select, func, literal
+from sqlalchemy.dialects import mssql
+from sqlalchemy.engine import default
+
+class CTETest(fixtures.TestBase, AssertsCompiledSQL):
+
+ __dialect__ = 'default'
+
+ def test_nonrecursive(self):
+ orders = table('orders',
+ column('region'),
+ column('amount'),
+ column('product'),
+ column('quantity')
+ )
+
+ regional_sales = select([
+ orders.c.region,
+ func.sum(orders.c.amount).label('total_sales')
+ ]).group_by(orders.c.region).cte("regional_sales")
+
+ top_regions = select([regional_sales.c.region]).\
+ where(
+ regional_sales.c.total_sales >
+ select([
+ func.sum(regional_sales.c.total_sales)/10
+ ])
+ ).cte("top_regions")
+
+ s = select([
+ orders.c.region,
+ orders.c.product,
+ func.sum(orders.c.quantity).label("product_units"),
+ func.sum(orders.c.amount).label("product_sales")
+ ]).where(orders.c.region.in_(
+ select([top_regions.c.region])
+ )).group_by(orders.c.region, orders.c.product)
+
+ # needs to render regional_sales first as top_regions
+ # refers to it
+ self.assert_compile(
+ s,
+ "WITH regional_sales AS (SELECT orders.region AS region, "
+ "sum(orders.amount) AS total_sales FROM orders "
+ "GROUP BY orders.region), "
+ "top_regions AS (SELECT "
+ "regional_sales.region AS region FROM regional_sales "
+ "WHERE regional_sales.total_sales > "
+ "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
+ "anon_1 FROM regional_sales)) "
+ "SELECT orders.region, orders.product, "
+ "sum(orders.quantity) AS product_units, "
+ "sum(orders.amount) AS product_sales "
+ "FROM orders WHERE orders.region "
+ "IN (SELECT top_regions.region FROM top_regions) "
+ "GROUP BY orders.region, orders.product"
+ )
+
+ def test_recursive(self):
+ parts = table('parts',
+ column('part'),
+ column('sub_part'),
+ column('quantity'),
+ )
+
+ included_parts = select([
+ parts.c.sub_part,
+ parts.c.part,
+ parts.c.quantity]).\
+ where(parts.c.part=='our part').\
+ cte(recursive=True)
+
+ incl_alias = included_parts.alias()
+ parts_alias = parts.alias()
+ included_parts = included_parts.union(
+ select([
+ parts_alias.c.part,
+ parts_alias.c.sub_part,
+ parts_alias.c.quantity]).\
+ where(parts_alias.c.part==incl_alias.c.sub_part)
+ )
+
+ s = select([
+ included_parts.c.sub_part,
+ func.sum(included_parts.c.quantity).label('total_quantity')]).\
+ select_from(included_parts.join(
+ parts,included_parts.c.part==parts.c.part)).\
+ group_by(included_parts.c.sub_part)
+ self.assert_compile(s,
+ "WITH RECURSIVE anon_1(sub_part, part, quantity) "
+ "AS (SELECT parts.sub_part AS sub_part, parts.part "
+ "AS part, parts.quantity AS quantity FROM parts "
+ "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+ "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+ "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+ "WHERE parts_1.part = anon_2.sub_part) "
+ "SELECT anon_1.sub_part, "
+ "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+ "JOIN parts ON anon_1.part = parts.part "
+ "GROUP BY anon_1.sub_part"
+ )
+
+ # quick check that the "WITH RECURSIVE" varies per
+ # dialect
+ self.assert_compile(s,
+ "WITH anon_1(sub_part, part, quantity) "
+ "AS (SELECT parts.sub_part AS sub_part, parts.part "
+ "AS part, parts.quantity AS quantity FROM parts "
+ "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+ "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+ "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+ "WHERE parts_1.part = anon_2.sub_part) "
+ "SELECT anon_1.sub_part, "
+ "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+ "JOIN parts ON anon_1.part = parts.part "
+ "GROUP BY anon_1.sub_part",
+ dialect=mssql.dialect()
+ )
+
+ def test_union(self):
+ orders = table('orders',
+ column('region'),
+ column('amount'),
+ )
+
+ regional_sales = select([
+ orders.c.region,
+ orders.c.amount
+ ]).cte("regional_sales")
+
+ s = select([regional_sales.c.region]).\
+ where(
+ regional_sales.c.amount > 500
+ )
+
+ self.assert_compile(s,
+ "WITH regional_sales AS "
+ "(SELECT orders.region AS region, "
+ "orders.amount AS amount FROM orders) "
+ "SELECT regional_sales.region "
+ "FROM regional_sales WHERE "
+ "regional_sales.amount > :amount_1")
+
+ s = s.union_all(
+ select([regional_sales.c.region]).\
+ where(
+ regional_sales.c.amount < 300
+ )
+ )
+ self.assert_compile(s,
+ "WITH regional_sales AS "
+ "(SELECT orders.region AS region, "
+ "orders.amount AS amount FROM orders) "
+ "SELECT regional_sales.region FROM regional_sales "
+ "WHERE regional_sales.amount > :amount_1 "
+ "UNION ALL SELECT regional_sales.region "
+ "FROM regional_sales WHERE "
+ "regional_sales.amount < :amount_2")
+
+ def test_reserved_quote(self):
+ orders = table('orders',
+ column('order'),
+ )
+ s = select([orders.c.order]).cte("regional_sales", recursive=True)
+ s = select([s.c.order])
+ self.assert_compile(s,
+ 'WITH RECURSIVE regional_sales("order") AS '
+ '(SELECT orders."order" AS "order" '
+ "FROM orders)"
+ ' SELECT regional_sales."order" '
+ "FROM regional_sales"
+ )
+
+ def test_positional_binds(self):
+ orders = table('orders',
+ column('order'),
+ )
+ s = select([orders.c.order, literal("x")]).cte("regional_sales")
+ s = select([s.c.order, literal("y")])
+ dialect = default.DefaultDialect()
+ dialect.positional = True
+ dialect.paramstyle = 'numeric'
+ self.assert_compile(s,
+ 'WITH regional_sales AS (SELECT orders."order" '
+ 'AS "order", :1 AS anon_2 FROM orders) SELECT '
+ 'regional_sales."order", :2 AS anon_1 FROM regional_sales',
+ checkpositional=('x', 'y'),
+ dialect=dialect
+ )
+
+ self.assert_compile(s.union(s),
+ 'WITH regional_sales AS (SELECT orders."order" '
+ 'AS "order", :1 AS anon_2 FROM orders) SELECT '
+ 'regional_sales."order", :2 AS anon_1 FROM regional_sales '
+ 'UNION SELECT regional_sales."order", :3 AS anon_1 '
+ 'FROM regional_sales',
+ checkpositional=('x', 'y', 'y'),
+ dialect=dialect
+ )
+
+ s = select([orders.c.order]).\
+ where(orders.c.order=='x').cte("regional_sales")
+ s = select([s.c.order]).where(s.c.order=="y")
+ self.assert_compile(s,
+ 'WITH regional_sales AS (SELECT orders."order" AS '
+ '"order" FROM orders WHERE orders."order" = :1) '
+ 'SELECT regional_sales."order" FROM regional_sales '
+ 'WHERE regional_sales."order" = :2',
+ checkpositional=('x', 'y'),
+ dialect=dialect
+ )