From: Mike Bayer Date: Fri, 31 Mar 2006 02:25:59 +0000 (+0000) Subject: Jonas Borgström's fantastic SelectRsults patch that adds dynamic list argument suppor... X-Git-Tag: rel_0_1_6~50 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1a5e65c14f11ea2d88e2a00cea6cbd82f371e385;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Jonas Borgström's fantastic SelectRsults patch that adds dynamic list argument support to the mapper.select() methd. associated unit test tweaks and mapper integration. --- diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index a0de86df4f..f8faea8555 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -9,6 +9,7 @@ import sqlalchemy.sql as sql import sqlalchemy.schema as schema import sqlalchemy.engine as engine import sqlalchemy.util as util +import util as mapperutil import sync from sqlalchemy.exceptions import * import objectstore @@ -419,7 +420,7 @@ class Mapper(object): e.g. result = usermapper.select_by(user_name = 'fred') """ - return self.select_whereclause(self._by_clause(*args, **params)) + return mapperutil.SelectResults(self, self._by_clause(*args, **params)) def selectfirst_by(self, *args, **params): """works like select_by(), but only returns the first result by itself, or None if no @@ -428,7 +429,7 @@ class Mapper(object): def selectone_by(self, *args, **params): """works like selectfirst_by(), but throws an error if not exactly one result was returned.""" - ret = self.select_by(*args, **params) + ret = list(self.select_by(*args, **params)[0:2]) if len(ret) == 1: return ret[0] raise InvalidRequestError('Multiple rows returned for selectone_by') @@ -491,7 +492,7 @@ class Mapper(object): """works like select(), but only returns the first result by itself, or None if no objects returned.""" params['limit'] = 1 - ret = self.select(*args, **params) + ret = self.select_whereclause(*args, **params) if ret: return ret[0] else: @@ -499,7 +500,7 @@ class Mapper(object): def selectone(self, *args, **params): """works like selectfirst(), but throws an error if not exactly one result was returned.""" - ret = self.select(*args, **params) + ret = list(self.select(*args, **params)[0:2]) if len(ret) == 1: return ret[0] raise InvalidRequestError('Multiple rows returned for selectone') @@ -517,7 +518,7 @@ class Mapper(object): if arg is not None and isinstance(arg, sql.Selectable): return self.select_statement(arg, **kwargs) else: - return self.select_whereclause(arg, **kwargs) + return mapperutil.SelectResults(self, arg, ops=kwargs) def select_whereclause(self, whereclause=None, params=None, **kwargs): statement = self._compile(whereclause, **kwargs) diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index b18d4d54fe..206905aa42 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -616,7 +616,7 @@ class LazyLoader(PropertyLoader): order_by = self.secondary.default_order_by() else: order_by = False - result = self.mapper.select(self.lazywhere, order_by=order_by, params=params) + result = list(self.mapper.select(self.lazywhere, order_by=order_by, params=params)) else: result = [] if self.uselist: diff --git a/lib/sqlalchemy/mapping/util.py b/lib/sqlalchemy/mapping/util.py new file mode 100644 index 0000000000..74b6c7557f --- /dev/null +++ b/lib/sqlalchemy/mapping/util.py @@ -0,0 +1,73 @@ +from sqlalchemy.sql import and_, select, func + +class SelectResults(object): + def __init__(self, mapper, clause=None, ops={}): + self._mapper = mapper + self._clause = clause + self._ops = {} + self._ops.update(ops) + + def count(self): + return self._mapper.count(self._clause) + + def __len__(self): + return self.count() + + def min(self, col): + return select([func.min(col)], self._clause, **self._ops).scalar() + + def max(self, col): + return select([func.max(col)], self._clause, **self._ops).scalar() + + def sum(self, col): + return select([func.sum(col)], self._clause, **self._ops).scalar() + + def avg(self, col): + return select([func.avg(col)], self._clause, **self._ops).scalar() + + def clone(self): + return SelectResults(self._mapper, self._clause, self._ops.copy()) + + def filter(self, clause): + new = self.clone() + new._clause = and_(self._clause, clause) + return new + + def order_by(self, order_by): + new = self.clone() + new._ops['order_by'] = order_by + return new + + def limit(self, limit): + return self[:limit] + + def offset(self, offset): + return self[offset:] + + def list(self): + return list(self) + + def __getitem__(self, item): + if isinstance(item, slice): + start = item.start + stop = item.stop + if (isinstance(start, int) and start < 0) or \ + (isinstance(stop, int) and stop < 0): + return list(self)[item] + else: + res = self.clone() + if start is not None and stop is not None: + res._ops.update(dict(offset=start, limit=stop-start)) + elif start is None and stop is not None: + res._ops.update(dict(limit=stop)) + elif start is not None and stop is None: + res._ops.update(dict(offset=start)) + if item.step is not None: + return list(res)[None:None:item.step] + else: + return res + else: + return list(self[item:item+1])[0] + + def __iter__(self): + return iter(self._mapper.select_whereclause(self._clause, **self._ops)) \ No newline at end of file diff --git a/test/alltests.py b/test/alltests.py index d30f97287c..b266ebcb13 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -33,6 +33,7 @@ def suite(): # ORM selecting 'mapper', + 'selectresults', 'eagertest1', 'eagertest2', diff --git a/test/mapper.py b/test/mapper.py index 4a8edd0974..6bfc3f3b89 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -133,7 +133,7 @@ class MapperTest(MapperSuperTest): # object isnt refreshed yet, using dict to bypass trigger self.assert_(u.__dict__['user_name'] != 'jack') # do a select - m.select() + m.select().list() # test that it refreshed self.assert_(u.__dict__['user_name'] == 'jack') @@ -255,7 +255,7 @@ class MapperTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) )) - l = m.options(eagerload('addresses')).select() + l = m.options(eagerload('addresses')).select().list() def go(): self.assert_result(l, User, *user_address_result) @@ -266,7 +266,7 @@ class MapperTest(MapperSuperTest): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) )) - l = m.options(lazyload('addresses')).select() + l = m.options(lazyload('addresses')).select().list() def go(): self.assert_result(l, User, *user_address_result) self.assert_sql_count(db, go, 3) @@ -282,12 +282,12 @@ class MapperTest(MapperSuperTest): }) m2 = m.options(eagerload('orders.items.keywords')) - u = m.select() + u = m.select().list() def go(): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(db, go, 3) objectstore.clear() - u = m2.select() + u = m2.select().list() self.assert_sql_count(db, go, 2) class PropertyTest(MapperSuperTest): @@ -368,7 +368,7 @@ class DeferredTest(MapperSuperTest): self.assert_(o.description is None) def go(): - l = m.select() + l = m.select().list() o2 = l[2] print o2.description @@ -397,7 +397,7 @@ class DeferredTest(MapperSuperTest): }) def go(): - l = m.select() + l = m.select().list() o2 = l[2] print o2.opened, o2.description, o2.userident self.assert_sql(db, go, [ @@ -410,7 +410,7 @@ class DeferredTest(MapperSuperTest): m = mapper(Order, orders) m2 = m.options(defer('user_id')) def go(): - l = m2.select() + l = m2.select().list() print l[2].user_id self.assert_sql(db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), @@ -419,7 +419,7 @@ class DeferredTest(MapperSuperTest): objectstore.clear() m3 = m2.options(undefer('user_id')) def go(): - l = m3.select() + l = m3.select().list() print l[3].user_id self.assert_sql(db, go, [ ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}), diff --git a/test/proxy_engine.py b/test/proxy_engine.py index 170e526d96..cd01272b5c 100644 --- a/test/proxy_engine.py +++ b/test/proxy_engine.py @@ -96,7 +96,7 @@ class ThreadProxyTest(PersistTest): try: trans = objectstore.begin() - all = User.select()[:] + all = User.select()[:].list() assert all == [] u = User() diff --git a/test/testbase.py b/test/testbase.py index f3dfac15bc..1578be5a0c 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -74,6 +74,7 @@ class AssertMixin(PersistTest): """given a list-based structure of keys/properties which represent information within an object structure, and a list of actual objects, asserts that the list of objects corresponds to the structure.""" def assert_result(self, result, class_, *objects): + result = list(result) if echo: print repr(result) self.assert_list(result, class_, objects)