From: Mike Bayer Date: Fri, 4 Feb 2011 23:33:49 +0000 (-0500) Subject: - apply optimizations to alternate row proxies, [ticket:1787] X-Git-Tag: rel_0_7b1~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=637fdd32510c7e04dbf279482057f3fa19c97456;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - apply optimizations to alternate row proxies, [ticket:1787] - add check to fetchmany() for None, don't send argument if not present, helps DBAPIs which don't accept "None" for default (ie. pysqlite, maybe others) - add tests to test_execute to provide 100% coverage for the three alternate result proxy classes --- diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 04f3aab954..bc1c877037 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -127,6 +127,7 @@ from sqlalchemy.engine import base from sqlalchemy import types as sqltypes, util, exc, processors from datetime import datetime import random +import collections from sqlalchemy.util.compat import decimal import re @@ -423,8 +424,8 @@ class ReturningResultProxy(base.FullyBufferedResultProxy): return ret def _buffer_rows(self): - return [tuple(self._returning_params["ret_%d" % i] - for i, c in enumerate(self._returning_params))] + return collections.deque([tuple(self._returning_params["ret_%d" % i] + for i, c in enumerate(self._returning_params))]) class OracleDialect_cx_oracle(OracleDialect): execution_ctx_cls = OracleExecutionContext_cx_oracle diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index ccce58b9c2..b78a305374 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -26,6 +26,7 @@ from sqlalchemy import exc, schema, util, types, log, interfaces, \ event, events from sqlalchemy.sql import expression from sqlalchemy import processors +import collections class Dialect(object): """Define the behavior of a specific database and DB-API combination. @@ -2603,7 +2604,10 @@ class ResultProxy(object): def _fetchmany_impl(self, size=None): try: - return self.cursor.fetchmany(size) + if size is None: + return self.cursor.fetchmany() + else: + return self.cursor.fetchmany(size) except AttributeError: self._non_result() @@ -2756,24 +2760,29 @@ class BufferedRowResultProxy(ResultProxy): 5 : 10, 10 : 20, 20 : 50, - 50 : 100 + 50 : 100, + 100 : 250, + 250 : 500, + 500 : 1000 } def __buffer_rows(self): size = getattr(self, '_bufsize', 1) - self.__rowbuffer = self.cursor.fetchmany(size) + self.__rowbuffer = collections.deque(self.cursor.fetchmany(size)) self._bufsize = self.size_growth.get(size, size) def _fetchone_impl(self): if self.closed: return None - if len(self.__rowbuffer) == 0: + if not self.__rowbuffer: self.__buffer_rows() - if len(self.__rowbuffer) == 0: + if not self.__rowbuffer: return None - return self.__rowbuffer.pop(0) + return self.__rowbuffer.popleft() def _fetchmany_impl(self, size=None): + if size is None: + return self._fetchall_impl() result = [] for x in range(0, size): row = self._fetchone_impl() @@ -2783,8 +2792,9 @@ class BufferedRowResultProxy(ResultProxy): return result def _fetchall_impl(self): - ret = self.__rowbuffer + list(self.cursor.fetchall()) - self.__rowbuffer[:] = [] + self.__rowbuffer.extend(self.cursor.fetchall()) + ret = self.__rowbuffer + self.__rowbuffer = collections.deque() return ret class FullyBufferedResultProxy(ResultProxy): @@ -2800,15 +2810,17 @@ class FullyBufferedResultProxy(ResultProxy): self.__rowbuffer = self._buffer_rows() def _buffer_rows(self): - return self.cursor.fetchall() + return collections.deque(self.cursor.fetchall()) def _fetchone_impl(self): if self.__rowbuffer: - return self.__rowbuffer.pop(0) + return self.__rowbuffer.popleft() else: return None def _fetchmany_impl(self, size=None): + if size is None: + return self._fetchall_impl() result = [] for x in range(0, size): row = self._fetchone_impl() @@ -2819,7 +2831,7 @@ class FullyBufferedResultProxy(ResultProxy): def _fetchall_impl(self): ret = self.__rowbuffer - self.__rowbuffer = [] + self.__rowbuffer = collections.deque() return ret class BufferedColumnRow(RowProxy): diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 318bc15d50..37afbd4a3e 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -8,6 +8,7 @@ import sqlalchemy as tsa from test.lib import TestBase, testing, engines import logging from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam +from sqlalchemy.engine import base, default users, metadata = None, None class ExecuteTest(TestBase): @@ -434,6 +435,67 @@ class ResultProxyTest(TestBase): writer.writerow(row) assert s.getvalue().strip() == '1,Test' +class AlternateResultProxyTest(TestBase): + requires = ('sqlite', ) + + @classmethod + def setup_class(cls): + from sqlalchemy.engine import base, create_engine, default + cls.engine = engine = create_engine('sqlite://') + m = MetaData() + cls.table = t = Table('test', m, + Column('x', Integer, primary_key=True), + Column('y', String(50, convert_unicode='force')) + ) + m.create_all(engine) + engine.execute(t.insert(), [ + {'x':i, 'y':"t_%d" % i} for i in xrange(1, 12) + ]) + + def _test_proxy(self, cls): + class ExcCtx(default.DefaultExecutionContext): + def get_result_proxy(self): + return cls(self) + self.engine.dialect.execution_ctx_cls = ExcCtx + rows = [] + r = self.engine.execute(select([self.table])) + assert isinstance(r, cls) + for i in range(5): + rows.append(r.fetchone()) + eq_(rows, [(i, "t_%d" % i) for i in xrange(1, 6)]) + + rows = r.fetchmany(3) + eq_(rows, [(i, "t_%d" % i) for i in xrange(6, 9)]) + + rows = r.fetchall() + eq_(rows, [(i, "t_%d" % i) for i in xrange(9, 12)]) + + r = self.engine.execute(select([self.table])) + rows = r.fetchmany(None) + eq_(rows[0], (1, "t_1")) + # number of rows here could be one, or the whole thing + assert len(rows) == 1 or len(rows) == 11 + + r = self.engine.execute(select([self.table]).limit(1)) + r.fetchone() + eq_(r.fetchone(), None) + + r = self.engine.execute(select([self.table]).limit(5)) + rows = r.fetchmany(6) + eq_(rows, [(i, "t_%d" % i) for i in xrange(1, 6)]) + + def test_plain(self): + self._test_proxy(base.ResultProxy) + + def test_buffered_row_result_proxy(self): + self._test_proxy(base.BufferedRowResultProxy) + + def test_fully_buffered_result_proxy(self): + self._test_proxy(base.FullyBufferedResultProxy) + + def test_buffered_column_result_proxy(self): + self._test_proxy(base.BufferedColumnResultProxy) + class EngineEventsTest(TestBase): def _assert_stmts(self, expected, received):