]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Aug 2005 00:42:55 +0000 (00:42 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Aug 2005 00:42:55 +0000 (00:42 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py
test/select.py

index a5e5a5b19f2b5a47af96ce22cbbb2fdfbecb5292..884588668612784f87fe3bcb47fcfdf5ed486cc2 100644 (file)
@@ -116,14 +116,14 @@ class ANSICompiler(sql.Compiled):
         self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
         
     def visit_binary(self, binary):
-        if isinstance(binary.right, sql.Select):
-            s = self.get_str(binary.left) + " " + str(binary.operator) + " (" + self.get_str(binary.right) + ")"
-        else:
-            s = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
+        result = self.get_str(binary.left)
+        if binary.operator is not None:
+            result += " " + binary.operator
+        result += " " + self.get_str(binary.right)
         if binary.parens:
-           self.strings[binary] = "(" + s + ")"
-        else:
-            self.strings[binary] = s
+            result = "(" + result + ")"
+        
+        self.strings[binary] = result
         
     def visit_bindparam(self, bindparam):
         self.binds[bindparam.shortname] = bindparam
@@ -181,7 +181,11 @@ class ANSICompiler(sql.Compiled):
         for tup in select._clauses:
             text += " " + tup[0] + " " + self.get_str(tup[1])
 
-        self.strings[select] = text
+        if getattr(select, 'issubquery', False):
+            self.strings[select] = "(" + text + ")"
+        else:
+            self.strings[select] = text
+
         self.froms[select] = "(" + text + ")"
 
 
index a8b75b875dc1585628476a51fe06c79d1e1feab2..00333d4c8b58c543ce555a049a53e83bdbb12017 100644 (file)
@@ -103,12 +103,8 @@ def or_(*clauses):
 
 def exists(*args, **params):
     s = select(*args, **params)
-    return BinaryClause(TextClause("EXISTS"), s, '')
+    return BinaryClause(TextClause("EXISTS"), s, None)
 
-def in_(*args, **params):
-    s = select(*args, **params)
-    return BinaryClause(TextClause("IN"), s, '')
-    
 def union(*selects, **params):
     return _compound_select('UNION', *selects, **params)
 
@@ -121,7 +117,7 @@ def subquery(alias, *args, **params):
 def bindparam(key, value = None):
     return BindParamClause(key, value)
 
-def textclause(text):
+def text(text):
     return TextClause(text)
 
 def sequence():
@@ -142,6 +138,9 @@ def _compound_select(keyword, *selects, **params):
 
     return s
 
+def _is_literal(element):
+    return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
+
 class ClauseVisitor(schema.SchemaVisitor):
     """builds upon SchemaVisitor to define the visiting of SQL statement elements in 
     addition to Schema elements."""
@@ -327,14 +326,13 @@ class CompoundClause(ClauseElement):
         return CompoundClause(self.operator, *clauses)
         
     def append(self, clause):
-        if type(clause) == str:
-            clause = TextClause(clause)
+        if _is_literal(clause):
+            clause = TextClause(str(clause))
         elif isinstance(clause, CompoundClause):
             clause.parens = True
-            
         self.clauses.append(clause)
         self.fromobj += clause._get_from_objects()
-        
+
     def accept_visitor(self, visitor):
         for c in self.clauses:
             c.accept_visitor(visitor)
@@ -364,8 +362,6 @@ class BinaryClause(ClauseElement):
     def __init__(self, left, right, operator):
         self.left = left
         self.right = right
-        if isinstance(right, Select):
-            right._set_from_objects([])
         self.operator = operator
         self.parens = False
 
@@ -391,7 +387,6 @@ class Selectable(FromClause):
     c = property(lambda self: self.columns)
 
     def accept_visitor(self, visitor):
-        print repr(self.__class__)
         raise NotImplementedError()
     
     def select(self, whereclauses = None, **params):
@@ -414,19 +409,16 @@ class Join(Selectable):
 
     def hash_key(self):
         return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
-        
-    def add_join(self, join):
-        pass
-        
+
     def select(self, whereclauses = None, **params):
         return select([self.left, self.right], and_(self.onclause, whereclauses), **params)
-    
+
     def accept_visitor(self, visitor):
         self.left.accept_visitor(visitor)
         self.right.accept_visitor(visitor)
         self.onclause.accept_visitor(visitor)
         visitor.visit_join(self)
-            
+
     def _engine(self):
         return self.left._engine() or self.right._engine()
         
@@ -434,7 +426,7 @@ class Join(Selectable):
         m = {}
         for x in self.onclause._get_from_objects():
             m[x.id] = x
-        result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()] 
+        result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
         for x in result:
             m[x.id] = x
         result = m.values()
@@ -493,7 +485,7 @@ class ColumnSelectable(Selectable):
         return [self.column.table]
     
     def _compare(self, operator, obj):
-        if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column):
+        if _is_literal(obj):
             if self.column.table.name is None:
                 obj = BindParamClause(self.name, obj, shortname = self.name)
             else:
