]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some cleanup, some method privating, some pep8, fixed up _col_aggregate and merged
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Apr 2008 16:25:47 +0000 (16:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Apr 2008 16:25:47 +0000 (16:25 +0000)
its functionality with _count()

lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/util.py

index a2312e3633d7abf7e15786897873fa4a0c5b658c..1ba703e2e7ffd20e660065cc3d75600a78977c3c 100644 (file)
@@ -48,7 +48,7 @@ class Query(object):
         self._params = {}
         self._yield_per = None
         self._criterion = None
-        self._joinable_tables = None
+        self.__joinable_tables = None
         self._having = None
         self._column_aggregate = None
         self._populate_existing = False
@@ -60,9 +60,9 @@ class Query(object):
         self._only_load_props = None
         self._refresh_instance = None
         
-        self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
+        self.__init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
 
-    def _init_mapper(self, mapper):
+    def __init_mapper(self, mapper):
         """populate all instance variables derived from this Query's mapper."""
         
         self.mapper = mapper
@@ -74,27 +74,27 @@ class Query(object):
         self._joinpoint = self.mapper
         self._entities.append(_PrimaryMapperEntity(self.mapper))
         if self.mapper.with_polymorphic:
-            self._set_with_polymorphic(*self.mapper.with_polymorphic)
+            self.__set_with_polymorphic(*self.mapper.with_polymorphic)
         else:
             self._with_polymorphic = []
 
-    def _generate_alias_ids(self):
+    def __generate_alias_ids(self):
         self._alias_ids = dict([
             (k, list(v)) for k, v in self._alias_ids.iteritems()
         ])
 
-    def _no_criterion(self, meth):
-        return self._conditional_clone(meth, [self._no_criterion_condition])
+    def __no_criterion(self, meth):
+        return self.__conditional_clone(meth, [self.__no_criterion_condition])
 
-    def _no_statement(self, meth):
-        return self._conditional_clone(meth, [self._no_statement_condition])
+    def __no_statement(self, meth):
+        return self.__conditional_clone(meth, [self.__no_statement_condition])
 
-    def _reset_all(self, mapper, meth):
-        q = self._conditional_clone(meth, [self._no_criterion_condition])
-        q._init_mapper(mapper, mapper)
+    def __reset_all(self, mapper, meth):
+        q = self.__conditional_clone(meth, [self.__no_criterion_condition])
+        q.__init_mapper(mapper, mapper)
         return q
 
-    def _set_select_from(self, from_obj):
+    def __set_select_from(self, from_obj):
         if isinstance(from_obj, expression._SelectBaseMixin):
             # alias SELECTs and unions
             from_obj = from_obj.alias()
@@ -108,12 +108,12 @@ class Query(object):
         else:
             self._aliases_head = self._aliases_tail = None
 
-    def _set_with_polymorphic(self, cls_or_mappers, selectable=None):
+    def __set_with_polymorphic(self, cls_or_mappers, selectable=None):
         mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
         self._with_polymorphic = mappers
-        self._set_select_from(from_obj)
+        self.__set_select_from(from_obj)
 
-    def _no_criterion_condition(self, q, meth):
+    def __no_criterion_condition(self, q, meth):
         if q._criterion or q._statement:
             util.warn(
                 ("Query.%s() being called on a Query with existing criterion; "
@@ -125,10 +125,10 @@ class Query(object):
         q._aliases_tail = q._aliases_head
         q.table = q._from_obj = q.mapper.mapped_table
         if q.mapper.with_polymorphic:
-            q._set_with_polymorphic(*q.mapper.with_polymorphic)
+            q.__set_with_polymorphic(*q.mapper.with_polymorphic)
 
-    def _no_entities(self, meth):
-        q = self._no_statement(meth)
+    def __no_entities(self, meth):
+        q = self.__no_statement(meth)
         if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity):
             raise exceptions.InvalidRequestError(
                 ("Query.%s() being called on a Query with existing  "
@@ -136,19 +136,30 @@ class Query(object):
         q._entities = []
         return q
 
-    def _no_statement_condition(self, q, meth):
+    def __no_statement_condition(self, q, meth):
         if q._statement:
             raise exceptions.InvalidRequestError(
                 ("Query.%s() being called on a Query with an existing full "
                  "statement - can't apply criterion.") % meth)
 
-    def _conditional_clone(self, methname=None, conditions=None):
+    def __conditional_clone(self, methname=None, conditions=None):
         q = self._clone()
         if conditions:
             for condition in conditions:
                 condition(q, methname)
         return q
-        
+
+    def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None):
+        if populate_existing:
+            self._populate_existing = populate_existing
+        if version_check:
+            self._version_check = version_check
+        if refresh_instance:
+            self._refresh_instance = refresh_instance
+        if only_load_props:
+            self._only_load_props = util.Set(only_load_props)
+        return self
+
     def _clone(self):
         q = Query.__new__(Query)
         q.__dict__ = self.__dict__.copy()
@@ -172,6 +183,13 @@ class Query(object):
     whereclause = property(whereclause)
 
     def _with_current_path(self, path):
+        """indicate that this query applies to objects loaded within a certain path.
+        
+        Used by deferred loaders (see strategies.py) which transfer query 
+        options from an originating query to a newly generated query intended
+        for the deferred load.
+        
+        """
         q = self._clone()
         q._current_path = path
         return q
@@ -201,9 +219,9 @@ class Query(object):
         clause which will usually lead to incorrect results.
 
         """
-        q = self._no_criterion('with_polymorphic')
+        q = self.__no_criterion('with_polymorphic')
 
-        q._set_with_polymorphic(cls_or_mappers, selectable=selectable)
+        q.__set_with_polymorphic(cls_or_mappers, selectable=selectable)
 
         return q
     
@@ -290,6 +308,14 @@ class Query(object):
     query_from_parent = classmethod(query_from_parent)
 
     def autoflush(self, setting):
+        """Return a Query with a specific 'autoflush' setting.
+
+        Note that a Session with autoflush=False will
+        not autoflush, even if this flag is set to True at the 
+        Query level.  Therefore this flag is usually used only
+        to disable autoflush for a specific Query.
+        
+        """
         q = self._clone()
         q._autoflush = setting
         return q
@@ -303,9 +329,10 @@ class Query(object):
         All changes present on entities which are already present in the
         session will be reset and the entities will all be marked "clean".
 
-        This is essentially the en-masse version of load().
+        An alternative to populate_existing() is to expire the Session
+        fully using session.expire_all().
+        
         """
-
         q = self._clone()
         q._populate_existing = True
         return q
@@ -323,8 +350,8 @@ class Query(object):
 
         currently, this method only works with immediate parent relationships, but in the
         future may be enhanced to work across a chain of parent mappers.
-        """
 
+        """
         from sqlalchemy.orm import properties
         mapper = object_mapper(instance)
         if property is None:
@@ -360,6 +387,7 @@ class Query(object):
             id
                 a string ID matching that given to query.join() or query.outerjoin(); rows will be
                 selected from the aliased join created via those methods.
+
         """
         q = self._clone()
 
@@ -375,8 +403,8 @@ class Query(object):
         """return a Query that selects from this Query's SELECT statement.
         
         The API for this method hasn't been decided yet and is subject to change.
+
         """
-        
         q = self._clone()
         q._eager_loaders = util.Set()
         fromclause = q.compile()
@@ -386,9 +414,9 @@ class Query(object):
         """Turn this query into a 'columns only' query.
         
         The API for this method hasn't been decided yet and is subject to change.
-        """
 
-        q = self._no_entities('_values')
+        """
+        q = self.__no_entities('_values')
         q._only_load_props = q._eager_loaders = util.Set()
 
         for column in columns:
@@ -412,8 +440,8 @@ class Query(object):
 
         column
           a string column name or sql.ColumnElement to be added to the results.
-        """
 
+        """
         q = self._clone()
         q._entities = q._entities + [self._add_column(column, id)]
         return q
@@ -430,8 +458,8 @@ class Query(object):
     def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
-        """
 
+        """
         return self._options(False, *args)
 
     def _conditional_options(self, *args):
@@ -492,7 +520,7 @@ class Query(object):
         if self._aliases_tail:
             criterion = self._aliases_tail.adapt_clause(criterion)
 
-        q = self._no_statement("filter")
+        q = self.__no_statement("filter")
         if q._criterion is not None:
             q._criterion = q._criterion & criterion
         else:
@@ -507,22 +535,6 @@ class Query(object):
 
         return self.filter(sql.and_(*clauses))
 
-    def _col_aggregate(self, col, func):
-        """Execute ``func()`` function against the given column.
-
-        For performance, only use subselect if `order_by` attribute is set.
-
-        """
-        ops = {'distinct':self._distinct, 'order_by':self._order_by or None, 'from_obj':self._from_obj}
-
-        if self._autoflush and not self._populate_existing:
-            self.session._autoflush()
-
-        if self._order_by is not False:
-            s1 = sql.select([col], self._criterion, **ops).alias('u')
-            return self.session.execute(sql.select([func(s1.corresponding_column(col))]), mapper=self.mapper).scalar()
-        else:
-            return self.session.execute(sql.select([func(col)], self._criterion, **ops), mapper=self.mapper).scalar()
 
     def min(self, col):
         """Execute the SQL ``min()`` function against the given column."""
@@ -547,7 +559,7 @@ class Query(object):
     def order_by(self, *criterion):
         """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self._no_statement("order_by")
+        q = self.__no_statement("order_by")
 
         if self._aliases_tail:
             criterion = [expression._literal_as_text(o) for o in criterion]
@@ -563,7 +575,7 @@ class Query(object):
     def group_by(self, *criterion):
         """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self._no_statement("group_by")
+        q = self.__no_statement("group_by")
         if q._group_by is False:
             q._group_by = criterion
         else:
@@ -583,7 +595,7 @@ class Query(object):
         if self._aliases_tail:
             criterion = self._aliases_tail.adapt_clause(criterion)
 
-        q = self._no_statement("having")
+        q = self.__no_statement("having")
         if q._having is not None:
             q._having = q._having & criterion
         else:
@@ -607,8 +619,8 @@ class Query(object):
             session.query(Company).join(['employees', 'tasks'])
             session.query(Houses).join([Colonials.rooms, Room.closets])
             session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
-        """
 
+        """
         return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
 
     def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
@@ -630,17 +642,16 @@ class Query(object):
             session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
 
         """
-
         return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
     
     def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
         (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
         # TODO: improve the generative check here to look for primary mapped entity, etc.
-        q = self._no_statement("join")
+        q = self.__no_statement("join")
         q._from_obj = clause
         q._joinpoint = mapper
         q._aliases = aliases
-        q._generate_alias_ids()
+        q.__generate_alias_ids()
         
         if aliases:
             q._aliases_tail = aliases
@@ -660,16 +671,16 @@ class Query(object):
         return q
 
     def _get_joinable_tables(self):
-        if not self._joinable_tables or self._joinable_tables[0] is not self._from_obj:
+        if not self.__joinable_tables or self.__joinable_tables[0] is not self._from_obj:
             currenttables = [self._from_obj]
             def visit_join(join):
                 currenttables.append(join.left)
                 currenttables.append(join.right)
             visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
-            self._joinable_tables = (self._from_obj, currenttables)
+            self.__joinable_tables = (self._from_obj, currenttables)
             return currenttables
         else:
-            return self._joinable_tables[1]
+            return self.__joinable_tables[1]
 
     def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
         if start is None:
@@ -771,9 +782,9 @@ class Query(object):
 
         Note that each call to join() or outerjoin() also starts from
         the root.
-        """
 
-        q = self._no_statement("reset_joinpoint")
+        """
+        q = self.__no_statement("reset_joinpoint")
         q._joinpoint = q.mapper
         if q.table not in q._get_joinable_tables():
             q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns)
@@ -788,14 +799,14 @@ class Query(object):
 
 
         `from_obj` is a single table or selectable.
-        """
 
-        new = self._no_criterion('select_from')
+        """
+        new = self.__no_criterion('select_from')
         if isinstance(from_obj, (tuple, list)):
             util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
             from_obj = from_obj[-1]
 
-        new._set_select_from(from_obj)
+        new.__set_select_from(from_obj)
         return new
     
     def __getitem__(self, item):
@@ -824,24 +835,25 @@ class Query(object):
 
     def limit(self, limit):
         """Apply a ``LIMIT`` to the query and return the newly resulting
+
         ``Query``.
-        """
 
+        """
         return self[:limit]
 
     def offset(self, offset):
         """Apply an ``OFFSET`` to the query and return the newly resulting
         ``Query``.
-        """
 
+        """
         return self[offset:]
 
     def distinct(self):
         """Apply a ``DISTINCT`` to the query and return the newly resulting
         ``Query``.
-        """
 
-        new = self._no_statement("distinct")
+        """
+        new = self.__no_statement("distinct")
         new._distinct = True
         return new
 
@@ -849,6 +861,7 @@ class Query(object):
         """Return the results represented by this ``Query`` as a list.
 
         This results in an execution of the underlying query.
+
         """
         return list(self)
 
@@ -866,10 +879,9 @@ class Query(object):
         Also see the ``instances()`` method.
 
         """
-
         if isinstance(statement, basestring):
             statement = sql.text(statement)
-        q = self._no_criterion('from_statement')
+        q = self.__no_criterion('from_statement')
         q._statement = statement
         return q
 
@@ -877,8 +889,8 @@ class Query(object):
         """Return the first result of this ``Query`` or None if the result doesn't contain any row.
 
         This results in an execution of the underlying query.
-        """
 
+        """
         if self._column_aggregate is not None:
             return self._col_aggregate(*self._column_aggregate)
 
@@ -892,8 +904,8 @@ class Query(object):
         """Return the first result of this ``Query``, raising an exception if more than one row exists.
 
         This results in an execution of the underlying query.
-        """
 
+        """
         if self._column_aggregate is not None:
             return self._col_aggregate(*self._column_aggregate)
 
@@ -989,10 +1001,10 @@ class Query(object):
         
         # dont use 'polymorphic' mapper if we are refreshing an instance
         if refresh_instance and q.mapper is not q.mapper:
-            q = q._reset_all(q.mapper, '_get')
+            q = q.__reset_all(q.mapper, '_get')
 
         if ident is not None:
-            q = q._no_criterion('get')
+            q = q.__no_criterion('get')
             params = {}
             (_get_clause, _get_params) = q.mapper._get_clause
             q = q.filter(_get_clause)
@@ -1005,7 +1017,7 @@ class Query(object):
 
         if lockmode is not None:
             q = q.with_lockmode(lockmode)
-        q = q._select_context_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
+        q = q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
         q._order_by = None
         try:
             # call using all() to avoid LIMIT compilation complexity
@@ -1027,8 +1039,8 @@ class Query(object):
 
         the whereclause, params and \**kwargs arguments are deprecated.  use filter()
         and other generative methods to establish modifiers.
-        """
 
+        """
         q = self
         if whereclause is not None:
             q = q.filter(whereclause)
@@ -1042,26 +1054,34 @@ class Query(object):
 
         this is the purely generative version which will become
         the public method in version 0.5.
+
         """
+        return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key))
 
+    def _col_aggregate(self, col, func, nested_cols=None):
         whereclause = self._criterion
-
+        
         context = QueryContext(self)
         from_obj = self._from_obj
 
         if self._should_nest_selectable:
-            s = sql.select([self.table], whereclause, from_obj=from_obj, **self._select_args).alias('getcount').count()
+            if not nested_cols:
+                nested_cols = [col]
+            s = sql.select(nested_cols, whereclause, from_obj=from_obj, **self._select_args)
+            s = s.alias()
+            s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s)
         else:
