From b2e04755cc5f382596fb174c7381f60dd2972d94 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 10 Mar 2007 02:49:12 +0000 Subject: [PATCH] - the full featureset of the SelectResults extension has been merged into a new set of methods available off of Query. These methods all provide "generative" behavior, whereby the Query is copied and a new one returned with additional criterion added. The new methods include: filter() - applies select criterion to the query filter_by() - applies "by"-style criterion to the query avg() - return the avg() function on the given column join() - join to a property (or across a list of properties) outerjoin() - like join() but uses LEFT OUTER JOIN limit()/offset() - apply LIMIT/OFFSET range-based access which applies limit/offset: session.query(Foo)[3:5] distinct() - apply DISTINCT list() - evaluate the criterion and return results no incompatible changes have been made to Query's API and no methods have been deprecated. Existing methods like select(), select_by(), get(), get_by() all execute the query at once and return results like they always did. join_to()/join_via() are still there although the generative join()/outerjoin() methods are easier to use. - the return value for multiple mappers used with instances() now returns a cartesian product of the requested list of mappers, represented as a list of tuples. this corresponds to the documented behavior. So that instances match up properly, the "uniquing" is disabled when this feature is used. - strings and columns can also be sent to the *args of instances() where those exact result columns will be part of the result tuples. - query() method is added by assignmapper. this helps with navigating to all the new generative methods on Query. --- CHANGES | 34 ++- lib/sqlalchemy/ext/assignmapper.py | 1 + lib/sqlalchemy/orm/mapper.py | 5 +- lib/sqlalchemy/orm/query.py | 445 ++++++++++++++++++++++------- test/ext/selectresults.py | 10 +- test/orm/alltests.py | 1 + test/orm/generative.py | 229 +++++++++++++++ test/orm/inheritance5.py | 11 +- test/orm/mapper.py | 204 +++++++------ 9 files changed, 747 insertions(+), 193 deletions(-) create mode 100644 test/orm/generative.py diff --git a/CHANGES b/CHANGES index 96391716c0..35ffc81d64 100644 --- a/CHANGES +++ b/CHANGES @@ -26,6 +26,36 @@ - fixed use_alter flag on ForeignKeyConstraint [ticket:503] - fixed usage of 2.4-only "reversed" in topological.py [ticket:506] - orm: + - the full featureset of the SelectResults extension has been merged + into a new set of methods available off of Query. These methods + all provide "generative" behavior, whereby the Query is copied + and a new one returned with additional criterion added. + The new methods include: + + filter() - applies select criterion to the query + filter_by() - applies "by"-style criterion to the query + avg() - return the avg() function on the given column + join() - join to a property (or across a list of properties) + outerjoin() - like join() but uses LEFT OUTER JOIN + limit()/offset() - apply LIMIT/OFFSET + range-based access which applies limit/offset: + session.query(Foo)[3:5] + distinct() - apply DISTINCT + list() - evaluate the criterion and return results + + no incompatible changes have been made to Query's API and no methods + have been deprecated. Existing methods like select(), select_by(), + get(), get_by() all execute the query at once and return results + like they always did. join_to()/join_via() are still there although + the generative join()/outerjoin() methods are easier to use. + + - the return value for multiple mappers used with instances() now returns + a cartesian product of the requested list of mappers, represented + as a list of tuples. this corresponds to the documented behavior. + So that instances match up properly, the "uniquing" is disabled when + this feature is used. + - strings and columns can also be sent to the *args of instances() where + those exact result columns will be part of the result tuples. - a full select() construct can be passed to query.select() (which worked anyway), but also query.selectfirst(), query.selectone() which will be used as is (i.e. no query is compiled). works similarly to @@ -46,7 +76,9 @@ - extensions: - options() method on SelectResults now implemented "generatively" like the rest of the SelectResults methods [ticket:472] - + - query() method is added by assignmapper. this helps with + navigating to all the new generative methods on Query. + 0.3.5 - sql: - the value of "case_sensitive" defaults to True now, regardless of the diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index 178f150e54..aee96f06ea 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -34,6 +34,7 @@ def assign_mapper(ctx, class_, *args, **kwargs): extension = ctx.mapper_extension m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m + class_.query = classmethod(lambda cls: Query(class_, session=ctx.current)) for name in ['get', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'get_by', 'join_to', 'join_via', 'count', 'count_by', 'options', 'instances']: monkeypatch_query_method(ctx, class_, name) for name in ['flush', 'delete', 'expire', 'refresh', 'expunge', 'merge', 'save', 'update', 'save_or_update']: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e1fa56c650..d28445be61 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1695,6 +1695,9 @@ class _ExtensionCarrier(MapperExtension): def __init__(self): self.__elements = [] + def __iter__(self): + return iter(self.__elements) + def insert(self, extension): """Insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" @@ -1766,7 +1769,7 @@ class ExtensionOption(MapperOption): self.ext = ext def process_query(self, query): - query._insert_extension(self.ext) + query.extension.append(self.ext) class ClassKey(object): """Key a class and an entity name to a mapper, via the mapper_registry.""" diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8df5628d15..6650954e1b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -21,7 +21,6 @@ class Query(object): self.with_options = with_options or [] self.select_mapper = self.mapper.get_select_mapper().compile() self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) - self.order_by = kwargs.pop('order_by', self.mapper.order_by) self.lockmode = lockmode self.extension = mapper._ExtensionCarrier() if extension is not None: @@ -35,12 +34,40 @@ class Query(object): _get_clause.clauses.append(primary_key == sql.bindparam(primary_key._label, type=primary_key.type, unique=True)) self.mapper._get_clause = _get_clause self._get_clause = self.mapper._get_clause - for opt in util.flatten_iterator(self.with_options): - opt.process_query(self) - def _insert_extension(self, ext): - self.extension.insert(ext) + self._order_by = kwargs.pop('order_by', False) + self._distinct = kwargs.pop('distinct', False) + self._offset = kwargs.pop('offset', None) + self._limit = kwargs.pop('limit', None) + self._criterion = None + self._joinpoint = self.mapper + self._from_obj = [self.table] + for opt in util.flatten_iterator(self.with_options): + opt.process_query(self) + + def _clone(self): + q = Query.__new__(Query) + q.mapper = self.mapper + q.select_mapper = self.select_mapper + q._order_by = self._order_by + q._distinct = self._distinct + q.always_refresh = self.always_refresh + q.with_options = list(self.with_options) + q._session = self.session + q.is_polymorphic = self.is_polymorphic + q.lockmode = self.lockmode + q.extension = mapper._ExtensionCarrier() + for ext in self.extension: + q.extension.append(ext) + q._offset = self._offset + q._limit = self._limit + q._get_clause = self._get_clause + q._from_obj = list(self._from_obj) + q._joinpoint = self._joinpoint + q._criterion = self._criterion + return q + def _get_session(self): if self._session is None: return self.mapper.get_session() @@ -90,20 +117,8 @@ class Query(object): """Return a single object instance based on the given key/value criterion. - This is either the first value in the result list, or None if - the list is empty. - - The keys are mapped to property or column names mapped by this - mapper's Table, and the values are coerced into a ``WHERE`` - clause separated by ``AND`` operators. If the local - property/column names dont contain the key, a search will be - performed against this mapper's immediate list of relations as - well, forming the appropriate join conditions if a matching - property is located. - - E.g.:: - - u = usermapper.get_by(user_name = 'fred') + The criterion is constructed in the same way as the + ``select_by()`` method. """ ret = self.extension.get_by(self, *args, **params) @@ -131,6 +146,11 @@ class Query(object): mapper's immediate list of relations as well, forming the appropriate join conditions if a matching property is located. + if the located property is a column-based property, the comparison + value should be a scalar with an appropriate type. If the + property is a relationship-bound property, the comparison value + should be an instance of the related class. + E.g.:: result = usermapper.select_by(user_name = 'fred') @@ -145,61 +165,13 @@ class Query(object): """Return a ``ClauseElement`` representing the ``WHERE`` clause that would normally be sent to ``select_whereclause()`` by ``select_by()``. - """ - return self._join_by(args, params) - - def _join_by(self, args, params, start=None): - """Return a ``ClauseElement`` representing the ``WHERE`` - clause that would normally be sent to ``select_whereclause()`` - by ``select_by()``. + The criterion is constructed in the same way as the + ``select_by()`` method. """ - clause = None - for arg in args: - if clause is None: - clause = arg - else: - clause &= arg + return self._join_by(args, params) - for key, value in params.iteritems(): - (keys, prop) = self._locate_prop(key, start=start) - c = prop.compare(value) & self.join_via(keys) - if clause is None: - clause = c - else: - clause &= c - return clause - - def _locate_prop(self, key, start=None): - import properties - keys = [] - seen = util.Set() - def search_for_prop(mapper_): - if mapper_ in seen: - return None - seen.add(mapper_) - if mapper_.props.has_key(key): - prop = mapper_.props[key] - if isinstance(prop, SynonymProperty): - prop = mapper_.props[prop.name] - if isinstance(prop, properties.PropertyLoader): - keys.insert(0, prop.key) - return prop - else: - for prop in mapper_.props.values(): - if not isinstance(prop, properties.PropertyLoader): - continue - x = search_for_prop(prop.mapper) - if x: - keys.insert(0, prop.key) - return x - else: - return None - p = search_for_prop(start or self.mapper) - if p is None: - raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key) - return [keys, p] def join_to(self, key): """Given the key name of a property, will recursively descend @@ -236,6 +208,9 @@ class Query(object): """Like ``select_by()``, but only return the first result by itself, or None if no objects returned. Synonymous with ``get_by()``. + + The criterion is constructed in the same way as the + ``select_by()`` method. """ return self.get_by(*args, **params) @@ -243,6 +218,9 @@ class Query(object): def selectone_by(self, *args, **params): """Like ``selectfirst_by()``, but throws an error if not exactly one result was returned. + + The criterion is constructed in the same way as the + ``select_by()`` method. """ ret = self.select_whereclause(self.join_by(*args, **params), limit=2) @@ -326,15 +304,20 @@ class Query(object): """Given a ``WHERE`` criterion, create a ``SELECT COUNT`` statement, execute and return the resulting count value. """ + if self._criterion: + if whereclause is not None: + whereclause = sql.and_(self._criterion, whereclause) + else: + whereclause = self._criterion + from_obj = kwargs.pop('from_obj', self._from_obj) + kwargs.setdefault('distinct', self._distinct) - from_obj = kwargs.pop('from_obj', []) alltables = [] for l in [sql_util.TableFinder(x) for x in from_obj]: alltables += l - + if self.table not in alltables: from_obj.append(self.table) - if self._nestable(**kwargs): s = sql.select([self.table], whereclause, **kwargs).alias('getcount').count() else: @@ -361,14 +344,201 @@ class Query(object): """Return a new Query object, applying the given list of MapperOptions. """ - - return Query(self.mapper, self._session, with_options=args) + q = self._clone() + for opt in util.flatten_iterator(args): + q.with_options.append(opt) + opt.process_query(q) + for opt in util.flatten_iterator(self.with_options): + opt.process_query(self) + return q def with_lockmode(self, mode): """Return a new Query object with the specified locking mode.""" + q = self._clone() + q.lockmode = mode + return q + + def filter(self, criterion): + """apply the given filtering criterion to the query and return the newly resulting ``Query`` + + the criterion is any sql.ClauseElement applicable to the WHERE clause of a select. + """ + q = self._clone() + if q._criterion is not None: + q._criterion = q._criterion & criterion + else: + q._criterion = criterion + return q + + def filter_by(self, *args, **kwargs): + """apply the given filtering criterion to the query and return the newly resulting ``Query`` + + The criterion is constructed in the same way as the + ``select_by()`` method. + """ + return self.filter(self._join_by(args, kwargs, start=self._joinpoint)) + + def _join_to(self, prop, outerjoin=False): + if isinstance(prop, list): + mapper = self._joinpoint + keys = [] + for key in prop: + p = mapper.props[key] + keys.append(key) + mapper = p.mapper + else: + [keys,p] = self._locate_prop(prop, start=self._joinpoint) + clause = self._from_obj[-1] + mapper = self._joinpoint + for key in keys: + prop = mapper.props[key] + if outerjoin: + clause = clause.outerjoin(prop.select_table, prop.get_join(mapper)) + else: + clause = clause.join(prop.select_table, prop.get_join(mapper)) + mapper = prop.mapper + return (clause, mapper) + + def _join_by(self, args, params, start=None): + """Return a ``ClauseElement`` representing the ``WHERE`` + clause that would normally be sent to ``select_whereclause()`` + by ``select_by()``. - return Query(self.mapper, self._session, lockmode=mode) + The criterion is constructed in the same way as the + ``select_by()`` method. + """ + clause = None + for arg in args: + if clause is None: + clause = arg + else: + clause &= arg + + for key, value in params.iteritems(): + (keys, prop) = self._locate_prop(key, start=start) + c = prop.compare(value) & self.join_via(keys) + if clause is None: + clause = c + else: + clause &= c + return clause + + def _locate_prop(self, key, start=None): + import properties + keys = [] + seen = util.Set() + def search_for_prop(mapper_): + if mapper_ in seen: + return None + seen.add(mapper_) + if mapper_.props.has_key(key): + prop = mapper_.props[key] + if isinstance(prop, SynonymProperty): + prop = mapper_.props[prop.name] + if isinstance(prop, properties.PropertyLoader): + keys.insert(0, prop.key) + return prop + else: + for prop in mapper_.props.values(): + if not isinstance(prop, properties.PropertyLoader): + continue + x = search_for_prop(prop.mapper) + if x: + keys.insert(0, prop.key) + return x + else: + return None + p = search_for_prop(start or self.mapper) + if p is None: + raise exceptions.InvalidRequestError("Cant locate property named '%s'" % key) + return [keys, p] + + def _col_aggregate(self, col, func): + """Execute ``func()`` function against the given column. + + For performance, only use subselect if `order_by` attribute is set. + """ + + ops = {'distinct':self._distinct, 'order_by':self._order_by, 'from_obj':self._from_obj} + + if self._order_by is not False: + s1 = sql.select([col], self._criterion, **ops).alias('u') + return sql.select([func(s1.corresponding_column(col))]).scalar() + else: + return sql.select([func(col)], self._criterion, **ops).scalar() + + def min(self, col): + """Execute the SQL ``min()`` function against the given column.""" + + return self._col_aggregate(col, sql.func.min) + + def max(self, col): + """Execute the SQL ``max()`` function against the given column.""" + + return self._col_aggregate(col, sql.func.max) + + def sum(self, col): + """Execute the SQL ``sum``() function against the given column.""" + + return self._col_aggregate(col, sql.func.sum) + + def avg(self, col): + """Execute the SQL ``avg()`` function against the given column.""" + + return self._col_aggregate(col, sql.func.avg) + + def order_by(self, criterion): + """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" + + q = self._clone() + if q._order_by is False: + q._order_by = util.to_list(criterion) + else: + q._order_by.extend(util.to_list(criterion)) + return q + + def join(self, prop): + """create a join of this ``Query`` object's criterion + to a relationship and return the newly resulting ``Query``. + + 'prop' may be a string property name in which it is located + in the same manner as keyword arguments in ``select_by``, or + it may be a list of strings in which case the property is located + by direct traversal of each keyname (i.e. like join_via()). + """ + + q = self._clone() + (clause, mapper) = self._join_to(prop, outerjoin=False) + q._from_obj = [clause] + q._joinpoint = mapper + return q + + def outerjoin(self, prop): + """create a left outer join of this ``Query`` object's criterion + to a relationship and return the newly resulting ``Query``. + + 'prop' may be a string property name in which it is located + in the same manner as keyword arguments in ``select_by``, or + it may be a list of strings in which case the property is located + by direct traversal of each keyname (i.e. like join_via()). + """ + q = self._clone() + (clause, mapper) = self._join_to(prop, outerjoin=True) + q._from_obj = [clause] + q._joinpoint = mapper + return q + + def select_from(self, from_obj): + """Set the `from_obj` parameter of the query. + + `from_obj` is a list of one or more tables. + """ + + new = self._clone() + new._from_obj = from_obj + return new + def __getattr__(self, key): if (key.startswith('select_by_')): key = key[10:] @@ -383,6 +553,57 @@ class Query(object): else: raise AttributeError(key) + def __getitem__(self, item): + if isinstance(item, slice): + start = item.start + stop = item.stop + if (isinstance(start, int) and start < 0) or \ + (isinstance(stop, int) and stop < 0): + return list(self)[item] + else: + res = self._clone() + if start is not None and stop is not None: + res._offset = (self._offset or 0)+ start + res._limit = stop-start + elif start is None and stop is not None: + res._limit = stop + elif start is not None and stop is None: + res._offset = (self._offset or 0) + start + if item.step is not None: + return list(res)[None:None:item.step] + else: + return res + else: + return list(self[item:item+1])[0] + + def limit(self, limit): + """Apply a ``LIMIT`` to the query.""" + + return self[:limit] + + def offset(self, offset): + """Apply an ``OFFSET`` to the query.""" + + return self[offset:] + + def distinct(self): + """Apply a ``DISTINCT`` to the query.""" + + new = self._clone() + new._distinct = True + return new + + def list(self): + """Return the results represented by this ``Query`` as a list. + + This results in an execution of the underlying query. + """ + + return list(self) + + def __iter__(self): + return iter(self.select_whereclause()) + def execute(self, clauseelement, params=None, *args, **kwargs): """Execute the given ClauseElement-based statement against this Query's session/mapper, return the resulting list of @@ -400,9 +621,27 @@ class Query(object): finally: result.close() - def instances(self, cursor, *mappers, **kwargs): + def instances(self, cursor, *mappers_or_columns, **kwargs): """Return a list of mapped instances corresponding to the rows in a given *cursor* (i.e. ``ResultProxy``). + + *mappers_or_columns is an optional list containing one or more of + classes, mappers, strings or sql.ColumnElements which will be + applied to each row and added horizontally to the result set, + which becomes a list of tuples. The first element in each tuple + is the usual result based on the mapper represented by this + ``Query``. Each additional element in the tuple corresponds to an + entry in the *mappers_or_columns list. + + For each element in *mappers_or_columns, if the element is + a mapper or mapped class, an additional class instance will be + present in the tuple. If the element is a string or sql.ColumnElement, + the corresponding result column from each row will be present in the tuple. + + Note that when *mappers_or_columns is present, "uniquing" for the result set + is *disabled*, so that the resulting tuples contain entities as they actually + correspond. this indicates that multiple results may be present if this + option is used. """ self.__log_debug("instances()") @@ -411,25 +650,36 @@ class Query(object): context = SelectionContext(self.select_mapper, session, self.extension, with_options=self.with_options, **kwargs) - result = util.UniqueAppender([]) - if mappers: - otherresults = [] - for m in mappers: - otherresults.append(util.UniqueAppender([])) - + process = [] + if mappers_or_columns: + for m in mappers_or_columns: + if isinstance(m, type): + m = mapper.class_mapper(m) + if isinstance(m, mapper.Mapper): + appender = [] + def proc(context, row): + m._instance(context, row, appender) + process.append((proc, appender)) + elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring): + res = [] + def proc(context, row): + res.append(row[m]) + process.append((proc, res)) + result = [] + else: + result = util.UniqueAppender([]) + for row in cursor.fetchall(): self.select_mapper._instance(context, row, result) - i = 0 - for m in mappers: - m._instance(context, row, otherresults[i]) - i+=1 + for proc in process: + proc[0](context, row) # store new stuff in the identity map for value in context.identity_map.values(): session._register_persistent(value) - if mappers: - return [result.data] + [o.data for o in otherresults] + if mappers_or_columns: + return zip(*([result] + [o[1] for o in process])) else: return result.data @@ -491,11 +741,14 @@ class Query(object): statement suitable for usage in the execute() method. """ + if self._criterion: + whereclause = sql.and_(self._criterion, whereclause) + if whereclause is not None and self.is_polymorphic: # adapt the given WHERECLAUSE to adjust instances of this query's mapped table to be that of our select_table, # which may be the "polymorphic" selectable used by our mapper. whereclause.accept_visitor(sql_util.ClauseAdapter(self.table)) - + context = kwargs.pop('query_context', None) if context is None: context = QueryContext(self, kwargs) @@ -506,7 +759,7 @@ class Query(object): limit = context.limit offset = context.offset if order_by is False: - order_by = self.order_by + order_by = self.mapper.order_by if order_by is False: if self.table.default_order_by() is not None: order_by = self.table.default_order_by() @@ -579,12 +832,12 @@ class QueryContext(OperationContext): def __init__(self, query, kwargs): self.query = query - self.order_by = kwargs.pop('order_by', False) - self.from_obj = kwargs.pop('from_obj', []) + self.order_by = kwargs.pop('order_by', query._order_by) + self.from_obj = kwargs.pop('from_obj', query._from_obj) self.lockmode = kwargs.pop('lockmode', query.lockmode) - self.distinct = kwargs.pop('distinct', False) - self.limit = kwargs.pop('limit', None) - self.offset = kwargs.pop('offset', None) + self.distinct = kwargs.pop('distinct', query._distinct) + self.limit = kwargs.pop('limit', query._limit) + self.offset = kwargs.pop('offset', query._offset) self.eager_loaders = util.Set([x for x in query.mapper._eager_loaders]) self.statement = None super(QueryContext, self).__init__(query.mapper, query.with_options, **kwargs) diff --git a/test/ext/selectresults.py b/test/ext/selectresults.py index ca052ad3b2..88476c9cc0 100644 --- a/test/ext/selectresults.py +++ b/test/ext/selectresults.py @@ -12,14 +12,15 @@ class Foo(object): class SelectResultsTest(PersistTest): def setUpAll(self): self.install_threadlocal() - global foo - foo = Table('foo', testbase.db, + global foo, metadata + metadata = BoundMetaData(testbase.db) + foo = Table('foo', metadata, Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), Column('bar', Integer), Column('range', Integer)) assign_mapper(Foo, foo, extension=SelectResultsExt()) - foo.create() + metadata.create_all() for i in range(100): Foo(bar=i, range=i%10) objectstore.flush() @@ -30,8 +31,7 @@ class SelectResultsTest(PersistTest): self.res = self.query.select() def tearDownAll(self): - global foo - foo.drop() + metadata.drop_all() self.uninstall_threadlocal() clear_mappers() diff --git a/test/orm/alltests.py b/test/orm/alltests.py index c52902ac74..c30bc14064 100644 --- a/test/orm/alltests.py +++ b/test/orm/alltests.py @@ -5,6 +5,7 @@ def suite(): modules_to_test = ( 'orm.attributes', 'orm.mapper', + 'orm.generative', 'orm.lazytest1', 'orm.eagertest1', 'orm.eagertest2', diff --git a/test/orm/generative.py b/test/orm/generative.py new file mode 100644 index 0000000000..37ce1dcc9b --- /dev/null +++ b/test/orm/generative.py @@ -0,0 +1,229 @@ +from testbase import PersistTest, AssertMixin +import testbase +import tables + +from sqlalchemy import * + + +class Foo(object): + pass + +class GenerativeQueryTest(PersistTest): + def setUpAll(self): + self.install_threadlocal() + global foo, metadata + metadata = BoundMetaData(testbase.db) + foo = Table('foo', metadata, + Column('id', Integer, Sequence('foo_id_seq'), primary_key=True), + Column('bar', Integer), + Column('range', Integer)) + + assign_mapper(Foo, foo) + metadata.create_all() + for i in range(100): + Foo(bar=i, range=i%10) + objectstore.flush() + + def setUp(self): + self.query = Foo.query() + self.orig = self.query.select_whereclause() + self.res = self.query + + def tearDownAll(self): + metadata.drop_all() + self.uninstall_threadlocal() + clear_mappers() + + def test_selectby(self): + res = self.query.filter_by(range=5) + assert res.order_by([Foo.c.bar])[0].bar == 5 + assert res.order_by([desc(Foo.c.bar)])[0].bar == 95 + + def test_slice(self): + assert self.query[1] == self.orig[1] + assert list(self.query[10:20]) == self.orig[10:20] + assert list(self.query[10:]) == self.orig[10:] + assert list(self.query[:10]) == self.orig[:10] + assert list(self.query[:10]) == self.orig[:10] + assert list(self.query[10:40:3]) == self.orig[10:40:3] + assert list(self.query[-5:]) == self.orig[-5:] + assert self.query[10:20][5] == self.orig[10:20][5] + + def test_aggregate(self): + assert self.query.count() == 100 + assert self.query.filter(foo.c.bar<30).min(foo.c.bar) == 0 + assert self.query.filter(foo.c.bar<30).max(foo.c.bar) == 29 + + @testbase.unsupported('mysql') + def test_aggregate_1(self): + # this one fails in mysql as the result comes back as a string + assert self.query.filter(foo.c.bar<30).sum(foo.c.bar) == 435 + + @testbase.unsupported('postgres', 'mysql', 'firebird') + def test_aggregate_2(self): + # this one fails with postgres, the floating point comparison fails + assert self.query.filter(foo.c.bar<30).avg(foo.c.bar) == 14.5 + + def test_filter(self): + assert self.query.count() == 100 + assert self.query.filter(Foo.c.bar < 30).count() == 30 + res2 = self.query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10) + assert res2.count() == 19 + + def test_options(self): + class ext1(MapperExtension): + def populate_instance(self, mapper, selectcontext, row, instance, identitykey, isnew): + instance.TEST = "hello world" + return EXT_PASS + objectstore.clear() + assert self.res.options(extension(ext1()))[0].TEST == "hello world" + + def test_order_by(self): + assert self.res.order_by([Foo.c.bar])[0].bar == 0 + assert self.res.order_by([desc(Foo.c.bar)])[0].bar == 99 + + def test_offset(self): + assert list(self.res.order_by([Foo.c.bar]).offset(10))[0].bar == 10 + + def test_offset(self): + assert len(list(self.res.limit(10))) == 10 + +class Obj1(object): + pass +class Obj2(object): + pass + +class GenerativeTest2(PersistTest): + def setUpAll(self): + self.install_threadlocal() + global metadata, table1, table2 + metadata = BoundMetaData(testbase.db) + table1 = Table('Table1', metadata, + Column('id', Integer, primary_key=True), + ) + table2 = Table('Table2', metadata, + Column('t1id', Integer, ForeignKey("Table1.id"), primary_key=True), + Column('num', Integer, primary_key=True), + ) + assign_mapper(Obj1, table1) + assign_mapper(Obj2, table2) + metadata.create_all() + table1.insert().execute({'id':1},{'id':2},{'id':3},{'id':4}) + table2.insert().execute({'num':1,'t1id':1},{'num':2,'t1id':1},{'num':3,'t1id':1},\ +{'num':4,'t1id':2},{'num':5,'t1id':2},{'num':6,'t1id':3}) + + def setUp(self): + self.query = Query(Obj1) + #self.orig = self.query.select_whereclause() + #self.res = self.query.select() + + def tearDownAll(self): + metadata.drop_all() + self.uninstall_threadlocal() + clear_mappers() + + def test_distinctcount(self): + res = self.query + assert res.count() == 4 + res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)) + assert res.count() == 3 + res = self.query.filter(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1)).distinct() + self.assertEqual(res.count(), 1) + +class RelationsTest(AssertMixin): + def setUpAll(self): + tables.create() + tables.data() + def tearDownAll(self): + tables.drop() + def tearDown(self): + clear_mappers() + def test_jointo(self): + """test the join and outerjoin functions on Query""" + mapper(tables.User, tables.users, properties={ + 'orders':relation(mapper(tables.Order, tables.orders, properties={ + 'items':relation(mapper(tables.Item, tables.orderitems)) + })) + }) + session = create_session() + query = session.query(tables.User) + x = query.join('orders').join('items').filter(tables.Item.c.item_id==2) + print x.compile() + self.assert_result(list(x), tables.User, tables.user_result[2]) + def test_outerjointo(self): + """test the join and outerjoin functions on Query""" + mapper(tables.User, tables.users, properties={ + 'orders':relation(mapper(tables.Order, tables.orders, properties={ + 'items':relation(mapper(tables.Item, tables.orderitems)) + })) + }) + session = create_session() + query = session.query(tables.User) + x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) + print x.compile() + self.assert_result(list(x), tables.User, *tables.user_result[1:3]) + def test_outerjointo_count(self): + """test the join and outerjoin functions on Query""" + mapper(tables.User, tables.users, properties={ + 'orders':relation(mapper(tables.Order, tables.orders, properties={ + 'items':relation(mapper(tables.Item, tables.orderitems)) + })) + }) + session = create_session() + query = session.query(tables.User) + x = query.outerjoin('orders').outerjoin('items').filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count() + assert x==2 + def test_from(self): + mapper(tables.User, tables.users, properties={ + 'orders':relation(mapper(tables.Order, tables.orders, properties={ + 'items':relation(mapper(tables.Item, tables.orderitems)) + })) + }) + session = create_session() + query = session.query(tables.User) + x = query.select_from([tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)]).\ + filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)) + print x.compile() + self.assert_result(list(x), tables.User, *tables.user_result[1:3]) + + +class CaseSensitiveTest(PersistTest): + def setUpAll(self): + self.install_threadlocal() + global metadata, table1, table2 + metadata = BoundMetaData(testbase.db) + table1 = Table('Table1', metadata, + Column('ID', Integer, primary_key=True), + ) + table2 = Table('Table2', metadata, + Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True), + Column('NUM', Integer, primary_key=True), + ) + assign_mapper(Obj1, table1) + assign_mapper(Obj2, table2) + metadata.create_all() + table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4}) + table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\ +{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3}) + + def setUp(self): + self.query = Query(Obj1) + #self.orig = self.query.select_whereclause() + #self.res = self.query.select() + + def tearDownAll(self): + metadata.drop_all() + self.uninstall_threadlocal() + clear_mappers() + + def test_distinctcount(self): + res = self.query + assert res.count() == 4 + res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)) + assert res.count() == 3 + res = self.query.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct() + self.assertEqual(res.count(), 1) + + +if __name__ == "__main__": + testbase.main() diff --git a/test/orm/inheritance5.py b/test/orm/inheritance5.py index 640a8d70fe..0ae4260dce 100644 --- a/test/orm/inheritance5.py +++ b/test/orm/inheritance5.py @@ -1,6 +1,5 @@ from sqlalchemy import * import testbase -from sqlalchemy.ext.selectresults import SelectResults class AttrSettable(object): def __init__(self, **kwargs): @@ -374,8 +373,8 @@ class RelationTest4(testbase.ORMTest): assert str(car1.employee) == "Engineer E4, status X" session.clear() - s = SelectResults(session.query(Car)) - c = s.join_to("employee").select(employee_join.c.name=="E4")[0] + s = session.query(Car) + c = s.join("employee").select(employee_join.c.name=="E4")[0] assert c.car_id==car1.car_id class RelationTest5(testbase.ORMTest): @@ -580,7 +579,7 @@ class RelationTest7(testbase.ORMTest): for p in r: assert p.car_id == p.car.car_id -class SelectResultsTest(testbase.AssertMixin): +class GenerativeTest(testbase.AssertMixin): def setUpAll(self): # cars---owned by--- people (abstract) --- has a --- status # | ^ ^ | @@ -693,9 +692,9 @@ class SelectResultsTest(testbase.AssertMixin): # test these twice because theres caching involved for x in range(0, 2): - r = SelectResults(session.query(Person)).select_by(people.c.name.like('%2')).join_to('status').select_by(name="active") + r = session.query(Person).filter_by(people.c.name.like('%2')).join('status').filter_by(name="active") assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]" - r = SelectResults(session.query(Engineer)).join_to('status').select(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active")) + r = session.query(Engineer).join('status').filter(people.c.name.in_('E2', 'E3', 'E4', 'M4', 'M2', 'M1') & (status.c.name=="active")) assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]" class MultiLevelTest(testbase.ORMTest): diff --git a/test/orm/mapper.py b/test/orm/mapper.py index d8c23e103a..261fcc1163 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -296,16 +296,16 @@ class MapperTest(MapperSuperTest): sess = create_session() q = sess.query(m) - l = q.select((orderitems.c.item_name=='item 4') & q.join_via(['orders', 'items'])) + l = q.filter(orderitems.c.item_name=='item 4').join(['orders', 'items']).list() self.assert_result(l, User, user_result[0]) l = q.select_by(item_name='item 4') self.assert_result(l, User, user_result[0]) - l = q.select((orderitems.c.item_name=='item 4') & q.join_to('item_name')) + l = q.filter(orderitems.c.item_name=='item 4').join('item_name').list() self.assert_result(l, User, user_result[0]) - l = q.select((orderitems.c.item_name=='item 4') & q.join_to('items')) + l = q.filter(orderitems.c.item_name=='item 4').join('items').list() self.assert_result(l, User, user_result[0]) # test comparing to an object instance @@ -587,15 +587,29 @@ class MapperTest(MapperSuperTest): }) sess = create_session() - q2 = sess.query(User).options(eagerload('orders.items.keywords')) + + # eagerload nothing. u = sess.query(User).select() def go(): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(db, go, 3) sess.clear() + + print "-------MARK----------" + # eagerload orders, orders.items, orders.items.keywords + q2 = sess.query(User).options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords')) u = q2.select() print "-------MARK2----------" + self.assert_sql_count(db, go, 0) + + sess.clear() + + # eagerload "keywords" on items. it will lazy load "orders", then lazy load + # the "items" on the order, but on "items" it will eager load the "keywords" + print "-------MARK3----------" + q3 = sess.query(User).options(eagerload('orders.items.keywords')) + u = q3.select() self.assert_sql_count(db, go, 2) class InheritanceTest(MapperSuperTest): @@ -867,7 +881,7 @@ class LazyTest(MapperSuperTest): addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True) ), extension=ctx.mapper_extension) q = ctx.current.query(m) - u = q.selectfirst(users.c.user_id == 7) + u = q.filter(users.c.user_id == 7).selectfirst() ctx.current.expunge(u) self.assert_result([u], User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, @@ -1067,85 +1081,6 @@ class EagerTest(MapperSuperTest): {'user_id' : 9, 'addresses' : (Address, [])} ) - def testcustomfromalias(self): - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=True) - }) - mapper(Address, addresses) - query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True) - q = create_session().query(User) - - def go(): - l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute()) - self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) - - def testcustomeagerquery(self): - mapper(User, users, properties={ - # setting lazy=True - the contains_eager() option below - # should imply eagerload() - 'addresses':relation(Address, lazy=True) - }) - mapper(Address, addresses) - - selectquery = users.outerjoin(addresses).select(use_labels=True) - q = create_session().query(User) - - def go(): - l = q.options(contains_eager('addresses')).instances(selectquery.execute()) - self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) - - def testcustomeagerwithstringalias(self): - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False) - }) - mapper(Address, addresses) - - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) - q = create_session().query(User) - - def go(): - l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute()) - self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) - - def testcustomeagerwithalias(self): - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False) - }) - mapper(Address, addresses) - - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) - q = create_session().query(User) - - def go(): - l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute()) - self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) - - def testcustomeagerwithdecorator(self): - mapper(User, users, properties={ - 'addresses':relation(Address, lazy=False) - }) - mapper(Address, addresses) - - adalias = addresses.alias('adalias') - selectquery = users.outerjoin(adalias).select(use_labels=True) - def decorate(row): - d = {} - for c in addresses.columns: - d[c] = row[adalias.corresponding_column(c)] - return d - - q = create_session().query(User) - - def go(): - l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute()) - self.assert_result(l, User, *user_address_result) - self.assert_sql_count(testbase.db, go, 1) def testorderby_desc(self): m = mapper(Address, addresses) @@ -1441,7 +1376,108 @@ class EagerTest(MapperSuperTest): {'item_id':5, 'item_name':'item 5'} ])}, ) + +class InstancesTest(MapperSuperTest): + def testcustomfromalias(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=True) + }) + mapper(Address, addresses) + query = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist').outerjoin(addresses).select(use_labels=True) + q = create_session().query(User) + + def go(): + l = q.options(contains_alias('ulist'), contains_eager('addresses')).instances(query.execute()) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(testbase.db, go, 1) + + def testcustomeagerquery(self): + mapper(User, users, properties={ + # setting lazy=True - the contains_eager() option below + # should imply eagerload() + 'addresses':relation(Address, lazy=True) + }) + mapper(Address, addresses) + + selectquery = users.outerjoin(addresses).select(use_labels=True) + q = create_session().query(User) + + def go(): + l = q.options(contains_eager('addresses')).instances(selectquery.execute()) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(testbase.db, go, 1) + + def testcustomeagerwithstringalias(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False) + }) + mapper(Address, addresses) + + adalias = addresses.alias('adalias') + selectquery = users.outerjoin(adalias).select(use_labels=True) + q = create_session().query(User) + + def go(): + l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute()) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(testbase.db, go, 1) + + def testcustomeagerwithalias(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False) + }) + mapper(Address, addresses) + + adalias = addresses.alias('adalias') + selectquery = users.outerjoin(adalias).select(use_labels=True) + q = create_session().query(User) + + def go(): + l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute()) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(testbase.db, go, 1) + + def testcustomeagerwithdecorator(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False) + }) + mapper(Address, addresses) + + adalias = addresses.alias('adalias') + selectquery = users.outerjoin(adalias).select(use_labels=True) + def decorate(row): + d = {} + for c in addresses.columns: + d[c] = row[adalias.corresponding_column(c)] + return d + + q = create_session().query(User) + + def go(): + l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute()) + self.assert_result(l, User, *user_address_result) + self.assert_sql_count(testbase.db, go, 1) + + def testmultiplemappers(self): + mapper(User, users, properties={ + 'addresses':relation(Address, lazy=False) + }) + mapper(Address, addresses) + + selectquery = users.outerjoin(addresses).select(use_labels=True) + q = create_session().query(User) + l = q.instances(selectquery.execute(), Address) + # note the result is a cartesian product + assert repr(l) == "[(User(user_id=7,user_name=u'jack'), Address(address_id=1,user_id=7,email_address=u'jack@bean.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=2,user_id=8,email_address=u'ed@wood.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=3,user_id=8,email_address=u'ed@bettyboop.com')), (User(user_id=8,user_name=u'ed'), Address(address_id=4,user_id=8,email_address=u'ed@lala.com'))]" + # check identity map still in effect even though dupe results + assert l[1][0] is l[2][0] + def testmapperspluscolumn(self): + mapper(User, users) + s = select([users, func.count(addresses.c.address_id).label('count')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c]) + q = create_session().query(User) + l = q.instances(s.execute(), "count") + assert repr(l) == "[(User(user_id=7,user_name=u'jack'), 1), (User(user_id=8,user_name=u'ed'), 3), (User(user_id=9,user_name=u'fred'), 0)]" if __name__ == "__main__": testbase.main() -- 2.47.2