From: Mike Bayer Date: Mon, 4 Jun 2007 23:50:22 +0000 (+0000) Subject: refactoring step 2. all deprecated functions now express their functionality X-Git-Tag: rel_0_4_6~222 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6abe864cebab652048e58cb24edd736e17922b9b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git refactoring step 2. all deprecated functions now express their functionality in terms of generative behavior. also the thing will run like crap right now until the next refactor stage...stay tuned --- diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 31ba414d4a..c894f3767f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -39,11 +39,15 @@ class Query(object): self._distinct = kwargs.pop('distinct', False) self._offset = kwargs.pop('offset', None) self._limit = kwargs.pop('limit', None) + self._statement = None + self._params = {} self._criterion = None self._col = None self._func = None self._joinpoint = self.mapper self._from_obj = [self.table] + self._populate_existing = False + self._version_check = False for opt in util.flatten_iterator(self.with_options): opt.process_query(self) @@ -68,6 +72,10 @@ class Query(object): q._from_obj = list(self._from_obj) q._joinpoint = self._joinpoint q._criterion = self._criterion + q._statement = self._statement + q._params = self._params.copy() + q._populate_existing = self._populate_existing + q._version_check = self._version_check q._col = self._col q._func = self._func return q @@ -143,6 +151,11 @@ class Query(object): else: primary_key = self.primary_key_columns s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) + if params is None: + params = {} + else: + params = params.copy() + params.update(self._params) return self.session.scalar(self.mapper, s, params=params) def _with_lazy_criterion(cls, instance, prop, reverse=False): @@ -284,6 +297,13 @@ class Query(object): q.lockmode = mode return q + def params(self, **kwargs): + """add values for bind parameters which may have been specified in filter().""" + + q = self._clone() + q._params.update(kwargs) + return q + def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` @@ -565,13 +585,24 @@ class Query(object): return list(self) + def from_statement(self, statement): + if isinstance(statement, basestring): + statement = sql.text(statement) + q = self._clone() + q._statement = statement + return q + def first(self): """Return the first result of this ``Query``. This results in an execution of the underlying query. """ if self._col is None or self._func is None: - return self[0] + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] + else: + return None else: return self._col_aggregate(self._col, self._func) @@ -594,7 +625,13 @@ class Query(object): raise exceptions.InvalidRequestError('Multiple rows returned for one()') def __iter__(self): - return iter(self.select_whereclause()) + statement = self.compile() + statement.use_labels = True + result = self.session.execute(self.mapper, statement, params=self._params) + try: + return iter(self.instances(result)) + finally: + result.close() def instances(self, cursor, *mappers_or_columns, **kwargs): @@ -624,6 +661,9 @@ class Query(object): session = self.session + kwargs.setdefault('populate_existing', self._populate_existing) + kwargs.setdefault('version_check', self._version_check) + context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs) process = [] @@ -685,8 +725,12 @@ class Query(object): for i, primary_key in enumerate(self.primary_key_columns): params[primary_key._label] = ident[i] try: - statement = self.compile(self._get_clause, lockmode=lockmode) - return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0] + q = self + if lockmode is not None: + q = q.with_lockmode(lockmode) + q = q.filter(self._get_clause) + q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) + return q.first() except IndexError: return None @@ -708,13 +752,14 @@ class Query(object): return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) - def compile(self, whereclause = None, **kwargs): - """Given a WHERE criterion, produce a ClauseElement-based - statement suitable for usage in the execute() method. - """ - - if self._criterion: - whereclause = sql.and_(self._criterion, whereclause) + def compile(self): + """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" + + if self._statement: + self._statement.use_labels = True + return self._statement + + whereclause = self._criterion if whereclause is not None and self.is_polymorphic: # adapt the given WHERECLAUSE to adjust instances of this query's mapped @@ -732,9 +777,7 @@ class Query(object): # get/create query context. get the ultimate compile arguments # from there - context = kwargs.pop('query_context', None) - if context is None: - context = QueryContext(self, kwargs) + context = QueryContext(self) order_by = context.order_by group_by = context.group_by from_obj = context.from_obj @@ -790,10 +833,12 @@ class Query(object): statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) if order_by: statement.order_by(*util.to_list(order_by)) + # for a DISTINCT query, you need the columns explicitly specified in order # to use it in "order_by". ensure they are in the column criterion (particularly oid). # TODO: this should be done at the SQL level not the mapper level - if kwargs.get('distinct', False) and order_by: + # TODO: need test coverage for this + if context.distinct and order_by: [statement.append_column(c) for c in util.to_list(order_by)] context.statement = statement @@ -829,59 +874,33 @@ class Query(object): return self.count(self.join_by(*args, **params)) - def selectfirst(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).first()""" - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 1 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if ret: - return ret[0] - else: - return None - - def selectone(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).one()""" - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 2 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone') - - def select(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" - - ret = self.extension.select(self, arg=arg, **kwargs) - if ret is not mapper.EXT_PASS: - return ret - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return self.select_statement(arg, **kwargs) - else: - return self.select_whereclause(whereclause=arg, **kwargs) def select_whereclause(self, whereclause=None, params=None, **kwargs): """DEPRECATED. use query.filter(whereclause).all()""" - statement = self.compile(whereclause, **kwargs) - return self._select_statement(statement, params=params) - - def execute(self, clauseelement, params=None, *args, **kwargs): - """DEPRECATED. use query.select_from()""" + q = self.filter(whereclause)._legacy_select_kwargs(**kwargs) + if params is not None: + q = q.params(**params) + return list(q) + + def _legacy_select_kwargs(self, **kwargs): + q = self + if "order_by" in kwargs and kwargs['order_by']: + q = q.order_by(kwargs['order_by']) + if "group_by" in kwargs: + q = q.group_by(kwargs['group_by']) + if "from_obj" in kwargs: + q = q.select_from(kwargs['from_obj']) + if "lockmode" in kwargs: + q = q.with_lockmode(kwargs['lockmode']) + if "distinct" in kwargs: + q = q.distinct() + if "limit" in kwargs: + q = q.limit(kwargs['limit']) + if "offset" in kwargs: + q = q.offset(kwargs['offset']) + return q - result = self.session.execute(self.mapper, clauseelement, params=params) - try: - return self.instances(result, **kwargs) - finally: - result.close() def get_by(self, *args, **params): """DEPRECATED. use query.filter(*args).filter_by(**params).first()""" @@ -889,42 +908,76 @@ class Query(object): ret = self.extension.get_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret - x = self.select_whereclause(self.join_by(*args, **params), limit=1) - if x: - return x[0] - else: - return None + + return self._legacy_filter_by(*args, **params).first() def select_by(self, *args, **params): - """DEPRECATED. use use query.filter(*args).filter_by(**params).list().""" + """DEPRECATED. use use query.filter(*args).filter_by(**params).all().""" ret = self.extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret - return self.select_whereclause(self.join_by(*args, **params)) + + return self._legacy_filter_by(*args, **params).list() def join_by(self, *args, **params): """DEPRECATED. use join() to construct joins based on attribute names.""" return self._legacy_join_by(args, params, start=self._joinpoint) + def _build_select(self, arg=None, params=None, **kwargs): + if isinstance(arg, sql.FromClause) and arg.supports_execution(): + return self.from_statement(arg) + else: + return self.filter(arg)._legacy_select_kwargs(**kwargs) + + def selectfirst(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).first()""" + + return self._build_select(arg, **kwargs).first() + + def selectone(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).one()""" + + return self._build_select(arg, **kwargs).one() + + def select(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" + + ret = self.extension.select(self, arg=arg, **kwargs) + if ret is not mapper.EXT_PASS: + return ret + return self._build_select(arg, **kwargs).all() + + def execute(self, clauseelement, params=None, *args, **kwargs): + """DEPRECATED. use query.from_statement().all()""" + + return self._select_statement(statement, params, **kwargs) + def select_statement(self, statement, **params): """DEPRECATED. Use query.from_statement(statement)""" - - return self._select_statement(statement, params=params) + + return self._select_statement(statement, params) def select_text(self, text, **params): """DEPRECATED. Use query.from_statement(statement)""" - t = sql.text(text) - return self.execute(t, params=params) + return self._select_statement(statement, params) def _select_statement(self, statement, params=None, **kwargs): - statement.use_labels = True - if params is None: - params = {} - return self.execute(statement, params=params, **kwargs) - + q = self.from_statement(statement) + if params is not None: + q = q.params(**params) + q._select_context_options(**kwargs) + return list(q) + + def _select_context_options(self, populate_existing=None, version_check=None): + if populate_existing is not None: + self._populate_existing = populate_existing + if version_check is not None: + self._version_check = version_check + return self + def join_to(self, key): """DEPRECATED. use join() to create joins based on property names.""" @@ -1001,18 +1054,12 @@ class Query(object): def selectfirst_by(self, *args, **params): """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).first()""" - return self.get_by(*args, **params) + return self._legacy_filter_by(*args, **params).first() def selectone_by(self, *args, **params): """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).one()""" - ret = self.select_whereclause(self.join_by(*args, **params), limit=2) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by') + return self._legacy_filter_by(*args, **params).one() @@ -1024,18 +1071,18 @@ class QueryContext(OperationContext): in a query construction. """ - def __init__(self, query, kwargs): + def __init__(self, query): self.query = query - self.order_by = kwargs.pop('order_by', query._order_by) - self.group_by = kwargs.pop('group_by', query._group_by) - self.from_obj = kwargs.pop('from_obj', query._from_obj) - self.lockmode = kwargs.pop('lockmode', query.lockmode) - self.distinct = kwargs.pop('distinct', query._distinct) - self.limit = kwargs.pop('limit', query._limit) - self.offset = kwargs.pop('offset', query._offset) + self.order_by = query._order_by + self.group_by = query._group_by + self.from_obj = query._from_obj + self.lockmode = query.lockmode + self.distinct = query._distinct + self.limit = query._limit + self.offset = query._offset self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) self.statement = None - super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs) + super(QueryContext, self).__init__(query.mapper, query.with_options) def select_args(self): """Return a dictionary of attributes from this diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 558fa62809..e754945bb8 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1229,7 +1229,7 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - s = session.query(m).compile(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)) + s = session.query(m).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)).compile() c = s.compile() self.echo("\n" + str(c) + repr(c.get_params())) diff --git a/test/orm/query.py b/test/orm/query.py index 8d3f5e67d4..fbee5e88c8 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -189,6 +189,12 @@ class GetTest(QueryTest): class LocalFoo(Base):pass mapper(LocalFoo, table) assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) + +class SliceTest(QueryTest): + def test_first(self): + assert create_session().query(User).first() == User(id=7) + + assert create_session().query(User).filter(users.c.id==27).first() is None class FilterTest(QueryTest): def test_basic(self):