def exists(*args, **params):
s = select(*args, **params)
- return BinaryClause(TextClause("EXISTS"), s, '')
+ return BinaryClause(TextClause("EXISTS"), s, None)
-def in_(*args, **params):
- s = select(*args, **params)
- return BinaryClause(TextClause("IN"), s, '')
-
def union(*selects, **params):
return _compound_select('UNION', *selects, **params)
def bindparam(key, value = None):
return BindParamClause(key, value)
-def textclause(text):
+def text(text):
return TextClause(text)
def sequence():
return s
+def _is_literal(element):
+ return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
+
class ClauseVisitor(schema.SchemaVisitor):
"""builds upon SchemaVisitor to define the visiting of SQL statement elements in
addition to Schema elements."""
return CompoundClause(self.operator, *clauses)
def append(self, clause):
- if type(clause) == str:
- clause = TextClause(clause)
+ if _is_literal(clause):
+ clause = TextClause(str(clause))
elif isinstance(clause, CompoundClause):
clause.parens = True
-
self.clauses.append(clause)
self.fromobj += clause._get_from_objects()
-
+
def accept_visitor(self, visitor):
for c in self.clauses:
c.accept_visitor(visitor)
def __init__(self, left, right, operator):
self.left = left
self.right = right
- if isinstance(right, Select):
- right._set_from_objects([])
self.operator = operator
self.parens = False
c = property(lambda self: self.columns)
def accept_visitor(self, visitor):
- print repr(self.__class__)
raise NotImplementedError()
def select(self, whereclauses = None, **params):
def hash_key(self):
return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
-
- def add_join(self, join):
- pass
-
+
def select(self, whereclauses = None, **params):
return select([self.left, self.right], and_(self.onclause, whereclauses), **params)
-
+
def accept_visitor(self, visitor):
self.left.accept_visitor(visitor)
self.right.accept_visitor(visitor)
self.onclause.accept_visitor(visitor)
visitor.visit_join(self)
-
+
def _engine(self):
return self.left._engine() or self.right._engine()
m = {}
for x in self.onclause._get_from_objects():
m[x.id] = x
- result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
+ result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
for x in result:
m[x.id] = x
result = m.values()
return [self.column.table]
def _compare(self, operator, obj):
- if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column):
+ if _is_literal(obj):
if self.column.table.name is None:
obj = BindParamClause(self.name, obj, shortname = self.name)
else:
def __gt__(self, other):
return self._compare('>', other)
- def __ge__(self, other):
+ def __ge__(self, other):
return self._compare('>=', other)
-
+
def like(self, other):
return self._compare('LIKE', other)
-
+
+ def in_(self, *other):
+ if _is_literal(other[0]):
+ return self._compare('IN', CompoundClause(',', other))
+ else:
+ return self._compare('IN', union(*other))
+
def startswith(self, other):
return self._compare('LIKE', str(other) + "%")
self.whereclause = whereclause
self.engine = engine
+ # indicates if this select statement is a subquery inside of a WHERE clause
+ # note this is different from a subquery inside the FROM list
+ self.issubquery = False
+
self._text = None
self._raw_columns = []
self._clauses = []
self.order_by(*order_by)
def append_column(self, column):
- if type(column) == str:
- column = ColumnClause(column, self)
+ if _is_literal(column):
+ column = ColumnClause(str(column), self)
self._raw_columns.append(column)
for f in column._get_from_objects():
self.froms.setdefault(f.id, f)
-
+
for co in column.columns:
if self.use_labels:
co._make_proxy(self, name = co.label)
def set_whereclause(self, whereclause):
if type(whereclause) == str:
self.whereclause = TextClause(whereclause)
-
- for f in self.whereclause._get_from_objects():
- self.froms.setdefault(f.id, f)
class CorrelatedVisitor(ClauseVisitor):
def visit_select(s, select):
for f in self.froms.keys():
select.clear_from(f)
+ select.issubquery = True
self.whereclause.accept_visitor(CorrelatedVisitor())
+
+ for f in self.whereclause._get_from_objects():
+ self.froms.setdefault(f.id, f)
+
def clear_from(self, id):
self.append_from(FromClause(from_name = None, from_key = id))
+
def append_from(self, fromclause):
if type(fromclause) == str:
fromclause = FromClause(from_name = fromclause)
return engine.compile(self, bindparams)
def accept_visitor(self, visitor):
-# for c in self._raw_columns:
-# c.accept_visitor(visitor)
for f in self.froms.values():
f.accept_visitor(visitor)
if self.whereclause is not None:
return None
- def _set_from_objects(self, obj):
- self._from_obj = obj
-
def _get_from_objects(self):
- return getattr(self, '_from_obj', [self])
+ if self.issubquery:
+ return []
+ else:
+ return [self]
class UpdateBase(ClauseElement):
for key in parameters.keys():
value = parameters[key]
if isinstance(value, Select):
- value.append_from(FromClause(from_key=self.table.id))
- elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+ value.clear_from(self.table.id)
+ elif _is_literal(value):
try:
col = self.table.c[key]
parameters[key] = bindparam(col.name, value)
for c in self.table.columns:
if d.has_key(c):
value = d[c]
- if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+ if _is_literal(value):
value = bindparam(c.name, value)
values.append((c, value))
return values
Column('otherstuff', 5),
)
-class SelectTest(PersistTest):
+class SQLTest(PersistTest):
+ def runtest(self, clause, result, engine = None, params = None):
+ c = clause.compile(engine, params)
+ print "\n" + str(c) + repr(c.get_params())
+ cc = re.sub(r'\n', '', str(c))
+ self.assert_(cc == result)
+
+class SelectTest(SQLTest):
def testtext(self):
self.runtest(
- textclause("select * from foo where lala = bar") ,
+ text("select * from foo where lala = bar") ,
"select * from foo where lala = bar",
engine = db
)
-
+
def testtableselect(self):
self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
myothertable.othername FROM mytable, myothertable")
-
+
def testsubquery(self):
-
- s = select([table], table.c.name == 'jack')
+ s = select([table], table.c.name == 'jack')
self.runtest(
select(
[s],
myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)",
engine = oracle.engine(use_ansi = False))
-
-
def testbindparam(self):
- #return
self.runtest(select(
[table, table2],
and_(table.c.id == table2.c.id,
FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename"
)
+ def testcorrelatedsubquery(self):
+ self.runtest(
+ select([table], table.c.id == select([table2.c.id], table.c.name == table2.c.name)),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)"
+ )
+
+ self.runtest(
+ select([table], exists([1], table2.c.id == table.c.id)),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)"
+ )
+
+ s = subquery('sq2', [table], exists([1], table2.c.id == table.c.id))
+ self.runtest(
+ select([s, table])
+ ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)) sq2, mytable")
+
+ def testin(self):
+ self.runtest(select([table], table.c.id.in_(1, 2, 3)),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (1, 2, 3)")
+ self.runtest(select([table], table.c.id.in_(select([table2.c.id]))),
+ "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)")
+
+class CRUDTest(SQLTest):
def testinsert(self):
# generic insert, will create bind params for all columns
self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
def testcorrelatedupdate(self):
# test against a straight text subquery
- u = update(table, values = {table.c.name : TextClause("select name from mytable where id=mytable.id")})
+ u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")})
self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
# test against a regular constructed subquery
def testdelete(self):
self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
-
- def runtest(self, clause, result, engine = None, params = None):
- c = clause.compile(engine, params)
- print "\n" + str(c) + repr(c.get_params())
- cc = re.sub(r'\n', '', str(c))
- self.assert_(cc == result)
if __name__ == "__main__":
unittest.main()