]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added rudimentary support for limit and offset (with the hack version in oracle)
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 7 Dec 2005 01:37:55 +0000 (01:37 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 7 Dec 2005 01:37:55 +0000 (01:37 +0000)
fixed up order_by to support a list/scalar of columns or asc/desc
fixed up query.py unit test

12 files changed:
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/dependency.py
test/mapper.py
test/query.py
test/tables.py

index 691106a710ae9f12465b4195b096ce18ae318931..79215dec7a40526e2e4b211b9f8d3d23b0e03316 100644 (file)
@@ -269,7 +269,12 @@ class ANSICompiler(sql.Compiled):
             t = self.get_str(select.having)
             if t:
                 text += " \nHAVING " + t
-                
+
+        if select.limit is not None or select.offset is not None:
+            # TODO: ok, so this is a simple limit/offset thing.
+            # need to make this DB neutral for mysql, oracle
+            text += self.limit_clause(select)
+            
         if getattr(select, 'issubquery', False):
             self.strings[select] = "(" + text + ")"
         else:
@@ -277,6 +282,14 @@ class ANSICompiler(sql.Compiled):
 
         self.froms[select] = "(" + text + ")"
 
+    def limit_clause(self, select):
+        if select.limit is not None:
+            return  " \n LIMIT " + str(select.limit)
+        if select.offset is not None:
+            if select.limit is None:
+                return " \n LIMIT -1"
+            return " OFFSET " + str(select.offset)
+
     def visit_table(self, table):
         self.froms[table] = table.fullname
         self.strings[table] = ""
index 9e3f677d119f38cea30f8889311b6e43c30966fe..1eef6facbae5d76296e66ffb201413323f68d64c 100644 (file)
@@ -175,7 +175,16 @@ class MySQLTableImpl(sql.TableImpl):
     rowid_column = property(lambda s: s._rowid_col())
 
 class MySQLCompiler(ansisql.ANSICompiler):
-    pass
+    def limit_clause(self, select):
+        text = ""
+        if select.limit is not None:
+            text +=  " \n LIMIT " + str(select.limit)
+        if select.offset is not None:
+            if select.limit is None:
+                # striaght from the MySQL docs, I kid you not
+                text += " \n LIMIT 18446744073709551615"
+            text += " OFFSET " + str(select.offset)
+        return text
         
 class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, first_pk=False):
index 58d533b1080a119dc56364cac9067585d6f53528..93670729df59d909f1f566769872cfa40a839a66 100644 (file)
@@ -195,6 +195,28 @@ class OracleCompiler(ansisql.ANSICompiler):
                 self.bindparams[c.key] = None
         return ansisql.ANSICompiler.visit_insert(self, insert)
 
+    def visit_select(self, select):
+        """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a 
+        subquery with rownum criterion."""
+        if getattr(select, '_oracle_visit', False):
+            ansisql.ANSICompiler.visit_select(self, select)
+            return
+        if select.limit is not None or select.offset is not None:
+            select._oracle_visit = True
+            limitselect = select.select()
+            if select.limit is not None:
+                limitselect.append_whereclause("rownum<%d" % select.limit)
+            if select.offset is not None:
+                limitselect.append_whereclause("rownum>%d" % select.offset)
+            limitselect.accept_visitor(self)
+            self.strings[select] = self.strings[limitselect]
+            self.froms[select] = self.froms[limitselect]
+        else:
+            ansisql.ANSICompiler.visit_select(self, select)
+            
+    def limit_clause(self, select):
+        return ""
+
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name
index 5a3d6388a3f2f3dc546e52308b2d038eb78f51d5..531e8af03cde42853b8aa14c9b5a8eebbb00fb1f 100644 (file)
@@ -240,6 +240,16 @@ class PGCompiler(ansisql.ANSICompiler):
             if c.sequence is not None and not c.sequence.optional:
                 self.bindparams[c.key] = None
         return ansisql.ANSICompiler.visit_insert(self, insert)
+
+    def limit_clause(self, select):
+        text = ""
+        if select.limit is not None:
+            text +=  " \n LIMIT " + str(select.limit)
+        if select.offset is not None:
+            if select.limit is None:
+                text += " \n LIMIT ALL"
+            text += " OFFSET " + str(select.offset)
+        return text
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
index 23a9d048300c813cff0a473afae06df2a9de5b76..fc85d92acac2518095b4ba6cc25f729931b56dc3 100644 (file)
@@ -188,7 +188,17 @@ class SQLiteCompiler(ansisql.ANSICompiler):
     def __init__(self, *args, **params):
         params.setdefault('paramstyle', 'named')
         ansisql.ANSICompiler.__init__(self, *args, **params)
