]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
more paring down...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 00:50:22 +0000 (00:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jun 2007 00:50:22 +0000 (00:50 +0000)
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py

index c894f3767fd3c0c06fb560fec36976af67896e8c..f00ee4203bbcd7f2155ca685db3df88388f9a3e3 100644 (file)
@@ -14,31 +14,25 @@ __all__ = ['Query', 'QueryContext', 'SelectionContext']
 class Query(object):
     """Encapsulates the object-fetching operations provided by Mappers."""
 
-    def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, with_options=None, extension=None, **kwargs):
+    def __init__(self, class_or_mapper, session=None, entity_name=None):
         if isinstance(class_or_mapper, type):
             self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
         else:
             self.mapper = class_or_mapper.compile()
-        self.with_options = with_options or []
+        self.with_options = []
         self.select_mapper = self.mapper.get_select_mapper().compile()
-        self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
-        self.lockmode = lockmode
-        self.extension = ExtensionCarrier()
-        if extension is not None:
-            self.extension.append(extension)
-        self.extension.append(self.mapper.extension)
-        self.is_polymorphic = self.mapper is not self.select_mapper
+        self.lockmode = None
+        self.extension = self.mapper.extension.copy()
         self._session = session
             
         self._entities = []
 
-        self._get_clause = self.select_mapper._get_clause
+        self._order_by = False
+        self._group_by = False
+        self._distinct = False
+        self._offset = None
+        self._limit = None
 
-        self._order_by = kwargs.pop('order_by', False)
-        self._group_by = kwargs.pop('group_by', False)
-        self._distinct = kwargs.pop('distinct', False)
-        self._offset = kwargs.pop('offset', None)
-        self._limit = kwargs.pop('limit', None)
         self._statement = None
         self._params = {}
         self._criterion = None
@@ -49,8 +43,6 @@ class Query(object):
         self._populate_existing = False
         self._version_check = False
 
-        for opt in util.flatten_iterator(self.with_options):
-            opt.process_query(self)
         
     def _clone(self):
         q = Query.__new__(Query)
@@ -59,16 +51,13 @@ class Query(object):
         q._order_by = self._order_by
         q._distinct = self._distinct
         q._entities = list(self._entities)
-        q.always_refresh = self.always_refresh
         q.with_options = list(self.with_options)
         q._session = self.session
-        q.is_polymorphic = self.is_polymorphic
         q.lockmode = self.lockmode
         q.extension = self.extension.copy()
         q._offset = self._offset
         q._limit = self._limit
         q._group_by = self._group_by
-        q._get_clause = self._get_clause
         q._from_obj = list(self._from_obj)
         q._joinpoint = self._joinpoint
         q._criterion = self._criterion
@@ -125,38 +114,6 @@ class Query(object):
             raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
         return instance
 
-    def count(self, whereclause=None, params=None, **kwargs):
-        """Apply this query's criterion to a SELECT COUNT statement.
-        
-        the whereclause, params and **kwargs arguments are deprecated.  use filter()
-        and other generative methods to establish modifiers.
-        """
-        
-        if self._criterion:
-            if whereclause is not None:
-                whereclause = sql.and_(self._criterion, whereclause)
-            else:
-                whereclause = self._criterion
-        from_obj = kwargs.pop('from_obj', self._from_obj)
-        kwargs.setdefault('distinct', self._distinct)
-
-        alltables = []
-        for l in [sql_util.TableFinder(x) for x in from_obj]:
-            alltables += l
-        
-        if self.table not in alltables:
-            from_obj.append(self.table)
-        if self._nestable(**kwargs):
-            s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
-        else:
-            primary_key = self.primary_key_columns
-            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
-        if params is None:
-            params = {}
-        else:
-            params = params.copy()
-        params.update(self._params)
-        return self.session.scalar(self.mapper, s, params=params)
 
     def _with_lazy_criterion(cls, instance, prop, reverse=False):
         """extract query criterion from a LazyLoader strategy given a Mapper, 
@@ -279,7 +236,7 @@ class Query(object):
         q._entities.append(column)
         return q
         
-    def options(self, *args, **kwargs):
+    def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
         """
