]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the full featureset of the SelectResults extension has been merged
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Mar 2007 02:49:12 +0000 (02:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Mar 2007 02:49:12 +0000 (02:49 +0000)
into a new set of methods available off of Query.  These methods
all provide "generative" behavior, whereby the Query is copied
and a new one returned with additional criterion added.
The new methods include:

  filter() - applies select criterion to the query
  filter_by() - applies "by"-style criterion to the query
  avg() - return the avg() function on the given column
  join() - join to a property (or across a list of properties)
  outerjoin() - like join() but uses LEFT OUTER JOIN
  limit()/offset() - apply LIMIT/OFFSET
  range-based access which applies limit/offset:
     session.query(Foo)[3:5]
  distinct() - apply DISTINCT
  list() - evaluate the criterion and return results

no incompatible changes have been made to Query's API and no methods
have been deprecated.  Existing methods like select(), select_by(),
get(), get_by() all execute the query at once and return results
like they always did.  join_to()/join_via() are still there although
the generative join()/outerjoin() methods are easier to use.

- the return value for multiple mappers used with instances() now returns
a cartesian product of the requested list of mappers, represented
as a list of tuples.  this corresponds to the documented behavior.
So that instances match up properly, the "uniquing" is disabled when
this feature is used.
- strings and columns can also be sent to the *args of instances() where
those exact result columns will be part of the result tuples.
- query() method is added by assignmapper.  this helps with
navigating to all the new generative methods on Query.

CHANGES
lib/sqlalchemy/ext/assignmapper.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
test/ext/selectresults.py
test/orm/alltests.py
test/orm/generative.py [new file with mode: 0644]
test/orm/inheritance5.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 96391716c076a815c4c6379184ad4454c4dd2717..35ffc81d64769e5b175ad11c58770e8efcc015c1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     - fixed use_alter flag on ForeignKeyConstraint [ticket:503]
     - fixed usage of 2.4-only "reversed" in topological.py [ticket:506]
 - orm:
+    - the full featureset of the SelectResults extension has been merged
+      into a new set of methods available off of Query.  These methods
+      all provide "generative" behavior, whereby the Query is copied
+      and a new one returned with additional criterion added.  
+      The new methods include:
+
+          filter() - applies select criterion to the query
+          filter_by() - applies "by"-style criterion to the query
+          avg() - return the avg() function on the given column
+          join() - join to a property (or across a list of properties)
+          outerjoin() - like join() but uses LEFT OUTER JOIN
+          limit()/offset() - apply LIMIT/OFFSET
+          range-based access which applies limit/offset:  
+             session.query(Foo)[3:5]
+          distinct() - apply DISTINCT
+          list() - evaluate the criterion and return results
+          
+      no incompatible changes have been made to Query's API and no methods
+      have been deprecated.  Existing methods like select(), select_by(),
+      get(), get_by() all execute the query at once and return results
+      like they always did.  join_to()/join_via() are still there although
+      the generative join()/outerjoin() methods are easier to use.
+      
+    - the return value for multiple mappers used with instances() now returns
+      a cartesian product of the requested list of mappers, represented
+      as a list of tuples.  this corresponds to the documented behavior.
+      So that instances match up properly, the "uniquing" is disabled when 
+      this feature is used.
+    - strings and columns can also be sent to the *args of instances() where
+      those exact result columns will be part of the result tuples.
     - a full select() construct can be passed to query.select() (which
       worked anyway), but also query.selectfirst(), query.selectone() which
       will be used as is (i.e. no query is compiled). works similarly to
@@ -46,7 +76,9 @@
 - extensions:
     - options() method on SelectResults now implemented "generatively"
       like the rest of the SelectResults methods [ticket:472]
-
+    - query() method is added by assignmapper.  this helps with 
+      navigating to all the new generative methods on Query.
+    
 0.3.5
 - sql:
     - the value of "case_sensitive" defaults to True now, regardless of the
index 178f150e5475c9f03b2706047b1034fd32fd16c4..aee96f06eaeca2b9830d780ade1fc0b516e69f02 100644 (file)
@@ -34,6 +34,7 @@ def assign_mapper(ctx, class_, *args, **kwargs):
         extension = ctx.mapper_extension
     m = mapper(class_, extension=extension, *args, **kwargs)
     class_.mapper = m
+    class_.query = classmethod(lambda cls: Query(class_, session=ctx.current))
     for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']:
         monkeypatch_query_method(ctx, class_, name)
     for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']:
index e1fa56c650e0917102fa546fca8dc1b9a463c179..d28445be61bf6f2740406dd2610a11b996d4175f 100644 (file)
@@ -1695,6 +1695,9 @@ class _ExtensionCarrier(MapperExtension):
     def __init__(self):
         self.__elements = []
 
+    def __iter__(self):
+        return iter(self.__elements)
+        
     def insert(self, extension):
         """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
 
@@ -1766,7 +1769,7 @@ class ExtensionOption(MapperOption):
         self.ext = ext
 
     def process_query(self, query):
-        query._insert_extension(self.ext)
+        query.extension.append(self.ext)
 
 class ClassKey(object):
     """Key a class and an entity name to a mapper, via the mapper_registry."""
index 8df5628d15b46391205fcdb08a5cdefb76d9b4cd..6650954e1bbd5feaeb11236bba32cc090530c5b0 100644 (file)
@@ -21,7 +21,6 @@ class Query(object):
         self.with_options = with_options or []
         self.select_mapper = self.mapper.get_select_mapper().compile()
         self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
-        self.order_by = kwargs.pop('order_by', self.mapper.order_by)
         self.lockmode = lockmode
         self.extension = mapper._ExtensionCarrier()
         if extension is not None:
@@ -35,12 +34,40 @@ class Query(object):
                 _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True))
             self.mapper._get_clause = _get_clause
         self._get_clause = self.mapper._get_clause
-        for opt in util.flatten_iterator(self.with_options):
-            opt.process_query(self)
 
-    def _insert_extension(self, ext):
-        self.extension.insert(ext)
+        self._order_by = kwargs.pop('order_by', False)
+        self._distinct = kwargs.pop('distinct', False)
+        self._offset = kwargs.pop('offset', None)
+        self._limit = kwargs.pop('limit', None)
+        self._criterion = None
+        self._joinpoint = self.mapper
+        self._from_obj = [self.table]
 
+        for opt in util.flatten_iterator(self.with_options):
+            opt.process_query(self)
+        
+    def _clone(self):
+        q = Query.__new__(Query)
+        q.mapper = self.mapper
+        q.select_mapper = self.select_mapper
+        q._order_by = self._order_by
+        q._distinct = self._distinct
+        q.always_refresh = self.always_refresh
+        q.with_options = list(self.with_options)
+        q._session = self.session
+        q.is_polymorphic = self.is_polymorphic
+        q.lockmode = self.lockmode
+        q.extension = mapper._ExtensionCarrier()
+        for ext in self.extension:
+            q.extension.append(ext)
+        q._offset = self._offset
+        q._limit = self._limit
+        q._get_clause = self._get_clause
+        q._from_obj = list(self._from_obj)
+        q._joinpoint = self._joinpoint
+        q._criterion = self._criterion
+        return q
+    
     def _get_session(self):
         if self._session is None:
             return self.mapper.get_session()
@@ -90,20 +117,8 @@ class Query(object):
         """Return a single object instance based on the given
         key/value criterion.
 
-        This is either the first value in the result list, or None if
-        the list is empty.
-
-        The keys are mapped to property or column names mapped by this
-        mapper's Table, and the values are coerced into a ``WHERE``
-        clause separated by ``AND`` operators.  If the local
-        property/column names dont contain the key, a search will be
-        performed against this mapper's immediate list of relations as
-        well, forming the appropriate join conditions if a matching
-        property is located.
-
-        E.g.::
-
-          u = usermapper.get_by(user_name = 'fred')
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
         """
 
         ret = self.extension.get_by(self, *args, **params)
@@ -131,6 +146,11 @@ class Query(object):
         mapper's immediate list of relations as well, forming the
         appropriate join conditions if a matching property is located.
 
+        if the located property is a column-based property, the comparison
+        value should be a scalar with an appropriate type.  If the 
+        property is a relationship-bound property, the comparison value
+        should be an instance of the related class.
+
         E.g.::
 
           result = usermapper.select_by(user_name = 'fred')
@@ -145,61 +165,13 @@ class Query(object):
         """Return a ``ClauseElement`` representing the ``WHERE``
         clause that would normally be sent to ``select_whereclause()``
         by ``select_by()``.
-        """
 
-        return self._join_by(args, params)
-
-    def _join_by(self, args, params, start=None):
-        """Return a ``ClauseElement`` representing the ``WHERE``
-        clause that would normally be sent to ``select_whereclause()``
-        by ``select_by()``.
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
         """
 
-        clause = None
-        for arg in args:
-            if clause is None:
-                clause = arg
-            else:
-                clause &= arg
+        return self._join_by(args, params)
 
-        for key, value in params.iteritems():
-            (keys, prop) = self._locate_prop(key, start=start)
-            c = prop.compare(value) & self.join_via(keys)
-            if clause is None:
-                clause =  c
-            else:
-                clause &= c
-        return clause
-
-    def _locate_prop(self, key, start=None):
-        import properties
-        keys = []
-        seen = util.Set()
-        def search_for_prop(mapper_):
-            if mapper_ in seen:
-                return None
-            seen.add(mapper_)
-            if mapper_.props.has_key(key):
-                prop = mapper_.props[key]
-                if isinstance(prop, SynonymProperty):
-                    prop = mapper_.props[prop.name]
-                if isinstance(prop, properties.PropertyLoader):
-                    keys.insert(0, prop.key)
-                return prop
-            else:
-                for prop in mapper_.props.values():
-                    if not isinstance(prop, properties.PropertyLoader):
-                        continue
-                    x = search_for_prop(prop.mapper)
-                    if x:
-                        keys.insert(0, prop.key)
-                        return x
-                else:
-                    return None
-        p = search_for_prop(start or self.mapper)
-        if p is None:
-            raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key)
-        return [keys, p]
 
     def join_to(self, key):
         """Given the key name of a property, will recursively descend
@@ -236,6 +208,9 @@ class Query(object):
         """Like ``select_by()``, but only return the first result by
         itself, or None if no objects returned.  Synonymous with
         ``get_by()``.
+
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
         """
 
         return self.get_by(*args, **params)
@@ -243,6 +218,9 @@ class Query(object):
     def selectone_by(self, *args, **params):
         """Like ``selectfirst_by()``, but throws an error if not
         exactly one result was returned.
+
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
         """
 
         ret = self.select_whereclause(self.join_by(*args, **params), limit=2)
@@ -326,15 +304,20 @@ class Query(object):
         """Given a ``WHERE`` criterion, create a ``SELECT COUNT``
         statement, execute and return the resulting count value.
         """
+        if self._criterion:
+            if whereclause is not None:
+                whereclause = sql.and_(self._criterion, whereclause)
+            else:
+                whereclause = self._criterion
+        from_obj = kwargs.pop('from_obj', self._from_obj)
+        kwargs.setdefault('distinct', self._distinct)
 
-        from_obj = kwargs.pop('from_obj', [])
         alltables = []
         for l in [sql_util.TableFinder(x) for x in from_obj]:
             alltables += l
-
+        
         if self.table not in alltables:
             from_obj.append(self.table)
-
         if self._nestable(**kwargs):
             s = sql.select([self.table], whereclause, **kwargs).alias('getcount').count()
         else:
@@ -361,14 +344,201 @@ class Query(object):
         """Return a new Query object, applying the given list of
         MapperOptions.
         """
-
-        return Query(self.mapper, self._session, with_options=args)
+        q = self._clone()
+        for opt in util.flatten_iterator(args):
+            q.with_options.append(opt)
+            opt.process_query(q)
+        for opt in util.flatten_iterator(self.with_options):
+            opt.process_query(self)
+        return q
 
     def with_lockmode(self, mode):
         """Return a new Query object with the specified locking mode."""
+        q = self._clone()
+        q.lockmode = mode
+        return q
+    
+    def filter(self, criterion):
+        """apply the given filtering criterion to the query and return the newly resulting ``Query``
+        
+        the criterion is any sql.ClauseElement applicable to the WHERE clause of a select.
+        """
+        q = self._clone()
+        if q._criterion is not None:
+            q._criterion = q._criterion & criterion
+        else:
+            q._criterion = criterion
+        return q
+
+    def filter_by(self, *args, **kwargs):
+        """apply the given filtering criterion to the query and return the newly resulting ``Query``
+
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
+        """
+        return self.filter(self._join_by(args, kwargs, start=self._joinpoint))
+
+    def _join_to(self, prop, outerjoin=False):
+        if isinstance(prop, list):
+            mapper = self._joinpoint
+            keys = []
+            for key in prop:
+                p = mapper.props[key]
+                keys.append(key)
+                mapper = p.mapper
+        else:
+            [keys,p] = self._locate_prop(prop, start=self._joinpoint)
+        clause = self._from_obj[-1]
+        mapper = self._joinpoint
+        for key in keys:
+            prop = mapper.props[key]
+            if outerjoin:
+                clause = clause.outerjoin(prop.select_table, prop.get_join(mapper))
+            else:
+                clause = clause.join(prop.select_table, prop.get_join(mapper))
+            mapper = prop.mapper
+        return (clause, mapper)
+
+    def _join_by(self, args, params, start=None):
+        """Return a ``ClauseElement`` representing the ``WHERE``
+        clause that would normally be sent to ``select_whereclause()``
+        by ``select_by()``.
 
-        return Query(self.mapper, self._session, lockmode=mode)
+        The criterion is constructed in the same way as the
+        ``select_by()`` method.
+        """
 
+        clause = None
+        for arg in args:
+            if clause is None:
+                clause = arg
+            else:
+                clause &= arg
+
+        for key, value in params.iteritems():
+            (keys, prop) = self._locate_prop(key, start=start)
+            c = prop.compare(value) & self.join_via(keys)
+            if clause is None:
+                clause =  c
+            else:
+                clause &= c
+        return clause
+
+    def _locate_prop(self, key, start=None):
+        import properties
+        keys = []
+        seen = util.Set()
+        def search_for_prop(mapper_):
+            if mapper_ in seen:
+                return None
+            seen.add(mapper_)
+            if mapper_.props.has_key(key):
+                prop = mapper_.props[key]
+                if isinstance(prop, SynonymProperty):
+                    prop = mapper_.props[prop.name]
+                if isinstance(prop, properties.PropertyLoader):
+                    keys.insert(0, prop.key)
+                return prop
+            else:
+                for prop in mapper_.props.values():
+                    if not isinstance(prop, properties.PropertyLoader):
+                        continue
+                    x = search_for_prop(prop.mapper)
+                    if x:
+                        keys.insert(0, prop.key)
+                        return x
+                else:
+                    return None
+        p = search_for_prop(start or self.mapper)
+        if p is None:
+            raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key)
+        return [keys, p]
+
+    def _col_aggregate(self, col, func):
+        """Execute ``func()`` function against the given column.
+
+        For performance, only use subselect if `order_by` attribute is set.
+        """
+
+        ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj}
+
+        if self._order_by is not False:
+            s1 = sql.select([col], self._criterion, **ops).alias('u')
+            return sql.select([func(s1.corresponding_column(col))]).scalar()
+        else:
+            return sql.select([func(col)], self._criterion, **ops).scalar()
+
+    def min(self, col):
+        """Execute the SQL ``min()`` function against the given column."""
+
+        return self._col_aggregate(col, sql.func.min)
+
+    def max(self, col):
+        """Execute the SQL ``max()`` function against the given column."""
+
+        return self._col_aggregate(col, sql.func.max)
+
+    def sum(self, col):
+        """Execute the SQL ``sum``() function against the given column."""
+
+        return self._col_aggregate(col, sql.func.sum)
+
+    def avg(self, col):
+        """Execute the SQL ``avg()`` function against the given column."""
+
+        return self._col_aggregate(col, sql.func.avg)
+    
+    def order_by(self, criterion):
+        """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
+
+        q = self._clone()
+        if q._order_by is False:    
+            q._order_by = util.to_list(criterion)
+        else:
+            q._order_by.extend(util.to_list(criterion))
+        return q
+    
+    def join(self, prop):
+        """create a join of this ``Query`` object's criterion
+        to a relationship and return the newly resulting ``Query``.
+        
+        'prop' may be a string property name in which it is located
+        in the same manner as keyword arguments in ``select_by``, or
+        it may be a list of strings in which case the property is located
+        by direct traversal of each keyname (i.e. like join_via()).
+        """
+        
+        q = self._clone()
+        (clause, mapper) = self._join_to(prop, outerjoin=False)
+        q._from_obj = [clause]
+        q._joinpoint = mapper
+        return q
+
+    def outerjoin(self, prop):
+        """create a left outer join of this ``Query`` object's criterion
+        to a relationship and return the newly resulting ``Query``.
+        
+        'prop' may be a string property name in which it is located
+        in the same manner as keyword arguments in ``select_by``, or
+        it may be a list of strings in which case the property is located
+        by direct traversal of each keyname (i.e. like join_via()).
+        """
+        q = self._clone()
+        (clause, mapper) = self._join_to(prop, outerjoin=True)
+        q._from_obj = [clause]
+        q._joinpoint = mapper
+        return q
+
+    def select_from(self, from_obj):
+        """Set the `from_obj` parameter of the query.
+
+        `from_obj` is a list of one or more tables.
+        """
+
+        new = self._clone()
+        new._from_obj = from_obj
+        return new
+        
     def __getattr__(self, key):
         if (key.startswith('select_by_')):
             key = key[10:]
@@ -383,6 +553,57 @@ class Query(object):
         else:
             raise AttributeError(key)
 
+    def __getitem__(self, item):
+        if isinstance(item, slice):
+            start = item.start
+            stop = item.stop
+            if (isinstance(start, int) and start < 0) or \
+               (isinstance(stop, int) and stop < 0):
+                return list(self)[item]
+            else:
+                res = self._clone()
+                if start is not None and stop is not None:
+                    res._offset = (self._offset or 0)+ start
+                    res._limit = stop-start
+                elif start is None and stop is not None:
+                    res._limit = stop
+                elif start is not None and stop is None:
+                    res._offset = (self._offset or 0) + start
+                if item.step is not None:
+                    return list(res)[None:None:item.step]
+                else:
+                    return res
+        else:
+            return list(self[item:item+1])[0]
+
+    def limit(self, limit):
+        """Apply a ``LIMIT`` to the query."""
+
+        return self[:limit]
+
+    def offset(self, offset):
+        """Apply an ``OFFSET`` to the query."""
+
+        return self[offset:]
+
+    def distinct(self):
+        """Apply a ``DISTINCT`` to the query."""
+
+        new = self._clone()
+        new._distinct = True
+        return new
+
+    def list(self):
+        """Return the results represented by this ``Query`` as a list.
+
+        This results in an execution of the underlying query.
+        """
+
+        return list(self)
+    
+    def __iter__(self):
+        return iter(self.select_whereclause())
+
     def execute(self, clauseelement, params=None, *args, **kwargs):
         """Execute the given ClauseElement-based statement against
         this Query's session/mapper, return the resulting list of
