""" # noqa
+from functools import partial
from .pymysql import MySQLDialect_pymysql
from ... import pool
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop("creator_fn", partial(self.aiomysql.connect))
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiomysql_connection(
self,
- await_fallback(self.aiomysql.connect(*arg, **kw)),
+ await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_aiomysql_connection(
self,
- await_only(self.aiomysql.connect(*arg, **kw)),
+ await_only(creator_fn(*arg, **kw)),
)
""" # noqa
-
from contextlib import asynccontextmanager
+from functools import partial
from .pymysql import MySQLDialect_pymysql
from ... import pool
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
+ creator_fn = kw.pop("creator_fn", partial(self.asyncmy.connect))
if util.asbool(async_fallback):
return AsyncAdaptFallback_asyncmy_connection(
self,
- await_fallback(self.asyncmy.connect(*arg, **kw)),
+ await_fallback(creator_fn(*arg, **kw)),
)
else:
return AsyncAdapt_asyncmy_connection(
self,
- await_only(self.asyncmy.connect(*arg, **kw)),
+ await_only(creator_fn(*arg, **kw)),
)
import collections
import decimal
-import functools
+from functools import partial
import json as _py_json
import re
import time
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
- creator_fn = kw.pop(
- "creator_fn", functools.partial(self.asyncpg.connect)
- )
+ creator_fn = kw.pop("creator_fn", partial(self.asyncpg.connect))
prepared_statement_cache_size = kw.pop(
"prepared_statement_cache_size", 100
)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
-
- connection = self.aiosqlite.connect(*arg, **kw)
+ creator_fn = kw.pop("creator_fn", partial(self.aiosqlite.connect))
+ connection = creator_fn(*arg, **kw)
# it's a Thread. you'll thank us later
connection.daemon = True
import datetime
+from functools import partial
from sqlalchemy import bindparam
from sqlalchemy import Column
from sqlalchemy.dialects import mysql
from sqlalchemy.engine.url import make_url
from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_warnings
def test_sysdate(self, connection):
d = connection.execute(func.sysdate()).scalar()
assert isinstance(d, datetime.datetime)
+
+
+class AsyncMySQLTest(fixtures.TestBase):
+ __requires__ = ("async_dialect",)
+ __only_on__ = "mysql+aiomysql", "mysql+asyncmy"
+
+ @async_test
+ async def test_async_creator(self, async_testing_engine):
+ if testing.against("mysql+aiomysql"):
+ import aiomysql
+
+ connect_func = partial(aiomysql.connect)
+ if testing.against("mysql+asyncmy"):
+ import asyncmy
+
+ connect_func = partial(asyncmy.connect)
+
+ async def async_creator():
+ conn = await connect_func(
+ host=config.db.url.host,
+ port=config.db.url.port,
+ user=config.db.url.username,
+ password=config.db.url.password,
+ db=config.db.url.database,
+ )
+ return conn
+
+ engine = async_testing_engine(
+ options={"async_creator": async_creator},
+ )
+ async with engine.connect() as conn:
+ result = await conn.execute(select(1))
+ eq_(result.scalar(), 1)
from sqlalchemy import testing
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.testing import async_test
+from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
assert len(cache) > 0
@async_test
- async def test_async_creator(self, metadata, async_testing_engine):
+ async def test_async_creator(self, async_testing_engine):
import asyncpg
+ url = config.db.url.render_as_string(hide_password=False)
+ # format URL properly, strip driver
+ url = url.replace("+asyncpg", "")
+
async def async_creator():
- conn = await asyncpg.connect(
- "postgresql://scott:tiger@127.0.0.1:5432/test"
- )
+ conn = await asyncpg.connect(url)
return conn
engine = async_testing_engine(
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
+from sqlalchemy.testing import async_test
from sqlalchemy.testing import combinations
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
eq_(res, ["sqlitetempview"])
finally:
connection.exec_driver_sql("DROP VIEW sqlitetempview")
+
+
+class AsyncSqlliteTest(fixtures.TestBase):
+ __requires__ = ("async_dialect",)
+ __only_on__ = "sqlite+aiosqlite"
+
+ @async_test
+ async def test_async_creator(self, async_testing_engine):
+ import aiosqlite
+
+ async def async_creator():
+ conn = await aiosqlite.connect(":memory:")
+ return conn
+
+ engine = async_testing_engine(
+ options={"async_creator": async_creator},
+ )
+ async with engine.connect() as conn:
+ result = await conn.execute(select(1))
+ eq_(result.scalar(), 1)