]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
made order_by/group_by construction a little more simplisitc
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 05:00:53 +0000 (05:00 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 05:00:53 +0000 (05:00 +0000)
fix to mapper extension
CompoundSelect can export all columns now, not sure if theres any advantage there

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/sql.py

index adfa36918bdbca780653e0758bca86226a5b9b01..dfc15a3832ff3db9ffe39bdcc4227c8d422540da 100644 (file)
@@ -233,8 +233,12 @@ class ANSICompiler(sql.Compiled):
         
     def visit_compound_select(self, cs):
         text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
-        for tup in cs.clauses:
-            text += " " + tup[0] + " " + self.get_str(tup[1])
+        group_by = self.get_str(cs.group_by_clause)
+        if group_by:
+            text += " GROUP BY " + group_by
+        order_by = self.get_str(cs.order_by_clause)
+        if order_by:
+            text += " ORDER BY " + order_by
         if cs.parens:
             self.strings[cs] = "(" + text + ")"
         else:
@@ -361,10 +365,13 @@ class ANSICompiler(sql.Compiled):
             if t:
                 text += " \nWHERE " + t
 
-        for tup in select.clauses:
-            ss = self.get_str(tup[1])
-            if ss:
-                text += " " + tup[0] + " " + ss
+        group_by = self.get_str(select.group_by_clause)
+        if group_by:
+            text += " GROUP BY " + group_by
+
+        order_by = self.get_str(select.order_by_clause)
+        if order_by:
+            text += " ORDER BY " + order_by
 
         if select.having is not None:
             t = self.get_str(select.having)
index 4736a2dda1f62f10b6cc17e85cab1c806752816a..c673cb9617e9a2ae78d139caf94d08699a4ad87c 100644 (file)
@@ -276,9 +276,8 @@ class OracleCompiler(ansisql.ANSICompiler):
             return
         if select.limit is not None or select.offset is not None:
             select._oracle_visit = True
-            if hasattr(select, "order_by_clause"):
-                orderby = self.strings[select.order_by_clause]
-            else:
+            orderby = self.strings[select.order_by_clause]
+            if not orderby:
                 # to use ROW_NUMBER(), an ORDER BY is required.  so here we dig in
                 # as best we can to find some column we can order by
                 # TODO: try to get "oid_column" to be used here
index a9f7df63c6f418a7629f8a849d852a449cd342d4..d6b1fb015e578c905f319b2476d301df1f4561f8 100644 (file)
@@ -1062,7 +1062,7 @@ class MapperExtension(object):
         if self.next is None:
             return EXT_PASS
         else:
-            return self.next.populate_instance(row, imap, result, instance, isnew)
+            return self.next.populate_instance(mapper, instance, row, identitykey, imap, isnew)
     def before_insert(self, mapper, instance):
         """called before an object instance is INSERTed into its table.
         
index 52b75b1eeb6f305becf826e97c45d7439bd950ae..45a95795fe3e6c18f141439ced810ccd262dff05 100644 (file)
@@ -733,8 +733,6 @@ class ClauseList(ClauseElement):
         self.clauses.append(clause)
     def accept_visitor(self, visitor):
         for c in self.clauses:
-            if c is None:
-                raise "oh weird" + repr(self.clauses)
             c.accept_visitor(visitor)
         visitor.visit_clauselist(self)
     def _get_from_objects(self):
@@ -1141,31 +1139,19 @@ class TableClause(FromClause):
 class SelectBaseMixin(object):
     """base class for Select and CompoundSelects"""
     def order_by(self, *clauses):
-        self._append_clause('order_by_clause', "ORDER BY", *clauses)
+        if clauses[0] is None:
+            self.order_by_clause = ClauseList()
+        elif getattr(self, 'order_by_clause', None):
+            self.order_by_clause = ClauseList(*(list(clauses)+list(self.order_by_clause.clauses)))
+        else:
+            self.order_by_clause = ClauseList(*clauses)
     def group_by(self, *clauses):
-        self._append_clause('group_by_clause', "GROUP BY", *clauses)
-    def _append_clause(self, attribute, prefix, *clauses):
-        if len(clauses) == 1 and clauses[0] is None:
-            try:
-                delattr(self, attribute)
-            except AttributeError:
-                pass
-            return
-        if not hasattr(self, attribute):
-            l = ClauseList(*clauses)
-            setattr(self, attribute, l)
+        if clauses[0] is None:
+            self.group_by_clause = ClauseList()
+        elif getattr(self, 'group_by_clause', None):
+            self.group_by_clause = ClauseList(*(list(clauses)+list(self.group_by_clause.clauses)))
         else:
-            getattr(self, attribute).clauses  += clauses
-    def _get_clauses(self):
-        # TODO: this is a little stupid.  make ORDER BY/GROUP BY keywords handled by 
-        # the compiler, make group_by_clause/order_by_clause regular attributes
-        x =[]
-        if getattr(self, 'group_by_clause', None):
-            x.append(("GROUP BY", self.group_by_clause))
-        if getattr(self, 'order_by_clause', None):
-            x.append(("ORDER BY", self.order_by_clause))
-        return x
-    clauses = property(_get_clauses)
+            self.group_by_clause = ClauseList(*clauses)
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
     def _get_from_objects(self):
@@ -1186,23 +1172,23 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         for s in self.selects:
             s.group_by(None)
             s.order_by(None)
-        group_by = kwargs.get('group_by', None)
-        if group_by:
-            self.group_by(*group_by)
-        order_by = kwargs.get('order_by', None)
-        if order_by:
-            self.order_by(*order_by)
+        self.group_by(*kwargs.get('group_by', [None]))
+        self.order_by(*kwargs.get('order_by', [None]))
+
     def _exportable_columns(self):
-        return self.selects[0].columns
+        for s in self.selects:
+            for c in s.c:
+                yield c
+
     def _proxy_column(self, column):
-        self._columns[column.key] = column
-        if column.primary_key:
-            self._primary_key.append(column)
-        if column.foreign_key:
-            self._foreign_keys.append(column)
+        if self.use_labels:
+            return column._make_proxy(self, name=column._label)
+        else:
+            return column._make_proxy(self, name=column.name)
+        
     def accept_visitor(self, visitor):
-        for tup in self.clauses:
-            tup[1].accept_visitor(visitor)
+        self.order_by_clause.accept_visitor(visitor)
+        self.group_by_clause.accept_visitor(visitor)
         for s in self.selects:
             s.accept_visitor(visitor)
         visitor.visit_compound_select(self)
@@ -1251,6 +1237,9 @@ class Select(SelectBaseMixin, FromClause):
         self._correlated = None
         self._correlator = Select.CorrelatedVisitor(self, False)
         self._wherecorrelator = Select.CorrelatedVisitor(self, True)
+
+        self.group_by(*(group_by or [None]))
+        self.order_by(*(order_by or [None]))
         
         if columns is not None:
             for c in columns:
@@ -1263,11 +1252,7 @@ class Select(SelectBaseMixin, FromClause):
             
         for f in from_obj:
             self.append_from(f)
-
-        if group_by:
-            self.group_by(*group_by)
-        if order_by:
-            self.order_by(*order_by)
+        
             
     class CorrelatedVisitor(ClauseVisitor):
         """visits a clause, locates any Select clauses, and tells them that they should
@@ -1355,8 +1340,8 @@ class Select(SelectBaseMixin, FromClause):
             self.whereclause.accept_visitor(visitor)
         if self.having is not None:
             self.having.accept_visitor(visitor)
-        for tup in self.clauses:
-            tup[1].accept_visitor(visitor)
+        self.order_by_clause.accept_visitor(visitor)
+        self.group_by_clause.accept_visitor(visitor)
         visitor.visit_select(self)
     
     def union(self, other, **kwargs):