From: Mike Bayer Date: Fri, 4 Dec 2020 21:26:44 +0000 (-0500) Subject: add aiomysql support X-Git-Tag: rel_1_4_0b2~108^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f333762db4b72c604e44f20f86d171dc249b741;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add aiomysql support This is a re-gerrit of the original gerrit merged in Ia8ad3efe3b50ce75a3bed1e020e1b82acb5f2eda Reverted due to ongoing issues. Fixes: #5747 Change-Id: I2b57e76b817eed8f89457a2146b523a1cab656a8 --- diff --git a/doc/build/changelog/unreleased_14/5747.rst b/doc/build/changelog/unreleased_14/5747.rst new file mode 100644 index 0000000000..47cf648cda --- /dev/null +++ b/doc/build/changelog/unreleased_14/5747.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: feature, mysql + :tickets: 5747 + + Added support for the aiomysql driver when using the asyncio SQLAlchemy + extension. + + .. seealso:: + + :ref:`aiomysql` \ No newline at end of file diff --git a/doc/build/dialects/mysql.rst b/doc/build/dialects/mysql.rst index 1f2236155b..c0bfa7bc62 100644 --- a/doc/build/dialects/mysql.rst +++ b/doc/build/dialects/mysql.rst @@ -181,6 +181,13 @@ MySQL-Connector .. automodule:: sqlalchemy.dialects.mysql.mysqlconnector +.. _aiomysql: + +aiomysql +-------- + +.. automodule:: sqlalchemy.dialects.mysql.aiomysql + cymysql ------- diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index 9fdc96f6fb..c6781c1685 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -49,6 +49,10 @@ from .base import VARCHAR from .base import YEAR from .dml import Insert from .dml import insert +from ...util import compat + +if compat.py3k: + from . import aiomysql # noqa # default dialect diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py new file mode 100644 index 0000000000..f560ece332 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -0,0 +1,278 @@ +# mysql/aiomysql.py +# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: mysql+aiomysql + :name: aiomysql + :dbapi: aiomysql + :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/aio-libs/aiomysql + +The aiomysql dialect is SQLAlchemy's second Python asyncio dialect. + +Using a special asyncio mediation layer, the aiomysql dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname") + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + + +""" # noqa + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiomysql_cursor: + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + if parameters is None: + result = self.await_(self._cursor.execute(operation)) + else: + result = self.await_(self._cursor.execute(operation, parameters)) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(self.await_(self._cursor.fetchall())) + return result + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._cursor.executemany(operation, seq_of_parameters) + ) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.aiomysql.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection: + await_ = staticmethod(await_only) + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): + self.aiomysql = aiomysql + self.pymysql = pymysql + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.aiomysql, name)) + + for name in ( + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "TIMESTAMP", + "Binary", + ): + setattr(self, name, getattr(self.pymysql, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if async_fallback: + return AsyncAdaptFallback_aiomysql_connection( + self, + await_fallback(self.aiomysql.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_aiomysql_connection( + self, + await_only(self.aiomysql.connect(*arg, **kw)), + ) + + +class MySQLDialect_aiomysql(MySQLDialect_pymysql): + driver = "aiomysql" + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_aiomysql_ss_cursor + + @classmethod + def dbapi(cls): + return AsyncAdapt_aiomysql_dbapi( + __import__("aiomysql"), __import__("pymysql") + ) + + @classmethod + def get_pool_class(self, url): + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url) + if "passwd" in kw: + kw["password"] = kw.pop("passwd") + return args, kw + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_aiomysql, self).is_disconnect( + e, connection, cursor + ): + return True + else: + str_e = str(e).lower() + return "not connected" in str_e + + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT + + return CLIENT.FOUND_ROWS + + +dialect = MySQLDialect_aiomysql diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index b20e061fb5..605407f463 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -211,16 +211,25 @@ class MySQLDialect_mysqldb(MySQLDialect): # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. client_flag = opts.get("client_flag", 0) + + client_flag_found_rows = self._found_rows_client_flag() + if client_flag_found_rows is not None: + client_flag |= client_flag_found_rows + opts["client_flag"] = client_flag + return [[], opts] + + def _found_rows_client_flag(self): if self.dbapi is not None: try: CLIENT_FLAGS = __import__( self.dbapi.__name__ + ".constants.CLIENT" ).constants.CLIENT - client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): - self.supports_sane_rowcount = False - opts["client_flag"] = client_flag - return [[], opts] + return None + else: + return CLIENT_FLAGS.FOUND_ROWS + else: + return None def _extract_error_code(self, exception): return exception.args[0] diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 7f8a707d52..9c7e0420fb 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -17,7 +17,14 @@ if util.TYPE_CHECKING: from ...engine.result import Row -class AsyncResult(FilterResult): +class AsyncCommon(FilterResult): + async def close(self): + """Close this result.""" + + await greenlet_spawn(self._real_result.close) + + +class AsyncResult(AsyncCommon): """An asyncio wrapper around a :class:`_result.Result` object. The :class:`_asyncio.AsyncResult` only applies to statement executions that @@ -370,7 +377,7 @@ class AsyncResult(FilterResult): return AsyncMappingResult(self._real_result) -class AsyncScalarResult(FilterResult): +class AsyncScalarResult(AsyncCommon): """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values rather than :class:`_row.Row` values. @@ -500,7 +507,7 @@ class AsyncScalarResult(FilterResult): return await greenlet_spawn(self._only_one_row, True, True, False) -class AsyncMappingResult(FilterResult): +class AsyncMappingResult(AsyncCommon): """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values rather than :class:`_engine.Row` values. diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 9484d41d09..f31c7c1377 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -114,9 +114,8 @@ class RowFetchTest(fixtures.TablesTest): class PercentSchemaNamesTest(fixtures.TablesTest): """tests using percent signs, spaces in table and column names. - This is a very fringe use case, doesn't work for MySQL - or PostgreSQL. the requirement, "percent_schema_names", - is marked "skip" by default. + This didn't work for PostgreSQL / MySQL drivers for a long time + but is now supported. """ @@ -233,6 +232,8 @@ class ServerSideCursorsTest( elif self.engine.dialect.driver == "pymysql": sscursor = __import__("pymysql.cursors").cursors.SSCursor return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "aiomysql": + return cursor.server_side elif self.engine.dialect.driver == "mysqldb": sscursor = __import__("MySQLdb.cursors").cursors.SSCursor return isinstance(cursor, sscursor) diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index b230bad6f0..34a968aff1 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -30,6 +30,11 @@ def setup_filters(): warnings.filterwarnings( "ignore", category=DeprecationWarning, message=".*inspect.get.*argspec" ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="The loop argument is deprecated", + ) # ignore things that are deprecated *as of* 2.0 :) warnings.filterwarnings( diff --git a/setup.cfg b/setup.cfg index 1912fd3cd6..46fe781044 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,6 +65,7 @@ postgresql_asyncpg = postgresql_psycopg2binary = psycopg2-binary postgresql_psycopg2cffi = psycopg2cffi pymysql = pymysql +aiomysql = aiomysql [egg_info] tag_build = dev @@ -124,6 +125,7 @@ pg8000 = postgresql+pg8000://scott:tiger@127.0.0.1:5432/test postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4 +aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true mariadb = mariadb://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 0dc35f99e8..937a574c89 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1011,31 +1011,29 @@ class RealReconnectTest(fixtures.TestBase): self.engine.dispose() def test_reconnect(self): - conn = self.engine.connect() + with self.engine.connect() as conn: - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.closed - - self.engine.test_shutdown() + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.closed - _assert_invalidated(conn.execute, select(1)) + self.engine.test_shutdown() - assert not conn.closed - assert conn.invalidated + _assert_invalidated(conn.execute, select(1)) - assert conn.invalidated - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.invalidated + assert not conn.closed + assert conn.invalidated - # one more time - self.engine.test_shutdown() - _assert_invalidated(conn.execute, select(1)) + assert conn.invalidated + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.invalidated - assert conn.invalidated - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.invalidated + # one more time + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select(1)) - conn.close() + assert conn.invalidated + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.invalidated @testing.requires.independent_connections def test_multiple_invalidate(self): @@ -1056,52 +1054,52 @@ class RealReconnectTest(fixtures.TestBase): assert self.engine.pool is p2 def test_branched_invalidate_branch_to_parent(self): - c1 = self.engine.connect() + with self.engine.connect() as c1: - with patch.object(self.engine.pool, "logger") as logger: - c1_branch = c1.connect() - eq_(c1_branch.execute(select(1)).scalar(), 1) + with patch.object(self.engine.pool, "logger") as logger: + c1_branch = c1.connect() + eq_(c1_branch.execute(select(1)).scalar(), 1) - self.engine.test_shutdown() + self.engine.test_shutdown() - _assert_invalidated(c1_branch.execute, select(1)) - assert c1.invalidated - assert c1_branch.invalidated + _assert_invalidated(c1_branch.execute, select(1)) + assert c1.invalidated + assert c1_branch.invalidated - c1_branch._revalidate_connection() - assert not c1.invalidated - assert not c1_branch.invalidated + c1_branch._revalidate_connection() + assert not c1.invalidated + assert not c1_branch.invalidated - assert "Invalidate connection" in logger.mock_calls[0][1][0] + assert "Invalidate connection" in logger.mock_calls[0][1][0] def test_branched_invalidate_parent_to_branch(self): - c1 = self.engine.connect() + with self.engine.connect() as c1: - c1_branch = c1.connect() - eq_(c1_branch.execute(select(1)).scalar(), 1) + c1_branch = c1.connect() + eq_(c1_branch.execute(select(1)).scalar(), 1) - self.engine.test_shutdown() + self.engine.test_shutdown() - _assert_invalidated(c1.execute, select(1)) - assert c1.invalidated - assert c1_branch.invalidated + _assert_invalidated(c1.execute, select(1)) + assert c1.invalidated + assert c1_branch.invalidated - c1._revalidate_connection() - assert not c1.invalidated - assert not c1_branch.invalidated + c1._revalidate_connection() + assert not c1.invalidated + assert not c1_branch.invalidated def test_branch_invalidate_state(self): - c1 = self.engine.connect() + with self.engine.connect() as c1: - c1_branch = c1.connect() + c1_branch = c1.connect() - eq_(c1_branch.execute(select(1)).scalar(), 1) + eq_(c1_branch.execute(select(1)).scalar(), 1) - self.engine.test_shutdown() + self.engine.test_shutdown() - _assert_invalidated(c1_branch.execute, select(1)) - assert not c1_branch.closed - assert not c1_branch._still_open_and_dbapi_connection_is_valid + _assert_invalidated(c1_branch.execute, select(1)) + assert not c1_branch.closed + assert not c1_branch._still_open_and_dbapi_connection_is_valid def test_ensure_is_disconnect_gets_connection(self): def is_disconnect(e, conn, cursor): @@ -1112,38 +1110,39 @@ class RealReconnectTest(fixtures.TestBase): # assert cursor is None self.engine.dialect.is_disconnect = is_disconnect - conn = self.engine.connect() - self.engine.test_shutdown() - with expect_warnings( - "An exception has occurred during handling .*", py2konly=True - ): - assert_raises(tsa.exc.DBAPIError, conn.execute, select(1)) + + with self.engine.connect() as conn: + self.engine.test_shutdown() + with expect_warnings( + "An exception has occurred during handling .*", py2konly=True + ): + assert_raises(tsa.exc.DBAPIError, conn.execute, select(1)) def test_rollback_on_invalid_plain(self): - conn = self.engine.connect() - trans = conn.begin() - conn.invalidate() - trans.rollback() + with self.engine.connect() as conn: + trans = conn.begin() + conn.invalidate() + trans.rollback() @testing.requires.two_phase_transactions def test_rollback_on_invalid_twophase(self): - conn = self.engine.connect() - trans = conn.begin_twophase() - conn.invalidate() - trans.rollback() + with self.engine.connect() as conn: + trans = conn.begin_twophase() + conn.invalidate() + trans.rollback() @testing.requires.savepoints def test_rollback_on_invalid_savepoint(self): - conn = self.engine.connect() - conn.begin() - trans2 = conn.begin_nested() - conn.invalidate() - trans2.rollback() + with self.engine.connect() as conn: + conn.begin() + trans2 = conn.begin_nested() + conn.invalidate() + trans2.rollback() def test_invalidate_twice(self): - conn = self.engine.connect() - conn.invalidate() - conn.invalidate() + with self.engine.connect() as conn: + conn.invalidate() + conn.invalidate() @testing.skip_if( [lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle" @@ -1182,93 +1181,92 @@ class RealReconnectTest(fixtures.TestBase): engine = engines.reconnecting_engine( options=dict(poolclass=pool.NullPool) ) - conn = engine.connect() - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.closed - engine.test_shutdown() - _assert_invalidated(conn.execute, select(1)) - assert not conn.closed - assert conn.invalidated - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.invalidated + with engine.connect() as conn: + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.closed + engine.test_shutdown() + _assert_invalidated(conn.execute, select(1)) + assert not conn.closed + assert conn.invalidated + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.invalidated def test_close(self): - conn = self.engine.connect() - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.closed + with self.engine.connect() as conn: + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.closed - self.engine.test_shutdown() + self.engine.test_shutdown() - _assert_invalidated(conn.execute, select(1)) + _assert_invalidated(conn.execute, select(1)) - conn.close() - conn = self.engine.connect() - eq_(conn.execute(select(1)).scalar(), 1) + with self.engine.connect() as conn: + eq_(conn.execute(select(1)).scalar(), 1) def test_with_transaction(self): - conn = self.engine.connect() - trans = conn.begin() - assert trans.is_valid - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.closed - self.engine.test_shutdown() - _assert_invalidated(conn.execute, select(1)) - assert not conn.closed - assert conn.invalidated - assert trans.is_active - assert not trans.is_valid + with self.engine.connect() as conn: + trans = conn.begin() + assert trans.is_valid + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.closed + self.engine.test_shutdown() + _assert_invalidated(conn.execute, select(1)) + assert not conn.closed + assert conn.invalidated + assert trans.is_active + assert not trans.is_valid - assert_raises_message( - tsa.exc.PendingRollbackError, - "Can't reconnect until invalid transaction is rolled back", - conn.execute, - select(1), - ) - assert trans.is_active - assert not trans.is_valid + assert_raises_message( + tsa.exc.PendingRollbackError, + "Can't reconnect until invalid transaction is rolled back", + conn.execute, + select(1), + ) + assert trans.is_active + assert not trans.is_valid - assert_raises_message( - tsa.exc.PendingRollbackError, - "Can't reconnect until invalid transaction is rolled back", - trans.commit, - ) + assert_raises_message( + tsa.exc.PendingRollbackError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit, + ) - # becomes inactive - assert not trans.is_active - assert not trans.is_valid + # becomes inactive + assert not trans.is_active + assert not trans.is_valid - # still asks us to rollback - assert_raises_message( - tsa.exc.PendingRollbackError, - "Can't reconnect until invalid transaction is rolled back", - conn.execute, - select(1), - ) + # still asks us to rollback + assert_raises_message( + tsa.exc.PendingRollbackError, + "Can't reconnect until invalid transaction is rolled back", + conn.execute, + select(1), + ) - # still asks us.. - assert_raises_message( - tsa.exc.PendingRollbackError, - "Can't reconnect until invalid transaction is rolled back", - trans.commit, - ) + # still asks us.. + assert_raises_message( + tsa.exc.PendingRollbackError, + "Can't reconnect until invalid transaction is rolled back", + trans.commit, + ) - # still...it's being consistent in what it is asking. - assert_raises_message( - tsa.exc.PendingRollbackError, - "Can't reconnect until invalid transaction is rolled back", - conn.execute, - select(1), - ) + # still...it's being consistent in what it is asking. + assert_raises_message( + tsa.exc.PendingRollbackError, + "Can't reconnect until invalid transaction is rolled back", + conn.execute, + select(1), + ) - # OK! - trans.rollback() - assert not trans.is_active - assert not trans.is_valid + # OK! + trans.rollback() + assert not trans.is_active + assert not trans.is_valid - # conn still invalid but we can reconnect - assert conn.invalidated - eq_(conn.execute(select(1)).scalar(), 1) - assert not conn.invalidated + # conn still invalid but we can reconnect + assert conn.invalidated + eq_(conn.execute(select(1)).scalar(), 1) + assert not conn.invalidated class RecycleTest(fixtures.TestBase): @@ -1369,6 +1367,7 @@ class InvalidateDuringResultTest(fixtures.TestBase): "+pymysql", "+pg8000", "+asyncpg", + "+aiomysql", ], "Buffers the result set and doesn't check for connection close", ) diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index a361ff835a..6df8a0e006 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -111,7 +111,9 @@ class AsyncEngineTest(EngineFixture): dbapi_connection = connection_fairy.connection await conn.invalidate() - assert dbapi_connection._connection.is_closed() + + if testing.against("postgresql+asyncpg"): + assert dbapi_connection._connection.is_closed() new_fairy = await conn.get_raw_connection() is_not(new_fairy.connection, dbapi_connection) @@ -429,6 +431,8 @@ class AsyncResultTest(EngineFixture): eq_(result.keys(), ["user_id", "user_name"]) + await result.close() + @async_test async def test_unique_all(self, async_engine): users = self.tables.users diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index a3b8add677..44e2955428 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -55,6 +55,7 @@ class AsyncSessionQueryTest(AsyncFixture): eq_(result.scalars().all(), self.static.user_address_result) @async_test + @testing.requires.independent_cursors async def test_stream_partitions(self, async_session): User = self.classes.User @@ -99,6 +100,7 @@ class AsyncSessionTransactionTest(AsyncFixture): result = await async_session.execute(select(User)) eq_(result.scalar(), u1) + await outer_conn.rollback() eq_(await outer_conn.scalar(select(func.count(User.id))), 1) @async_test @@ -118,6 +120,7 @@ class AsyncSessionTransactionTest(AsyncFixture): await async_session.commit() + await outer_conn.rollback() eq_(await outer_conn.scalar(select(func.count(User.id))), 1) @async_test @@ -139,6 +142,7 @@ class AsyncSessionTransactionTest(AsyncFixture): finally: await trans.commit() + await outer_conn.rollback() eq_(await outer_conn.scalar(select(func.count(User.id))), 1) @async_test diff --git a/test/requirements.py b/test/requirements.py index 5911d87af8..d8be25238a 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1306,7 +1306,9 @@ class DefaultRequirements(SuiteRequirements): def async_dialect(self): """dialect makes use of await_() to invoke operations on the DBAPI.""" - return only_on(["postgresql+asyncpg"]) + return only_on( + ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"] + ) @property def oracle_test_dblink(self): @@ -1353,7 +1355,10 @@ class DefaultRequirements(SuiteRequirements): @property def percent_schema_names(self): - return exclusions.open() + return skip_if( + ["mysql+aiomysql", "mariadb+aiomysql"], + "see pr https://github.com/aio-libs/aiomysql/pull/545", + ) @property def order_by_label_with_expression(self): diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 9ef533be3a..545cf96abb 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -291,15 +291,41 @@ class CursorResultTest(fixtures.TablesTest): def test_fetchmany(self, connection): users = self.tables.users - connection.execute(users.insert(), user_id=7, user_name="jack") - connection.execute(users.insert(), user_id=8, user_name="ed") - connection.execute(users.insert(), user_id=9, user_name="fred") + connection.execute( + users.insert(), + [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)], + ) r = connection.execute(users.select()) rows = [] for row in r.fetchmany(size=2): rows.append(row) eq_(len(rows), 2) + def test_fetchmany_arraysize_default(self, connection): + users = self.tables.users + + connection.execute( + users.insert(), + [{"user_id": i, "user_name": "n%d" % i} for i in range(1, 150)], + ) + r = connection.execute(users.select()) + arraysize = r.cursor.arraysize + rows = list(r.fetchmany()) + + eq_(len(rows), min(arraysize, 150)) + + def test_fetchmany_arraysize_set(self, connection): + users = self.tables.users + + connection.execute( + users.insert(), + [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)], + ) + r = connection.execute(users.select()) + r.cursor.arraysize = 4 + rows = list(r.fetchmany()) + eq_(len(rows), 4) + def test_column_slices(self, connection): users = self.tables.users addresses = self.tables.addresses diff --git a/tox.ini b/tox.ini index 6cfcf62efc..9d0c75f77f 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,7 @@ deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only postgresql: .[postgresql_pg8000]; python_version >= '3' mysql: .[mysql] mysql: .[pymysql] + mysql: .[aiomysql]; python_version >= '3' mysql: .[mariadb_connector]; python_version >= '3' # we should probably try to get mysql_connector back in the mix @@ -78,7 +79,7 @@ setenv= mysql: MYSQL={env:TOX_MYSQL:--db mysql} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql} - py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector} + py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true} mssql: MSSQL={env:TOX_MSSQL:--db mssql}