]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Corrected a lot of mssql limit / offset issues. Also ensured that mssql uses the...
authorMichael Trier <mtrier@gmail.com>
Sat, 8 Nov 2008 04:43:35 +0000 (04:43 +0000)
committerMichael Trier <mtrier@gmail.com>
Sat, 8 Nov 2008 04:43:35 +0000 (04:43 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py
test/dialect/mssql.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index 0a0946b35c3fd2c1f8a8c7eeac1dcd4b20c5e359..681e7ed525ea24a160a7a6f71bf32affdc7ba6ae 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,6 +4,16 @@
 =======
 CHANGES
 =======
+0.5.0rc4
+========
+- mssql
+    - Lots of cleanup and fixes to correct problems with
+      limit and offset.
+
+    - Correct situation where subqueries as part of a
+      binary expression need to be translated to use the
+      IN and NOT IN syntax.
+
 0.5.0rc3
 ========
 - features
index f86a95548258b5f2c270b54b4b680324a333ddce..3291098282c42b2ef1a6b876279da20cc4ef3ab9 100644 (file)
@@ -922,12 +922,15 @@ class MSSQLCompiler(compiler.DefaultCompiler):
 
     def get_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
-        if not self.dialect.has_window_funcs:
+        if select._distinct or select._limit:
             s = select._distinct and "DISTINCT " or ""
+            
             if select._limit:
-                s += "TOP %s " % (select._limit,)
-            if select._offset:
-                raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+                if not select._offset:
+                    s += "TOP %s " % (select._limit,)
+                else:
+                    if not self.dialect.has_window_funcs:
+                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
             return s
         return compiler.DefaultCompiler.get_select_precolumns(self, select)
 
@@ -938,13 +941,13 @@ class MSSQLCompiler(compiler.DefaultCompiler):
     def visit_select(self, select, **kwargs):
         """Look for ``LIMIT`` and OFFSET in a select statement, and if
         so tries to wrap it in a subquery with ``row_number()`` criterion.
+
         """
-        if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._limit is not None or select._offset is not None):
+        if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and (select._offset is not None):
             # to use ROW_NUMBER(), an ORDER BY is required.
             orderby = self.process(select._order_by_clause)
             if not orderby:
-                orderby = list(select.oid_column.proxies)[0]
-                orderby = self.process(orderby)
+                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
 
             _offset = select._offset
             _limit = select._limit
@@ -952,12 +955,9 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
 
             limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
-            if _offset is not None:
-                limitselect.append_whereclause("mssql_rn>=%d" % _offset)
-                if _limit is not None:
-                    limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
-            else:
-                limitselect.append_whereclause("mssql_rn<=%d" % _limit)
+            limitselect.append_whereclause("mssql_rn>%d" % _offset)
+            if _limit is not None:
+                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
             return self.process(limitselect, iswrapper=True, **kwargs)
         else:
             return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
@@ -1003,10 +1003,17 @@ class MSSQLCompiler(compiler.DefaultCompiler):
 
     def visit_binary(self, binary, **kwargs):
         """Move bind parameters to the right-hand side of an operator, where possible."""
+
         if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
             and not isinstance(binary.right, expression._BindParamClause):
             return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
         else:
+            if (binary.operator in (operator.eq, operator.ne)) and (
+                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._SelectBaseMixin)) or \
+                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._SelectBaseMixin)) or \
+                 isinstance(binary.left, expression._SelectBaseMixin) or isinstance(binary.right, expression._SelectBaseMixin)):
+                op = binary.operator == operator.eq and "IN" or "NOT IN"
+                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
             return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
 
     def label_select_column(self, select, column, asfrom):
index 4708cc28c4cfcc93bc2f0775d605adb00f848f18..26fc752430dc2aafe14b3d381f4f42e2abac9ccc 100755 (executable)
@@ -20,6 +20,16 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         t = table('sometable', column('somecolumn'))
         self.assert_compile(t.update(t.c.somecolumn==7), "UPDATE sometable SET somecolumn=:somecolumn WHERE sometable.somecolumn = :somecolumn_1", dict(somecolumn=10))
 