@@ -400,9 +621,27 @@ class Query(object):
         finally:
             result.close()
 
-    def instances(self, cursor, *mappers, **kwargs):
+    def instances(self, cursor, *mappers_or_columns, **kwargs):
         """Return a list of mapped instances corresponding to the rows
         in a given *cursor* (i.e. ``ResultProxy``).
+        
+        *mappers_or_columns is an optional list containing one or more of
+        classes, mappers, strings or sql.ColumnElements which will be
+        applied to each row and added horizontally to the result set,
+        which becomes a list of tuples. The first element in each tuple
+        is the usual result based on the mapper represented by this
+        ``Query``. Each additional element in the tuple corresponds to an
+        entry in the *mappers_or_columns list.
+        
+        For each element in *mappers_or_columns, if the element is 
+        a mapper or mapped class, an additional class instance will be 
+        present in the tuple.  If the element is a string or sql.ColumnElement, 
+        the corresponding result column from each row will be present in the tuple.
+        
+        Note that when *mappers_or_columns is present, "uniquing" for the result set
+        is *disabled*, so that the resulting tuples contain entities as they actually
+        correspond.  this indicates that multiple results may be present if this 
+        option is used.
         """
 
         self.__log_debug("instances()")
@@ -411,25 +650,36 @@ class Query(object):
 
         context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
 