+    def limit_clause(self, select):
+        text = ""
+        if select.limit is not None:
+            text +=  " \n LIMIT " + str(select.limit)
+        if select.offset is not None:
+            if select.limit is None:
+                text += " \n LIMIT -1"
+            text += " OFFSET " + str(select.offset)
+        return text
 
+        
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, override_pk=False, **kwargs):
         colspec = column.name + " " + column.type.get_col_spec()
index d1fd2315ed787235799aa91f62fd487a1e66815f..3429555d3b6961cf9582f7827827c33ce77f8538 100644 (file)
@@ -541,8 +541,8 @@ class EagerLoader(PropertyLoader):
     
         # if this eagermapper is to select using an "alias" to isolate it from other
         # eager mappers against the same table, we have to redefine our secondary
-        # or primary join condition to reference the aliased table.  else
-        # we set up the target clause objects as what they are defined in the 
+        # or primary join condition to reference the aliased table (and the order_by).  
+        # else we set up the target clause objects as what they are defined in the 
         # superclass.
         if self.selectalias is not None:
             self.eagertarget = self.target.alias(self.selectalias)
@@ -554,11 +554,21 @@ class EagerLoader(PropertyLoader):
             else:
                 self.eagerprimary = self.primaryjoin.copy_container()
                 self.eagerprimary.accept_visitor(aliasizer)
+            if self.order_by is not None:
+                self.eager_order_by = [o.copy_container() for o in self.order_by]
+                for i in range(0, len(self.eager_order_by)):
+                    if isinstance(self.eager_order_by[i], schema.Column):
+                        self.eager_order_by[i] = self.eagertarget._get_col_by_original(self.eager_order_by[i])
+                    else:
+                        self.eager_order_by[i].accept_visitor(aliasizer)
+            else:
+                self.eager_order_by = None
         else:
             self.eagertarget = self.target
             self.eagerprimary = self.primaryjoin
             self.eagersecondary = self.secondaryjoin
-
+            self.eager_order_by = self.order_by
+            
     def setup(self, key, statement, recursion_stack = None, **options):
         """add a left outer join to the statement thats being constructed"""
 
@@ -588,8 +598,8 @@ class EagerLoader(PropertyLoader):
             if self.order_by is None:
                 statement.order_by(self.eagertarget.rowid_column)
 
-        if self.order_by is not None:
-            statement.order_by(*[self.eagertarget._get_col_by_original(c) for c in self.order_by])
+        if self.eager_order_by is not None:
+            statement.order_by(*self.eager_order_by)
             
         statement.append_from(statement._outerjoin)
         statement.append_column(self.eagertarget)
@@ -691,12 +701,18 @@ class Aliasizer(sql.ClauseVisitor):
             aliasname = table.name + "_" + hex(random.randint(0, 65535))[2:]
             return self.aliases.setdefault(table, sql.alias(table, aliasname))
 
+    def visit_compound(self, compound):
+        for i in range(0, len(compound.clauses)):
+            if isinstance(compound.clauses[i], schema.Column) and self.tables.has_key(compound.clauses[i].table):
+                compound.clauses[i] = self.get_alias(compound.clauses[i].table)._get_col_by_original(compound.clauses[i])
+                self.match = True
+
     def visit_binary(self, binary):
         if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table):
-            binary.left = self.get_alias(binary.left.table).c[binary.left.name]
+            binary.left = self.get_alias(binary.left.table)._get_col_by_original(binary.left)
             self.match = True
         if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table):
-            binary.right = self.get_alias(binary.right.table).c[binary.right.name]
+            binary.right = self.get_alias(binary.right.table)._get_col_by_original(binary.right)
             self.match = True
 
 class BinaryVisitor(sql.ClauseVisitor):
index 71cbacdac0f0bb78294f8c63a080f22c2c608e07..10e8f3e74c07392e9c7c609fa67554551f0ed437 100644 (file)
@@ -552,7 +552,7 @@ class CompoundClause(ClauseList):
             f += c._get_from_objects()
         return f
     def hash_key(self):
-        return string.join([c.hash_key() for c in self.clauses], self.operator)
+        return string.join([c.hash_key() for c in self.clauses], self.operator or " ")
 
 class Function(ClauseList, CompareMixin):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
@@ -948,7 +948,7 @@ class CompoundSelect(Selectable, TailClauseMixin):
 class Select(Selectable, TailClauseMixin):
     """finally, represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
-    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None):
+    def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None):
         self._columns = util.OrderedProperties()
         self._froms = util.OrderedDict()
         self.use_labels = use_labels
@@ -958,6 +958,8 @@ class Select(Selectable, TailClauseMixin):
         self.having = None
         self._engine = engine
         self.rowid_column = None
+        self.limit = limit
+        self.offset = offset
         
         # indicates if this select statement is a subquery inside another query
         self.issubquery = False
index c247c86edfbd46060f52b55fb8fc2c7230f9279a..d280368f69101ad3d028d181d44d29c0237a36bc 100644 (file)
@@ -23,6 +23,8 @@ def to_list(x):
         return None
     if not isinstance(x, list) and not isinstance(x, tuple):
         return [x]
+    else:
+        return x
         
 class OrderedProperties(object):
     """an object that maintains the order in which attributes are set upon it.