-            primary_key = self.mapper.primary_key
-            s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **self._select_args)
+            s = sql.select([func(col)], whereclause, from_obj=from_obj, **self._select_args)
+            
         if self._autoflush and not self._populate_existing:
             self.session._autoflush()
         return self.session.scalar(s, params=self._params, mapper=self.mapper)
 
     def compile(self):
         """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
+
         return self._compile_context().statement
-    
+
     def _compile_context(self):
 
         context = QueryContext(self)
@@ -1176,7 +1196,7 @@ class Query(object):
         """
         if self._column_aggregate is not None:
             raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
-        q = self._no_statement("aggregate")
+        q = self.__no_statement("aggregate")
         q._column_aggregate = (col, func)
         return q
 
@@ -1331,20 +1351,9 @@ class Query(object):
         q = self.from_statement(statement)
         if params is not None:
             q = q.params(params)
-        q._select_context_options(**kwargs)
+        q.__get_options(**kwargs)
         return list(q)
 
-    def _select_context_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): #pragma: no cover
-        if populate_existing:
-            self._populate_existing = populate_existing
-        if version_check:
-            self._version_check = version_check
-        if refresh_instance:
-            self._refresh_instance = refresh_instance
-        if only_load_props:
-            self._only_load_props = util.Set(only_load_props)
-        return self
-
     def join_to(self, key): #pragma: no cover
         """DEPRECATED. use join() to create joins based on property names."""
 
@@ -1603,11 +1612,11 @@ class QueryContext(object):
         self.options = query._with_options
         self.attributes = query._attributes.copy()
 
-    def exec_with_path(self, mapper, propkey, func, *args, **kwargs):
+    def exec_with_path(self, mapper, propkey, fn, *args, **kwargs):
         oldpath = self.path
         self.path += (mapper.base_mapper, propkey)
         try:
-            return func(*args, **kwargs)
+            return fn(*args, **kwargs)
         finally:
             self.path = oldpath
 
index 6975f10f8542d6064bd345411cf4440c43e3d80e..97ed2a1923b5af55c9392e44f0c9b8baf5237184 100644 (file)
@@ -194,22 +194,18 @@ class PropertyAliasedClauses(AliasedClauses):
         
         if prop.secondary:
             self.secondary = prop.secondary.alias()
+            primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
+            secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
+
             if parentclauses is not None:
-                primary_aliasizer = sql_util.ClauseAdapter(self.secondary).chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))
-                secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
+                primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))
 
-            else:
-                primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
-                secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
-                
             self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True)
             self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
         else:
+            primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
             if parentclauses is not None: 
-                primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
                 primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents))
-            else:
-                primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
             
             self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
             self.secondary = None