]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for server side cursors to mysqldb and pymysql
authorRoman Podoliaka <roman.podoliaka@gmail.com>
Thu, 3 Nov 2016 22:31:05 +0000 (00:31 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Nov 2016 17:09:27 +0000 (12:09 -0500)
This allows to skip buffering of the results on the client side, e.g.
the following snippet:

    table = sa.Table(
        'testtbl', sa.MetaData(),
        sa.Column('id', sa.Integer, primary_key=True),
        sa.Column('a', sa.Integer),
        sa.Column('b', sa.String(512))
    )
    table.create(eng, checkfirst=True)

    with eng.connect() as conn:
        result = conn.execute(table.select().limit(1)).fetchone()
        if result is None:
            for _ in range(1000):
                conn.execute(
                    table.insert(),
                    [{'a': random.randint(1, 100000),
                      'b': ''.join(random.choice(string.ascii_letters) for _ in range(100))}
                      for _ in range(1000)]
                )

    with eng.connect() as conn:
        for row in conn.execution_options(stream_results=True).execute(table.select()):
            pass

now uses ~23 MB of memory instead of ~327 MB on CPython 3.5.2 and
PyMySQL 0.7.9.

psycopg2 implementation and execution options (stream_results,
server_side_cursors) are reused.

Change-Id: I4dc23ce3094f027bdff51b896b050361991c62e2

doc/build/changelog/changelog_11.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/dialects/mysql/pymysql.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_results.py
test/dialect/postgresql/test_query.py

index 4c8bdc7d96ba58f8c63215341abdba3c73e29e66..a25520ba3e28013c9b9ba260bd640534b34e334c 100644 (file)
         strategy would cause backrefs and/or back_populates options to be
         ignored.
 
+    .. change::
+        :tags: feature, mysql
+
+        Added support for server side cursors to the mysqlclient and
+        pymysql dialects.   This feature is available via the
+        :paramref:`.Connection.execution_options.stream_results` flag as well
+        as the ``server_side_cursors=True`` dialect argument in the
+        same way that it has been for psycopg2 on Postgresql.  Pull request
+        courtesy Roman Podoliaka.
+
     .. change::
         :tags: bug, mysql
         :tickets: 3841
index e7e5338905e77fa8758f6aa53a6b3bc47ecb136a..449fffabab1aa64552b6f4b841a71cd622ed386f 100644 (file)
@@ -177,6 +177,22 @@ multi-column key for some storage engines::
         Column('id', Integer, primary_key=True)
        )
 
+.. _mysql_ss_cursors:
+
+Server Side Cursors
+-------------------
+
+Server-side cursor support is available for the MySQLdb and PyMySQL dialects.
+From a MySQL point of view this means that the ``MySQLdb.cursors.SSCursor`` or
+``pymysql.cursors.SSCursor`` class is used when building up the cursor which
+will receive results.  The most typical way of invoking this feature is via the
+:paramref:`.Connection.execution_options.stream_results` connection execution
+option.   Server side cursors can also be enabled for all SELECT statements
+unconditionally by passing ``server_side_cursors=True`` to
+:func:`.create_engine`.
+
+.. versionadded:: 1.1.4 - added server-side cursor support.
+
 .. _mysql_unicode:
 
 Unicode
@@ -743,6 +759,12 @@ class MySQLExecutionContext(default.DefaultExecutionContext):
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_RE.match(statement)
 
+    def create_server_side_cursor(self):
+        if self.dialect.supports_server_side_cursors:
+            return self._dbapi_connection.cursor(self.dialect._sscursor)
+        else:
+            raise NotImplementedError()
+
 
 class MySQLCompiler(compiler.SQLCompiler):
 
index aa8377b27c3a6d3bf5b9969c9359d0f90c24eb43..568c05f62b38f6b1262f9db33876f1cd2df4d304 100644 (file)
@@ -38,6 +38,11 @@ using a URL like the following::
 
     mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
 
