From 2eaaa50b465197d497e0b437d37c339c01b4f3c8 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 7 Aug 2005 00:42:55 +0000 Subject: [PATCH] --- lib/sqlalchemy/ansisql.py | 20 ++++++---- lib/sqlalchemy/sql.py | 81 ++++++++++++++++++++------------------- test/select.py | 52 +++++++++++++++++-------- 3 files changed, 90 insertions(+), 63 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index a5e5a5b19f..8845886686 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -116,14 +116,14 @@ class ANSICompiler(sql.Compiled): self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') def visit_binary(self, binary): - if isinstance(binary.right, sql.Select): - s = self.get_str(binary.left) + " " + str(binary.operator) + " (" + self.get_str(binary.right) + ")" - else: - s = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + result = self.get_str(binary.left) + if binary.operator is not None: + result += " " + binary.operator + result += " " + self.get_str(binary.right) if binary.parens: - self.strings[binary] = "(" + s + ")" - else: - self.strings[binary] = s + result = "(" + result + ")" + + self.strings[binary] = result def visit_bindparam(self, bindparam): self.binds[bindparam.shortname] = bindparam @@ -181,7 +181,11 @@ class ANSICompiler(sql.Compiled): for tup in select._clauses: text += " " + tup[0] + " " + self.get_str(tup[1]) - self.strings[select] = text + if getattr(select, 'issubquery', False): + self.strings[select] = "(" + text + ")" + else: + self.strings[select] = text + self.froms[select] = "(" + text + ")" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a8b75b875d..00333d4c8b 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -103,12 +103,8 @@ def or_(*clauses): def exists(*args, **params): s = select(*args, **params) - return BinaryClause(TextClause("EXISTS"), s, '') + return BinaryClause(TextClause("EXISTS"), s, None) -def in_(*args, **params): - s = select(*args, **params) - return BinaryClause(TextClause("IN"), s, '') - def union(*selects, **params): return _compound_select('UNION', *selects, **params) @@ -121,7 +117,7 @@ def subquery(alias, *args, **params): def bindparam(key, value = None): return BindParamClause(key, value) -def textclause(text): +def text(text): return TextClause(text) def sequence(): @@ -142,6 +138,9 @@ def _compound_select(keyword, *selects, **params): return s +def _is_literal(element): + return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem) + class ClauseVisitor(schema.SchemaVisitor): """builds upon SchemaVisitor to define the visiting of SQL statement elements in addition to Schema elements.""" @@ -327,14 +326,13 @@ class CompoundClause(ClauseElement): return CompoundClause(self.operator, *clauses) def append(self, clause): - if type(clause) == str: - clause = TextClause(clause) + if _is_literal(clause): + clause = TextClause(str(clause)) elif isinstance(clause, CompoundClause): clause.parens = True - self.clauses.append(clause) self.fromobj += clause._get_from_objects() - + def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor) @@ -364,8 +362,6 @@ class BinaryClause(ClauseElement): def __init__(self, left, right, operator): self.left = left self.right = right - if isinstance(right, Select): - right._set_from_objects([]) self.operator = operator self.parens = False @@ -391,7 +387,6 @@ class Selectable(FromClause): c = property(lambda self: self.columns) def accept_visitor(self, visitor): - print repr(self.__class__) raise NotImplementedError() def select(self, whereclauses = None, **params): @@ -414,19 +409,16 @@ class Join(Selectable): def hash_key(self): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) - - def add_join(self, join): - pass - + def select(self, whereclauses = None, **params): return select([self.left, self.right], and_(self.onclause, whereclauses), **params) - + def accept_visitor(self, visitor): self.left.accept_visitor(visitor) self.right.accept_visitor(visitor) self.onclause.accept_visitor(visitor) visitor.visit_join(self) - + def _engine(self): return self.left._engine() or self.right._engine() @@ -434,7 +426,7 @@ class Join(Selectable): m = {} for x in self.onclause._get_from_objects(): m[x.id] = x - result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] + result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] for x in result: m[x.id] = x result = m.values() @@ -493,7 +485,7 @@ class ColumnSelectable(Selectable): return [self.column.table] def _compare(self, operator, obj): - if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column): + if _is_literal(obj): if self.column.table.name is None: obj = BindParamClause(self.name, obj, shortname = self.name) else: @@ -516,12 +508,18 @@ class ColumnSelectable(Selectable): def __gt__(self, other): return self._compare('>', other) - def __ge__(self, other): + def __ge__(self, other): return self._compare('>=', other) - + def like(self, other): return self._compare('LIKE', other) - + + def in_(self, *other): + if _is_literal(other[0]): + return self._compare('IN', CompoundClause(',', other)) + else: + return self._compare('IN', union(*other)) + def startswith(self, other): return self._compare('LIKE', str(other) + "%") @@ -578,6 +576,10 @@ class Select(Selectable): self.whereclause = whereclause self.engine = engine + # indicates if this select statement is a subquery inside of a WHERE clause + # note this is different from a subquery inside the FROM list + self.issubquery = False + self._text = None self._raw_columns = [] self._clauses = [] @@ -598,14 +600,14 @@ class Select(Selectable): self.order_by(*order_by) def append_column(self, column): - if type(column) == str: - column = ColumnClause(column, self) + if _is_literal(column): + column = ColumnClause(str(column), self) self._raw_columns.append(column) for f in column._get_from_objects(): self.froms.setdefault(f.id, f) - + for co in column.columns: if self.use_labels: co._make_proxy(self, name = co.label) @@ -615,18 +617,21 @@ class Select(Selectable): def set_whereclause(self, whereclause): if type(whereclause) == str: self.whereclause = TextClause(whereclause) - - for f in self.whereclause._get_from_objects(): - self.froms.setdefault(f.id, f) class CorrelatedVisitor(ClauseVisitor): def visit_select(s, select): for f in self.froms.keys(): select.clear_from(f) + select.issubquery = True self.whereclause.accept_visitor(CorrelatedVisitor()) + + for f in self.whereclause._get_from_objects(): + self.froms.setdefault(f.id, f) + def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) + def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(from_name = fromclause) @@ -658,8 +663,6 @@ class Select(Selectable): return engine.compile(self, bindparams) def accept_visitor(self, visitor): -# for c in self._raw_columns: -# c.accept_visitor(visitor) for f in self.froms.values(): f.accept_visitor(visitor) if self.whereclause is not None: @@ -689,11 +692,11 @@ class Select(Selectable): return None - def _set_from_objects(self, obj): - self._from_obj = obj - def _get_from_objects(self): - return getattr(self, '_from_obj', [self]) + if self.issubquery: + return [] + else: + return [self] class UpdateBase(ClauseElement): @@ -709,8 +712,8 @@ class UpdateBase(ClauseElement): for key in parameters.keys(): value = parameters[key] if isinstance(value, Select): - value.append_from(FromClause(from_key=self.table.id)) - elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): + value.clear_from(self.table.id) + elif _is_literal(value): try: col = self.table.c[key] parameters[key] = bindparam(col.name, value) @@ -747,7 +750,7 @@ class UpdateBase(ClauseElement): for c in self.table.columns: if d.has_key(c): value = d[c] - if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): + if _is_literal(value): value = bindparam(c.name, value) values.append((c, value)) return values diff --git a/test/select.py b/test/select.py index 2d3f23eb61..b43636333f 100644 --- a/test/select.py +++ b/test/select.py @@ -30,24 +30,30 @@ table3 = Table( Column('otherstuff', 5), ) -class SelectTest(PersistTest): +class SQLTest(PersistTest): + def runtest(self, clause, result, engine = None, params = None): + c = clause.compile(engine, params) + print "\n" + str(c) + repr(c.get_params()) + cc = re.sub(r'\n', '', str(c)) + self.assert_(cc == result) + +class SelectTest(SQLTest): def testtext(self): self.runtest( - textclause("select * from foo where lala = bar") , + text("select * from foo where lala = bar") , "select * from foo where lala = bar", engine = db ) - + def testtableselect(self): self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable") self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \ myothertable.othername FROM mytable, myothertable") - + def testsubquery(self): - - s = select([table], table.c.name == 'jack') + s = select([table], table.c.name == 'jack') self.runtest( select( [s], @@ -269,10 +275,7 @@ mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \ myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)", engine = oracle.engine(use_ansi = False)) - - def testbindparam(self): - #return self.runtest(select( [table, table2], and_(table.c.id == table2.c.id, @@ -283,7 +286,30 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename" ) + def testcorrelatedsubquery(self): + self.runtest( + select([table], table.c.id == select([table2.c.id], table.c.name == table2.c.name)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)" + ) + + self.runtest( + select([table], exists([1], table2.c.id == table.c.id)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)" + ) + + s = subquery('sq2', [table], exists([1], table2.c.id == table.c.id)) + self.runtest( + select([s, table]) + ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)) sq2, mytable") + + def testin(self): + self.runtest(select([table], table.c.id.in_(1, 2, 3)), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (1, 2, 3)") + self.runtest(select([table], table.c.id.in_(select([table2.c.id]))), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)") + +class CRUDTest(SQLTest): def testinsert(self): # generic insert, will create bind params for all columns self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)") @@ -315,7 +341,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable def testcorrelatedupdate(self): # test against a straight text subquery - u = update(table, values = {table.c.name : TextClause("select name from mytable where id=mytable.id")}) + u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")}) self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)") # test against a regular constructed subquery @@ -326,12 +352,6 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable def testdelete(self): self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid") - - def runtest(self, clause, result, engine = None, params = None): - c = clause.compile(engine, params) - print "\n" + str(c) + repr(c.get_params()) - cc = re.sub(r'\n', '', str(c)) - self.assert_(cc == result) if __name__ == "__main__": unittest.main() -- 2.47.2