From 8924a0e4fe3f6dfe3f90413deb8bee8cd776cafb Mon Sep 17 00:00:00 2001 From: Michael Trier Date: Sat, 8 Nov 2008 04:43:35 +0000 Subject: [PATCH] Corrected a lot of mssql limit / offset issues. Also ensured that mssql uses the IN / NOT IN syntax when using a binary expression with a subquery. --- CHANGES | 10 ++++++++++ lib/sqlalchemy/databases/mssql.py | 33 +++++++++++++++++++------------ test/dialect/mssql.py | 32 ++++++++++-------------------- test/sql/query.py | 10 ++++++---- 4 files changed, 46 insertions(+), 39 deletions(-) diff --git a/CHANGES b/CHANGES index 0a0946b35c..681e7ed525 100644 --- a/CHANGES +++ b/CHANGES @@ -4,6 +4,16 @@ ======= CHANGES ======= +0.5.0rc4 +======== +- mssql + - Lots of cleanup and fixes to correct problems with + limit and offset. + + - Correct situation where subqueries as part of a + binary expression need to be translated to use the + IN and NOT IN syntax. + 0.5.0rc3 ======== - features diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index f86a955482..3291098282 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -922,12 +922,15 @@ class MSSQLCompiler(compiler.DefaultCompiler): def get_select_precolumns(self, select): """ MS-SQL puts TOP, it's version of LIMIT here """ - if not self.dialect.has_window_funcs: + if select._distinct or select._limit: s = select._distinct and "DISTINCT " or "" + if select._limit: - s += "TOP %s " % (select._limit,) - if select._offset: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') + if not select._offset: + s += "TOP %s " % (select._limit,) + else: + if not self.dialect.has_window_funcs: + raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s return compiler.DefaultCompiler.get_select_precolumns(self, select) @@ -938,13 +941,13 @@ class MSSQLCompiler(compiler.DefaultCompiler): def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. + """ - if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._limit is not None or select._offset is not None): + if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._offset is not None): # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.process(select._order_by_clause) if not orderby: - orderby = list(select.oid_column.proxies)[0] - orderby = self.process(orderby) + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') _offset = select._offset _limit = select._limit @@ -952,12 +955,9 @@ class MSSQLCompiler(compiler.DefaultCompiler): select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) - if _offset is not None: - limitselect.append_whereclause("mssql_rn>=%d" % _offset) - if _limit is not None: - limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) - else: - limitselect.append_whereclause("mssql_rn<=%d" % _limit) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) return self.process(limitselect, iswrapper=True, **kwargs) else: return compiler.DefaultCompiler.visit_select(self, select, **kwargs) @@ -1003,10 +1003,17 @@ class MSSQLCompiler(compiler.DefaultCompiler): def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where possible.""" + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ and not isinstance(binary.right, expression._BindParamClause): return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) else: + if (binary.operator in (operator.eq, operator.ne)) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._SelectBaseMixin)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._SelectBaseMixin)) or \ + isinstance(binary.left, expression._SelectBaseMixin) or isinstance(binary.right, expression._SelectBaseMixin)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def label_select_column(self, select, column, asfrom): diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index 4708cc28c4..26fc752430 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -20,6 +20,16 @@ class CompileTest(TestBase, AssertsCompiledSQL): t = table('sometable', column('somecolumn')) self.assert_compile(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :somecolumn_1", dict(somecolumn=10)) + def test_in_with_subqueries(self): + """Test that when using subqueries in a binary expression + the == and != are changed to IN and NOT IN respectively. + + """ + + t = table('sometable', column('somecolumn')) + self.assert_compile(t.select().where(t.c.somecolumn==t.select()), "SELECT sometable.somecolumn FROM sometable WHERE sometable.somecolumn IN (SELECT sometable.somecolumn FROM sometable)") + self.assert_compile(t.select().where(t.c.somecolumn!=t.select()), "SELECT sometable.somecolumn FROM sometable WHERE sometable.somecolumn NOT IN (SELECT sometable.somecolumn FROM sometable)") + def test_count(self): t = table('sometable', column('somecolumn')) self.assert_compile(t.count(), "SELECT count(sometable.somecolumn) AS tbl_row_count FROM sometable") @@ -197,28 +207,6 @@ class QueryTest(TestBase): finally: table.drop() - def test_select_limit_nooffset(self): - metadata = MetaData(testing.db) - - users = Table('query_users', metadata, - Column('user_id', INT, primary_key = True), - Column('user_name', VARCHAR(20)), - ) - addresses = Table('query_addresses', metadata, - Column('address_id', Integer, primary_key=True), - Column('user_id', Integer, ForeignKey('query_users.user_id')), - Column('address', String(30))) - metadata.create_all() - - try: - try: - r = users.select(limit=3, offset=2, - order_by=[users.c.user_id]).execute().fetchall() - assert False # InvalidRequestError should have been raised - except exc.InvalidRequestError: - pass - finally: - metadata.drop_all() class Foo(object): def __init__(self, **kw): diff --git a/test/sql/query.py b/test/sql/query.py index 3118aef646..ac11b44522 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -648,9 +648,10 @@ class LimitTest(TestBase): r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r)) - @testing.crashes('mssql', 'FIXME: guessing') @testing.fails_on('maxdb') def test_select_limit_offset(self): + """Test the interaction between limit and offset""" + r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall() self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')]) r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall() @@ -659,14 +660,15 @@ class LimitTest(TestBase): def test_select_distinct_limit(self): """Test the interaction between limit and distinct""" - r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).execute().fetchall()]) + r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).order_by(addresses.c.address).execute().fetchall()]) self.assert_(len(r) == 3, repr(r)) self.assert_(r[0] != r[1] and r[1] != r[2], repr(r)) + @testing.fails_on('mssql') def test_select_distinct_offset(self): - """Test the interaction between limit and offset""" + """Test the interaction between distinct and offset""" - r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).execute().fetchall()]) + r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).order_by(addresses.c.address).execute().fetchall()]) self.assert_(len(r) == 4, repr(r)) self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r)) -- 2.47.3