--- /dev/null
+.. change::
+ :tags: feature, mysql
+ :tickets: 5747
+
+ Added support for the aiomysql driver when using the asyncio SQLAlchemy
+ extension.
+
+ .. seealso::
+
+ :ref:`aiomysql`
\ No newline at end of file
.. automodule:: sqlalchemy.dialects.mysql.mysqlconnector
+.. _aiomysql:
+
+aiomysql
+--------
+
+.. automodule:: sqlalchemy.dialects.mysql.aiomysql
+
cymysql
-------
from .base import YEAR
from .dml import Insert
from .dml import insert
+from ...util import compat
+
+if compat.py3k:
+ from . import aiomysql # noqa
# default dialect
--- /dev/null
+# mysql/aiomysql.py
+# Copyright (C) 2005-2020 the SQLAlchemy authors and contributors <see AUTHORS
+# file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+r"""
+.. dialect:: mysql+aiomysql
+ :name: aiomysql
+ :dbapi: aiomysql
+ :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
+ :url: https://github.com/aio-libs/aiomysql
+
+The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
+
+Using a special asyncio mediation layer, the aiomysql dialect is usable
+as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
+extension package.
+
+This dialect should normally be used only with the
+:func:`_asyncio.create_async_engine` engine creation function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname")
+
+Unicode
+-------
+
+Please see :ref:`mysql_unicode` for current recommendations on unicode
+handling.
+
+
+""" # noqa
+
+from .pymysql import MySQLDialect_pymysql
+from ... import pool
+from ...util.concurrency import await_fallback
+from ...util.concurrency import await_only
+
+
+class AsyncAdapt_aiomysql_cursor:
+ server_side = False
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor()
+
+ # see https://github.com/aio-libs/aiomysql/issues/543
+ self._cursor = self.await_(cursor.__aenter__())
+ self._rows = []
+
+ @property
+ def description(self):
+ return self._cursor.description
+
+ @property
+ def rowcount(self):
+ return self._cursor.rowcount
+
+ @property
+ def arraysize(self):
+ return self._cursor.arraysize
+
+ @arraysize.setter
+ def arraysize(self, value):
+ self._cursor.arraysize = value
+
+ @property
+ def lastrowid(self):
+ return self._cursor.lastrowid
+
+ def close(self):
+ self._rows[:] = []
+
+ def execute(self, operation, parameters=None):
+ if parameters is None:
+ result = self.await_(self._cursor.execute(operation))
+ else:
+ result = self.await_(self._cursor.execute(operation, parameters))
+
+ if not self.server_side:
+ # aiomysql has a "fake" async result, so we have to pull it out
+ # of that here since our default result is not async.
+ # we could just as easily grab "_rows" here and be done with it
+ # but this is safer.
+ self._rows = list(self.await_(self._cursor.fetchall()))
+ return result
+
+ def executemany(self, operation, seq_of_parameters):
+ return self.await_(
+ self._cursor.executemany(operation, seq_of_parameters)
+ )
+
+ def setinputsizes(self, *inputsizes):
+ pass
+
+ def __iter__(self):
+ while self._rows:
+ yield self._rows.pop(0)
+
+ def fetchone(self):
+ if self._rows:
+ return self._rows.pop(0)
+ else:
+ return None
+
+ def fetchmany(self, size=None):
+ if size is None:
+ size = self.arraysize
+
+ retval = self._rows[0:size]
+ self._rows[:] = self._rows[size:]
+ return retval
+
+ def fetchall(self):
+ retval = self._rows[:]
+ self._rows[:] = []
+ return retval
+
+
+class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
+
+ server_side = True
+
+ def __init__(self, adapt_connection):
+ self._adapt_connection = adapt_connection
+ self._connection = adapt_connection._connection
+ self.await_ = adapt_connection.await_
+
+ cursor = self._connection.cursor(
+ adapt_connection.dbapi.aiomysql.SSCursor
+ )
+
+ self._cursor = self.await_(cursor.__aenter__())
+
+ def close(self):
+ if self._cursor is not None:
+ self.await_(self._cursor.close())
+ self._cursor = None
+
+ def fetchone(self):
+ return self.await_(self._cursor.fetchone())
+
+ def fetchmany(self, size=None):
+ return self.await_(self._cursor.fetchmany(size=size))
+
+ def fetchall(self):
+ return self.await_(self._cursor.fetchall())
+
+
+class AsyncAdapt_aiomysql_connection:
+ await_ = staticmethod(await_only)
+
+ def __init__(self, dbapi, connection):
+ self.dbapi = dbapi
+ self._connection = connection
+
+ def ping(self, reconnect):
+ return self.await_(self._connection.ping(reconnect))
+
+ def character_set_name(self):
+ return self._connection.character_set_name()
+
+ def autocommit(self, value):
+ self.await_(self._connection.autocommit(value))
+
+ def cursor(self, server_side=False):
+ if server_side:
+ return AsyncAdapt_aiomysql_ss_cursor(self)
+ else:
+ return AsyncAdapt_aiomysql_cursor(self)
+
+ def rollback(self):
+ self.await_(self._connection.rollback())
+
+ def commit(self):
+ self.await_(self._connection.commit())
+
+ def close(self):
+ # it's not awaitable.
+ self._connection.close()
+
+
+class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
+ __slots__ = ()
+
+ await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_aiomysql_dbapi:
+ def __init__(self, aiomysql, pymysql):
+ self.aiomysql = aiomysql
+ self.pymysql = pymysql
+ self.paramstyle = "format"
+ self._init_dbapi_attributes()
+
+ def _init_dbapi_attributes(self):
+ for name in (
+ "Warning",
+ "Error",
+ "InterfaceError",
+ "DataError",
+ "DatabaseError",
+ "OperationalError",
+ "InterfaceError",
+ "IntegrityError",
+ "ProgrammingError",
+ "InternalError",
+ "NotSupportedError",
+ ):
+ setattr(self, name, getattr(self.aiomysql, name))
+
+ for name in (
+ "NUMBER",
+ "STRING",
+ "DATETIME",
+ "BINARY",
+ "TIMESTAMP",
+ "Binary",
+ ):
+ setattr(self, name, getattr(self.pymysql, name))
+
+ def connect(self, *arg, **kw):
+ async_fallback = kw.pop("async_fallback", False)
+
+ if async_fallback:
+ return AsyncAdaptFallback_aiomysql_connection(
+ self,
+ await_fallback(self.aiomysql.connect(*arg, **kw)),
+ )
+ else:
+ return AsyncAdapt_aiomysql_connection(
+ self,
+ await_only(self.aiomysql.connect(*arg, **kw)),
+ )
+
+
+class MySQLDialect_aiomysql(MySQLDialect_pymysql):
+ driver = "aiomysql"
+
+ supports_server_side_cursors = True
+ _sscursor = AsyncAdapt_aiomysql_ss_cursor
+
+ @classmethod
+ def dbapi(cls):
+ return AsyncAdapt_aiomysql_dbapi(
+ __import__("aiomysql"), __import__("pymysql")
+ )
+
+ @classmethod
+ def get_pool_class(self, url):
+ return pool.AsyncAdaptedQueuePool
+
+ def create_connect_args(self, url):
+ args, kw = super(MySQLDialect_aiomysql, self).create_connect_args(url)
+ if "passwd" in kw:
+ kw["password"] = kw.pop("passwd")
+ return args, kw
+
+ def is_disconnect(self, e, connection, cursor):
+ if super(MySQLDialect_aiomysql, self).is_disconnect(
+ e, connection, cursor
+ ):
+ return True
+ else:
+ str_e = str(e).lower()
+ return "not connected" in str_e
+
+ def _found_rows_client_flag(self):
+ from pymysql.constants import CLIENT
+
+ return CLIENT.FOUND_ROWS
+
+
+dialect = MySQLDialect_aiomysql
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
+
+ client_flag_found_rows = self._found_rows_client_flag()
+ if client_flag_found_rows is not None:
+ client_flag |= client_flag_found_rows
+ opts["client_flag"] = client_flag
+ return [[], opts]
+
+ def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
- client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
- self.supports_sane_rowcount = False
- opts["client_flag"] = client_flag
- return [[], opts]
+ return None
+ else:
+ return CLIENT_FLAGS.FOUND_ROWS
+ else:
+ return None
def _extract_error_code(self, exception):
return exception.args[0]
from ...engine.result import Row
-class AsyncResult(FilterResult):
+class AsyncCommon(FilterResult):
+ async def close(self):
+ """Close this result."""
+
+ await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
"""An asyncio wrapper around a :class:`_result.Result` object.
The :class:`_asyncio.AsyncResult` only applies to statement executions that
return AsyncMappingResult(self._real_result)
-class AsyncScalarResult(FilterResult):
+class AsyncScalarResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
rather than :class:`_row.Row` values.
return await greenlet_spawn(self._only_one_row, True, True, False)
-class AsyncMappingResult(FilterResult):
+class AsyncMappingResult(AsyncCommon):
"""A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary values
rather than :class:`_engine.Row` values.
class PercentSchemaNamesTest(fixtures.TablesTest):
"""tests using percent signs, spaces in table and column names.
- This is a very fringe use case, doesn't work for MySQL
- or PostgreSQL. the requirement, "percent_schema_names",
- is marked "skip" by default.
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
"""
elif self.engine.dialect.driver == "pymysql":
sscursor = __import__("pymysql.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "aiomysql":
+ return cursor.server_side
elif self.engine.dialect.driver == "mysqldb":
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
return isinstance(cursor, sscursor)
warnings.filterwarnings(
"ignore", category=DeprecationWarning, message=".*inspect.get.*argspec"
)
+ warnings.filterwarnings(
+ "ignore",
+ category=DeprecationWarning,
+ message="The loop argument is deprecated",
+ )
# ignore things that are deprecated *as of* 2.0 :)
warnings.filterwarnings(
postgresql_psycopg2binary = psycopg2-binary
postgresql_psycopg2cffi = psycopg2cffi
pymysql = pymysql
+aiomysql = aiomysql
[egg_info]
tag_build = dev
postgresql_psycopg2cffi = postgresql+psycopg2cffi://scott:tiger@127.0.0.1:5432/test
mysql = mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
pymysql = mysql+pymysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+aiomysql = mysql+aiomysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4&async_fallback=true
mariadb = mariadb://scott:tiger@127.0.0.1:3306/test
mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008
self.engine.dispose()
def test_reconnect(self):
- conn = self.engine.connect()
+ with self.engine.connect() as conn:
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.closed
-
- self.engine.test_shutdown()
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.closed
- _assert_invalidated(conn.execute, select(1))
+ self.engine.test_shutdown()
- assert not conn.closed
- assert conn.invalidated
+ _assert_invalidated(conn.execute, select(1))
- assert conn.invalidated
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.invalidated
+ assert not conn.closed
+ assert conn.invalidated
- # one more time
- self.engine.test_shutdown()
- _assert_invalidated(conn.execute, select(1))
+ assert conn.invalidated
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.invalidated
- assert conn.invalidated
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.invalidated
+ # one more time
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select(1))
- conn.close()
+ assert conn.invalidated
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.invalidated
@testing.requires.independent_connections
def test_multiple_invalidate(self):
assert self.engine.pool is p2
def test_branched_invalidate_branch_to_parent(self):
- c1 = self.engine.connect()
+ with self.engine.connect() as c1:
- with patch.object(self.engine.pool, "logger") as logger:
- c1_branch = c1.connect()
- eq_(c1_branch.execute(select(1)).scalar(), 1)
+ with patch.object(self.engine.pool, "logger") as logger:
+ c1_branch = c1.connect()
+ eq_(c1_branch.execute(select(1)).scalar(), 1)
- self.engine.test_shutdown()
+ self.engine.test_shutdown()
- _assert_invalidated(c1_branch.execute, select(1))
- assert c1.invalidated
- assert c1_branch.invalidated
+ _assert_invalidated(c1_branch.execute, select(1))
+ assert c1.invalidated
+ assert c1_branch.invalidated
- c1_branch._revalidate_connection()
- assert not c1.invalidated
- assert not c1_branch.invalidated
+ c1_branch._revalidate_connection()
+ assert not c1.invalidated
+ assert not c1_branch.invalidated
- assert "Invalidate connection" in logger.mock_calls[0][1][0]
+ assert "Invalidate connection" in logger.mock_calls[0][1][0]
def test_branched_invalidate_parent_to_branch(self):
- c1 = self.engine.connect()
+ with self.engine.connect() as c1:
- c1_branch = c1.connect()
- eq_(c1_branch.execute(select(1)).scalar(), 1)
+ c1_branch = c1.connect()
+ eq_(c1_branch.execute(select(1)).scalar(), 1)
- self.engine.test_shutdown()
+ self.engine.test_shutdown()
- _assert_invalidated(c1.execute, select(1))
- assert c1.invalidated
- assert c1_branch.invalidated
+ _assert_invalidated(c1.execute, select(1))
+ assert c1.invalidated
+ assert c1_branch.invalidated
- c1._revalidate_connection()
- assert not c1.invalidated
- assert not c1_branch.invalidated
+ c1._revalidate_connection()
+ assert not c1.invalidated
+ assert not c1_branch.invalidated
def test_branch_invalidate_state(self):
- c1 = self.engine.connect()
+ with self.engine.connect() as c1:
- c1_branch = c1.connect()
+ c1_branch = c1.connect()
- eq_(c1_branch.execute(select(1)).scalar(), 1)
+ eq_(c1_branch.execute(select(1)).scalar(), 1)
- self.engine.test_shutdown()
+ self.engine.test_shutdown()
- _assert_invalidated(c1_branch.execute, select(1))
- assert not c1_branch.closed
- assert not c1_branch._still_open_and_dbapi_connection_is_valid
+ _assert_invalidated(c1_branch.execute, select(1))
+ assert not c1_branch.closed
+ assert not c1_branch._still_open_and_dbapi_connection_is_valid
def test_ensure_is_disconnect_gets_connection(self):
def is_disconnect(e, conn, cursor):
# assert cursor is None
self.engine.dialect.is_disconnect = is_disconnect
- conn = self.engine.connect()
- self.engine.test_shutdown()
- with expect_warnings(
- "An exception has occurred during handling .*", py2konly=True
- ):
- assert_raises(tsa.exc.DBAPIError, conn.execute, select(1))
+
+ with self.engine.connect() as conn:
+ self.engine.test_shutdown()
+ with expect_warnings(
+ "An exception has occurred during handling .*", py2konly=True
+ ):
+ assert_raises(tsa.exc.DBAPIError, conn.execute, select(1))
def test_rollback_on_invalid_plain(self):
- conn = self.engine.connect()
- trans = conn.begin()
- conn.invalidate()
- trans.rollback()
+ with self.engine.connect() as conn:
+ trans = conn.begin()
+ conn.invalidate()
+ trans.rollback()
@testing.requires.two_phase_transactions
def test_rollback_on_invalid_twophase(self):
- conn = self.engine.connect()
- trans = conn.begin_twophase()
- conn.invalidate()
- trans.rollback()
+ with self.engine.connect() as conn:
+ trans = conn.begin_twophase()
+ conn.invalidate()
+ trans.rollback()
@testing.requires.savepoints
def test_rollback_on_invalid_savepoint(self):
- conn = self.engine.connect()
- conn.begin()
- trans2 = conn.begin_nested()
- conn.invalidate()
- trans2.rollback()
+ with self.engine.connect() as conn:
+ conn.begin()
+ trans2 = conn.begin_nested()
+ conn.invalidate()
+ trans2.rollback()
def test_invalidate_twice(self):
- conn = self.engine.connect()
- conn.invalidate()
- conn.invalidate()
+ with self.engine.connect() as conn:
+ conn.invalidate()
+ conn.invalidate()
@testing.skip_if(
[lambda: util.py3k, "oracle+cx_oracle"], "Crashes on py3k+cx_oracle"
engine = engines.reconnecting_engine(
options=dict(poolclass=pool.NullPool)
)
- conn = engine.connect()
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.closed
- engine.test_shutdown()
- _assert_invalidated(conn.execute, select(1))
- assert not conn.closed
- assert conn.invalidated
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.invalidated
+ with engine.connect() as conn:
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.closed
+ engine.test_shutdown()
+ _assert_invalidated(conn.execute, select(1))
+ assert not conn.closed
+ assert conn.invalidated
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.invalidated
def test_close(self):
- conn = self.engine.connect()
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.closed
+ with self.engine.connect() as conn:
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.closed
- self.engine.test_shutdown()
+ self.engine.test_shutdown()
- _assert_invalidated(conn.execute, select(1))
+ _assert_invalidated(conn.execute, select(1))
- conn.close()
- conn = self.engine.connect()
- eq_(conn.execute(select(1)).scalar(), 1)
+ with self.engine.connect() as conn:
+ eq_(conn.execute(select(1)).scalar(), 1)
def test_with_transaction(self):
- conn = self.engine.connect()
- trans = conn.begin()
- assert trans.is_valid
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.closed
- self.engine.test_shutdown()
- _assert_invalidated(conn.execute, select(1))
- assert not conn.closed
- assert conn.invalidated
- assert trans.is_active
- assert not trans.is_valid
+ with self.engine.connect() as conn:
+ trans = conn.begin()
+ assert trans.is_valid
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.closed
+ self.engine.test_shutdown()
+ _assert_invalidated(conn.execute, select(1))
+ assert not conn.closed
+ assert conn.invalidated
+ assert trans.is_active
+ assert not trans.is_valid
- assert_raises_message(
- tsa.exc.PendingRollbackError,
- "Can't reconnect until invalid transaction is rolled back",
- conn.execute,
- select(1),
- )
- assert trans.is_active
- assert not trans.is_valid
+ assert_raises_message(
+ tsa.exc.PendingRollbackError,
+ "Can't reconnect until invalid transaction is rolled back",
+ conn.execute,
+ select(1),
+ )
+ assert trans.is_active
+ assert not trans.is_valid
- assert_raises_message(
- tsa.exc.PendingRollbackError,
- "Can't reconnect until invalid transaction is rolled back",
- trans.commit,
- )
+ assert_raises_message(
+ tsa.exc.PendingRollbackError,
+ "Can't reconnect until invalid transaction is rolled back",
+ trans.commit,
+ )
- # becomes inactive
- assert not trans.is_active
- assert not trans.is_valid
+ # becomes inactive
+ assert not trans.is_active
+ assert not trans.is_valid
- # still asks us to rollback
- assert_raises_message(
- tsa.exc.PendingRollbackError,
- "Can't reconnect until invalid transaction is rolled back",
- conn.execute,
- select(1),
- )
+ # still asks us to rollback
+ assert_raises_message(
+ tsa.exc.PendingRollbackError,
+ "Can't reconnect until invalid transaction is rolled back",
+ conn.execute,
+ select(1),
+ )
- # still asks us..
- assert_raises_message(
- tsa.exc.PendingRollbackError,
- "Can't reconnect until invalid transaction is rolled back",
- trans.commit,
- )
+ # still asks us..
+ assert_raises_message(
+ tsa.exc.PendingRollbackError,
+ "Can't reconnect until invalid transaction is rolled back",
+ trans.commit,
+ )
- # still...it's being consistent in what it is asking.
- assert_raises_message(
- tsa.exc.PendingRollbackError,
- "Can't reconnect until invalid transaction is rolled back",
- conn.execute,
- select(1),
- )
+ # still...it's being consistent in what it is asking.
+ assert_raises_message(
+ tsa.exc.PendingRollbackError,
+ "Can't reconnect until invalid transaction is rolled back",
+ conn.execute,
+ select(1),
+ )
- # OK!
- trans.rollback()
- assert not trans.is_active
- assert not trans.is_valid
+ # OK!
+ trans.rollback()
+ assert not trans.is_active
+ assert not trans.is_valid
- # conn still invalid but we can reconnect
- assert conn.invalidated
- eq_(conn.execute(select(1)).scalar(), 1)
- assert not conn.invalidated
+ # conn still invalid but we can reconnect
+ assert conn.invalidated
+ eq_(conn.execute(select(1)).scalar(), 1)
+ assert not conn.invalidated
class RecycleTest(fixtures.TestBase):
"+pymysql",
"+pg8000",
"+asyncpg",
+ "+aiomysql",
],
"Buffers the result set and doesn't check for connection close",
)
dbapi_connection = connection_fairy.connection
await conn.invalidate()
- assert dbapi_connection._connection.is_closed()
+
+ if testing.against("postgresql+asyncpg"):
+ assert dbapi_connection._connection.is_closed()
new_fairy = await conn.get_raw_connection()
is_not(new_fairy.connection, dbapi_connection)
eq_(result.keys(), ["user_id", "user_name"])
+ await result.close()
+
@async_test
async def test_unique_all(self, async_engine):
users = self.tables.users
eq_(result.scalars().all(), self.static.user_address_result)
@async_test
+ @testing.requires.independent_cursors
async def test_stream_partitions(self, async_session):
User = self.classes.User
result = await async_session.execute(select(User))
eq_(result.scalar(), u1)
+ await outer_conn.rollback()
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
@async_test
await async_session.commit()
+ await outer_conn.rollback()
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
@async_test
finally:
await trans.commit()
+ await outer_conn.rollback()
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
@async_test
def async_dialect(self):
"""dialect makes use of await_() to invoke operations on the DBAPI."""
- return only_on(["postgresql+asyncpg"])
+ return only_on(
+ ["postgresql+asyncpg", "mysql+aiomysql", "mariadb+aiomysql"]
+ )
@property
def oracle_test_dblink(self):
@property
def percent_schema_names(self):
- return exclusions.open()
+ return skip_if(
+ ["mysql+aiomysql", "mariadb+aiomysql"],
+ "see pr https://github.com/aio-libs/aiomysql/pull/545",
+ )
@property
def order_by_label_with_expression(self):
def test_fetchmany(self, connection):
users = self.tables.users
- connection.execute(users.insert(), user_id=7, user_name="jack")
- connection.execute(users.insert(), user_id=8, user_name="ed")
- connection.execute(users.insert(), user_id=9, user_name="fred")
+ connection.execute(
+ users.insert(),
+ [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)],
+ )
r = connection.execute(users.select())
rows = []
for row in r.fetchmany(size=2):
rows.append(row)
eq_(len(rows), 2)
+ def test_fetchmany_arraysize_default(self, connection):
+ users = self.tables.users
+
+ connection.execute(
+ users.insert(),
+ [{"user_id": i, "user_name": "n%d" % i} for i in range(1, 150)],
+ )
+ r = connection.execute(users.select())
+ arraysize = r.cursor.arraysize
+ rows = list(r.fetchmany())
+
+ eq_(len(rows), min(arraysize, 150))
+
+ def test_fetchmany_arraysize_set(self, connection):
+ users = self.tables.users
+
+ connection.execute(
+ users.insert(),
+ [{"user_id": i, "user_name": "n%d" % i} for i in range(7, 15)],
+ )
+ r = connection.execute(users.select())
+ r.cursor.arraysize = 4
+ rows = list(r.fetchmany())
+ eq_(len(rows), 4)
+
def test_column_slices(self, connection):
users = self.tables.users
addresses = self.tables.addresses
postgresql: .[postgresql_pg8000]; python_version >= '3'
mysql: .[mysql]
mysql: .[pymysql]
+ mysql: .[aiomysql]; python_version >= '3'
mysql: .[mariadb_connector]; python_version >= '3'
# we should probably try to get mysql_connector back in the mix
mysql: MYSQL={env:TOX_MYSQL:--db mysql}
mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
- py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector}
+ py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql?async_fallback=true}
mssql: MSSQL={env:TOX_MSSQL:--db mssql}