+Server Side Cursors
+-------------------
+
+The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
+
 """
 
 from .base import (MySQLDialect, MySQLExecutionContext,
@@ -87,6 +92,19 @@ class MySQLDialect_mysqldb(MySQLDialect):
     statement_compiler = MySQLCompiler_mysqldb
     preparer = MySQLIdentifierPreparer_mysqldb
 
+    def __init__(self, server_side_cursors=False, **kwargs):
+        super(MySQLDialect_mysqldb, self).__init__(**kwargs)
+        self.server_side_cursors = server_side_cursors
+
+    @util.langhelpers.memoized_property
+    def supports_server_side_cursors(self):
+        try:
+            cursors = __import__('MySQLdb.cursors').cursors
+            self._sscursor = cursors.SSCursor
+            return True
+        except (ImportError, AttributeError):
+            return False
+
     @classmethod
     def dbapi(cls):
         return __import__('MySQLdb')
index 3c493fbfc25872cecd1cbf34b25b47b5ecda15e0..e29c17d8b6298f21cbd161f2a7a0bc8884d466bd 100644 (file)
@@ -30,7 +30,7 @@ to the pymysql driver as well.
 """
 
 from .mysqldb import MySQLDialect_mysqldb
-from ...util import py3k
+from ...util import langhelpers, py3k
 
 
 class MySQLDialect_pymysql(MySQLDialect_mysqldb):
@@ -44,6 +44,19 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb):
     supports_unicode_statements = True
     supports_unicode_binds = True
 
+    def __init__(self, server_side_cursors=False, **kwargs):
+        super(MySQLDialect_pymysql, self).__init__(**kwargs)
+        self.server_side_cursors = server_side_cursors
+
+    @langhelpers.memoized_property
+    def supports_server_side_cursors(self):
+        try:
+            cursors = __import__('pymysql.cursors').cursors
+            self._sscursor = cursors.SSCursor
+            return True
+        except (ImportError, AttributeError):
+            return False
+
     @classmethod
     def dbapi(cls):
         return __import__('pymysql')
index 8488da816a2f98ed5ce5ebec812e425486b048cc..27a1ec0990f0d34da2f1dc723895901d56808588 100644 (file)
@@ -28,7 +28,8 @@ psycopg2-specific keyword arguments which are accepted by
   :class:`~sqlalchemy.engine.ResultProxy` uses special row-buffering
   behavior when this feature is enabled, such that groups of 100 rows at a
   time are fetched over the wire to reduce conversational overhead.
-  Note that the ``stream_results=True`` execution option is a more targeted
+  Note that the :paramref:`.Connection.execution_options.stream_results`
+  execution option is a more targeted
   way of enabling this mode on a per-execution basis.
 * ``use_native_unicode``: Enable the usage of Psycopg2 "native unicode" mode
   per connection.  True by default.
@@ -422,53 +423,24 @@ class _PGUUID(UUID):
                 return value
             return process
 
-# When we're handed literal SQL, ensure it's a SELECT query. Since
-# 8.3, combining cursors and "FOR UPDATE" has been fine.
-SERVER_SIDE_CURSOR_RE = re.compile(
-    r'\s*SELECT',
-    re.I | re.UNICODE)
 
 _server_side_id = util.counter()
 
 
 class PGExecutionContext_psycopg2(PGExecutionContext):