@@ -516,12 +508,18 @@ class ColumnSelectable(Selectable):
     def __gt__(self, other):
         return self._compare('>', other)
 
-    def __ge__(self, other):    
+    def __ge__(self, other):
         return self._compare('>=', other)
-        
+
     def like(self, other):
         return self._compare('LIKE', other)
-    
+
+    def in_(self, *other):
+        if _is_literal(other[0]):
+            return self._compare('IN', CompoundClause(',', other))
+        else:
+            return self._compare('IN', union(*other))
+
     def startswith(self, other):
         return self._compare('LIKE', str(other) + "%")
     
@@ -578,6 +576,10 @@ class Select(Selectable):
         self.whereclause = whereclause
         self.engine = engine
         
+        # indicates if this select statement is a subquery inside of a WHERE clause
+        # note this is different from a subquery inside the FROM list
+        self.issubquery = False
+        
         self._text = None
         self._raw_columns = []
         self._clauses = []
@@ -598,14 +600,14 @@ class Select(Selectable):
             self.order_by(*order_by)
 
     def append_column(self, column):
-        if type(column) == str:
-            column = ColumnClause(column, self)
+        if _is_literal(column):
+            column = ColumnClause(str(column), self)
 
         self._raw_columns.append(column)
 
         for f in column._get_from_objects():
             self.froms.setdefault(f.id, f)
-                
+
         for co in column.columns:
             if self.use_labels:
                 co._make_proxy(self, name = co.label)
@@ -615,18 +617,21 @@ class Select(Selectable):
     def set_whereclause(self, whereclause):
         if type(whereclause) == str:
             self.whereclause = TextClause(whereclause)
-            
-        for f in self.whereclause._get_from_objects():
-            self.froms.setdefault(f.id, f)
 
         class CorrelatedVisitor(ClauseVisitor):
             def visit_select(s, select):
                 for f in self.froms.keys():
                     select.clear_from(f)
+                    select.issubquery = True
         self.whereclause.accept_visitor(CorrelatedVisitor())
+
+        for f in self.whereclause._get_from_objects():
+            self.froms.setdefault(f.id, f)
+
    
     def clear_from(self, id):
         self.append_from(FromClause(from_name = None, from_key = id))
+        
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = FromClause(from_name = fromclause)
@@ -658,8 +663,6 @@ class Select(Selectable):
         return engine.compile(self, bindparams)
 
     def accept_visitor(self, visitor):
-#        for c in self._raw_columns:
-#            c.accept_visitor(visitor)
         for f in self.froms.values():
             f.accept_visitor(visitor)
         if self.whereclause is not None:
@@ -689,11 +692,11 @@ class Select(Selectable):
             
         return None
 
-    def _set_from_objects(self, obj):
-        self._from_obj = obj
-        
     def _get_from_objects(self):
