From: Mike Bayer Date: Wed, 7 Dec 2005 01:37:55 +0000 (+0000) Subject: added rudimentary support for limit and offset (with the hack version in oracle) X-Git-Tag: rel_0_1_0~252 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=44abaa9e567e6c412a2ce499907c65d97aa865dc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added rudimentary support for limit and offset (with the hack version in oracle) fixed up order_by to support a list/scalar of columns or asc/desc fixed up query.py unit test --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 691106a710..79215dec7a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -269,7 +269,12 @@ class ANSICompiler(sql.Compiled): t = self.get_str(select.having) if t: text += " \nHAVING " + t - + + if select.limit is not None or select.offset is not None: + # TODO: ok, so this is a simple limit/offset thing. + # need to make this DB neutral for mysql, oracle + text += self.limit_clause(select) + if getattr(select, 'issubquery', False): self.strings[select] = "(" + text + ")" else: @@ -277,6 +282,14 @@ class ANSICompiler(sql.Compiled): self.froms[select] = "(" + text + ")" + def limit_clause(self, select): + if select.limit is not None: + return " \n LIMIT " + str(select.limit) + if select.offset is not None: + if select.limit is None: + return " \n LIMIT -1" + return " OFFSET " + str(select.offset) + def visit_table(self, table): self.froms[table] = table.fullname self.strings[table] = "" diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 9e3f677d11..1eef6facba 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -175,7 +175,16 @@ class MySQLTableImpl(sql.TableImpl): rowid_column = property(lambda s: s._rowid_col()) class MySQLCompiler(ansisql.ANSICompiler): - pass + def limit_clause(self, select): + text = "" + if select.limit is not None: + text += " \n LIMIT " + str(select.limit) + if select.offset is not None: + if select.limit is None: + # striaght from the MySQL docs, I kid you not + text += " \n LIMIT 18446744073709551615" + text += " OFFSET " + str(select.offset) + return text class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, first_pk=False): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 58d533b108..93670729df 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -195,6 +195,28 @@ class OracleCompiler(ansisql.ANSICompiler): self.bindparams[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) + def visit_select(self, select): + """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a + subquery with rownum criterion.""" + if getattr(select, '_oracle_visit', False): + ansisql.ANSICompiler.visit_select(self, select) + return + if select.limit is not None or select.offset is not None: + select._oracle_visit = True + limitselect = select.select() + if select.limit is not None: + limitselect.append_whereclause("rownum<%d" % select.limit) + if select.offset is not None: + limitselect.append_whereclause("rownum>%d" % select.offset) + limitselect.accept_visitor(self) + self.strings[select] = self.strings[limitselect] + self.froms[select] = self.froms[limitselect] + else: + ansisql.ANSICompiler.visit_select(self, select) + + def limit_clause(self, select): + return "" + class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 5a3d6388a3..531e8af03c 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -240,6 +240,16 @@ class PGCompiler(ansisql.ANSICompiler): if c.sequence is not None and not c.sequence.optional: self.bindparams[c.key] = None return ansisql.ANSICompiler.visit_insert(self, insert) + + def limit_clause(self, select): + text = "" + if select.limit is not None: + text += " \n LIMIT " + str(select.limit) + if select.offset is not None: + if select.limit is None: + text += " \n LIMIT ALL" + text += " OFFSET " + str(select.offset) + return text class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 23a9d04830..fc85d92aca 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -188,7 +188,17 @@ class SQLiteCompiler(ansisql.ANSICompiler): def __init__(self, *args, **params): params.setdefault('paramstyle', 'named') ansisql.ANSICompiler.__init__(self, *args, **params) + def limit_clause(self, select): + text = "" + if select.limit is not None: + text += " \n LIMIT " + str(select.limit) + if select.offset is not None: + if select.limit is None: + text += " \n LIMIT -1" + text += " OFFSET " + str(select.offset) + return text + class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, override_pk=False, **kwargs): colspec = column.name + " " + column.type.get_col_spec() diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index d1fd2315ed..3429555d3b 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -541,8 +541,8 @@ class EagerLoader(PropertyLoader): # if this eagermapper is to select using an "alias" to isolate it from other # eager mappers against the same table, we have to redefine our secondary - # or primary join condition to reference the aliased table. else - # we set up the target clause objects as what they are defined in the + # or primary join condition to reference the aliased table (and the order_by). + # else we set up the target clause objects as what they are defined in the # superclass. if self.selectalias is not None: self.eagertarget = self.target.alias(self.selectalias) @@ -554,11 +554,21 @@ class EagerLoader(PropertyLoader): else: self.eagerprimary = self.primaryjoin.copy_container() self.eagerprimary.accept_visitor(aliasizer) + if self.order_by is not None: + self.eager_order_by = [o.copy_container() for o in self.order_by] + for i in range(0, len(self.eager_order_by)): + if isinstance(self.eager_order_by[i], schema.Column): + self.eager_order_by[i] = self.eagertarget._get_col_by_original(self.eager_order_by[i]) + else: + self.eager_order_by[i].accept_visitor(aliasizer) + else: + self.eager_order_by = None else: self.eagertarget = self.target self.eagerprimary = self.primaryjoin self.eagersecondary = self.secondaryjoin - + self.eager_order_by = self.order_by + def setup(self, key, statement, recursion_stack = None, **options): """add a left outer join to the statement thats being constructed""" @@ -588,8 +598,8 @@ class EagerLoader(PropertyLoader): if self.order_by is None: statement.order_by(self.eagertarget.rowid_column) - if self.order_by is not None: - statement.order_by(*[self.eagertarget._get_col_by_original(c) for c in self.order_by]) + if self.eager_order_by is not None: + statement.order_by(*self.eager_order_by) statement.append_from(statement._outerjoin) statement.append_column(self.eagertarget) @@ -691,12 +701,18 @@ class Aliasizer(sql.ClauseVisitor): aliasname = table.name + "_" + hex(random.randint(0, 65535))[2:] return self.aliases.setdefault(table, sql.alias(table, aliasname)) + def visit_compound(self, compound): + for i in range(0, len(compound.clauses)): + if isinstance(compound.clauses[i], schema.Column) and self.tables.has_key(compound.clauses[i].table): + compound.clauses[i] = self.get_alias(compound.clauses[i].table)._get_col_by_original(compound.clauses[i]) + self.match = True + def visit_binary(self, binary): if isinstance(binary.left, schema.Column) and self.tables.has_key(binary.left.table): - binary.left = self.get_alias(binary.left.table).c[binary.left.name] + binary.left = self.get_alias(binary.left.table)._get_col_by_original(binary.left) self.match = True if isinstance(binary.right, schema.Column) and self.tables.has_key(binary.right.table): - binary.right = self.get_alias(binary.right.table).c[binary.right.name] + binary.right = self.get_alias(binary.right.table)._get_col_by_original(binary.right) self.match = True class BinaryVisitor(sql.ClauseVisitor): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 71cbacdac0..10e8f3e74c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -552,7 +552,7 @@ class CompoundClause(ClauseList): f += c._get_from_objects() return f def hash_key(self): - return string.join([c.hash_key() for c in self.clauses], self.operator) + return string.join([c.hash_key() for c in self.clauses], self.operator or " ") class Function(ClauseList, CompareMixin): """describes a SQL function. extends ClauseList to provide comparison operators.""" @@ -948,7 +948,7 @@ class CompoundSelect(Selectable, TailClauseMixin): class Select(Selectable, TailClauseMixin): """finally, represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" - def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None): + def __init__(self, columns=None, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None, limit=None, offset=None): self._columns = util.OrderedProperties() self._froms = util.OrderedDict() self.use_labels = use_labels @@ -958,6 +958,8 @@ class Select(Selectable, TailClauseMixin): self.having = None self._engine = engine self.rowid_column = None + self.limit = limit + self.offset = offset # indicates if this select statement is a subquery inside another query self.issubquery = False diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c247c86edf..d280368f69 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -23,6 +23,8 @@ def to_list(x): return None if not isinstance(x, list) and not isinstance(x, tuple): return [x] + else: + return x class OrderedProperties(object): """an object that maintains the order in which attributes are set upon it. diff --git a/test/dependency.py b/test/dependency.py index 864e273b39..159a1864a8 100644 --- a/test/dependency.py +++ b/test/dependency.py @@ -1,7 +1,9 @@ from testbase import PersistTest -import sqlalchemy.util as util +import sqlalchemy.mapping.topological as topological import unittest, sys, os +class DependencySorter(topological.QueueDependencySorter):pass + class thingy(object): def __init__(self, name): self.name = name @@ -31,7 +33,7 @@ class DependencySortTest(PersistTest): (node4, subnode3), (node4, subnode4) ] - head = util.DependencySorter(tuples, []).sort() + head = DependencySorter(tuples, []).sort() print "\n" + str(head) def testsort2(self): @@ -49,7 +51,7 @@ class DependencySortTest(PersistTest): (node5, node6), (node6, node2) ] - head = util.DependencySorter(tuples, [node7]).sort() + head = DependencySorter(tuples, [node7]).sort() print "\n" + str(head) def testsort3(self): @@ -62,9 +64,9 @@ class DependencySortTest(PersistTest): (node3, node2), (node1,node3) ] - head1 = util.DependencySorter(tuples, [node1, node2, node3]).sort() - head2 = util.DependencySorter(tuples, [node3, node1, node2]).sort() - head3 = util.DependencySorter(tuples, [node3, node2, node1]).sort() + head1 = DependencySorter(tuples, [node1, node2, node3]).sort() + head2 = DependencySorter(tuples, [node3, node1, node2]).sort() + head3 = DependencySorter(tuples, [node3, node2, node1]).sort() # TODO: figure out a "node == node2" function #self.assert_(str(head1) == str(head2) == str(head3)) @@ -83,7 +85,7 @@ class DependencySortTest(PersistTest): (node1, node3), (node3, node2) ] - head = util.DependencySorter(tuples, []).sort() + head = DependencySorter(tuples, []).sort() print "\n" + str(head) def testsort5(self): @@ -111,7 +113,7 @@ class DependencySortTest(PersistTest): node3, node4 ] - head = util.DependencySorter(tuples, allitems).sort() + head = DependencySorter(tuples, allitems).sort() print "\n" + str(head) diff --git a/test/mapper.py b/test/mapper.py index a0b34d1720..81836f34b9 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -57,7 +57,7 @@ class MapperTest(MapperSuperTest): self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, - {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])}, + {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])}, {'user_id' : 9, 'addresses' : (Address, [])} ) @@ -69,7 +69,7 @@ class MapperTest(MapperSuperTest): l = m.options(lazyload('addresses')).select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, - {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])}, + {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}])}, {'user_id' : 9, 'addresses' : (Address, [])} ) @@ -127,16 +127,27 @@ class LazyTest(MapperSuperTest): addresses = relation(m, lazy = True, order_by=addresses.c.email_address), )) l = m.select() - dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), - dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), - dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, - {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, + {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, {'user_id' : 9, 'addresses' : (Address, [])} ) + def testorderby_desc(self): + m = mapper(Address, addresses) + + m = mapper(User, users, properties = dict( + addresses = relation(m, lazy = True, order_by=[desc(addresses.c.email_address)]), + )) + l = m.select() + + self.assert_result(l, User, + {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, + {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@bettyboop.com'}])}, + {'user_id' : 9, 'addresses' : (Address, [])}, + ) + def testonetoone(self): m = mapper(User, users, properties = dict( address = relation(Address, addresses, lazy = True, uselist = False) @@ -173,7 +184,7 @@ class LazyTest(MapperSuperTest): 'closed_orders' : (Order, [{'order_id' : 1},{'order_id' : 5},]) }, {'user_id' : 8, - 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]), + 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'open_orders' : (Order, []), 'closed_orders' : (Order, []) }, @@ -217,7 +228,7 @@ class EagerTest(MapperSuperTest): l = m.select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, - {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}])}, + {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3},{'address_id' : 4}])}, {'user_id' : 9, 'addresses' : (Address, [])} ) @@ -228,15 +239,25 @@ class EagerTest(MapperSuperTest): addresses = relation(m, lazy = False, order_by=addresses.c.email_address), )) l = m.select() - dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), - dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), - dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") - self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, - {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, + {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@bettyboop.com'}, {'email_address':'ed@lala.com'}, {'email_address':'ed@wood.com'}])}, {'user_id' : 9, 'addresses' : (Address, [])} ) + + def testorderby_desc(self): + m = mapper(Address, addresses) + + m = mapper(User, users, properties = dict( + addresses = relation(m, lazy = False, selectalias='lala', order_by=[desc(addresses.c.email_address)]), + )) + l = m.select() + + self.assert_result(l, User, + {'user_id' : 7, 'addresses' : (Address, [{'email_address' : 'jack@bean.com'}])}, + {'user_id' : 8, 'addresses' : (Address, [{'email_address':'ed@wood.com'},{'email_address':'ed@lala.com'}, {'email_address':'ed@bettyboop.com'}, ])}, + {'user_id' : 9, 'addresses' : (Address, [])}, + ) def testonetoone(self): m = mapper(User, users, properties = dict( @@ -268,7 +289,7 @@ class EagerTest(MapperSuperTest): )) l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id)) self.assert_result(l, User, - {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@lala.com'}])}, + {'user_id' : 8, 'addresses' : (Address, [{'address_id' : 2, 'email_address':'ed@wood.com'}, {'address_id':3, 'email_address':'ed@bettyboop.com'}, {'address_id':4, 'email_address':'ed@lala.com'}])}, ) @@ -297,7 +318,7 @@ class EagerTest(MapperSuperTest): 'orders' : (Order, [{'order_id' : 1}, {'order_id' : 3},{'order_id' : 5},]) }, {'user_id' : 8, - 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]), + 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'orders' : (Order, []) }, {'user_id' : 9, @@ -323,7 +344,7 @@ class EagerTest(MapperSuperTest): 'closed_orders' : (Order, [{'order_id' : 1},{'order_id' : 5},]) }, {'user_id' : 8, - 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]), + 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'open_orders' : (Order, []), 'closed_orders' : (Order, []) }, @@ -356,7 +377,7 @@ class EagerTest(MapperSuperTest): ]) }, {'user_id' : 8, - 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}]), + 'addresses' : (Address, [{'address_id' : 2}, {'address_id' : 3}, {'address_id' : 4}]), 'orders' : (Order, []) }, {'user_id' : 9, diff --git a/test/query.py b/test/query.py index 76fbcc0785..8fc9694f45 100644 --- a/test/query.py +++ b/test/query.py @@ -10,22 +10,29 @@ from sqlalchemy import * class QueryTest(PersistTest): - def setUp(self): - self.users = Table('query_users', db, + def setUpAll(self): + global users + users = Table('query_users', db, Column('user_id', INT, primary_key = True), Column('user_name', VARCHAR(20)), redefine = True ) - self.users.create() - + users.create() + + def setUp(self): + self.users = users + def tearDown(self): + self.users.delete().execute() + + def tearDownAll(self): + global users + users.drop() def testinsert(self): - c = db.connection() self.users.insert().execute(user_id = 7, user_name = 'jack') print repr(self.users.select().execute().fetchall()) def testupdate(self): - c = db.connection() self.users.insert().execute(user_id = 7, user_name = 'jack') print repr(self.users.select().execute().fetchall()) @@ -69,9 +76,20 @@ class QueryTest(PersistTest): db.transaction(dostuff) print repr(self.users.select().execute().fetchall()) - - def tearDown(self): - self.users.drop() + def testselectlimit(self): + self.users.insert().execute(user_id=1, user_name='john') + self.users.insert().execute(user_id=2, user_name='jack') + self.users.insert().execute(user_id=3, user_name='ed') + self.users.insert().execute(user_id=4, user_name='wendy') + self.users.insert().execute(user_id=5, user_name='laura') + self.users.insert().execute(user_id=6, user_name='ralph') + self.users.insert().execute(user_id=7, user_name='fido') + r = self.users.select(limit=3).execute().fetchall() + self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')]) + r = self.users.select(limit=3, offset=2).execute().fetchall() + self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')]) + r = self.users.select(offset=5).execute().fetchall() + self.assert_(r==[(6, 'ralph'), (7, 'fido')]) if __name__ == "__main__": - unittest.main() + testbase.main() diff --git a/test/tables.py b/test/tables.py index fa3c9150ae..8bb0587e57 100644 --- a/test/tables.py +++ b/test/tables.py @@ -84,7 +84,8 @@ def data(): addresses.insert().execute( dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), - dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") + dict(address_id = 3, user_id = 8, email_address = "ed@bettyboop.com"), + dict(address_id = 4, user_id = 8, email_address = "ed@lala.com") ) orders.insert().execute( dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),