]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Jonas Borgström's fantastic SelectRsults patch that adds dynamic list argument suppor...
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Mar 2006 02:25:59 +0000 (02:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 31 Mar 2006 02:25:59 +0000 (02:25 +0000)
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/mapping/util.py [new file with mode: 0644]
test/alltests.py
test/mapper.py
test/proxy_engine.py
test/testbase.py

index a0de86df4f8b12b03033afee0dc217d5a37e4424..f8faea8555880f84dc91313acfa8fb0d48b14ea9 100644 (file)
@@ -9,6 +9,7 @@ import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
 import sqlalchemy.engine as engine
 import sqlalchemy.util as util
+import util as mapperutil
 import sync
 from sqlalchemy.exceptions import *
 import objectstore
@@ -419,7 +420,7 @@ class Mapper(object):
         
         e.g.   result = usermapper.select_by(user_name = 'fred')
         """
-        return self.select_whereclause(self._by_clause(*args, **params))
+        return mapperutil.SelectResults(self, self._by_clause(*args, **params))
     
     def selectfirst_by(self, *args, **params):
         """works like select_by(), but only returns the first result by itself, or None if no 
@@ -428,7 +429,7 @@ class Mapper(object):
 
     def selectone_by(self, *args, **params):
         """works like selectfirst_by(), but throws an error if not exactly one result was returned."""
-        ret = self.select_by(*args, **params)
+        ret = list(self.select_by(*args, **params)[0:2])
         if len(ret) == 1:
             return ret[0]
         raise InvalidRequestError('Multiple rows returned for selectone_by')
@@ -491,7 +492,7 @@ class Mapper(object):
         """works like select(), but only returns the first result by itself, or None if no 
         objects returned."""
         params['limit'] = 1
-        ret = self.select(*args, **params)
+        ret = self.select_whereclause(*args, **params)
         if ret:
             return ret[0]
         else:
@@ -499,7 +500,7 @@ class Mapper(object):
             
     def selectone(self, *args, **params):
         """works like selectfirst(), but throws an error if not exactly one result was returned."""
-        ret = self.select(*args, **params)
+        ret = list(self.select(*args, **params)[0:2])
         if len(ret) == 1:
             return ret[0]
         raise InvalidRequestError('Multiple rows returned for selectone')
@@ -517,7 +518,7 @@ class Mapper(object):
         if arg is not None and isinstance(arg, sql.Selectable):
             return self.select_statement(arg, **kwargs)
         else:
-            return self.select_whereclause(arg, **kwargs)
+            return mapperutil.SelectResults(self, arg, ops=kwargs)
 
     def select_whereclause(self, whereclause=None, params=None, **kwargs):
         statement = self._compile(whereclause, **kwargs)
index b18d4d54fe6a8d02ceac8792164fec93a02b9a4c..206905aa42610b0632f89cc082f0e37aed0c5d7c 100644 (file)
@@ -616,7 +616,7 @@ class LazyLoader(PropertyLoader):
                     order_by = self.secondary.default_order_by()
                 else:
                     order_by = False
-                result = self.mapper.select(self.lazywhere, order_by=order_by, params=params)
+                result = list(self.mapper.select(self.lazywhere, order_by=order_by, params=params))
             else:
                 result = []
             if self.uselist:
diff --git a/lib/sqlalchemy/mapping/util.py b/lib/sqlalchemy/mapping/util.py
new file mode 100644 (file)
index 0000000..74b6c75
--- /dev/null
@@ -0,0 +1,73 @@
+from sqlalchemy.sql import and_, select, func
+
+class SelectResults(object):
+    def __init__(self, mapper, clause=None, ops={}):
+        self._mapper = mapper
+        self._clause = clause
+        self._ops = {}
+        self._ops.update(ops)
+
+    def count(self):
+        return self._mapper.count(self._clause)
+    
+    def __len__(self):
+        return self.count()
+        
+    def min(self, col):
+        return select([func.min(col)], self._clause, **self._ops).scalar()
+
+    def max(self, col):
+        return select([func.max(col)], self._clause, **self._ops).scalar()
+
+    def sum(self, col):
+        return select([func.sum(col)], self._clause, **self._ops).scalar()
+
+    def avg(self, col):
+        return select([func.avg(col)], self._clause, **self._ops).scalar()
+
+    def clone(self):
+        return SelectResults(self._mapper, self._clause, self._ops.copy())
+        
+    def filter(self, clause):
+        new = self.clone()
+        new._clause = and_(self._clause, clause)
+        return new
+
+    def order_by(self, order_by):
+        new = self.clone()
+        new._ops['order_by'] = order_by
+        return new
+
+    def limit(self, limit):
+        return self[:limit]
+
+    def offset(self, offset):
+        return self[offset:]
+
+    def list(self):
+        return list(self)
+        
+    def __getitem__(self, item):
+        if isinstance(item, slice):
+            start = item.start
+            stop = item.stop
+            if (isinstance(start, int) and start < 0) or \
+               (isinstance(stop, int) and stop < 0):
+                return list(self)[item]
+            else:
+                res = self.clone()
+                if start is not None and stop is not None:
+                    res._ops.update(dict(offset=start, limit=stop-start))
+                elif start is None and stop is not None:
+                    res._ops.update(dict(limit=stop))
+                elif start is not None and stop is None:
+                    res._ops.update(dict(offset=start))
+                if item.step is not None:
+                    return list(res)[None:None:item.step]
+                else:
+                    return res
+        else:
+            return list(self[item:item+1])[0]
+    
+    def __iter__(self):
+        return iter(self._mapper.select_whereclause(self._clause, **self._ops))
\ No newline at end of file
index d30f97287cd6697fa675e586a1749e555b560e33..b266ebcb138383b3fa7eadf3f0d78207dd0a1eb7 100644 (file)
@@ -33,6 +33,7 @@ def suite():
         
         # ORM selecting
         'mapper',
+        'selectresults',
         'eagertest1',
         'eagertest2',
         
index 4a8edd09747aec3fb7559940d65269b1c760768e..6bfc3f3b894742b36fec01e41368b7a9f47dc97c 100644 (file)
@@ -133,7 +133,7 @@ class MapperTest(MapperSuperTest):
         # object isnt refreshed yet, using dict to bypass trigger
         self.assert_(u.__dict__['user_name'] != 'jack')
         # do a select
-        m.select()
+        m.select().list()
         # test that it refreshed
         self.assert_(u.__dict__['user_name'] == 'jack')
         
@@ -255,7 +255,7 @@ class MapperTest(MapperSuperTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True)
         ))
