From 6abe864cebab652048e58cb24edd736e17922b9b Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 4 Jun 2007 23:50:22 +0000 Subject: [PATCH] refactoring step 2. all deprecated functions now express their functionality in terms of generative behavior. also the thing will run like crap right now until the next refactor stage...stay tuned --- lib/sqlalchemy/orm/query.py | 239 +++++++++++++++++++++--------------- test/orm/mapper.py | 2 +- test/orm/query.py | 6 + 3 files changed, 150 insertions(+), 97 deletions(-) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 31ba414d4a..c894f3767f 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -39,11 +39,15 @@ class Query(object): self._distinct = kwargs.pop('distinct', False) self._offset = kwargs.pop('offset', None) self._limit = kwargs.pop('limit', None) + self._statement = None + self._params = {} self._criterion = None self._col = None self._func = None self._joinpoint = self.mapper self._from_obj = [self.table] + self._populate_existing = False + self._version_check = False for opt in util.flatten_iterator(self.with_options): opt.process_query(self) @@ -68,6 +72,10 @@ class Query(object): q._from_obj = list(self._from_obj) q._joinpoint = self._joinpoint q._criterion = self._criterion + q._statement = self._statement + q._params = self._params.copy() + q._populate_existing = self._populate_existing + q._version_check = self._version_check q._col = self._col q._func = self._func return q @@ -143,6 +151,11 @@ class Query(object): else: primary_key = self.primary_key_columns s = sql.select([sql.func.count(list(primary_key)[0])], whereclause, from_obj=from_obj, **kwargs) + if params is None: + params = {} + else: + params = params.copy() + params.update(self._params) return self.session.scalar(self.mapper, s, params=params) def _with_lazy_criterion(cls, instance, prop, reverse=False): @@ -284,6 +297,13 @@ class Query(object): q.lockmode = mode return q + def params(self, **kwargs): + """add values for bind parameters which may have been specified in filter().""" + + q = self._clone() + q._params.update(kwargs) + return q + def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` @@ -565,13 +585,24 @@ class Query(object): return list(self) + def from_statement(self, statement): + if isinstance(statement, basestring): + statement = sql.text(statement) + q = self._clone() + q._statement = statement + return q + def first(self): """Return the first result of this ``Query``. This results in an execution of the underlying query. """ if self._col is None or self._func is None: - return self[0] + ret = list(self[0:1]) + if len(ret) > 0: + return ret[0] + else: + return None else: return self._col_aggregate(self._col, self._func) @@ -594,7 +625,13 @@ class Query(object): raise exceptions.InvalidRequestError('Multiple rows returned for one()') def __iter__(self): - return iter(self.select_whereclause()) + statement = self.compile() + statement.use_labels = True + result = self.session.execute(self.mapper, statement, params=self._params) + try: + return iter(self.instances(result)) + finally: + result.close() def instances(self, cursor, *mappers_or_columns, **kwargs): @@ -624,6 +661,9 @@ class Query(object): session = self.session + kwargs.setdefault('populate_existing', self._populate_existing) + kwargs.setdefault('version_check', self._version_check) + context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs) process = [] @@ -685,8 +725,12 @@ class Query(object): for i, primary_key in enumerate(self.primary_key_columns): params[primary_key._label] = ident[i] try: - statement = self.compile(self._get_clause, lockmode=lockmode) - return self._select_statement(statement, params=params, populate_existing=reload, version_check=(lockmode is not None))[0] + q = self + if lockmode is not None: + q = q.with_lockmode(lockmode) + q = q.filter(self._get_clause) + q = q.params(**params)._select_context_options(populate_existing=reload, version_check=(lockmode is not None)) + return q.first() except IndexError: return None @@ -708,13 +752,14 @@ class Query(object): return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) - def compile(self, whereclause = None, **kwargs): - """Given a WHERE criterion, produce a ClauseElement-based - statement suitable for usage in the execute() method. - """ - - if self._criterion: - whereclause = sql.and_(self._criterion, whereclause) + def compile(self): + """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" + + if self._statement: + self._statement.use_labels = True + return self._statement + + whereclause = self._criterion if whereclause is not None and self.is_polymorphic: # adapt the given WHERECLAUSE to adjust instances of this query's mapped @@ -732,9 +777,7 @@ class Query(object): # get/create query context. get the ultimate compile arguments # from there - context = kwargs.pop('query_context', None) - if context is None: - context = QueryContext(self, kwargs) + context = QueryContext(self) order_by = context.order_by group_by = context.group_by from_obj = context.from_obj @@ -790,10 +833,12 @@ class Query(object): statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **context.select_args()) if order_by: statement.order_by(*util.to_list(order_by)) + # for a DISTINCT query, you need the columns explicitly specified in order # to use it in "order_by". ensure they are in the column criterion (particularly oid). # TODO: this should be done at the SQL level not the mapper level - if kwargs.get('distinct', False) and order_by: + # TODO: need test coverage for this + if context.distinct and order_by: [statement.append_column(c) for c in util.to_list(order_by)] context.statement = statement @@ -829,59 +874,33 @@ class Query(object): return self.count(self.join_by(*args, **params)) - def selectfirst(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).first()""" - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 1 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if ret: - return ret[0] - else: - return None - - def selectone(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).one()""" - - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - ret = self.select_statement(arg, **kwargs) - else: - kwargs['limit'] = 2 - ret = self.select_whereclause(whereclause=arg, **kwargs) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone') - - def select(self, arg=None, **kwargs): - """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" - - ret = self.extension.select(self, arg=arg, **kwargs) - if ret is not mapper.EXT_PASS: - return ret - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return self.select_statement(arg, **kwargs) - else: - return self.select_whereclause(whereclause=arg, **kwargs) def select_whereclause(self, whereclause=None, params=None, **kwargs): """DEPRECATED. use query.filter(whereclause).all()""" - statement = self.compile(whereclause, **kwargs) - return self._select_statement(statement, params=params) - - def execute(self, clauseelement, params=None, *args, **kwargs): - """DEPRECATED. use query.select_from()""" + q = self.filter(whereclause)._legacy_select_kwargs(**kwargs) + if params is not None: + q = q.params(**params) + return list(q) + + def _legacy_select_kwargs(self, **kwargs): + q = self + if "order_by" in kwargs and kwargs['order_by']: + q = q.order_by(kwargs['order_by']) + if "group_by" in kwargs: + q = q.group_by(kwargs['group_by']) + if "from_obj" in kwargs: + q = q.select_from(kwargs['from_obj']) + if "lockmode" in kwargs: + q = q.with_lockmode(kwargs['lockmode']) + if "distinct" in kwargs: + q = q.distinct() + if "limit" in kwargs: + q = q.limit(kwargs['limit']) + if "offset" in kwargs: + q = q.offset(kwargs['offset']) + return q - result = self.session.execute(self.mapper, clauseelement, params=params) - try: - return self.instances(result, **kwargs) - finally: - result.close() def get_by(self, *args, **params): """DEPRECATED. use query.filter(*args).filter_by(**params).first()""" @@ -889,42 +908,76 @@ class Query(object): ret = self.extension.get_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret - x = self.select_whereclause(self.join_by(*args, **params), limit=1) - if x: - return x[0] - else: - return None + + return self._legacy_filter_by(*args, **params).first() def select_by(self, *args, **params): - """DEPRECATED. use use query.filter(*args).filter_by(**params).list().""" + """DEPRECATED. use use query.filter(*args).filter_by(**params).all().""" ret = self.extension.select_by(self, *args, **params) if ret is not mapper.EXT_PASS: return ret - return self.select_whereclause(self.join_by(*args, **params)) + + return self._legacy_filter_by(*args, **params).list() def join_by(self, *args, **params): """DEPRECATED. use join() to construct joins based on attribute names.""" return self._legacy_join_by(args, params, start=self._joinpoint) + def _build_select(self, arg=None, params=None, **kwargs): + if isinstance(arg, sql.FromClause) and arg.supports_execution(): + return self.from_statement(arg) + else: + return self.filter(arg)._legacy_select_kwargs(**kwargs) + + def selectfirst(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).first()""" + + return self._build_select(arg, **kwargs).first() + + def selectone(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).one()""" + + return self._build_select(arg, **kwargs).one() + + def select(self, arg=None, **kwargs): + """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" + + ret = self.extension.select(self, arg=arg, **kwargs) + if ret is not mapper.EXT_PASS: + return ret + return self._build_select(arg, **kwargs).all() + + def execute(self, clauseelement, params=None, *args, **kwargs): + """DEPRECATED. use query.from_statement().all()""" + + return self._select_statement(statement, params, **kwargs) + def select_statement(self, statement, **params): """DEPRECATED. Use query.from_statement(statement)""" - - return self._select_statement(statement, params=params) + + return self._select_statement(statement, params) def select_text(self, text, **params): """DEPRECATED. Use query.from_statement(statement)""" - t = sql.text(text) - return self.execute(t, params=params) + return self._select_statement(statement, params) def _select_statement(self, statement, params=None, **kwargs): - statement.use_labels = True - if params is None: - params = {} - return self.execute(statement, params=params, **kwargs) - + q = self.from_statement(statement) + if params is not None: + q = q.params(**params) + q._select_context_options(**kwargs) + return list(q) + + def _select_context_options(self, populate_existing=None, version_check=None): + if populate_existing is not None: + self._populate_existing = populate_existing + if version_check is not None: + self._version_check = version_check + return self + def join_to(self, key): """DEPRECATED. use join() to create joins based on property names.""" @@ -1001,18 +1054,12 @@ class Query(object): def selectfirst_by(self, *args, **params): """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).first()""" - return self.get_by(*args, **params) + return self._legacy_filter_by(*args, **params).first() def selectone_by(self, *args, **params): """DEPRECATED. Use query.filter(*args).filter_by(**kwargs).one()""" - ret = self.select_whereclause(self.join_by(*args, **params), limit=2) - if len(ret) == 1: - return ret[0] - elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for selectone_by') - else: - raise exceptions.InvalidRequestError('Multiple rows returned for selectone_by') + return self._legacy_filter_by(*args, **params).one() @@ -1024,18 +1071,18 @@ class QueryContext(OperationContext): in a query construction. """ - def __init__(self, query, kwargs): + def __init__(self, query): self.query = query - self.order_by = kwargs.pop('order_by', query._order_by) - self.group_by = kwargs.pop('group_by', query._group_by) - self.from_obj = kwargs.pop('from_obj', query._from_obj) - self.lockmode = kwargs.pop('lockmode', query.lockmode) - self.distinct = kwargs.pop('distinct', query._distinct) - self.limit = kwargs.pop('limit', query._limit) - self.offset = kwargs.pop('offset', query._offset) + self.order_by = query._order_by + self.group_by = query._group_by + self.from_obj = query._from_obj + self.lockmode = query.lockmode + self.distinct = query._distinct + self.limit = query._limit + self.offset = query._offset self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) self.statement = None - super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs) + super(QueryContext, self).__init__(query.mapper, query.with_options) def select_args(self): """Return a dictionary of attributes from this diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 558fa62809..e754945bb8 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1229,7 +1229,7 @@ class EagerTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - s = session.query(m).compile(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)) + s = session.query(m).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.user_id)).compile() c = s.compile() self.echo("\n" + str(c) + repr(c.get_params())) diff --git a/test/orm/query.py b/test/orm/query.py index 8d3f5e67d4..fbee5e88c8 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -189,6 +189,12 @@ class GetTest(QueryTest): class LocalFoo(Base):pass mapper(LocalFoo, table) assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) + +class SliceTest(QueryTest): + def test_first(self): + assert create_session().query(User).first() == User(id=7) + + assert create_session().query(User).filter(users.c.id==27).first() is None class FilterTest(QueryTest): def test_basic(self): -- 2.47.3