From: Mike Bayer Date: Tue, 5 Jun 2007 20:12:33 +0000 (+0000) Subject: Query refactoring is complete. just needs filter_by([args], **kwargs) feature and... X-Git-Tag: rel_0_4_6~218 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2290d9cd6dd5fceb5cc6480297d26b18dff2fac7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Query refactoring is complete. just needs filter_by([args], **kwargs) feature and it should then comply with the 0.4 spec --- diff --git a/CHANGES b/CHANGES index 536f734547..615d150c9d 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,12 @@ 0.4.0 - orm + - major interface pare-down for Query: all selectXXX methods + are deprecated. generative methods are now the standard + way to do things, i.e. filter(), filter_by(), all(), one(), + etc. Deprecated methods are docstring'ed with their + new replacements. + - query.list() replaced with query.all() + - removed ancient query.select_by_attributename() capability. - along with recent speedups to ResultProxy, total number of function calls significantly reduced for large loads. test/perf/masseagerload.py reports 0.4 as having the fewest number diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index caa5e4412a..0255a922d2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -394,7 +394,8 @@ class ExtensionOption(MapperOption): self.ext = ext def process_query(self, query): - query.extension.append(self.ext) + query._extension = query._extension.copy() + query._extension.append(self.ext) class SynonymProperty(MapperProperty): def __init__(self, name, proxy=False): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index f00ee4203b..d4318049b4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -7,7 +7,6 @@ from sqlalchemy import sql, util, exceptions, sql_util, logging, schema from sqlalchemy.orm import mapper, class_mapper, object_mapper from sqlalchemy.orm.interfaces import OperationContext, SynonymProperty -from sqlalchemy.orm.util import ExtensionCarrier __all__ = ['Query', 'QueryContext', 'SelectionContext'] @@ -19,54 +18,31 @@ class Query(object): self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name) else: self.mapper = class_or_mapper.compile() - self.with_options = [] self.select_mapper = self.mapper.get_select_mapper().compile() - self.lockmode = None - self.extension = self.mapper.extension.copy() + self._session = session + self._with_options = [] + self._lockmode = None + self._extension = self.mapper.extension.copy() self._entities = [] - self._order_by = False self._group_by = False self._distinct = False self._offset = None self._limit = None - self._statement = None self._params = {} self._criterion = None - self._col = None - self._func = None + self._column_aggregate = None self._joinpoint = self.mapper self._from_obj = [self.table] self._populate_existing = False self._version_check = False - def _clone(self): q = Query.__new__(Query) - q.mapper = self.mapper - q.select_mapper = self.select_mapper - q._order_by = self._order_by - q._distinct = self._distinct - q._entities = list(self._entities) - q.with_options = list(self.with_options) - q._session = self.session - q.lockmode = self.lockmode - q.extension = self.extension.copy() - q._offset = self._offset - q._limit = self._limit - q._group_by = self._group_by - q._from_obj = list(self._from_obj) - q._joinpoint = self._joinpoint - q._criterion = self._criterion - q._statement = self._statement - q._params = self._params.copy() - q._populate_existing = self._populate_existing - q._version_check = self._version_check - q._col = self._col - q._func = self._func + q.__dict__ = self.__dict__.copy() return q def _get_session(self): @@ -88,7 +64,7 @@ class Query(object): columns. """ - ret = self.extension.get(self, ident, **kwargs) + ret = self._extension.get(self, ident, **kwargs) if ret is not mapper.EXT_PASS: return ret key = self.mapper.identity_key(ident) @@ -105,7 +81,7 @@ class Query(object): columns. """ - ret = self.extension.load(self, ident, **kwargs) + ret = self._extension.load(self, ident, **kwargs) if ret is not mapper.EXT_PASS: return ret key = self.mapper.identity_key(ident) @@ -207,7 +183,7 @@ class Query(object): """ q = self._clone() - q._entities.append(entity) + q._entities = q._entities + [entity] return q def add_column(self, column): @@ -233,31 +209,32 @@ class Query(object): """ q = self._clone() - q._entities.append(column) + q._entities = q._entities + [column] return q def options(self, *args): """Return a new Query object, applying the given list of MapperOptions. """ + q = self._clone() - for opt in util.flatten_iterator(args): - q.with_options.append(opt) + opts = [o for o in util.flatten_iterator(args)] + q._with_options = q._with_options + opts + for opt in opts: opt.process_query(q) - for opt in util.flatten_iterator(self.with_options): - opt.process_query(self) return q def with_lockmode(self, mode): """Return a new Query object with the specified locking mode.""" q = self._clone() - q.lockmode = mode + q._lockmode = mode return q def params(self, **kwargs): """add values for bind parameters which may have been specified in filter().""" q = self._clone() + q._params = q._params.copy() q._params.update(kwargs) return q @@ -337,16 +314,14 @@ class Query(object): mapper = prop.mapper return (clause, mapper) - def _generative_col_aggregate(self, col, func): """apply the given aggregate function to the query and return the newly resulting ``Query``. """ - if self._col is not None or self._func is not None: + if self._column_aggregate is not None: raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") q = self._clone() - q._col = col - q._func = func + q._column_aggregate = (col, func) return q def apply_min(self, col): @@ -414,7 +389,7 @@ class Query(object): if q._order_by is False: q._order_by = util.to_list(criterion) else: - q._order_by.extend(util.to_list(criterion)) + q._order_by = q._order_by + util.to_list(criterion) return q def group_by(self, criterion): @@ -424,7 +399,7 @@ class Query(object): if q._group_by is False: q._group_by = util.to_list(criterion) else: - q._group_by.extend(util.to_list(criterion)) + q._group_by = q._group_by + util.to_list(criterion) return q def join(self, prop): @@ -469,20 +444,6 @@ class Query(object): new._from_obj = list(new._from_obj) + util.to_list(from_obj) return new - def __getattr__(self, key): - if (key.startswith('select_by_')): - key = key[10:] - def foo(arg): - return self.select_by(**{key:arg}) - return foo - elif (key.startswith('get_by_')): - key = key[7:] - def foo(arg): - return self.get_by(**{key:arg}) - return foo - else: - raise AttributeError(key) - def __getitem__(self, item): if isinstance(item, slice): start = item.start @@ -537,10 +498,6 @@ class Query(object): """ return list(self) - def list(self): - """deprecated. use all()""" - - return list(self) def from_statement(self, statement): if isinstance(statement, basestring): @@ -554,24 +511,25 @@ class Query(object): This results in an execution of the underlying query. """ - if self._col is None or self._func is None: - ret = list(self[0:1]) - if len(ret) > 0: - return ret[0] - else: - return None - else: - return self._col_aggregate(self._col, self._func) - def scalar(self): - """deprecated. use first()""" - return self.first() + if self._column_aggregate is not None: + return self._col_aggregate(*self._column_aggregate) + + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] + else: + return None def one(self): """Return the first result of this ``Query``, raising an exception if more than one row exists. This results in an execution of the underlying query. """ + + if self._column_aggregate is not None: + return self._col_aggregate(*self._column_aggregate) + ret = list(self[0:2]) if len(ret) == 1: @@ -621,7 +579,7 @@ class Query(object): kwargs.setdefault('populate_existing', self._populate_existing) kwargs.setdefault('version_check', self._version_check) - context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs) + context = SelectionContext(self.select_mapper, session, self._extension, with_options=self._with_options, **kwargs) process = [] mappers_or_columns = tuple(self._entities) + mappers_or_columns @@ -667,7 +625,7 @@ class Query(object): def _get(self, key, ident=None, reload=False, lockmode=None): - lockmode = lockmode or self.lockmode + lockmode = lockmode or self._lockmode if not reload and not self.mapper.always_refresh and lockmode is None: try: return self.session._get(key) @@ -716,13 +674,25 @@ class Query(object): and other generative methods to establish modifiers. """ - if self._criterion: - if whereclause is not None: - whereclause = sql.and_(self._criterion, whereclause) - else: - whereclause = self._criterion - from_obj = kwargs.pop('from_obj', self._from_obj) - kwargs.setdefault('distinct', self._distinct) + q = self + if whereclause is not None: + q = q.filter(whereclause) + if params is not None: + q = q.params(**params) + q = q._legacy_select_kwargs(**kwargs) + return q._count() + + def _count(self): + """Apply this query's criterion to a SELECT COUNT statement. + + this is the purely generative version which will become + the public method in version 0.5. + """ + + whereclause = self._criterion + + context = QueryContext(self) + from_obj = context.from_obj alltables = [] for l in [sql_util.TableFinder(x) for x in from_obj]: @@ -730,18 +700,13 @@ class Query(object): if self.table not in alltables: from_obj.append(self.table) - if self._nestable(**kwargs): - s = sql.select([self.table], whereclause, from_obj=from_obj, **kwargs).alias('getcount').count() + if self._nestable(**context.select_args()): + s = sql.select([self.table], whereclause, from_obj=from_obj, **context.select_args()).alias('getcount').count() else: primary_key = self.primary_key_columns - s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) - if params is None: - params = {} - else: - params = params.copy() - params.update(self._params) - return self.session.scalar(self.mapper, s, params=params) - + s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **context.select_args()) + return self.session.scalar(self.mapper, s, params=self._params) + def compile(self): """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" @@ -856,6 +821,16 @@ class Query(object): # DEPRECATED LAND ! + def list(self): + """DEPRECATED. use all()""" + + return list(self) + + def scalar(self): + """DEPRECATED. use first()""" + + return self.first() + def _legacy_filter_by(self, *args, **kwargs): return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint)) @@ -895,7 +870,7 @@ class Query(object): def get_by(self, *args, **params): """DEPRECATED. use query.filter(*args).filter_by(**params).first()""" - ret = self.extension.get_by(self, *args, **params) + ret = self._extension.get_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret @@ -904,7 +879,7 @@ class Query(object): def select_by(self, *args, **params): """DEPRECATED. use use query.filter(*args).filter_by(**params).all().""" - ret = self.extension.select_by(self, *args, **params) + ret = self._extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret @@ -934,7 +909,7 @@ class Query(object): def select(self, arg=None, **kwargs): """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" - ret = self.extension.select(self, arg=arg, **kwargs) + ret = self._extension.select(self, arg=arg, **kwargs) if ret is not mapper.EXT_PASS: return ret return self._build_select(arg, **kwargs).all() @@ -1066,13 +1041,13 @@ class QueryContext(OperationContext): self.order_by = query._order_by self.group_by = query._group_by self.from_obj = query._from_obj - self.lockmode = query.lockmode + self.lockmode = query._lockmode self.distinct = query._distinct self.limit = query._limit self.offset = query._offset self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) self.statement = None - super(QueryContext, self).__init__(query.mapper, query.with_options) + super(QueryContext, self).__init__(query.mapper, query._with_options) def select_args(self): """Return a dictionary of attributes from this diff --git a/test/orm/generative.py b/test/orm/generative.py index 8882a6f5c2..d463928165 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -60,7 +60,8 @@ class GenerativeQueryTest(PersistTest): assert self.query.count() == 100 assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0 assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29 - assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).scalar() == 29 + assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29 + assert self.query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29 @testbase.unsupported('mysql') def test_aggregate_1(self): @@ -77,7 +78,8 @@ class GenerativeQueryTest(PersistTest): @testbase.unsupported('postgres', 'mysql', 'firebird', 'mssql') def test_aggregate_3(self): - assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).scalar() == 14.5 + assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).first() == 14.5 + assert self.res.filter(foo.c.bar<30).apply_avg(foo.c.bar).one() == 14.5 def test_filter(self): assert self.query.count() == 100