-        l = m.options(eagerload('addresses')).select()
+        l = m.options(eagerload('addresses')).select().list()
 
         def go():
             self.assert_result(l, User, *user_address_result)
@@ -266,7 +266,7 @@ class MapperTest(MapperSuperTest):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = False)
         ))
-        l = m.options(lazyload('addresses')).select()
+        l = m.options(lazyload('addresses')).select().list()
         def go():
             self.assert_result(l, User, *user_address_result)
         self.assert_sql_count(db, go, 3)
@@ -282,12 +282,12 @@ class MapperTest(MapperSuperTest):
             })
             
         m2 = m.options(eagerload('orders.items.keywords'))
-        u = m.select()
+        u = m.select().list()
         def go():
             print u[0].orders[1].items[0].keywords[1]
         self.assert_sql_count(db, go, 3)
         objectstore.clear()
-        u = m2.select()
+        u = m2.select().list()
         self.assert_sql_count(db, go, 2)
         
 class PropertyTest(MapperSuperTest):
@@ -368,7 +368,7 @@ class DeferredTest(MapperSuperTest):
         self.assert_(o.description is None)
         
         def go():
-            l = m.select()
+            l = m.select().list()
             o2 = l[2]
             print o2.description
 
@@ -397,7 +397,7 @@ class DeferredTest(MapperSuperTest):
         })
 
         def go():
-            l = m.select()
+            l = m.select().list()
             o2 = l[2]
             print o2.opened, o2.description, o2.userident
         self.assert_sql(db, go, [
@@ -410,7 +410,7 @@ class DeferredTest(MapperSuperTest):
         m = mapper(Order, orders)
         m2 = m.options(defer('user_id'))
         def go():
-            l = m2.select()
+            l = m2.select().list()
             print l[2].user_id
         self.assert_sql(db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
@@ -419,7 +419,7 @@ class DeferredTest(MapperSuperTest):
         objectstore.clear()
         m3 = m2.options(undefer('user_id'))
         def go():
-            l = m3.select()
+            l = m3.select().list()
             print l[3].user_id
         self.assert_sql(db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY orders.%s" % orders.default_order_by()[0].key, {}),
index 170e526d96d7af107209f4f69a69fa51410422c9..cd01272b5c35c73595496bb641b0a614859fc642 100644 (file)
@@ -96,7 +96,7 @@ class ThreadProxyTest(PersistTest):
                     try:
                         trans  = objectstore.begin()
 
-                        all = User.select()[:]
+                        all = User.select()[:].list()
                         assert all == []
 
                         u = User()
index f3dfac15bca772572a5d8484e913a33189e40c45..1578be5a0c7505fa35340d5f20e0d27e75671794 100644 (file)
@@ -74,6 +74,7 @@ class AssertMixin(PersistTest):
     """given a list-based structure of keys/properties which represent information within an object structure, and
     a list of actual objects, asserts that the list of objects corresponds to the structure."""
     def assert_result(self, result, class_, *objects):
+        result = list(result)
         if echo:
             print repr(result)
         self.assert_list(result, class_, objects)