-    def create_cursor(self):
-        # TODO: coverage for server side cursors + select.for_update()
-
-        if self.dialect.server_side_cursors:
-            is_server_side = \
-                self.execution_options.get('stream_results', True) and (
-                    (self.compiled and isinstance(self.compiled.statement,
-                                                  expression.Selectable)
-                     or
-                     (
-                        (not self.compiled or
-                         isinstance(self.compiled.statement,
-                                    expression.TextClause))
-                        and self.statement and SERVER_SIDE_CURSOR_RE.match(
-                            self.statement))
-                     )
-                )
-        else:
-            is_server_side = \
-                self.execution_options.get('stream_results', False)
-
-        self.__is_server_side = is_server_side
-        if is_server_side:
-            # use server-side cursors:
-            # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
-            ident = "c_%s_%s" % (hex(id(self))[2:],
-                                 hex(_server_side_id())[2:])
-            return self._dbapi_connection.cursor(ident)
-        else:
-            return self._dbapi_connection.cursor()
+    def create_server_side_cursor(self):
+        # use server-side cursors:
+        # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html
+        ident = "c_%s_%s" % (hex(id(self))[2:],
+                             hex(_server_side_id())[2:])
+        return self._dbapi_connection.cursor(ident)
 
     def get_result_proxy(self):
         # TODO: ouch
         if logger.isEnabledFor(logging.INFO):
             self._log_notices(self.cursor)
 
-        if self.__is_server_side:
+        if self._is_server_side:
             return _result.BufferedRowResultProxy(self)
         else:
             return _result.ResultProxy(self)
@@ -502,6 +474,8 @@ class PGDialect_psycopg2(PGDialect):
     if util.py2k:
         supports_unicode_statements = False
 
+    supports_server_side_cursors = True
+
     default_paramstyle = 'pyformat'
     # set to true based on psycopg2 version
     supports_sane_multi_rowcount = False
index 1d23c66b3c25a53d6e54b156191a2e0ae1202479..f071abaa15863e70901e730f9c46855728811eec 100644 (file)
@@ -295,7 +295,7 @@ class Connection(Connectable):
           Indicate to the dialect that results should be
           "streamed" and not pre-buffered, if possible.  This is a limitation
           of many DBAPIs.  The flag is currently understood only by the
-          psycopg2 dialect.
+          psycopg2, mysqldb and pymysql dialects.
 
         :param schema_translate_map: Available on: Connection, Engine.
           A dictionary mapping schema names to schema names, that will be
index 3ee240383c4c205650d0ab431755d9f1d8d546bc..719178f7ef4a3d8503123a9bc0ab8e22ca07d2db 100644 (file)
@@ -27,6 +27,11 @@ AUTOCOMMIT_REGEXP = re.compile(
     r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)',
     re.I | re.UNICODE)
 
+# When we're handed literal SQL, ensure it's a SELECT query
+SERVER_SIDE_CURSOR_RE = re.compile(
+    r'\s*SELECT',
+    re.I | re.UNICODE)
+
 
 class DefaultDialect(interfaces.Dialect):
     """Default implementation of Dialect"""
@@ -108,6 +113,8 @@ class DefaultDialect(interfaces.Dialect):
     supports_empty_insert = True
     supports_multivalues_insert = False
 
+    supports_server_side_cursors = False
+
     server_version_info = None
 
     construct_arguments = None
@@ -780,8 +787,40 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_REGEXP.match(statement)
 
+    def _use_server_side_cursor(self):
+        if not self.dialect.supports_server_side_cursors:
+            return False
+
+        if self.dialect.server_side_cursors:
+            use_server_side = \
+                self.execution_options.get('stream_results', True) and (
+                    (self.compiled and isinstance(self.compiled.statement,
+                                                  expression.Selectable)
+                     or
+                     (
+                        (not self.compiled or
+                         isinstance(self.compiled.statement,
+                                    expression.TextClause))
+                        and self.statement and SERVER_SIDE_CURSOR_RE.match(
+                            self.statement))
+                     )
+                )
+        else:
+            use_server_side = \
+                self.execution_options.get('stream_results', False)
+
+        return use_server_side
+
     def create_cursor(self):
-        return self._dbapi_connection.cursor()
+        if self._use_server_side_cursor():
+            self._is_server_side = True
+            return self.create_server_side_cursor()
+        else:
+            self._is_server_side = False
+            return self._dbapi_connection.cursor()
+
+    def create_server_side_cursor(self):
+        raise NotImplementedError()
 
     def pre_exec(self):
         pass