-        return getattr(self, '_from_obj', [self])
+        if self.issubquery:
+            return []
+        else:
+            return [self]
 
 
 class UpdateBase(ClauseElement):
@@ -709,8 +712,8 @@ class UpdateBase(ClauseElement):
         for key in parameters.keys():
             value = parameters[key]
             if isinstance(value, Select):
-                value.append_from(FromClause(from_key=self.table.id))
-            elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+                value.clear_from(self.table.id)
+            elif _is_literal(value):
                 try:
                     col = self.table.c[key]
                     parameters[key] = bindparam(col.name, value)
@@ -747,7 +750,7 @@ class UpdateBase(ClauseElement):
         for c in self.table.columns:
             if d.has_key(c):
                 value = d[c]
-                if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+                if _is_literal(value):
                     value = bindparam(c.name, value)
                 values.append((c, value))
         return values
index 2d3f23eb6199370edd8db8065b06a844243154f0..b43636333f8637b5da9b0e0749755d35c603c080 100644 (file)
@@ -30,24 +30,30 @@ table3 = Table(
     Column('otherstuff', 5),
 )
 
-class SelectTest(PersistTest):
+class SQLTest(PersistTest):
+    def runtest(self, clause, result, engine = None, params = None):
+        c = clause.compile(engine, params)
+        print "\n" + str(c) + repr(c.get_params())
+        cc = re.sub(r'\n', '', str(c))
+        self.assert_(cc == result)
+
+class SelectTest(SQLTest):
 
     def testtext(self):
         self.runtest(
-            textclause("select * from foo where lala = bar") ,
+            text("select * from foo where lala = bar") ,
             "select * from foo where lala = bar",
             engine = db
         )
-    
+
     def testtableselect(self):
         self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
 
         self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
 myothertable.othername FROM mytable, myothertable")
-        
+
     def testsubquery(self):
-    
-        s = select([table], table.c.name == 'jack')    
+        s = select([table], table.c.name == 'jack')
         self.runtest(
             select(
                 [s],
@@ -269,10 +275,7 @@ mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \
 myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)",
             engine = oracle.engine(use_ansi = False))
 
-
-
     def testbindparam(self):
-        #return
         self.runtest(select(
                     [table, table2],
                     and_(table.c.id == table2.c.id,
@@ -283,7 +286,30 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
 FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable.name = :mytablename"
                 )
 
+    def testcorrelatedsubquery(self):
+        self.runtest(
+            select([table], table.c.id == select([table2.c.id], table.c.name == table2.c.name)),
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid FROM myothertable WHERE mytable.name = myothertable.othername)"
+        )
+
+        self.runtest(
+            select([table], exists([1], table2.c.id == table.c.id)),
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)"
+        )
+
+        s = subquery('sq2', [table], exists([1], table2.c.id == table.c.id))
+        self.runtest(
+            select([s, table])
+        ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)) sq2, mytable")
+
+    def testin(self):
+        self.runtest(select([table], table.c.id.in_(1, 2, 3)),
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (1, 2, 3)")
 
+        self.runtest(select([table], table.c.id.in_(select([table2.c.id]))),
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid FROM myothertable)")
+    
+class CRUDTest(SQLTest):
     def testinsert(self):
         # generic insert, will create bind params for all columns
         self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
@@ -315,7 +341,7 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
 
     def testcorrelatedupdate(self):
         # test against a straight text subquery
-        u = update(table, values = {table.c.name : TextClause("select name from mytable where id=mytable.id")})
+        u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")})
         self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
         
         # test against a regular constructed subquery
@@ -326,12 +352,6 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
     def testdelete(self):
         self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
         
-        
-    def runtest(self, clause, result, engine = None, params = None):
-        c = clause.compile(engine, params)
-        print "\n" + str(c) + repr(c.get_params())
-        cc = re.sub(r'\n', '', str(c))
-        self.assert_(cc == result)
 
 if __name__ == "__main__":
     unittest.main()