index 864e273b39e3e8cc1b9c6b12c6c39350192aa2fd..159a1864a847b2785991d8a1a59a612dbb58eefb 100644 (file)
@@ -1,7 +1,9 @@
 from testbase import PersistTest
-import sqlalchemy.util as util
+import sqlalchemy.mapping.topological as topological
 import unittest, sys, os
 
+class DependencySorter(topological.QueueDependencySorter):pass
+    
 class thingy(object):
     def __init__(self, name):
         self.name = name
@@ -31,7 +33,7 @@ class DependencySortTest(PersistTest):
             (node4, subnode3),
             (node4, subnode4)
         ]
-        head = util.DependencySorter(tuples, []).sort()
+        head = DependencySorter(tuples, []).sort()
         print "\n" + str(head)
 
     def testsort2(self):
@@ -49,7 +51,7 @@ class DependencySortTest(PersistTest):
             (node5, node6),
             (node6, node2)
         ]
-        head = util.DependencySorter(tuples, [node7]).sort()
+        head = DependencySorter(tuples, [node7]).sort()
         print "\n" + str(head)
 
     def testsort3(self):
@@ -62,9 +64,9 @@ class DependencySortTest(PersistTest):
             (node3, node2),
             (node1,node3)
         ]
-        head1 = util.DependencySorter(tuples, [node1, node2, node3]).sort()
-        head2 = util.DependencySorter(tuples, [node3, node1, node2]).sort()
-        head3 = util.DependencySorter(tuples, [node3, node2, node1]).sort()
+        head1 = DependencySorter(tuples, [node1, node2, node3]).sort()
+        head2 = DependencySorter(tuples, [node3, node1, node2]).sort()
+        head3 = DependencySorter(tuples, [node3, node2, node1]).sort()
         
         # TODO: figure out a "node == node2" function
         #self.assert_(str(head1) == str(head2) == str(head3))
@@ -83,7 +85,7 @@ class DependencySortTest(PersistTest):
             (node1, node3),
             (node3, node2)
         ]
-        head = util.DependencySorter(tuples, []).sort()
+        head = DependencySorter(tuples, []).sort()
         print "\n" + str(head)
 
     def testsort5(self):
@@ -111,7 +113,7 @@ class DependencySortTest(PersistTest):
             node3,
             node4
         ]
-        head = util.DependencySorter(tuples, allitems).sort()
+        head = DependencySorter(tuples, allitems).sort()
         print "\n" + str(head)
         
 
index a0b34d172069f913c2930add7834ca38ea14b752..81836f34b9ba7527159e556bbf3a8f5b26d115a2 100644 (file)
@@ -57,7 +57,7 @@ class MapperTest(MapperSuperTest):
 
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
@@ -69,7 +69,7 @@ class MapperTest(MapperSuperTest):
         l = m.options(lazyload('addresses')).select()
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
@@ -127,16 +127,27 @@ class LazyTest(MapperSuperTest):
             addresses = relation(m, lazy = True, order_by=addresses.c.email_address),
         ))
         l = m.select()
-        dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
-        dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
-        dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
 
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
+    def testorderby_desc(self):
+        m = mapper(Address, addresses)
+
+        m = mapper(User, users, properties = dict(
+            addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]),
+        ))
+        l = m.select()
+
+        self.assert_result(l, User,
+            {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@bettyboop.com'}])},
+            {'user_id' : 9, 'addresses' : (Address, [])},
+            )
+
     def testonetoone(self):
         m = mapper(User, users, properties = dict(
             address = relation(Address, addresses, lazy = True, uselist = False)
@@ -173,7 +184,7 @@ class LazyTest(MapperSuperTest):
                 'closed_orders' : (Order, [{'order_id' : 1},{'order_id' : 5},])
             },
             {'user_id' : 8, 
-                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]),
+                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
                 'open_orders' : (Order, []),
                 'closed_orders' : (Order, [])
             },
@@ -217,7 +228,7 @@ class EagerTest(MapperSuperTest):
         l = m.select()
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3},{'address_id' : 4}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
 
@@ -228,15 +239,25 @@ class EagerTest(MapperSuperTest):
             addresses = relation(m, lazy = False, order_by=addresses.c.email_address),
         ))
         l = m.select()