@@ -831,7 +870,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         pass
 
     def get_result_proxy(self):
-        return result.ResultProxy(self)
+        if self._is_server_side:
+            return result.BufferedRowResultProxy(self)
+        else:
+            return result.ResultProxy(self)
 
     @property
     def rowcount(self):
index 23d33b0d14f96fb565fc2f818fb4f96cdce7a05d..139b61afbb865eac2e0c7e97501a5a3c3fb9ce45 100644 (file)
@@ -751,7 +751,9 @@ class Query(object):
             :meth:`~sqlalchemy.orm.query.Query.yield_per` will set the
             ``stream_results`` execution option to True, currently
             this is only understood by
-            :mod:`~sqlalchemy.dialects.postgresql.psycopg2` dialect
+            :mod:`~sqlalchemy.dialects.postgresql.psycopg2`,
+            :mod:`~sqlalchemy.dialects.mysql.mysqldb` and
+            :mod:`~sqlalchemy.dialects.mysql.pymysql` dialects
             which will stream results using server side cursors
             instead of pre-buffer all rows for this query. Other
             DBAPIs **pre-buffer all rows** before making them
index af148a3b91e59f98b1f442bc5767db8b8fa4e5c3..b001aaf755d46720730ab81178d90eb09b6662f9 100644 (file)
@@ -287,6 +287,14 @@ class SuiteRequirements(Requirements):
 
         return exclusions.closed()
 
+    @property
+    def server_side_cursors(self):
+        """Target dialect must support server side cursors."""
+
+        return exclusions.only_if([
+            lambda config: config.db.dialect.supports_server_side_cursors
+        ], "no server side cursors support")
+
     @property
     def sequences(self):
         """Target database must support SEQUENCEs."""
index f40d9a04c90f9c7d6363a00d7a5c0906c32039b3..98ddc7efcc8e6479be5bc52c8ea51cac9aa4f786 100644 (file)
@@ -3,8 +3,9 @@ from ..config import requirements
 from .. import exclusions
 from ..assertions import eq_
 from .. import engines
+from ... import testing
 
-from sqlalchemy import Integer, String, select, util, sql, DateTime
+from sqlalchemy import Integer, String, select, util, sql, DateTime, text, func
 import datetime
 from ..schema import Table, Column
 
@@ -218,3 +219,149 @@ class PercentSchemaNamesTest(fixtures.TablesTest):
             ),
             [(5, 15), (7, 15), (9, 15), (11, 15)]
         )