+    def test_in_with_subqueries(self):
+        """Test that when using subqueries in a binary expression
+        the == and != are changed to IN and NOT IN respectively.
+
+        """
+
+        t = table('sometable', column('somecolumn'))
+        self.assert_compile(t.select().where(t.c.somecolumn==t.select()), "SELECT sometable.somecolumn FROM sometable WHERE sometable.somecolumn IN (SELECT sometable.somecolumn FROM sometable)")
+        self.assert_compile(t.select().where(t.c.somecolumn!=t.select()), "SELECT sometable.somecolumn FROM sometable WHERE sometable.somecolumn NOT IN (SELECT sometable.somecolumn FROM sometable)")
+
     def test_count(self):
         t = table('sometable', column('somecolumn'))
         self.assert_compile(t.count(), "SELECT count(sometable.somecolumn) AS tbl_row_count FROM sometable")
@@ -197,28 +207,6 @@ class QueryTest(TestBase):
         finally:
             table.drop()
 
-    def test_select_limit_nooffset(self):
-        metadata = MetaData(testing.db)
-
-        users = Table('query_users', metadata,
-            Column('user_id', INT, primary_key = True),
-            Column('user_name', VARCHAR(20)),
-        )
-        addresses = Table('query_addresses', metadata,
-            Column('address_id', Integer, primary_key=True),
-            Column('user_id', Integer, ForeignKey('query_users.user_id')),
-            Column('address', String(30)))
-        metadata.create_all()
-
-        try:
-            try:
-                r = users.select(limit=3, offset=2,
-                                 order_by=[users.c.user_id]).execute().fetchall()
-                assert False # InvalidRequestError should have been raised
-            except exc.InvalidRequestError:
-                pass
-        finally:
-            metadata.drop_all()
 
 class Foo(object):
     def __init__(self, **kw):
index 3118aef6463015db94489ab229f9f803a61aa670..ac11b445228dac9a6eb8687defb55970f55e8152 100644 (file)
@@ -648,9 +648,10 @@ class LimitTest(TestBase):
         r = users.select(limit=3, order_by=[users.c.user_id]).execute().fetchall()
         self.assert_(r == [(1, 'john'), (2, 'jack'), (3, 'ed')], repr(r))
 
-    @testing.crashes('mssql', 'FIXME: guessing')
     @testing.fails_on('maxdb')
     def test_select_limit_offset(self):
+        """Test the interaction between limit and offset"""
+
         r = users.select(limit=3, offset=2, order_by=[users.c.user_id]).execute().fetchall()
         self.assert_(r==[(3, 'ed'), (4, 'wendy'), (5, 'laura')])
         r = users.select(offset=5, order_by=[users.c.user_id]).execute().fetchall()
@@ -659,14 +660,15 @@ class LimitTest(TestBase):
     def test_select_distinct_limit(self):
         """Test the interaction between limit and distinct"""
 
-        r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).execute().fetchall()])
+        r = sorted([x[0] for x in select([addresses.c.address]).distinct().limit(3).order_by(addresses.c.address).execute().fetchall()])
         self.assert_(len(r) == 3, repr(r))
         self.assert_(r[0] != r[1] and r[1] != r[2], repr(r))
 
+    @testing.fails_on('mssql')
     def test_select_distinct_offset(self):
-        """Test the interaction between limit and offset"""
+        """Test the interaction between distinct and offset"""
 
-        r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).execute().fetchall()])
+        r = sorted([x[0] for x in select([addresses.c.address]).distinct().offset(1).order_by(addresses.c.address).execute().fetchall()])
         self.assert_(len(r) == 4, repr(r))
         self.assert_(r[0] != r[1] and r[1] != r[2] and r[2] != [3], repr(r))