]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Nov 2005 03:05:33 +0000 (03:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 5 Nov 2005 03:05:33 +0000 (03:05 +0000)
lib/sqlalchemy/sql.py

index 736dc2f95acbb111f9257e02fc28993be8202fd7..506fb514fd3463e18f4a06524e3b3d907ad023a9 100644 (file)
@@ -727,14 +727,18 @@ class Select(Selectable):
         self._engine = engine
         self.rowid_column = None
         
-        # indicates if this select statement is a subquery inside of a WHERE clause
-        # note this is different from a subquery inside the FROM list
+        # indicates if this select statement is a subquery inside another query
         self.issubquery = False
+        # indicates if this select statement is a subquery as a criterion
+        # inside of a WHERE clause
+        self.is_where = False
         
         self._text = None
         self._raw_columns = []
         self._clauses = []
         self._correlated = None
+        self._correlator = Select.CorrelatedVisitor(self, False)
+        self._wherecorrelator = Select.CorrelatedVisitor(self, True)
         
         for c in columns:
             self.append_column(c)
@@ -752,14 +756,17 @@ class Select(Selectable):
             self.order_by(*order_by)
 
     class CorrelatedVisitor(ClauseVisitor):
-        def __init__(self, select):
+        """visits a clause, locates any Select clauses, and tells them that they should correlate their FROM list to that of their parent."""
+        def __init__(self, select, is_where):
             self.select = select
+            self.is_where = is_where
         def visit_select(self, select):
-            print "visit"
             if select is self.select:
                 return
+            select.is_where = self.is_where
             select.issubquery = True
-            select._correlated = self.select._froms
+            if select._correlated is None:
+                select._correlated = self.select._froms
 
     def append_column(self, column):
         if _is_literal(column):
@@ -767,11 +774,8 @@ class Select(Selectable):
 
         self._raw_columns.append(column)
 
-
         for f in column._get_from_objects():
-# TODO
-#            visitor = Select.CorrelatedVisitor(self)
-#            f.accept_visitor(visitor)
+            f.accept_visitor(self._correlator)
             if self.rowid_column is None and hasattr(f, 'rowid_column'):
                 self.rowid_column = f.rowid_column._make_proxy(self)
         column._process_from_dict(self._froms, False)
@@ -786,8 +790,7 @@ class Select(Selectable):
         if type(whereclause) == str:
             whereclause = TextClause(whereclause)
 
-        visitor = Select.CorrelatedVisitor(self)
-        whereclause.accept_visitor(visitor)
+        whereclause.accept_visitor(self._wherecorrelator)
         whereclause._process_from_dict(self._froms, False)
         
         if self.whereclause is not None:
@@ -802,9 +805,7 @@ class Select(Selectable):
         if type(fromclause) == str:
             fromclause = FromClause(from_name = fromclause)
 
-#        visitor = Select.CorrelatedVisitor(self)
-#        fromclause.accept_visitor(visitor)
-
+        fromclause.accept_visitor(self._correlator)
         fromclause._process_from_dict(self._froms, True)
         
     def append_clause(self, keyword, clause):
@@ -816,6 +817,7 @@ class Select(Selectable):
     def compile(self, engine = None, bindparams = None):
         if engine is None:
             engine = self.engine
+        print "ok, and engine is " + repr(self.engine)
         if engine is None:
             raise "no engine supplied, and no engine could be located within the clauses!"
 
@@ -849,12 +851,16 @@ class Select(Selectable):
         """tries to return a SQLEngine, either explicitly set in this object, or searched
         within the from clauses for one"""
         
-        if self._engine:
+        if self._engine is not None:
             return self._engine
         
-        for f in self.froms:
+        for f in self._froms.values():
+            print repr(self) + " looking in " + repr(f)
             e = f.engine
-            if e is not None:
+            print " and its " + repr(e)
+            if e is not None: 
+                self._engine = e
+                print "returning it !"
                 return e
             
         return None
@@ -862,7 +868,7 @@ class Select(Selectable):
     engine = property(lambda s: s._find_engine())
     
     def _get_from_objects(self):
-        if self.issubquery:
+        if self.is_where:
             return []
         else:
             return [self]