]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- apply optimizations to alternate row proxies, [ticket:1787]
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Feb 2011 23:33:49 +0000 (18:33 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Feb 2011 23:33:49 +0000 (18:33 -0500)
- add check to fetchmany() for None, don't send argument if not present,
helps DBAPIs which don't accept "None" for default (ie. pysqlite, maybe others)
- add tests to test_execute to provide 100% coverage for the three alternate
result proxy classes

lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/base.py
test/engine/test_execute.py

index 04f3aab9548cc459ece4e13d6abcf74cb80c27c6..bc1c877037e0d9455aaec604989097308c1c611b 100644 (file)
@@ -127,6 +127,7 @@ from sqlalchemy.engine import base
 from sqlalchemy import types as sqltypes, util, exc, processors
 from datetime import datetime
 import random
+import collections
 from sqlalchemy.util.compat import decimal
 import re
 
@@ -423,8 +424,8 @@ class ReturningResultProxy(base.FullyBufferedResultProxy):
         return ret
 
     def _buffer_rows(self):
-        return [tuple(self._returning_params["ret_%d" % i] 
-                    for i, c in enumerate(self._returning_params))]
+        return collections.deque([tuple(self._returning_params["ret_%d" % i] 
+                    for i, c in enumerate(self._returning_params))])
 
 class OracleDialect_cx_oracle(OracleDialect):
     execution_ctx_cls = OracleExecutionContext_cx_oracle
index ccce58b9c219b43f570592bf2513ddb4d9754e7d..b78a305374732262031328ba2af26f6b3dfe29c7 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy import exc, schema, util, types, log, interfaces, \
     event, events
 from sqlalchemy.sql import expression
 from sqlalchemy import processors
+import collections
 
 class Dialect(object):
     """Define the behavior of a specific database and DB-API combination.
@@ -2603,7 +2604,10 @@ class ResultProxy(object):
 
     def _fetchmany_impl(self, size=None):
         try:
-            return self.cursor.fetchmany(size)
+            if size is None:
+                return self.cursor.fetchmany()
+            else:
+                return self.cursor.fetchmany(size)
         except AttributeError:
             self._non_result()
 
@@ -2756,24 +2760,29 @@ class BufferedRowResultProxy(ResultProxy):
         5 : 10,
         10 : 20,
         20 : 50,
-        50 : 100
+        50 : 100,
+        100 : 250,
+        250 : 500,
+        500 : 1000
     }
 
     def __buffer_rows(self):
         size = getattr(self, '_bufsize', 1)
-        self.__rowbuffer = self.cursor.fetchmany(size)
+        self.__rowbuffer = collections.deque(self.cursor.fetchmany(size))
         self._bufsize = self.size_growth.get(size, size)
 
     def _fetchone_impl(self):
         if self.closed:
             return None
-        if len(self.__rowbuffer) == 0:
+        if not self.__rowbuffer:
             self.__buffer_rows()
-            if len(self.__rowbuffer) == 0:
+            if not self.__rowbuffer:
                 return None
-        return self.__rowbuffer.pop(0)
+        return self.__rowbuffer.popleft()
 
     def _fetchmany_impl(self, size=None):
+        if size is None:
+            return self._fetchall_impl()
         result = []
         for x in range(0, size):
             row = self._fetchone_impl()
@@ -2783,8 +2792,9 @@ class BufferedRowResultProxy(ResultProxy):
         return result
 
     def _fetchall_impl(self):
-        ret = self.__rowbuffer + list(self.cursor.fetchall())
-        self.__rowbuffer[:] = []
+        self.__rowbuffer.extend(self.cursor.fetchall())
+        ret = self.__rowbuffer
+        self.__rowbuffer = collections.deque()
         return ret
 
 class FullyBufferedResultProxy(ResultProxy):
@@ -2800,15 +2810,17 @@ class FullyBufferedResultProxy(ResultProxy):
         self.__rowbuffer = self._buffer_rows()
 
     def _buffer_rows(self):
-        return self.cursor.fetchall()
+        return collections.deque(self.cursor.fetchall())
 
     def _fetchone_impl(self):
         if self.__rowbuffer:
-            return self.__rowbuffer.pop(0)
+            return self.__rowbuffer.popleft()
         else:
             return None
 
     def _fetchmany_impl(self, size=None):
+        if size is None:
+            return self._fetchall_impl()
         result = []
         for x in range(0, size):
             row = self._fetchone_impl()
@@ -2819,7 +2831,7 @@ class FullyBufferedResultProxy(ResultProxy):
 
     def _fetchall_impl(self):
         ret = self.__rowbuffer
-        self.__rowbuffer = []
+        self.__rowbuffer = collections.deque()
         return ret
 
 class BufferedColumnRow(RowProxy):
index 318bc15d502bc8f2abcb08a42dfaf43109b2acd9..37afbd4a3ec010b44620fbdfa728be3aca0fb227 100644 (file)
@@ -8,6 +8,7 @@ import sqlalchemy as tsa
 from test.lib import TestBase, testing, engines
 import logging
 from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
+from sqlalchemy.engine import base, default
 
 users, metadata = None, None
 class ExecuteTest(TestBase):
@@ -434,6 +435,67 @@ class ResultProxyTest(TestBase):
         writer.writerow(row)
         assert s.getvalue().strip() == '1,Test'
 
+class AlternateResultProxyTest(TestBase):
+    requires = ('sqlite', )
+
+    @classmethod
+    def setup_class(cls):
+        from sqlalchemy.engine import base, create_engine, default
+        cls.engine = engine = create_engine('sqlite://')
+        m = MetaData()
+        cls.table = t = Table('test', m, 
+            Column('x', Integer, primary_key=True),
+            Column('y', String(50, convert_unicode='force'))
+        )
+        m.create_all(engine)
+        engine.execute(t.insert(), [
+            {'x':i, 'y':"t_%d" % i} for i in xrange(1, 12)
+        ])
+
+    def _test_proxy(self, cls):
+        class ExcCtx(default.DefaultExecutionContext):
+            def get_result_proxy(self):
+                return cls(self)
+        self.engine.dialect.execution_ctx_cls = ExcCtx
+        rows = []
+        r = self.engine.execute(select([self.table]))
+        assert isinstance(r, cls)
+        for i in range(5):
+            rows.append(r.fetchone())
+        eq_(rows, [(i, "t_%d" % i) for i in xrange(1, 6)])
+
+        rows = r.fetchmany(3)
+        eq_(rows, [(i, "t_%d" % i) for i in xrange(6, 9)])
+
+        rows = r.fetchall()
+        eq_(rows, [(i, "t_%d" % i) for i in xrange(9, 12)])
+
+        r = self.engine.execute(select([self.table]))
+        rows = r.fetchmany(None)
+        eq_(rows[0], (1, "t_1"))
+        # number of rows here could be one, or the whole thing
+        assert len(rows) == 1 or len(rows) == 11
+
+        r = self.engine.execute(select([self.table]).limit(1))
+        r.fetchone()
+        eq_(r.fetchone(), None)
+
+        r = self.engine.execute(select([self.table]).limit(5))
+        rows = r.fetchmany(6)
+        eq_(rows, [(i, "t_%d" % i) for i in xrange(1, 6)])
+
+    def test_plain(self):
+        self._test_proxy(base.ResultProxy)
+
+    def test_buffered_row_result_proxy(self):
+        self._test_proxy(base.BufferedRowResultProxy)
+
+    def test_fully_buffered_result_proxy(self):
+        self._test_proxy(base.FullyBufferedResultProxy)
+
+    def test_buffered_column_result_proxy(self):
+        self._test_proxy(base.BufferedColumnResultProxy)
+
 class EngineEventsTest(TestBase):
 
     def _assert_stmts(self, expected, received):