-        result = util.UniqueAppender([])
-        if mappers:
-            otherresults = []
-            for m in mappers:
-                otherresults.append(util.UniqueAppender([]))
-
+        process = []
+        if mappers_or_columns:
+            for m in mappers_or_columns:
+                if isinstance(m, type):
+                    m = mapper.class_mapper(m)
+                if isinstance(m, mapper.Mapper):
+                    appender = []
+                    def proc(context, row):
+                        m._instance(context, row, appender)
+                    process.append((proc, appender))
+                elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring):
+                    res = []
+                    def proc(context, row):
+                        res.append(row[m])
+                    process.append((proc, res))
+            result = []
+        else:
+            result = util.UniqueAppender([])
+                    
         for row in cursor.fetchall():
             self.select_mapper._instance(context, row, result)
-            i = 0
-            for m in mappers:
-                m._instance(context, row, otherresults[i])
-                i+=1
+            for proc in process:
+                proc[0](context, row)
 
         # store new stuff in the identity map
         for value in context.identity_map.values():
             session._register_persistent(value)
 
-        if mappers:
-            return [result.data] + [o.data for o in otherresults]
+        if mappers_or_columns:
+            return zip(*([result] + [o[1] for o in process]))
         else:
             return result.data
 
@@ -491,11 +741,14 @@ class Query(object):
         statement suitable for usage in the execute() method.
         """
 
+        if self._criterion:
+            whereclause = sql.and_(self._criterion, whereclause)
+
         if whereclause is not None and self.is_polymorphic:
             # adapt the given WHERECLAUSE to adjust instances of this query's mapped table to be that of our select_table,
             # which may be the "polymorphic" selectable used by our mapper.
             whereclause.accept_visitor(sql_util.ClauseAdapter(self.table))
-
+            
         context = kwargs.pop('query_context', None)
         if context is None:
             context = QueryContext(self, kwargs)
@@ -506,7 +759,7 @@ class Query(object):
         limit = context.limit
         offset = context.offset
         if order_by is False:
-            order_by = self.order_by
+            order_by = self.mapper.order_by
         if order_by is False:
             if self.table.default_order_by() is not None:
                 order_by = self.table.default_order_by()
@@ -579,12 +832,12 @@ class QueryContext(OperationContext):
 
     def __init__(self, query, kwargs):
         self.query = query
-        self.order_by = kwargs.pop('order_by', False)
-        self.from_obj = kwargs.pop('from_obj', [])
+        self.order_by = kwargs.pop('order_by', query._order_by)
+        self.from_obj = kwargs.pop('from_obj', query._from_obj)
         self.lockmode = kwargs.pop('lockmode', query.lockmode)
-        self.distinct = kwargs.pop('distinct', False)
-        self.limit = kwargs.pop('limit', None)
-        self.offset = kwargs.pop('offset', None)
+        self.distinct = kwargs.pop('distinct', query._distinct)
+        self.limit = kwargs.pop('limit', query._limit)
+        self.offset = kwargs.pop('offset', query._offset)
         self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
         self.statement = None
         super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs)
index ca052ad3b250bfd06b3930159321030383731d2c..88476c9cc0b96a59428070e02b6458c0192a9543 100644 (file)
@@ -12,14 +12,15 @@ class Foo(object):
 class SelectResultsTest(PersistTest):
     def setUpAll(self):
         self.install_threadlocal()
-        global foo
-        foo = Table('foo', testbase.db,
+        global foo, metadata
+        metadata = BoundMetaData(testbase.db)
+        foo = Table('foo', metadata,
                     Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
                     Column('bar', Integer),
                     Column('range', Integer))
         
         assign_mapper(Foo, foo, extension=SelectResultsExt())
-        foo.create()
+        metadata.create_all()
         for i in range(100):
             Foo(bar=i, range=i%10)
         objectstore.flush()
@@ -30,8 +31,7 @@ class SelectResultsTest(PersistTest):
         self.res = self.query.select()
         
     def tearDownAll(self):
-        global foo
-        foo.drop()
+        metadata.drop_all()
         self.uninstall_threadlocal()
         clear_mappers()
     
index c52902ac741b7c4a3c269e83c903c718ac603132..c30bc140641e879e32aafcbca1facef2269c3d18 100644 (file)
@@ -5,6 +5,7 @@ def suite():
     modules_to_test = (
        'orm.attributes',
         'orm.mapper',
+        'orm.generative',
         'orm.lazytest1',
         'orm.eagertest1',
         'orm.eagertest2',
diff --git a/test/orm/generative.py b/test/orm/generative.py
new file mode 100644 (file)
index 0000000..37ce1dc
--- /dev/null
@@ -0,0 +1,229 @@
+from testbase import PersistTest, AssertMixin
+import testbase
+import tables
+
+from sqlalchemy import *
+
+
+class Foo(object):
+    pass
+
+class GenerativeQueryTest(PersistTest):
+    def setUpAll(self):
+        self.install_threadlocal()
+        global foo, metadata
+        metadata = BoundMetaData(testbase.db)
+        foo = Table('foo', metadata,
+                    Column('id', Integer, Sequence('foo_id_seq'), primary_key=True),
+                    Column('bar', Integer),
+                    Column('range', Integer))
+        
+        assign_mapper(Foo, foo)
+        metadata.create_all()
+        for i in range(100):
+            Foo(bar=i, range=i%10)
+        objectstore.flush()
+    
+    def setUp(self):
+        self.query = Foo.query()
+        self.orig = self.query.select_whereclause()
+        self.res = self.query
+        
+    def tearDownAll(self):
+        metadata.drop_all()
+        self.uninstall_threadlocal()
+        clear_mappers()
+    
+    def test_selectby(self):
+        res = self.query.filter_by(range=5)
+        assert res.order_by([Foo.c.bar])[0].bar == 5
+        assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
+        
+    def test_slice(self):
+        assert self.query[1] == self.orig[1]
+        assert list(self.query[10:20]) == self.orig[10:20]
+        assert list(self.query[10:]) == self.orig[10:]
+        assert list(self.query[:10]) == self.orig[:10]
+        assert list(self.query[:10]) == self.orig[:10]
+        assert list(self.query[10:40:3]) == self.orig[10:40:3]
+        assert list(self.query[-5:]) == self.orig[-5:]
+        assert self.query[10:20][5] == self.orig[10:20][5]
+
+    def test_aggregate(self):
+        assert self.query.count() == 100
+        assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
+        assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29
+
+    @testbase.unsupported('mysql')
+    def test_aggregate_1(self):
+        # this one fails in mysql as the result comes back as a string
+        assert self.query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
+
+    @testbase.unsupported('postgres', 'mysql', 'firebird')
+    def test_aggregate_2(self):
+        # this one fails with postgres, the floating point comparison fails
+        assert self.query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5
+
+    def test_filter(self):
+        assert self.query.count() == 100
+        assert self.query.filter(Foo.c.bar < 30).count() == 30
+        res2 = self.query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+        assert res2.count() == 19
+    
+    def test_options(self):
+        class ext1(MapperExtension):
+            def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew):
+                instance.TEST = "hello world"
+                return EXT_PASS
+        objectstore.clear()
+        assert self.res.options(extension(ext1()))[0].TEST == "hello world"
+        
+    def test_order_by(self):
+        assert self.res.order_by([Foo.c.bar])[0].bar == 0
+        assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99
+
+    def test_offset(self):
+        assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+        
+    def test_offset(self):
+        assert len(list(self.res.limit(10))) == 10
+
+class Obj1(object):
+    pass
+class Obj2(object):
+    pass
+
+class GenerativeTest2(PersistTest):
+    def setUpAll(self):
+        self.install_threadlocal()
+        global metadata, table1, table2
+        metadata = BoundMetaData(testbase.db)
+        table1 = Table('Table1', metadata,
+            Column('id', Integer, primary_key=True),
+            )
+        table2 = Table('Table2', metadata,
+            Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True),
+            Column('num', Integer, primary_key=True),
+            )
+        assign_mapper(Obj1, table1)
+        assign_mapper(Obj2, table2)
+        metadata.create_all()
+        table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4})
+        table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\
+{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3})
+
+    def setUp(self):
+        self.query = Query(Obj1)
+        #self.orig = self.query.select_whereclause()
+        #self.res = self.query.select()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        self.uninstall_threadlocal()
+        clear_mappers()
+
+    def test_distinctcount(self):
+        res = self.query
+        assert res.count() == 4
+        res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1))
+        assert res.count() == 3
+        res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)).distinct()
+        self.assertEqual(res.count(), 1)
+
+class RelationsTest(AssertMixin):
+    def setUpAll(self):
+        tables.create()
+        tables.data()
+    def tearDownAll(self):
+        tables.drop()
+    def tearDown(self):
+        clear_mappers()
+    def test_jointo(self):
+        """test the join and outerjoin functions on Query"""
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = session.query(tables.User)
+        x = query.join('orders').join('items').filter(tables.Item.c.item_id==2)
+        print x.compile()
+        self.assert_result(list(x), tables.User, tables.user_result[2])
+    def test_outerjointo(self):
+        """test the join and outerjoin functions on Query"""
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = session.query(tables.User)
+        x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        print x.compile()
+        self.assert_result(list(x), tables.User, *tables.user_result[1:3])
+    def test_outerjointo_count(self):
+        """test the join and outerjoin functions on Query"""
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = session.query(tables.User)
+        x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+        assert x==2
+    def test_from(self):
+        mapper(tables.User, tables.users, properties={
+            'orders':relation(mapper(tables.Order, tables.orders, properties={
+                'items':relation(mapper(tables.Item, tables.orderitems))
+            }))
+        })
+        session = create_session()
+        query = session.query(tables.User)
+        x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\
+            filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        print x.compile()
+        self.assert_result(list(x), tables.User, *tables.user_result[1:3])
+        
+
+class CaseSensitiveTest(PersistTest):
+    def setUpAll(self):
+        self.install_threadlocal()
+        global metadata, table1, table2
+        metadata = BoundMetaData(testbase.db)
+        table1 = Table('Table1', metadata,
+            Column('ID', Integer, primary_key=True),
+            )
+        table2 = Table('Table2', metadata,
+            Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True),
+            Column('NUM', Integer, primary_key=True),
+            )
+        assign_mapper(Obj1, table1)
+        assign_mapper(Obj2, table2)
+        metadata.create_all()
+        table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4})
+        table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\
+{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3})
+
+    def setUp(self):
+        self.query = Query(Obj1)
+        #self.orig = self.query.select_whereclause()
+        #self.res = self.query.select()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        self.uninstall_threadlocal()
+        clear_mappers()
+        
+    def test_distinctcount(self):
+        res = self.query
+        assert res.count() == 4
+        res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
+        assert res.count() == 3
+        res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
+        self.assertEqual(res.count(), 1)
+
+
+if __name__ == "__main__":
+    testbase.main()        
index 640a8d70fefe0bc2f306f2d442900188e8ae8985..0ae4260dcefd65ceb97ad3b86e3cd709e3a5071f 100644 (file)
@@ -1,6 +1,5 @@
 from sqlalchemy import *
 import testbase
-from sqlalchemy.ext.selectresults import SelectResults
 
 class AttrSettable(object):
     def __init__(self, **kwargs):
@@ -374,8 +373,8 @@ class RelationTest4(testbase.ORMTest):
         assert str(car1.employee) == "Engineer E4, status X"
 
         session.clear()
-        s = SelectResults(session.query(Car))
-        c = s.join_to("employee").select(employee_join.c.name=="E4")[0]
+        s = session.query(Car)
+        c = s.join("employee").select(employee_join.c.name=="E4")[0]
         assert c.car_id==car1.car_id
 
 class RelationTest5(testbase.ORMTest):
@@ -580,7 +579,7 @@ class RelationTest7(testbase.ORMTest):
         for p in r:
             assert p.car_id == p.car.car_id
     
-class SelectResultsTest(testbase.AssertMixin):
+class GenerativeTest(testbase.AssertMixin):
     def setUpAll(self):
         #  cars---owned by---  people (abstract) --- has a --- status
         #   |                  ^    ^                            |
@@ -693,9 +692,9 @@ class SelectResultsTest(testbase.AssertMixin):
 
         # test these twice because theres caching involved
         for x in range(0, 2):
-            r = SelectResults(session.query(Person)).select_by(people.c.name.like('%2')).join_to('status').select_by(name="active")
+            r = session.query(Person).filter_by(people.c.name.like('%2')).join('status').filter_by(name="active")
             assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
-            r = SelectResults(session.query(Engineer)).join_to('status').select(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active"))
+            r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active"))
             assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
         
 class MultiLevelTest(testbase.ORMTest):
index d8c23e103a2f71c0ee7bdd38a747b56d5ae12ffc..261fcc1163ba0bd8ad3a292841451459e8facf23 100644 (file)
@@ -296,16 +296,16 @@ class MapperTest(MapperSuperTest):
         sess = create_session()
         q = sess.query(m)
 
-        l = q.select((orderitems.c.item_name=='item 4') & q.join_via(['orders', 'items']))
+        l = q.filter(orderitems.c.item_name=='item 4').join(['orders', 'items']).list()
         self.assert_result(l, User, user_result[0])
         
         l = q.select_by(item_name='item 4')
         self.assert_result(l, User, user_result[0])
 
-        l = q.select((orderitems.c.item_name=='item 4') & q.join_to('item_name'))
+        l = q.filter(orderitems.c.item_name=='item 4').join('item_name').list()
         self.assert_result(l, User, user_result[0])
 
-        l = q.select((orderitems.c.item_name=='item 4') & q.join_to('items'))
+        l = q.filter(orderitems.c.item_name=='item 4').join('items').list()
         self.assert_result(l, User, user_result[0])
 
         # test comparing to an object instance
@@ -587,15 +587,29 @@ class MapperTest(MapperSuperTest):
             })
             
         sess = create_session()
-        q2 = sess.query(User).options(eagerload('orders.items.keywords'))
+        
+        # eagerload nothing.
         u = sess.query(User).select()
         def go():
             print u[0].orders[1].items[0].keywords[1]
         self.assert_sql_count(db, go, 3)
         sess.clear()
+        
+        
         print "-------MARK----------"
+        # eagerload orders, orders.items, orders.items.keywords
+        q2 = sess.query(User).options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords'))
         u = q2.select()
         print "-------MARK2----------"
+        self.assert_sql_count(db, go, 0)
+
+        sess.clear()
+        
+        # eagerload "keywords" on items.  it will lazy load "orders", then lazy load
+        # the "items" on the order, but on "items" it will eager load the "keywords"
+        print "-------MARK3----------"
+        q3 = sess.query(User).options(eagerload('orders.items.keywords'))
+        u = q3.select()
         self.assert_sql_count(db, go, 2)
         
 class InheritanceTest(MapperSuperTest):
@@ -867,7 +881,7 @@ class LazyTest(MapperSuperTest):
             addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
         ), extension=ctx.mapper_extension)
         q = ctx.current.query(m)
-        u = q.selectfirst(users.c.user_id == 7)
+        u = q.filter(users.c.user_id == 7).selectfirst()
         ctx.current.expunge(u)
         self.assert_result([u], User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
@@ -1067,85 +1081,6 @@ class EagerTest(MapperSuperTest):
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
-    def testcustomfromalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True)
-        q = create_session().query(User)
-        
-        def go():
-            l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-        
-    def testcustomeagerquery(self):
-        mapper(User, users, properties={
-            # setting lazy=True - the contains_eager() option below
-            # should imply eagerload()
-            'addresses':relation(Address, lazy=True)
-        })
-        mapper(Address, addresses)
-        
-        selectquery = users.outerjoin(addresses).select(use_labels=True)
-        q = create_session().query(User)
-        
-        def go():
-            l = q.options(contains_eager('addresses')).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithstringalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithalias(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
-
-    def testcustomeagerwithdecorator(self):
-        mapper(User, users, properties={
-            'addresses':relation(Address, lazy=False)
-        })
-        mapper(Address, addresses)
-
-        adalias = addresses.alias('adalias')
-        selectquery = users.outerjoin(adalias).select(use_labels=True)
-        def decorate(row):
-            d = {}
-            for c in addresses.columns:
-                d[c] = row[adalias.corresponding_column(c)]
-            return d
-            
-        q = create_session().query(User)
-
-        def go():
-            l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
-            self.assert_result(l, User, *user_address_result)
-        self.assert_sql_count(testbase.db, go, 1)
         
     def testorderby_desc(self):
         m = mapper(Address, addresses)
@@ -1441,7 +1376,108 @@ class EagerTest(MapperSuperTest):
                 {'item_id':5, 'item_name':'item 5'}
                ])},
         )
+
+class InstancesTest(MapperSuperTest):
+    def testcustomfromalias(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=True)
+        })
+        mapper(Address, addresses)
+        query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True)
+        q = create_session().query(User)
+        
+        def go():
+            l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute())
+            self.assert_result(l, User, *user_address_result)
+        self.assert_sql_count(testbase.db, go, 1)
+        
+    def testcustomeagerquery(self):
+        mapper(User, users, properties={
+            # setting lazy=True - the contains_eager() option below
+            # should imply eagerload()
+            'addresses':relation(Address, lazy=True)
+        })
+        mapper(Address, addresses)
+        
+        selectquery = users.outerjoin(addresses).select(use_labels=True)
+        q = create_session().query(User)
+        
+        def go():
+            l = q.options(contains_eager('addresses')).instances(selectquery.execute())
+            self.assert_result(l, User, *user_address_result)
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def testcustomeagerwithstringalias(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False)
+        })
+        mapper(Address, addresses)
+
+        adalias = addresses.alias('adalias')
+        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        q = create_session().query(User)
+
+        def go():
+            l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
+            self.assert_result(l, User, *user_address_result)
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def testcustomeagerwithalias(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False)
+        })
+        mapper(Address, addresses)
+
+        adalias = addresses.alias('adalias')
+        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        q = create_session().query(User)
+
+        def go():
+            l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
+            self.assert_result(l, User, *user_address_result)
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def testcustomeagerwithdecorator(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False)
+        })
+        mapper(Address, addresses)
+
+        adalias = addresses.alias('adalias')
+        selectquery = users.outerjoin(adalias).select(use_labels=True)
+        def decorate(row):
+            d = {}
+            for c in addresses.columns:
+                d[c] = row[adalias.corresponding_column(c)]
+            return d
+            
+        q = create_session().query(User)
+
+        def go():
+            l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
+            self.assert_result(l, User, *user_address_result)
+        self.assert_sql_count(testbase.db, go, 1)
+    
+    def testmultiplemappers(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, lazy=False)
+        })
+        mapper(Address, addresses)
+
+        selectquery = users.outerjoin(addresses).select(use_labels=True)
+        q = create_session().query(User)
+        l = q.instances(selectquery.execute(), Address)
+        # note the result is a cartesian product
+        assert repr(l) == "[(User(user_id=7,user_name=u'jack'), Address(address_id=1,user_id=7,email_address=u'jack@bean.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=2,user_id=8,email_address=u'ed@wood.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=3,user_id=8,email_address=u'ed@bettyboop.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=4,user_id=8,email_address=u'ed@lala.com'))]"
         
+        # check identity map still in effect even though dupe results
+        assert l[1][0] is l[2][0]
         
+    def testmapperspluscolumn(self):
+        mapper(User, users)
+        s = select([users, func.count(addresses.c.address_id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c])
+        q = create_session().query(User)
+        l = q.instances(s.execute(), "count")
+        assert repr(l) == "[(User(user_id=7,user_name=u'jack'), 1), (User(user_id=8,user_name=u'ed'), 3), (User(user_id=9,user_name=u'fred'), 0)]"
 if __name__ == "__main__":    
     testbase.main()