from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
from sqlalchemy.orm import mapper, class_mapper, object_mapper
from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty
-from sqlalchemy.orm.util import ExtensionCarrier
__all__ = ['Query', 'QueryContext', 'SelectionContext']
self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
else:
self.mapper = class_or_mapper.compile()
- self.with_options = []
self.select_mapper = self.mapper.get_select_mapper().compile()
- self.lockmode = None
- self.extension = self.mapper.extension.copy()
+
self._session = session
+ self._with_options = []
+ self._lockmode = None
+ self._extension = self.mapper.extension.copy()
self._entities = []
-
self._order_by = False
self._group_by = False
self._distinct = False
self._offset = None
self._limit = None
-
self._statement = None
self._params = {}
self._criterion = None
- self._col = None
- self._func = None
+ self._column_aggregate = None
self._joinpoint = self.mapper
self._from_obj = [self.table]
self._populate_existing = False
self._version_check = False
-
def _clone(self):
q = Query.__new__(Query)
- q.mapper = self.mapper
- q.select_mapper = self.select_mapper
- q._order_by = self._order_by
- q._distinct = self._distinct
- q._entities = list(self._entities)
- q.with_options = list(self.with_options)
- q._session = self.session
- q.lockmode = self.lockmode
- q.extension = self.extension.copy()
- q._offset = self._offset
- q._limit = self._limit
- q._group_by = self._group_by
- q._from_obj = list(self._from_obj)
- q._joinpoint = self._joinpoint
- q._criterion = self._criterion
- q._statement = self._statement
- q._params = self._params.copy()
- q._populate_existing = self._populate_existing
- q._version_check = self._version_check
- q._col = self._col
- q._func = self._func
+ q.__dict__ = self.__dict__.copy()
return q
def _get_session(self):
columns.
"""
- ret = self.extension.get(self, ident, **kwargs)
+ ret = self._extension.get(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
key = self.mapper.identity_key(ident)
columns.
"""
- ret = self.extension.load(self, ident, **kwargs)
+ ret = self._extension.load(self, ident, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
key = self.mapper.identity_key(ident)
"""
q = self._clone()
- q._entities.append(entity)
+ q._entities = q._entities + [entity]
return q
def add_column(self, column):
"""
q = self._clone()
- q._entities.append(column)
+ q._entities = q._entities + [column]
return q
def options(self, *args):
"""Return a new Query object, applying the given list of
MapperOptions.
"""
+
q = self._clone()
- for opt in util.flatten_iterator(args):
- q.with_options.append(opt)
+ opts = [o for o in util.flatten_iterator(args)]
+ q._with_options = q._with_options + opts
+ for opt in opts:
opt.process_query(q)
- for opt in util.flatten_iterator(self.with_options):
- opt.process_query(self)
return q
def with_lockmode(self, mode):
"""Return a new Query object with the specified locking mode."""
q = self._clone()
- q.lockmode = mode
+ q._lockmode = mode
return q
def params(self, **kwargs):
"""add values for bind parameters which may have been specified in filter()."""
q = self._clone()
+ q._params = q._params.copy()
q._params.update(kwargs)
return q
mapper = prop.mapper
return (clause, mapper)
-
def _generative_col_aggregate(self, col, func):
"""apply the given aggregate function to the query and return the newly
resulting ``Query``.
"""
- if self._col is not None or self._func is not None:
+ if self._column_aggregate is not None:
raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
q = self._clone()
- q._col = col
- q._func = func
+ q._column_aggregate = (col, func)
return q
def apply_min(self, col):
if q._order_by is False:
q._order_by = util.to_list(criterion)
else:
- q._order_by.extend(util.to_list(criterion))
+ q._order_by = q._order_by + util.to_list(criterion)
return q
def group_by(self, criterion):
if q._group_by is False:
q._group_by = util.to_list(criterion)
else:
- q._group_by.extend(util.to_list(criterion))
+ q._group_by = q._group_by + util.to_list(criterion)
return q
def join(self, prop):
new._from_obj = list(new._from_obj) + util.to_list(from_obj)
return new
- def __getattr__(self, key):
- if (key.startswith('select_by_')):
- key = key[10:]
- def foo(arg):
- return self.select_by(**{key:arg})
- return foo
- elif (key.startswith('get_by_')):
- key = key[7:]
- def foo(arg):
- return self.get_by(**{key:arg})
- return foo
- else:
- raise AttributeError(key)
-
def __getitem__(self, item):
if isinstance(item, slice):
start = item.start
"""
return list(self)
- def list(self):
- """deprecated. use all()"""
-
- return list(self)
def from_statement(self, statement):
if isinstance(statement, basestring):
This results in an execution of the underlying query.
"""
- if self._col is None or self._func is None:
- ret = list(self[0:1])
- if len(ret) > 0:
- return ret[0]
- else:
- return None
- else:
- return self._col_aggregate(self._col, self._func)
- def scalar(self):
- """deprecated. use first()"""
- return self.first()
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
+
+ ret = list(self[0:1])
+ if len(ret) > 0:
+ return ret[0]
+ else:
+ return None
def one(self):
"""Return the first result of this ``Query``, raising an exception if more than one row exists.
This results in an execution of the underlying query.
"""
+
+ if self._column_aggregate is not None:
+ return self._col_aggregate(*self._column_aggregate)
+
ret = list(self[0:2])
if len(ret) == 1:
kwargs.setdefault('populate_existing', self._populate_existing)
kwargs.setdefault('version_check', self._version_check)
- context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs)
+ context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs)
process = []
mappers_or_columns = tuple(self._entities) + mappers_or_columns
def _get(self, key, ident=None, reload=False, lockmode=None):
- lockmode = lockmode or self.lockmode
+ lockmode = lockmode or self._lockmode
if not reload and not self.mapper.always_refresh and lockmode is None:
try:
return self.session._get(key)
and other generative methods to establish modifiers.
"""
- if self._criterion:
- if whereclause is not None:
- whereclause = sql.and_(self._criterion, whereclause)
- else:
- whereclause = self._criterion
- from_obj = kwargs.pop('from_obj', self._from_obj)
- kwargs.setdefault('distinct', self._distinct)
+ q = self
+ if whereclause is not None:
+ q = q.filter(whereclause)
+ if params is not None:
+ q = q.params(**params)
+ q = q._legacy_select_kwargs(**kwargs)
+ return q._count()
+
+ def _count(self):
+ """Apply this query's criterion to a SELECT COUNT statement.
+
+ this is the purely generative version which will become
+ the public method in version 0.5.
+ """
+
+ whereclause = self._criterion
+
+ context = QueryContext(self)
+ from_obj = context.from_obj
alltables = []
for l in [sql_util.TableFinder(x) for x in from_obj]:
if self.table not in alltables:
from_obj.append(self.table)
- if self._nestable(**kwargs):
- s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count()
+ if self._nestable(**context.select_args()):
+ s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count()
else:
primary_key = self.primary_key_columns
- s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs)
- if params is None:
- params = {}
- else:
- params = params.copy()
- params.update(self._params)
- return self.session.scalar(self.mapper, s, params=params)
-
+ s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args())
+ return self.session.scalar(self.mapper, s, params=self._params)
+
def compile(self):
"""compiles and returns a SQL statement based on the criterion and conditions within this Query."""
# DEPRECATED LAND !
+ def list(self):
+ """DEPRECATED. use all()"""
+
+ return list(self)
+
+ def scalar(self):
+ """DEPRECATED. use first()"""
+
+ return self.first()
+
def _legacy_filter_by(self, *args, **kwargs):
return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
def get_by(self, *args, **params):
"""DEPRECATED. use query.filter(*args).filter_by(**params).first()"""
- ret = self.extension.get_by(self, *args, **params)
+ ret = self._extension.get_by(self, *args, **params)
if ret is not mapper.EXT_PASS:
return ret
def select_by(self, *args, **params):
"""DEPRECATED. use use query.filter(*args).filter_by(**params).all()."""
- ret = self.extension.select_by(self, *args, **params)
+ ret = self._extension.select_by(self, *args, **params)
if ret is not mapper.EXT_PASS:
return ret
def select(self, arg=None, **kwargs):
"""DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
- ret = self.extension.select(self, arg=arg, **kwargs)
+ ret = self._extension.select(self, arg=arg, **kwargs)
if ret is not mapper.EXT_PASS:
return ret
return self._build_select(arg, **kwargs).all()
self.order_by = query._order_by
self.group_by = query._group_by
self.from_obj = query._from_obj
- self.lockmode = query.lockmode
+ self.lockmode = query._lockmode
self.distinct = query._distinct
self.limit = query._limit
self.offset = query._offset
self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders])
self.statement = None
- super(QueryContext, self).__init__(query.mapper, query.with_options)
+ super(QueryContext, self).__init__(query.mapper, query._with_options)
def select_args(self):
"""Return a dictionary of attributes from this