]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some clenaup on the "correlation" API on the _Select class
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2006 22:09:27 +0000 (22:09 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Nov 2006 22:09:27 +0000 (22:09 +0000)
lib/sqlalchemy/sql.py

index b5faf37fef75cff6d7ac30f434ae3b0f7f48610a..8605d5c0c5dd5fbb3be8fb867fa267327369ae24 100644 (file)
@@ -1381,7 +1381,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         self.keyword = keyword
         self.use_labels = kwargs.pop('use_labels', False)
         self.parens = kwargs.pop('parens', False)
-        self.correlate = kwargs.pop('correlate', False)
+        self.should_correlate = kwargs.pop('correlate', False)
         self.for_update = kwargs.pop('for_update', False)
         self.nowait = kwargs.pop('nowait', False)
         self.limit = kwargs.pop('limit', None)
@@ -1458,9 +1458,10 @@ class Select(_SelectBaseMixin, FromClause):
         # act like a single scalar column
         self.is_scalar = scalar
 
-        # indicates if this select statement, as a subquery, should correlate
-        # its FROM clause to that of an enclosing select statement
-        self.correlate = correlate
+        # indicates if this select statement, as a subquery, should automatically correlate
+        # its FROM clause to that of an enclosing select statement.
+        # note that the "correlate" method can be used to explicitly add a value to be correlated.
+        self.should_correlate = correlate
         
         # indicates if this select statement is a subquery inside another query
         self.is_subquery = False
@@ -1471,7 +1472,7 @@ class Select(_SelectBaseMixin, FromClause):
         
         self.distinct = distinct
         self._raw_columns = []
-        self._correlated = None
+        self.__correlated = {}
         self.__correlator = Select._CorrelatedVisitor(self, False)
         self.__wherecorrelator = Select._CorrelatedVisitor(self, True)
 
@@ -1508,10 +1509,9 @@ class Select(_SelectBaseMixin, FromClause):
             select.is_where = self.is_where
             select.is_subquery = True
             select.parens = True
-            if not select.correlate:
+            if not select.should_correlate:
                 return
-            if getattr(select, '_correlated', None) is None:
-                select._correlated = self.select._Select__froms
+            [select.correlate(x) for x in self.select._Select__froms]
                 
     def append_column(self, column):
         if _is_literal(column):
@@ -1548,9 +1548,13 @@ class Select(_SelectBaseMixin, FromClause):
             setattr(self, attribute, and_(getattr(self, attribute), condition))
         else:
             setattr(self, attribute, condition)
-
-    def clear_from(self, from_obj):
-        self.__froms[from_obj] = FromClause()
+    
+    def correlate(self, from_obj):
+        """given a FROM object, correlate this SELECT statement to it.  
+        
+        this basically means the given from object will not come out in this select statement's FROM 
+        clause when printed."""
+        self.__correlated[from_obj] = from_obj
         
     def append_from(self, fromclause):
         if type(fromclause) == str:
@@ -1569,7 +1573,7 @@ class Select(_SelectBaseMixin, FromClause):
         else:
             return None
     def _get_froms(self):
-        return [f for f in self.__froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f))]
+        return [f for f in self.__froms.values() if f is not self and (f not in self.__correlated)]
     froms = property(lambda s: s._get_froms(), doc="""a list containing all elements of the FROM clause""")
 
     def accept_visitor(self, visitor):
@@ -1635,7 +1639,7 @@ class _UpdateBase(ClauseElement):
         for key in parameters.keys():
             value = parameters[key]
             if isinstance(value, Select):
-                value.clear_from(self.table)
+                value.correlate(self.table)
             elif _is_literal(value):
                 if _is_literal(key):
                     col = self.table.c[key]