]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Query refactoring is complete. just needs filter_by([args], **kwargs) feature and...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 20:12:33 +0000 (20:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 20:12:33 +0000 (20:12 +0000)
comply with the 0.4 spec

CHANGES
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/query.py
test/orm/generative.py

diff --git a/CHANGES b/CHANGES
index 536f7345472d09e8d7b3cfc1a7717433bb4524ec..615d150c9daf396cbd79b3aa22c1c8512678a2ec 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,5 +1,12 @@
 0.4.0
 - orm
+    - major interface pare-down for Query:  all selectXXX methods
+      are deprecated.  generative methods are now the standard
+      way to do things, i.e. filter(), filter_by(), all(), one(),
+      etc.  Deprecated methods are docstring'ed with their 
+      new replacements.
+    - query.list() replaced with query.all()
+    - removed ancient query.select_by_attributename() capability.
     - along with recent speedups to ResultProxy, total number of
       function calls significantly reduced for large loads.
       test/perf/masseagerload.py reports 0.4 as having the fewest number
index caa5e4412a3588455209c5ad07fac9623dba7cba..0255a922d21a7ddbcbde7c8c29abe05433163728 100644 (file)
@@ -394,7 +394,8 @@ class ExtensionOption(MapperOption):
         self.ext = ext
 
     def process_query(self, query):
-        query.extension.append(self.ext)
+        query._extension = query._extension.copy()
+        query._extension.append(self.ext)
 
 class SynonymProperty(MapperProperty):
     def __init__(self, name, proxy=False):
index f00ee4203bbcd7f2155ca685db3df88388f9a3e3..d4318049b42afa7b8a544329b3d7d1f3df04dc14 100644 (file)
@@ -7,7 +7,6 @@
 from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
 from sqlalchemy.orm import mapper, class_mapper, object_mapper
 from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty
-from sqlalchemy.orm.util import ExtensionCarrier
 
 __all__ = ['Query', 'QueryContext', 'SelectionContext']
 
@@ -19,54 +18,31 @@ class Query(object):
             self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
         else:
             self.mapper = class_or_mapper.compile()
-        self.with_options = []
         self.select_mapper = self.mapper.get_select_mapper().compile()
-        self.lockmode = None
-        self.extension = self.mapper.extension.copy()
+        
         self._session = session
             
+        self._with_options = []
+        self._lockmode = None
+        self._extension = self.mapper.extension.copy()
         self._entities = []
-
         self._order_by = False
         self._group_by = False
         self._distinct = False
         self._offset = None
         self._limit = None
-
         self._statement = None
         self._params = {}
         self._criterion = None
-        self._col = None
-        self._func = None
+        self._column_aggregate = None
         self._joinpoint = self.mapper
         self._from_obj = [self.table]
         self._populate_existing = False
         self._version_check = False
-
         
     def _clone(self):
         q = Query.__new__(Query)
-        q.mapper = self.mapper
-        q.select_mapper = self.select_mapper
-        q._order_by = self._order_by
-        q._distinct = self._distinct
-        q._entities = list(self._entities)
-        q.with_options = list(self.with_options)
-        q._session = self.session
-        q.lockmode = self.lockmode
-        q.extension = self.extension.copy()
-        q._offset = self._offset
-        q._limit = self._limit
-        q._group_by = self._group_by
-        q._from_obj = list(self._from_obj)
-        q._joinpoint = self._joinpoint
-        q._criterion = self._criterion
-        q._statement = self._statement
-        q._params = self._params.copy()
-        q._populate_existing = self._populate_existing
-        q._version_check = self._version_check
-        q._col = self._col
-        q._func = self._func
+        q.__dict__ = self.__dict__.copy()
         return q
     
     def _get_session(self):
@@ -88,7 +64,7 @@ class Query(object):
         columns.
         """
 
-        ret = self.extension.get(self, ident, **kwargs)
+        ret = self._extension.get(self, ident, **kwargs)
         if ret is not mapper.EXT_PASS:
             return ret
         key = self.mapper.identity_key(ident)
@@ -105,7 +81,7 @@ class Query(object):
         columns.
         """
 
-        ret = self.extension.load(self, ident, **kwargs)
+        ret = self._extension.load(self, ident, **kwargs)
         if ret is not mapper.EXT_PASS:
             return ret
         key = self.mapper.identity_key(ident)
@@ -207,7 +183,7 @@ class Query(object):
                 
         """
         q = self._clone()
-        q._entities.append(entity)
+        q._entities = q._entities + [entity]
         return q
         
     def add_column(self, column):
@@ -233,31 +209,32 @@ class Query(object):
         """
         
         q = self._clone()
-        q._entities.append(column)
+        q._entities = q._entities + [column]
         return q
         
     def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
         """
+        
         q = self._clone()
-        for opt in util.flatten_iterator(args):
-            q.with_options.append(opt)
+        opts = [o for o in util.flatten_iterator(args)]
+        q._with_options = q._with_options + opts
+        for opt in opts:
             opt.process_query(q)
-        for opt in util.flatten_iterator(self.with_options):
-            opt.process_query(self)
         return q
 
     def with_lockmode(self, mode):
         """Return a new Query object with the specified locking mode."""
         q = self._clone()
-        q.lockmode = mode
+        q._lockmode = mode
         return q
     
     def params(self, **kwargs):
         """add values for bind parameters which may have been specified in filter()."""
         
         q = self._clone()
+        q._params = q._params.copy()
         q._params.update(kwargs)
         return q
         
@@ -337,16 +314,14 @@ class Query(object):
             mapper = prop.mapper
         return (clause, mapper)
 
-
     def _generative_col_aggregate(self, col, func):
         """apply the given aggregate function to the query and return the newly
         resulting ``Query``.
         """
-        if self._col is not None or self._func is not None:
+        if self._column_aggregate is not None:
             raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
         q = self._clone()
-        q._col = col
-        q._func = func
+        q._column_aggregate = (col, func)
         return q
 
     def apply_min(self, col):
@@ -414,7 +389,7 @@ class Query(object):
         if q._order_by is False:    
             q._order_by = util.to_list(criterion)
         else:
-            q._order_by.extend(util.to_list(criterion))
+            q._order_by = q._order_by + util.to_list(criterion)
         return q
 
     def group_by(self, criterion):
@@ -424,7 +399,7 @@ class Query(object):
         if q._group_by is False:    
             q._group_by = util.to_list(criterion)
         else:
-            q._group_by.extend(util.to_list(criterion))
+            q._group_by = q._group_by + util.to_list(criterion)
         return q
 
     def join(self, prop):
@@ -469,20 +444,6 @@ class Query(object):
         new._from_obj = list(new._from_obj) + util.to_list(from_obj)
         return new
         
-    def __getattr__(self, key):
-        if (key.startswith('select_by_')):
-            key = key[10:]
-            def foo(arg):
-                return self.select_by(**{key:arg})
-            return foo
-        elif (key.startswith('get_by_')):
-            key = key[7:]
-            def foo(arg):
-                return self.get_by(**{key:arg})
-            return foo
-        else:
-            raise AttributeError(key)
-
     def __getitem__(self, item):
         if isinstance(item, slice):
             start = item.start
@@ -537,10 +498,6 @@ class Query(object):
         """
         return list(self)
         
-    def list(self):
-        """deprecated.  use all()"""
-
-        return list(self)
     
     def from_statement(self, statement):
         if isinstance(statement, basestring):
@@ -554,24 +511,25 @@ class Query(object):
 
         This results in an execution of the underlying query.
         """
-        if self._col is None or self._func is None: 
-            ret = list(self[0:1])
-            if len(ret) > 0:
-                return ret[0]
-            else:
-                return None
-        else:
-            return self._col_aggregate(self._col, self._func)
 
-    def scalar(self):
-        """deprecated.  use first()"""
-        return self.first()
+        if self._column_aggregate is not None: 
+            return self._col_aggregate(*self._column_aggregate)
+        
+        ret = list(self[0:1])
+        if len(ret) > 0:
+            return ret[0]
+        else:
+            return None
 
     def one(self):
         """Return the first result of this ``Query``, raising an exception if more than one row exists.
 
         This results in an execution of the underlying query.
         """
+
+        if self._column_aggregate is not None: 
+            return self._col_aggregate(*self._column_aggregate)
+
         ret = list(self[0:2])
         
         if len(ret) == 1:
@@ -621,7 +579,7 @@ class Query(object):
         kwargs.setdefault('populate_existing', self._populate_existing)
         kwargs.setdefault('version_check', self._version_check)
         
-        context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
+        context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
 
         process = []
         mappers_or_columns = tuple(self._entities) + mappers_or_columns
@@ -667,7 +625,7 @@ class Query(object):
 
 
     def _get(self, key, ident=None, reload=False, lockmode=None):
-        lockmode = lockmode or self.lockmode
+        lockmode = lockmode or self._lockmode
         if not reload and not self.mapper.always_refresh and lockmode is None:
             try:
                 return self.session._get(key)
@@ -716,13 +674,25 @@ class Query(object):
         and other generative methods to establish modifiers.
         """
 
-        if self._criterion:
-            if whereclause is not None:
-                whereclause = sql.and_(self._criterion, whereclause)
-            else:
-                whereclause = self._criterion
-        from_obj = kwargs.pop('from_obj', self._from_obj)
-        kwargs.setdefault('distinct', self._distinct)
+        q = self
+        if whereclause is not None:
+            q = q.filter(whereclause)
+        if params is not None:
+            q = q.params(**params)
+        q = q._legacy_select_kwargs(**kwargs)
+        return q._count()
+
+    def _count(self):
+        """Apply this query's criterion to a SELECT COUNT statement.
+        
+        this is the purely generative version which will become 
+        the public method in version 0.5.
+        """
+
+        whereclause = self._criterion
+
+        context = QueryContext(self)
+        from_obj = context.from_obj
 
         alltables = []
         for l in [sql_util.TableFinder(x) for x in from_obj]:
@@ -730,18 +700,13 @@ class Query(object):
 
         if self.table not in alltables:
             from_obj.append(self.table)
-        if self._nestable(**kwargs):
-            s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
+        if self._nestable(**context.select_args()):
+            s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
         else:
             primary_key = self.primary_key_columns
-            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
-        if params is None:
-            params = {}
-        else:
-            params = params.copy()
-        params.update(self._params)
-        return self.session.scalar(self.mapper, s, params=params)
-
+            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+        return self.session.scalar(self.mapper, s, params=self._params)
+        
     def compile(self):
         """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
         
@@ -856,6 +821,16 @@ class Query(object):
 
     # DEPRECATED LAND !
 
+    def list(self):
+        """DEPRECATED.  use all()"""
+
+        return list(self)
+
+    def scalar(self):
+        """DEPRECATED.  use first()"""
+
+        return self.first()
+
     def _legacy_filter_by(self, *args, **kwargs):
         return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
 
@@ -895,7 +870,7 @@ class Query(object):
     def get_by(self, *args, **params):
         """DEPRECATED.  use query.filter(*args).filter_by(**params).first()"""
 
-        ret = self.extension.get_by(self, *args, **params)
+        ret = self._extension.get_by(self, *args, **params)
         if ret is not mapper.EXT_PASS:
             return ret
 
@@ -904,7 +879,7 @@ class Query(object):
     def select_by(self, *args, **params):
         """DEPRECATED. use use query.filter(*args).filter_by(**params).all()."""
 
-        ret = self.extension.select_by(self, *args, **params)
+        ret = self._extension.select_by(self, *args, **params)
         if ret is not mapper.EXT_PASS:
             return ret
 
@@ -934,7 +909,7 @@ class Query(object):
     def select(self, arg=None, **kwargs):
         """DEPRECATED.  use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
 
-        ret = self.extension.select(self, arg=arg, **kwargs)
+        ret = self._extension.select(self, arg=arg, **kwargs)
         if ret is not mapper.EXT_PASS:
             return ret
         return self._build_select(arg, **kwargs).all()
@@ -1066,13 +1041,13 @@ class QueryContext(OperationContext):
         self.order_by = query._order_by
         self.group_by = query._group_by
         self.from_obj = query._from_obj
-        self.lockmode = query.lockmode
+        self.lockmode = query._lockmode
         self.distinct = query._distinct
         self.limit = query._limit
         self.offset = query._offset
         self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
         self.statement = None
-        super(QueryContext, self).__init__(query.mapper, query.with_options)
+        super(QueryContext, self).__init__(query.mapper, query._with_options)
 
     def select_args(self):
         """Return a dictionary of attributes from this
index 8882a6f5c2dde480a17f793dd5a3564a1c4c792b..d4639281651f653c72f05d23a7ff93a4fa90b738 100644 (file)
@@ -60,7 +60,8 @@ class GenerativeQueryTest(PersistTest):
         assert self.query.count() == 100
         assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0
         assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29
-        assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29
+        assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
+        assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
 
     @testbase.unsupported('mysql')
     def test_aggregate_1(self):
@@ -77,7 +78,8 @@ class GenerativeQueryTest(PersistTest):
 
     @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql')
     def test_aggregate_3(self):
-        assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5
+        assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5
+        assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5
         
     def test_filter(self):
         assert self.query.count() == 100