]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add aiomysql support
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Dec 2020 21:26:44 +0000 (16:26 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Dec 2020 22:11:46 +0000 (17:11 -0500)
This is a re-gerrit of the original gerrit
merged in Ia8ad3efe3b50ce75a3bed1e020e1b82acb5f2eda
Reverted due to ongoing issues.

Fixes: #5747
Change-Id: I2b57e76b817eed8f89457a2146b523a1cab656a8

15 files changed:
doc/build/changelog/unreleased_14/5747.rst [new file with mode: 0644]
doc/build/dialects/mysql.rst
lib/sqlalchemy/dialects/mysql/__init__.py
lib/sqlalchemy/dialects/mysql/aiomysql.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/mysqldb.py
lib/sqlalchemy/ext/asyncio/result.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/warnings.py
setup.cfg
test/engine/test_reconnect.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_session_py3k.py
test/requirements.py
test/sql/test_resultset.py
tox.ini

diff --git a/doc/build/changelog/unreleased_14/5747.rst b/doc/build/changelog/unreleased_14/5747.rst
new file mode 100644 (file)
index 0000000..47cf648
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: feature, mysql
+    :tickets: 5747
+
+    Added support for the aiomysql driver when using the asyncio SQLAlchemy
+    extension.
+
+    .. seealso::
+
+      :ref:`aiomysql`
\ No newline at end of file
index 1f2236155b4cc22953ad8ed475894be4bd2890d9..c0bfa7bc6220ef55582e7dbe91d873782e70c42f 100644 (file)
@@ -181,6 +181,13 @@ MySQL-Connector
 
 .. automodule:: sqlalchemy.dialects.mysql.mysqlconnector
 
+.. _aiomysql:
+
+aiomysql
+--------
+
+.. automodule:: sqlalchemy.dialects.mysql.aiomysql
+
 cymysql
 -------
 
index 9fdc96f6fb7b0e6441f52d7592e7018974a3edf5..c6781c1685d67d7f7d35e2d07a92245aad0dd1c6 100644 (file)
@@ -49,6 +49,10 @@ from .base import VARCHAR
 from .base import YEAR
 from .dml import Insert
 from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+    from . import aiomysql  # noqa
 
 
 # default dialect
diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py
new file mode 100644 (file)
index 0000000..f560ece
--- /dev/null
@@ -0,0 +1,278 @@
+# mysql/aiomysql.py
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+aiomysql
+    :name: aiomysql
+    :dbapi: aiomysql
+    :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
+    :url: https://github.com/aio-libs/aiomysql
+
+The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
+
+Using a special asyncio mediation layer, the aiomysql dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+    from sqlalchemy.ext.asyncio import create_async_engine
+    engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname")
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+
+"""  # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiomysql_cursor:
+    server_side = False
+
+    def __init__(self, adapt_connection):
+        self._adapt_connection = adapt_connection
+        self._connection = adapt_connection._connection
+        self.await_ = adapt_connection.await_
+
+        cursor = self._connection.cursor()
+
+        # see https://github.com/aio-libs/aiomysql/issues/543
+        self._cursor = self.await_(cursor.__aenter__())
+        self._rows = []
+
+    @property
+    def description(self):
+        return self._cursor.description
+
+    @property
+    def rowcount(self):
+        return self._cursor.rowcount
+
+    @property
+    def arraysize(self):
+        return self._cursor.arraysize
+
+    @arraysize.setter
+    def arraysize(self, value):
+        self._cursor.arraysize = value
+
+    @property
+    def lastrowid(self):
+        return self._cursor.lastrowid
+
+    def close(self):
+        self._rows[:] = []
+
+    def execute(self, operation, parameters=None):
+        if parameters is None:
+            result = self.await_(self._cursor.execute(operation))
+        else:
+            result = self.await_(self._cursor.execute(operation, parameters))
+
+        if not self.server_side:
+            # aiomysql has a "fake" async result, so we have to pull it out
+            # of that here since our default result is not async.
+            # we could just as easily grab "_rows" here and be done with it
+            # but this is safer.
+            self._rows = list(self.await_(self._cursor.fetchall()))
+        return result
+
+    def executemany(self, operation, seq_of_parameters):
+        return self.await_(
+            self._cursor.executemany(operation, seq_of_parameters)
+        )
+
+    def setinputsizes(self, *inputsizes):
+        pass
+
+    def __iter__(self):
+        while self._rows:
+            yield self._rows.pop(0)
+
+    def fetchone(self):
+        if self._rows:
+            return self._rows.pop(0)
+        else:
+            return None
+
+    def fetchmany(self, size=None):
+        if size is None:
+            size = self.arraysize
+
+        retval = self._rows[0:size]
+        self._rows[:] = self._rows[size:]
+        return retval
+
+    def fetchall(self):
+        retval = self._rows[:]
+        self._rows[:] = []
+        return retval
+
+
+class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
+
+    server_side = True
+
+    def __init__(self, adapt_connection):
+        self._adapt_connection = adapt_connection
+        self._connection = adapt_connection._connection
+        self.await_ = adapt_connection.await_
+
+        cursor = self._connection.cursor(
+            adapt_connection.dbapi.aiomysql.SSCursor
+        )
+
+        self._cursor = self.await_(cursor.__aenter__())
+
+    def close(self):
+        if self._cursor is not None:
+            self.await_(self._cursor.close())
+            self._cursor = None
+
+    def fetchone(self):
+        return self.await_(self._cursor.fetchone())
+
+    def fetchmany(self, size=None):
+        return self.await_(self._cursor.fetchmany(size=size))
+
+    def fetchall(self):
+        return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiomysql_connection:
+    await_ = staticmethod(await_only)
+
+    def __init__(self, dbapi, connection):
+        self.dbapi = dbapi
+        self._connection = connection
+
+    def ping(self, reconnect):
+        return self.await_(self._connection.ping(reconnect))
+
+    def character_set_name(self):
+        return self._connection.character_set_name()
+
+    def autocommit(self, value):
+        self.await_(self._connection.autocommit(value))
+
+    def cursor(self, server_side=False):
+        if server_side:
+            return AsyncAdapt_aiomysql_ss_cursor(self)
+        else:
+            return AsyncAdapt_aiomysql_cursor(self)
+
+    def rollback(self):
+        self.await_(self._connection.rollback())
+
+    def commit(self):
+        self.await_(self._connection.commit())
+
+    def close(self):
+        # it's not awaitable.
+        self._connection.close()
+
+
+class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
+    __slots__ = ()
+
+    await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiomysql_dbapi:
+    def __init__(self, aiomysql, pymysql):
+        self.aiomysql = aiomysql
+        self.pymysql = pymysql
+        self.paramstyle = "format"
+        self._init_dbapi_attributes()
+
+    def _init_dbapi_attributes(self):
+        for name in (
+            "Warning",
+            "Error",
+            "InterfaceError",
+            "DataError",
+            "DatabaseError",
+            "OperationalError",
+            "InterfaceError",
+            "IntegrityError",
+            "ProgrammingError",
+            "InternalError",
+            "NotSupportedError",
+        ):
+            setattr(self, name, getattr(self.aiomysql, name))
+
+        for name in (
+            "NUMBER",
+            "STRING",
+            "DATETIME",
+            "BINARY",
+            "TIMESTAMP",
+            "Binary",
+        ):
+            setattr(self, name, getattr(self.pymysql, name))
+
+    def connect(self, *arg, **kw):
+        async_fallback = kw.pop("async_fallback", False)
+
+        if async_fallback:
+            return AsyncAdaptFallback_aiomysql_connection(
+                self,
+                await_fallback(self.aiomysql.connect(*arg, **kw)),
+            )
+        else:
+            return AsyncAdapt_aiomysql_connection(
+                self,
+                await_only(self.aiomysql.connect(*arg, **kw)),
+            )
+
+
+class MySQLDialect_aiomysql(MySQLDialect_pymysql):
+    driver = "aiomysql"
+
+    supports_server_side_cursors = True
+    _sscursor = AsyncAdapt_aiomysql_ss_cursor
+
+    @classmethod
+    def dbapi(cls):
+        return AsyncAdapt_aiomysql_dbapi(
+            __import__("aiomysql"), __import__("pymysql")
+        )
+
+    @classmethod
+    def get_pool_class(self, url):
+        return pool.AsyncAdaptedQueuePool
+
+    def create_connect_args(self, url):
+        args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url)
+        if "passwd" in kw:
+            kw["password"] = kw.pop("passwd")
+        return args, kw
+
+    def is_disconnect(self, e, connection, cursor):
+        if super(MySQLDialect_aiomysql, self).is_disconnect(
+            e, connection, cursor
+        ):
+            return True
+        else:
+            str_e = str(e).lower()
+            return "not connected" in str_e
+
+    def _found_rows_client_flag(self):
+        from pymysql.constants import CLIENT
+
+        return CLIENT.FOUND_ROWS
+
+
+dialect = MySQLDialect_aiomysql
index b20e061fb50e50fa8a2841823986e9d54da016f0..605407f46363409825b64654ab25aa4f2777fda2 100644 (file)
@@ -211,16 +211,25 @@ class MySQLDialect_mysqldb(MySQLDialect):
         # FOUND_ROWS must be set in CLIENT_FLAGS to enable
         # supports_sane_rowcount.
         client_flag = opts.get("client_flag", 0)
+
+        client_flag_found_rows = self._found_rows_client_flag()
+        if client_flag_found_rows is not None:
+            client_flag |= client_flag_found_rows
+            opts["client_flag"] = client_flag
+        return [[], opts]
+
+    def _found_rows_client_flag(self):
         if self.dbapi is not None:
             try:
                 CLIENT_FLAGS = __import__(
                     self.dbapi.__name__ + ".constants.CLIENT"
                 ).constants.CLIENT
-                client_flag |= CLIENT_FLAGS.FOUND_ROWS
             except (AttributeError, ImportError):
-                self.supports_sane_rowcount = False
-            opts["client_flag"] = client_flag
-        return [[], opts]
+                return None
+            else:
+                return CLIENT_FLAGS.FOUND_ROWS
+        else:
+            return None
 
     def _extract_error_code(self, exception):
         return exception.args[0]
index 7f8a707d5283e0ba97ea1a5fdcf155554b518ae7..9c7e0420fb46d63cc5cdd94312d9ecab66bb0477 100644 (file)
@@ -17,7 +17,14 @@ if util.TYPE_CHECKING:
     from ...engine.result import Row
 
 
-class AsyncResult(FilterResult):
+class AsyncCommon(FilterResult):
+    async def close(self):
+        """Close this result."""
+
+        await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
     """An asyncio wrapper around a :class:`_result.Result` object.
 
     The :class:`_asyncio.AsyncResult` only applies to statement executions that
@@ -370,7 +377,7 @@ class AsyncResult(FilterResult):
         return AsyncMappingResult(self._real_result)
 
 
-class AsyncScalarResult(FilterResult):
+class AsyncScalarResult(AsyncCommon):
     """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
     rather than :class:`_row.Row` values.
 
@@ -500,7 +507,7 @@ class AsyncScalarResult(FilterResult):
         return await greenlet_spawn(self._only_one_row, True, True, False)
 
 
-class AsyncMappingResult(FilterResult):
+class AsyncMappingResult(AsyncCommon):
     """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
     rather than :class:`_engine.Row` values.
 
index 9484d41d09c4639cbca7aa4ea86cd483dcee4b15..f31c7c13775968167a518c522511db0051e268fc 100644 (file)
@@ -114,9 +114,8 @@ class RowFetchTest(fixtures.TablesTest):
 class PercentSchemaNamesTest(fixtures.TablesTest):
     """tests using percent signs, spaces in table and column names.
 
-    This is a very fringe use case, doesn't work for MySQL
-    or PostgreSQL.  the requirement, "percent_schema_names",
-    is marked "skip" by default.
+    This didn't work for PostgreSQL / MySQL drivers for a long time
+    but is now supported.
 
     """
 
@@ -233,6 +232,8 @@ class ServerSideCursorsTest(
         elif self.engine.dialect.driver == "pymysql":
             sscursor = __import__("pymysql.cursors").cursors.SSCursor
             return isinstance(cursor, sscursor)
+        elif self.engine.dialect.driver == "aiomysql":
+            return cursor.server_side
         elif self.engine.dialect.driver == "mysqldb":
             sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
             return isinstance(cursor, sscursor)
index b230bad6f09dd4d6861e20bd176ef44e0be45c63..34a968aff13495e198ae520f6bf1f57a965184ea 100644 (file)
@@ -30,6 +30,11 @@ def setup_filters():
     warnings.filterwarnings(
         "ignore", category=DeprecationWarning, message=".*inspect.get.*argspec"
     )
+    warnings.filterwarnings(
+        "ignore",
+        category=DeprecationWarning,
+        message="The loop argument is deprecated",
+    )
 
     # ignore things that are deprecated *as of* 2.0 :)
     warnings.filterwarnings(
index 1912fd3cd6fe4b86a6de23486477696d66a60a9d..46fe781044233673ebf34c514c1b41229355cdae 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -65,6 +65,7 @@ postgresql_asyncpg =
 postgresql_psycopg2binary = psycopg2-binary
 postgresql_psycopg2cffi = psycopg2cffi
 pymysql = pymysql
+aiomysql = aiomysql
 
 [egg_info]
 tag_build = dev
@@ -124,6 +125,7 @@ pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test
 postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
 mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
 mariadb = mariadb://scott:tiger@127.0.0.1:3306/test
 mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
 mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008
index 0dc35f99e8507a549a19ee2a2a2068d51dd711d6..937a574c89cced5e415839bc59a16a2a5a853367 100644 (file)
@@ -1011,31 +1011,29 @@ class RealReconnectTest(fixtures.TestBase):
         self.engine.dispose()
 
     def test_reconnect(self):
-        conn = self.engine.connect()
+        with self.engine.connect() as conn:
 
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.closed
-
-        self.engine.test_shutdown()
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.closed
 
-        _assert_invalidated(conn.execute, select(1))
+            self.engine.test_shutdown()
 
-        assert not conn.closed
-        assert conn.invalidated
+            _assert_invalidated(conn.execute, select(1))
 
-        assert conn.invalidated
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.invalidated
+            assert not conn.closed
+            assert conn.invalidated
 
-        # one more time
-        self.engine.test_shutdown()
-        _assert_invalidated(conn.execute, select(1))
+            assert conn.invalidated
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.invalidated
 
-        assert conn.invalidated
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.invalidated
+            # one more time
+            self.engine.test_shutdown()
+            _assert_invalidated(conn.execute, select(1))
 
-        conn.close()
+            assert conn.invalidated
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.invalidated
 
     @testing.requires.independent_connections
     def test_multiple_invalidate(self):
@@ -1056,52 +1054,52 @@ class RealReconnectTest(fixtures.TestBase):
         assert self.engine.pool is p2
 
     def test_branched_invalidate_branch_to_parent(self):
-        c1 = self.engine.connect()
+        with self.engine.connect() as c1:
 
-        with patch.object(self.engine.pool, "logger") as logger:
-            c1_branch = c1.connect()
-            eq_(c1_branch.execute(select(1)).scalar(), 1)
+            with patch.object(self.engine.pool, "logger") as logger:
+                c1_branch = c1.connect()
+                eq_(c1_branch.execute(select(1)).scalar(), 1)
 
-            self.engine.test_shutdown()
+                self.engine.test_shutdown()
 
-            _assert_invalidated(c1_branch.execute, select(1))
-            assert c1.invalidated
-            assert c1_branch.invalidated
+                _assert_invalidated(c1_branch.execute, select(1))
+                assert c1.invalidated
+                assert c1_branch.invalidated
 
-            c1_branch._revalidate_connection()
-            assert not c1.invalidated
-            assert not c1_branch.invalidated
+                c1_branch._revalidate_connection()
+                assert not c1.invalidated
+                assert not c1_branch.invalidated
 
-        assert "Invalidate connection" in logger.mock_calls[0][1][0]
+            assert "Invalidate connection" in logger.mock_calls[0][1][0]
 
     def test_branched_invalidate_parent_to_branch(self):
-        c1 = self.engine.connect()
+        with self.engine.connect() as c1:
 
-        c1_branch = c1.connect()
-        eq_(c1_branch.execute(select(1)).scalar(), 1)
+            c1_branch = c1.connect()
+            eq_(c1_branch.execute(select(1)).scalar(), 1)
 
-        self.engine.test_shutdown()
+            self.engine.test_shutdown()
 
-        _assert_invalidated(c1.execute, select(1))
-        assert c1.invalidated
-        assert c1_branch.invalidated
+            _assert_invalidated(c1.execute, select(1))
+            assert c1.invalidated
+            assert c1_branch.invalidated
 
-        c1._revalidate_connection()
-        assert not c1.invalidated
-        assert not c1_branch.invalidated
+            c1._revalidate_connection()
+            assert not c1.invalidated
+            assert not c1_branch.invalidated
 
     def test_branch_invalidate_state(self):
-        c1 = self.engine.connect()
+        with self.engine.connect() as c1:
 
-        c1_branch = c1.connect()
+            c1_branch = c1.connect()
 
-        eq_(c1_branch.execute(select(1)).scalar(), 1)
+            eq_(c1_branch.execute(select(1)).scalar(), 1)
 
-        self.engine.test_shutdown()
+            self.engine.test_shutdown()
 
-        _assert_invalidated(c1_branch.execute, select(1))
-        assert not c1_branch.closed
-        assert not c1_branch._still_open_and_dbapi_connection_is_valid
+            _assert_invalidated(c1_branch.execute, select(1))
+            assert not c1_branch.closed
+            assert not c1_branch._still_open_and_dbapi_connection_is_valid
 
     def test_ensure_is_disconnect_gets_connection(self):
         def is_disconnect(e, conn, cursor):
@@ -1112,38 +1110,39 @@ class RealReconnectTest(fixtures.TestBase):
             # assert cursor is None
 
         self.engine.dialect.is_disconnect = is_disconnect
-        conn = self.engine.connect()
-        self.engine.test_shutdown()
-        with expect_warnings(
-            "An exception has occurred during handling .*", py2konly=True
-        ):
-            assert_raises(tsa.exc.DBAPIError, conn.execute, select(1))
+
+        with self.engine.connect() as conn:
+            self.engine.test_shutdown()
+            with expect_warnings(
+                "An exception has occurred during handling .*", py2konly=True
+            ):
+                assert_raises(tsa.exc.DBAPIError, conn.execute, select(1))
 
     def test_rollback_on_invalid_plain(self):
-        conn = self.engine.connect()
-        trans = conn.begin()
-        conn.invalidate()
-        trans.rollback()
+        with self.engine.connect() as conn:
+            trans = conn.begin()
+            conn.invalidate()
+            trans.rollback()
 
     @testing.requires.two_phase_transactions
     def test_rollback_on_invalid_twophase(self):
-        conn = self.engine.connect()
-        trans = conn.begin_twophase()
-        conn.invalidate()
-        trans.rollback()
+        with self.engine.connect() as conn:
+            trans = conn.begin_twophase()
+            conn.invalidate()
+            trans.rollback()
 
     @testing.requires.savepoints
     def test_rollback_on_invalid_savepoint(self):
-        conn = self.engine.connect()
-        conn.begin()
-        trans2 = conn.begin_nested()
-        conn.invalidate()
-        trans2.rollback()
+        with self.engine.connect() as conn:
+            conn.begin()
+            trans2 = conn.begin_nested()
+            conn.invalidate()
+            trans2.rollback()
 
     def test_invalidate_twice(self):
-        conn = self.engine.connect()
-        conn.invalidate()
-        conn.invalidate()
+        with self.engine.connect() as conn:
+            conn.invalidate()
+            conn.invalidate()
 
     @testing.skip_if(
         [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle"
@@ -1182,93 +1181,92 @@ class RealReconnectTest(fixtures.TestBase):
         engine = engines.reconnecting_engine(
             options=dict(poolclass=pool.NullPool)
         )
-        conn = engine.connect()
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.closed
-        engine.test_shutdown()
-        _assert_invalidated(conn.execute, select(1))
-        assert not conn.closed
-        assert conn.invalidated
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.invalidated
+        with engine.connect() as conn:
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.closed
+            engine.test_shutdown()
+            _assert_invalidated(conn.execute, select(1))
+            assert not conn.closed
+            assert conn.invalidated
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.invalidated
 
     def test_close(self):
-        conn = self.engine.connect()
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.closed
+        with self.engine.connect() as conn:
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.closed
 
-        self.engine.test_shutdown()
+            self.engine.test_shutdown()
 
-        _assert_invalidated(conn.execute, select(1))
+            _assert_invalidated(conn.execute, select(1))
 
-        conn.close()
-        conn = self.engine.connect()
-        eq_(conn.execute(select(1)).scalar(), 1)
+        with self.engine.connect() as conn:
+            eq_(conn.execute(select(1)).scalar(), 1)
 
     def test_with_transaction(self):
-        conn = self.engine.connect()
-        trans = conn.begin()
-        assert trans.is_valid
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.closed
-        self.engine.test_shutdown()
-        _assert_invalidated(conn.execute, select(1))
-        assert not conn.closed
-        assert conn.invalidated
-        assert trans.is_active
-        assert not trans.is_valid
+        with self.engine.connect() as conn:
+            trans = conn.begin()
+            assert trans.is_valid
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.closed
+            self.engine.test_shutdown()
+            _assert_invalidated(conn.execute, select(1))
+            assert not conn.closed
+            assert conn.invalidated
+            assert trans.is_active
+            assert not trans.is_valid
 
-        assert_raises_message(
-            tsa.exc.PendingRollbackError,
-            "Can't reconnect until invalid transaction is rolled back",
-            conn.execute,
-            select(1),
-        )
-        assert trans.is_active
-        assert not trans.is_valid
+            assert_raises_message(
+                tsa.exc.PendingRollbackError,
+                "Can't reconnect until invalid transaction is rolled back",
+                conn.execute,
+                select(1),
+            )
+            assert trans.is_active
+            assert not trans.is_valid
 
-        assert_raises_message(
-            tsa.exc.PendingRollbackError,
-            "Can't reconnect until invalid transaction is rolled back",
-            trans.commit,
-        )
+            assert_raises_message(
+                tsa.exc.PendingRollbackError,
+                "Can't reconnect until invalid transaction is rolled back",
+                trans.commit,
+            )
 
-        # becomes inactive
-        assert not trans.is_active
-        assert not trans.is_valid
+            # becomes inactive
+            assert not trans.is_active
+            assert not trans.is_valid
 
-        # still asks us to rollback
-        assert_raises_message(
-            tsa.exc.PendingRollbackError,
-            "Can't reconnect until invalid transaction is rolled back",
-            conn.execute,
-            select(1),
-        )
+            # still asks us to rollback
+            assert_raises_message(
+                tsa.exc.PendingRollbackError,
+                "Can't reconnect until invalid transaction is rolled back",
+                conn.execute,
+                select(1),
+            )
 
-        # still asks us..
-        assert_raises_message(
-            tsa.exc.PendingRollbackError,
-            "Can't reconnect until invalid transaction is rolled back",
-            trans.commit,
-        )
+            # still asks us..
+            assert_raises_message(
+                tsa.exc.PendingRollbackError,
+                "Can't reconnect until invalid transaction is rolled back",
+                trans.commit,
+            )
 
-        # still...it's being consistent in what it is asking.
-        assert_raises_message(
-            tsa.exc.PendingRollbackError,
-            "Can't reconnect until invalid transaction is rolled back",
-            conn.execute,
-            select(1),
-        )
+            # still...it's being consistent in what it is asking.
+            assert_raises_message(
+                tsa.exc.PendingRollbackError,
+                "Can't reconnect until invalid transaction is rolled back",
+                conn.execute,
+                select(1),
+            )
 
-        #  OK!
-        trans.rollback()
-        assert not trans.is_active
-        assert not trans.is_valid
+            #  OK!
+            trans.rollback()
+            assert not trans.is_active
+            assert not trans.is_valid
 
-        # conn still invalid but we can reconnect
-        assert conn.invalidated
-        eq_(conn.execute(select(1)).scalar(), 1)
-        assert not conn.invalidated
+            # conn still invalid but we can reconnect
+            assert conn.invalidated
+            eq_(conn.execute(select(1)).scalar(), 1)
+            assert not conn.invalidated
 
 
 class RecycleTest(fixtures.TestBase):
@@ -1369,6 +1367,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
             "+pymysql",
             "+pg8000",
             "+asyncpg",
+            "+aiomysql",
         ],
         "Buffers the result set and doesn't check for connection close",
     )
index a361ff835a9f93a1b0941410a60a298b7847a766..6df8a0e006e03a905efafe4af231cb717d2ab9b1 100644 (file)
@@ -111,7 +111,9 @@ class AsyncEngineTest(EngineFixture):
         dbapi_connection = connection_fairy.connection
 
         await conn.invalidate()
-        assert dbapi_connection._connection.is_closed()
+
+        if testing.against("postgresql+asyncpg"):
+            assert dbapi_connection._connection.is_closed()
 
         new_fairy = await conn.get_raw_connection()
         is_not(new_fairy.connection, dbapi_connection)
@@ -429,6 +431,8 @@ class AsyncResultTest(EngineFixture):
 
             eq_(result.keys(), ["user_id", "user_name"])
 
+            await result.close()
+
     @async_test
     async def test_unique_all(self, async_engine):
         users = self.tables.users
index a3b8add6774e26a8fb264bb56ffd570fb6867dc9..44e2955428b2898fbb0655c1d080e40e6a5a458a 100644 (file)
@@ -55,6 +55,7 @@ class AsyncSessionQueryTest(AsyncFixture):
         eq_(result.scalars().all(), self.static.user_address_result)
 
     @async_test
+    @testing.requires.independent_cursors
     async def test_stream_partitions(self, async_session):
         User = self.classes.User
 
@@ -99,6 +100,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
                 result = await async_session.execute(select(User))
                 eq_(result.scalar(), u1)
 
+            await outer_conn.rollback()
             eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
 
     @async_test
@@ -118,6 +120,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
 
             await async_session.commit()
 
+            await outer_conn.rollback()
             eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
 
     @async_test
@@ -139,6 +142,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
             finally:
                 await trans.commit()
 
+            await outer_conn.rollback()
             eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
 
     @async_test
index 5911d87af8a350f755419087c73208878a2e00d8..d8be25238a76f00dd390a6f372596b18834d0bd2 100644 (file)
@@ -1306,7 +1306,9 @@ class DefaultRequirements(SuiteRequirements):
     def async_dialect(self):
         """dialect makes use of await_() to invoke operations on the DBAPI."""
 
-        return only_on(["postgresql+asyncpg"])
+        return only_on(
+            ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"]
+        )
 
     @property
     def oracle_test_dblink(self):
@@ -1353,7 +1355,10 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def percent_schema_names(self):
-        return exclusions.open()
+        return skip_if(
+            ["mysql+aiomysql", "mariadb+aiomysql"],
+            "see pr https://github.com/aio-libs/aiomysql/pull/545",
+        )
 
     @property
     def order_by_label_with_expression(self):
index 9ef533be3a555e988822757683fca7816ab6eecb..545cf96abbe3146dab6e0a4758118ade6c4768d8 100644 (file)
@@ -291,15 +291,41 @@ class CursorResultTest(fixtures.TablesTest):
     def test_fetchmany(self, connection):
         users = self.tables.users
 
-        connection.execute(users.insert(), user_id=7, user_name="jack")
-        connection.execute(users.insert(), user_id=8, user_name="ed")
-        connection.execute(users.insert(), user_id=9, user_name="fred")
+        connection.execute(
+            users.insert(),
+            [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)],
+        )
         r = connection.execute(users.select())
         rows = []
         for row in r.fetchmany(size=2):
             rows.append(row)
         eq_(len(rows), 2)
 
+    def test_fetchmany_arraysize_default(self, connection):
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [{"user_id": i, "user_name": "n%d" % i} for i in range(1, 150)],
+        )
+        r = connection.execute(users.select())
+        arraysize = r.cursor.arraysize
+        rows = list(r.fetchmany())
+
+        eq_(len(rows), min(arraysize, 150))
+
+    def test_fetchmany_arraysize_set(self, connection):
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)],
+        )
+        r = connection.execute(users.select())
+        r.cursor.arraysize = 4
+        rows = list(r.fetchmany())
+        eq_(len(rows), 4)
+
     def test_column_slices(self, connection):
         users = self.tables.users
         addresses = self.tables.addresses
diff --git a/tox.ini b/tox.ini
index 6cfcf62efc25acd3287a7625d3a849d2d51580de..9d0c75f77fed170a642355a28270cd5e893d09e2 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -25,6 +25,7 @@ deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only
      postgresql: .[postgresql_pg8000]; python_version >= '3'
      mysql: .[mysql]
      mysql: .[pymysql]
+     mysql: .[aiomysql]; python_version >= '3'
      mysql: .[mariadb_connector]; python_version >= '3'
 
      # we should probably try to get mysql_connector back in the mix
@@ -78,7 +79,7 @@ setenv=
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
     mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
-    py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector}
+    py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true}
 
 
     mssql: MSSQL={env:TOX_MSSQL:--db mssql}