-        dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
-        dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
-        dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
-        
         self.assert_result(l, User,
             {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
-            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])},
             {'user_id' : 9, 'addresses' : (Address, [])}
             )
+
+    def testorderby_desc(self):
+        m = mapper(Address, addresses)
+
+        m = mapper(User, users, properties = dict(
+            addresses = relation(m, lazy = False, selectalias='lala', order_by=[desc(addresses.c.email_address)]),
+        ))
+        l = m.select()
+
+        self.assert_result(l, User,
+            {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'},{'email_address':'ed@lala.com'},  {'email_address':'ed@bettyboop.com'}, ])},
+            {'user_id' : 9, 'addresses' : (Address, [])},
+            )
         
     def testonetoone(self):
         m = mapper(User, users, properties = dict(
@@ -268,7 +289,7 @@ class EagerTest(MapperSuperTest):
         ))
         l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id))
         self.assert_result(l, User,
-            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@lala.com'}])},
+            {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@bettyboop.com'}, {'address_id':4, 'email_address':'ed@lala.com'}])},
         )
         
 
@@ -297,7 +318,7 @@ class EagerTest(MapperSuperTest):
                 'orders' : (Order, [{'order_id' : 1}, {'order_id' : 3},{'order_id' : 5},])
             },
             {'user_id' : 8, 
-                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]),
+                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
                 'orders' : (Order, [])
             },
             {'user_id' : 9, 
@@ -323,7 +344,7 @@ class EagerTest(MapperSuperTest):
                 'closed_orders' : (Order, [{'order_id' : 1},{'order_id' : 5},])
             },
             {'user_id' : 8, 
-                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]),
+                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
                 'open_orders' : (Order, []),
                 'closed_orders' : (Order, [])
             },
@@ -356,7 +377,7 @@ class EagerTest(MapperSuperTest):
                     ])
             },
             {'user_id' : 8, 
-                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]),
+                'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]),
                 'orders' : (Order, [])
             },
             {'user_id' : 9, 
index 76fbcc07852c5943bbc39a2a9dd16bdb013567e5..8fc9694f457225ccd42fbc618bad898dcfd2db5e 100644 (file)
@@ -10,22 +10,29 @@ from sqlalchemy import *
 
 class QueryTest(PersistTest):
     
-    def setUp(self):
-        self.users = Table('query_users', db,
+    def setUpAll(self):
+        global users
+        users = Table('query_users', db,
             Column('user_id', INT, primary_key = True),
             Column('user_name', VARCHAR(20)),
             redefine = True
         )
-        self.users.create()
-        
+        users.create()
+    
+    def setUp(self):
+        self.users = users
+    def tearDown(self):
+        self.users.delete().execute()
+    
+    def tearDownAll(self):
+        global users
+        users.drop()
         
     def testinsert(self):
-        c = db.connection()
         self.users.insert().execute(user_id = 7, user_name = 'jack')
         print repr(self.users.select().execute().fetchall())
         
     def testupdate(self):
-        c = db.connection()
 
         self.users.insert().execute(user_id = 7, user_name = 'jack')
         print repr(self.users.select().execute().fetchall())
@@ -69,9 +76,20 @@ class QueryTest(PersistTest):
         db.transaction(dostuff)
         print repr(self.users.select().execute().fetchall())    
 
-
-    def tearDown(self):
-        self.users.drop()
+    def testselectlimit(self):
+        self.users.insert().execute(user_id=1, user_name='john')
+        self.users.insert().execute(user_id=2, user_name='jack')
+        self.users.insert().execute(user_id=3, user_name='ed')
+        self.users.insert().execute(user_id=4, user_name='wendy')
+        self.users.insert().execute(user_id=5, user_name='laura')
+        self.users.insert().execute(user_id=6, user_name='ralph')
+        self.users.insert().execute(user_id=7, user_name='fido')
+        r = self.users.select(limit=3).execute().fetchall()
+        self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')])
+        r = self.users.select(limit=3, offset=2).execute().fetchall()
+        self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')])
+        r = self.users.select(offset=5).execute().fetchall()
+        self.assert_(r==[(6, 'ralph'), (7, 'fido')])
         
 if __name__ == "__main__":
-    unittest.main()        
+    testbase.main()        
index fa3c9150ae7caa23736b02f733b35be381be85e8..8bb0587e570a5cef7826489861efc9dcdc22692d 100644 (file)
@@ -84,7 +84,8 @@ def data():
     addresses.insert().execute(
         dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
         dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
-        dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
+        dict(address_id = 3, user_id = 8, email_address = "ed@bettyboop.com"),
+        dict(address_id = 4, user_id = 8, email_address = "ed@lala.com")
     )
     orders.insert().execute(
         dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),