@@ -711,7 +668,7 @@ class Query(object):
 
     def _get(self, key, ident=None, reload=False, lockmode=None):
         lockmode = lockmode or self.lockmode
-        if not reload and not self.always_refresh and lockmode is None:
+        if not reload and not self.mapper.always_refresh and lockmode is None:
             try:
                 return self.session._get(key)
             except KeyError:
@@ -728,7 +685,7 @@ class Query(object):
             q = self
             if lockmode is not None:
                 q = q.with_lockmode(lockmode)
-            q = q.filter(self._get_clause)
+            q = q.filter(self.select_mapper._get_clause)
             q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None))
             return q.first()
         except IndexError:
@@ -752,6 +709,39 @@ class Query(object):
 
         return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
 
+    def count(self, whereclause=None, params=None, **kwargs):
+        """Apply this query's criterion to a SELECT COUNT statement.
+
+        the whereclause, params and **kwargs arguments are deprecated.  use filter()
+        and other generative methods to establish modifiers.
+        """
+
+        if self._criterion:
+            if whereclause is not None:
+                whereclause = sql.and_(self._criterion, whereclause)
+            else:
+                whereclause = self._criterion
+        from_obj = kwargs.pop('from_obj', self._from_obj)
+        kwargs.setdefault('distinct', self._distinct)
+
+        alltables = []
+        for l in [sql_util.TableFinder(x) for x in from_obj]:
+            alltables += l
+
+        if self.table not in alltables:
+            from_obj.append(self.table)
+        if self._nestable(**kwargs):
+            s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
+        else:
+            primary_key = self.primary_key_columns
+            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
+        if params is None:
+            params = {}
+        else:
+            params = params.copy()
+        params.update(self._params)
+        return self.session.scalar(self.mapper, s, params=params)
+
     def compile(self):
         """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
         
@@ -761,7 +751,7 @@ class Query(object):
         
         whereclause = self._criterion
 
-        if whereclause is not None and self.is_polymorphic:
+        if whereclause is not None and (self.mapper is not self.select_mapper):
             # adapt the given WHERECLAUSE to adjust instances of this query's mapped 
             # table to be that of our select_table,
             # which may be the "polymorphic" selectable used by our mapper.
index a3184d62b5170de212d0b81a7ed5be74b3793089..14178c1d85590ae4be9ea94c091ce9bbf0b610db 100644 (file)
@@ -238,7 +238,7 @@ class LazyLoader(AbstractRelationLoader):
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         from sqlalchemy.orm import query
-        self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
+        self.use_get = not self.uselist and self.mapper._get_clause.compare(self.lazywhere)
         if self.use_get:
             self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
 
@@ -283,27 +283,31 @@ class LazyLoader(AbstractRelationLoader):
 
             # if we have a simple straight-primary key load, use mapper.get()
             # to possibly save a DB round trip
+            q = session.query(self.mapper)
             if self.use_get:
                 ident = []
+                # TODO: when options are added to allow switching between union-based and non-union
+                # based polymorphic loads on a per-query basis, this code needs to switch between "mapper" and "select_mapper",
+                # probably via the query's own "mapper" property, and also use one of two "lazy" clauses,
+                # one against the "union" the other not
                 for primary_key in self.select_mapper.pks_by_table[self.select_mapper.mapped_table]:
                     bind = self.lazyreverse[primary_key]
                     ident.append(params[bind.key])
-                return session.query(self.mapper).get(ident)
+                return q.get(ident)
             elif self.order_by is not False:
-                order_by = self.order_by
+                q = q.order_by(self.order_by)
             elif self.secondary is not None and self.secondary.default_order_by() is not None:
-                order_by = self.secondary.default_order_by()
-            else:
-                order_by = False
-            result = session.query(self.mapper, with_options=options).select_whereclause(self.lazywhere, order_by=order_by, params=params)
+                q = q.order_by(self.secondary.default_order_by())
+
+            if options:
+                q = q.options(*options)
+            q = q.filter(self.lazywhere).params(**params)
 
             if self.uselist:
-                return result
+                return q.all()
             else:
-                if len(result):
-                    return result[0]
-                else:
-                    return None
+                return q.first()
+
         return lazyload