From 6e477b750eb432e91f933abae3aff3cb58b27362 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 12 Aug 2007 22:01:30 +0000 Subject: [PATCH] - got is_subquery() working in the case of compound selects, test for ms-sql --- lib/sqlalchemy/ansisql.py | 18 +++++++++++++----- test/dialect/mssql.py | 26 ++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 8c7e6bb1cf..4d50b6a25a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -225,10 +225,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): if stack: self.stack.append(stack) try: - x = self.traverse_single(obj, **kwargs) - if x is None: - raise "hi " + repr(obj) - return x + return self.traverse_single(obj, **kwargs) finally: if stack: self.stack.pop(-1) @@ -383,13 +380,24 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): + stack_entry = {'select':cs} + + if asfrom: + stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True + elif self.stack and self.stack[-1].get('select'): + stack_entry['is_subquery'] = True + self.stack.append(stack_entry) + text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ") group_by = self.process(cs._group_by_clause, asfrom=asfrom) if group_by: text += " GROUP BY " + group_by + text += self.order_by_clause(cs) text += (cs._limit or cs._offset) and self.limit_clause(cs) or "" + self.stack.pop(-1) + if asfrom and parens: return "(" + text + ")" else: @@ -595,7 +603,7 @@ class ANSICompiler(engine.Compiled, sql.ClauseVisitor): t = self.process(select._having) if t: text += " \nHAVING " + t - + text += self.order_by_clause(select) text += (select._limit or select._offset) and self.limit_clause(select) or "" text += self.for_update_clause(select) diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index ec27291977..309a65bee2 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -24,11 +24,33 @@ class CompileTest(AssertMixin): def test_update(self): t = table('sometable', column('somecolumn')) - self._test(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :sometable_somecolumn", somecolumn=10) + self._test(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :sometable_somecolumn", somecolumn=10) def test_count(self): t = table('sometable', column('somecolumn')) - self._test(t.count(), "SELECT count(sometable.somecolumn) AS tbl_row_count FROM sometable") + self._test(t.count(), "SELECT count(sometable.somecolumn) AS tbl_row_count FROM sometable") + def test_union(self): + t1 = table('t1', + column('col1'), + column('col2'), + column('col3'), + column('col4') + ) + t2 = table('t2', + column('col1'), + column('col2'), + column('col3'), + column('col4')) + + (s1, s2) = ( + select([t1.c.col3.label('col3'), t1.c.col4.label('col4')], t1.c.col2.in_("t1col2r1", "t1col2r2")), + select([t2.c.col3.label('col3'), t2.c.col4.label('col4')], t2.c.col2.in_("t2col2r2", "t2col2r3")) + ) + u = union(s1, s2, order_by=['col3', 'col4']) + self._test(u, "SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE t1.col2 IN (:t1_col2, :t1_col2_1) UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:t2_col2, :t2_col2_1) ORDER BY col3, col4") + + self._test(u.alias('bar').select(), "SELECT bar.col3, bar.col4 FROM (SELECT t1.col3 AS col3, t1.col4 AS col4 FROM t1 WHERE t1.col2 IN (:t1_col2, :t1_col2_1) UNION SELECT t2.col3 AS col3, t2.col4 AS col4 FROM t2 WHERE t2.col2 IN (:t2_col2, :t2_col2_1)) AS bar") + if __name__ == "__main__": testbase.main() -- 2.47.3