+
+
+class ServerSideCursorsTest(fixtures.TestBase, testing.AssertsExecutionResults):
+
+    __requires__ = ('server_side_cursors', )
+
+    __backend__ = True
+
+    def _is_server_side(self, cursor):
+        if self.engine.url.drivername == 'postgresql':
+            return cursor.name
+        elif self.engine.url.drivername == 'mysql':
+            sscursor = __import__('MySQLdb.cursors').cursors.SSCursor
+            return isinstance(cursor, sscursor)
+        elif self.engine.url.drivername == 'mysql+pymysql':
+            sscursor = __import__('pymysql.cursors').cursors.SSCursor
+            return isinstance(cursor, sscursor)
+        else:
+            return False
+
+    def _fixture(self, server_side_cursors):
+        self.engine = engines.testing_engine(
+            options={'server_side_cursors': server_side_cursors}
+        )
+        return self.engine
+
+    def tearDown(self):
+        engines.testing_reaper.close_all()
+        self.engine.dispose()
+
+    def test_global_string(self):
+        engine = self._fixture(True)
+        result = engine.execute('select 1')
+        assert self._is_server_side(result.cursor)
+
+    def test_global_text(self):
+        engine = self._fixture(True)
+        result = engine.execute(text('select 1'))
+        assert self._is_server_side(result.cursor)
+
+    def test_global_expr(self):
+        engine = self._fixture(True)
+        result = engine.execute(select([1]))
+        assert self._is_server_side(result.cursor)
+
+    def test_global_off_explicit(self):
+        engine = self._fixture(False)
+        result = engine.execute(text('select 1'))
+
+        # It should be off globally ...
+
+        assert not self._is_server_side(result.cursor)
+
+    def test_stmt_option(self):
+        engine = self._fixture(False)
+
+        s = select([1]).execution_options(stream_results=True)
+        result = engine.execute(s)
+
+        # ... but enabled for this one.
+
+        assert self._is_server_side(result.cursor)
+
+    def test_conn_option(self):
+        engine = self._fixture(False)
+
+        # and this one
+        result = \
+            engine.connect().execution_options(stream_results=True).\
+            execute('select 1'
+                    )
+        assert self._is_server_side(result.cursor)
+
+    def test_stmt_enabled_conn_option_disabled(self):
+        engine = self._fixture(False)
+
+        s = select([1]).execution_options(stream_results=True)
+
+        # not this one
+        result = \
+            engine.connect().execution_options(stream_results=False).\
+            execute(s)
+        assert not self._is_server_side(result.cursor)
+
+    def test_stmt_option_disabled(self):
+        engine = self._fixture(True)
+        s = select([1]).execution_options(stream_results=False)
+        result = engine.execute(s)
+        assert not self._is_server_side(result.cursor)
+
+    def test_aliases_and_ss(self):
+        engine = self._fixture(False)
+        s1 = select([1]).execution_options(stream_results=True).alias()
+        result = engine.execute(s1)
+        assert self._is_server_side(result.cursor)
+
+        # s1's options shouldn't affect s2 when s2 is used as a
+        # from_obj.
+        s2 = select([1], from_obj=s1)
+        result = engine.execute(s2)
+        assert not self._is_server_side(result.cursor)
+
+    def test_for_update_expr(self):
+        engine = self._fixture(True)
+        s1 = select([1], for_update=True)
+        result = engine.execute(s1)
+        assert self._is_server_side(result.cursor)
+
+    def test_for_update_string(self):
+        engine = self._fixture(True)
+        result = engine.execute('SELECT 1 FOR UPDATE')
+        assert self._is_server_side(result.cursor)
+
+    def test_text_no_ss(self):
+        engine = self._fixture(False)
+        s = text('select 42')
+        result = engine.execute(s)
+        assert not self._is_server_side(result.cursor)
+
+    def test_text_ss_option(self):
+        engine = self._fixture(False)
+        s = text('select 42').execution_options(stream_results=True)
+        result = engine.execute(s)
+        assert self._is_server_side(result.cursor)
+
+    @testing.provide_metadata
+    def test_roundtrip(self):
+        md = self.metadata
+
+        engine = self._fixture(True)
+        test_table = Table('test_table', md,
+                           Column('id', Integer, primary_key=True),
+                           Column('data', String(50)))
+        test_table.create(checkfirst=True)
+        test_table.insert().execute(data='data1')
+        test_table.insert().execute(data='data2')
+        eq_(test_table.select().execute().fetchall(), [(1, 'data1'
+                                                        ), (2, 'data2')])
+        test_table.update().where(
+            test_table.c.id == 2).values(
+            data=test_table.c.data +
+            ' updated').execute()
+        eq_(test_table.select().execute().fetchall(),
+            [(1, 'data1'), (2, 'data2 updated')])
+        test_table.delete().execute()
+        eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
index b8129f1e36a94cac5e5dc31cbbc3942858e23756..47a12afecece105143f6f04874dea0d60295727e 100644 (file)
@@ -595,139 +595,6 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                 (33, 'd4')])
 
 
-class ServerSideCursorsTest(fixtures.TestBase, AssertsExecutionResults):
-
-    __requires__ = 'psycopg2_compatibility',
-
-    def _fixture(self, server_side_cursors):
-        self.engine = engines.testing_engine(
-            options={'server_side_cursors': server_side_cursors}
-        )
-        return self.engine
-
-    def tearDown(self):
-        engines.testing_reaper.close_all()
-        self.engine.dispose()
-
-    def test_global_string(self):
-        engine = self._fixture(True)
-        result = engine.execute('select 1')
-        assert result.cursor.name
-
-    def test_global_text(self):
-        engine = self._fixture(True)
-        result = engine.execute(text('select 1'))
-        assert result.cursor.name
-
-    def test_global_expr(self):
-        engine = self._fixture(True)
-        result = engine.execute(select([1]))
-        assert result.cursor.name
-
-    def test_global_off_explicit(self):
-        engine = self._fixture(False)
-        result = engine.execute(text('select 1'))
-
-        # It should be off globally ...
-
-        assert not result.cursor.name
-
-    def test_stmt_option(self):
-        engine = self._fixture(False)
-
-        s = select([1]).execution_options(stream_results=True)
-        result = engine.execute(s)
-
-        # ... but enabled for this one.
-
-        assert result.cursor.name
-
-    def test_conn_option(self):
-        engine = self._fixture(False)
-
-        # and this one
-        result = \
-            engine.connect().execution_options(stream_results=True).\
-            execute('select 1'
-                    )
-        assert result.cursor.name
-
-    def test_stmt_enabled_conn_option_disabled(self):
-        engine = self._fixture(False)
-
-        s = select([1]).execution_options(stream_results=True)
-
-        # not this one
-        result = \
-            engine.connect().execution_options(stream_results=False).\
-            execute(s)
-        assert not result.cursor.name
-
-    def test_stmt_option_disabled(self):
-        engine = self._fixture(True)
-        s = select([1]).execution_options(stream_results=False)
-        result = engine.execute(s)
-        assert not result.cursor.name
-
-    def test_aliases_and_ss(self):
-        engine = self._fixture(False)
-        s1 = select([1]).execution_options(stream_results=True).alias()
-        result = engine.execute(s1)
-        assert result.cursor.name
-
-        # s1's options shouldn't affect s2 when s2 is used as a
-        # from_obj.
-        s2 = select([1], from_obj=s1)
-        result = engine.execute(s2)
-        assert not result.cursor.name
-
-    def test_for_update_expr(self):
-        engine = self._fixture(True)
-        s1 = select([1], for_update=True)
-        result = engine.execute(s1)
-        assert result.cursor.name
-
-    def test_for_update_string(self):
-        engine = self._fixture(True)
-        result = engine.execute('SELECT 1 FOR UPDATE')
-        assert result.cursor.name
-
-    def test_text_no_ss(self):
-        engine = self._fixture(False)
-        s = text('select 42')
-        result = engine.execute(s)
-        assert not result.cursor.name
-
-    def test_text_ss_option(self):
-        engine = self._fixture(False)
-        s = text('select 42').execution_options(stream_results=True)
-        result = engine.execute(s)
-        assert result.cursor.name
-
-    @testing.provide_metadata
-    def test_roundtrip(self):
-        md = self.metadata
-
-        engine = self._fixture(True)
-        test_table = Table('test_table', md,
-                           Column('id', Integer, primary_key=True),
-                           Column('data', String(50)))
-        test_table.create(checkfirst=True)
-        test_table.insert().execute(data='data1')
-        nextid = engine.execute(Sequence('test_table_id_seq'))
-        test_table.insert().execute(id=nextid, data='data2')
-        eq_(test_table.select().execute().fetchall(), [(1, 'data1'
-                                                        ), (2, 'data2')])
-        test_table.update().where(
-            test_table.c.id == 2).values(
-            data=test_table.c.data +
-            ' updated').execute()
-        eq_(test_table.select().execute().fetchall(),
-            [(1, 'data1'), (2, 'data2 updated')])
-        test_table.delete().execute()
-        eq_(select([func.count('*')]).select_from(test_table).scalar(), 0)
-
-
 class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __only_on__ = 'postgresql >= 8.3'