From: Jack Wotherspoon Date: Thu, 1 Jun 2023 14:08:54 +0000 (-0400) Subject: chore: update based on feedback (#3) X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5c163c946be326cce25099b3f1194c51c4a6ff4f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git chore: update based on feedback (#3) --- diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index a4d6ea9c31..4c7efae702 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -34,8 +34,6 @@ This dialect should normally be used only with the """ # noqa -from functools import partial - from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util @@ -256,7 +254,7 @@ class AsyncAdapt_aiomysql_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("creator_fn", partial(self.aiomysql.connect)) + creator_fn = kw.pop("creator_fn", self.aiomysql.connect) if util.asbool(async_fallback): return AsyncAdaptFallback_aiomysql_connection( diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index 6f5b506a71..bcec2f41df 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -30,7 +30,6 @@ This dialect should normally be used only with the """ # noqa from contextlib import asynccontextmanager -from functools import partial from .pymysql import MySQLDialect_pymysql from ... import pool @@ -267,7 +266,7 @@ class AsyncAdapt_asyncmy_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("creator_fn", partial(self.asyncmy.connect)) + creator_fn = kw.pop("creator_fn", self.asyncmy.connect) if util.asbool(async_fallback): return AsyncAdaptFallback_asyncmy_connection( diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 783552dbf5..a044714473 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -161,7 +161,6 @@ from __future__ import annotations import collections import decimal -from functools import partial import json as _py_json import re import time @@ -873,7 +872,7 @@ class AsyncAdapt_asyncpg_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("creator_fn", partial(self.asyncpg.connect)) + creator_fn = kw.pop("creator_fn", self.asyncpg.connect) prepared_statement_cache_size = kw.pop( "prepared_statement_cache_size", 100 ) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py index f63bc45f15..132b9f65cb 100644 --- a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -298,7 +298,7 @@ class AsyncAdapt_aiosqlite_dbapi: def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("creator_fn", partial(self.aiosqlite.connect)) + creator_fn = kw.pop("creator_fn", self.aiosqlite.connect) connection = creator_fn(*arg, **kw) # it's a Thread. you'll thank us later diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 7cb2e6e692..7bdb8e24f8 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -39,6 +39,7 @@ from ...engine import create_pool_from_url as _create_pool_from_url from ...engine import Engine from ...engine.base import NestedTransaction from ...engine.base import Transaction +from ...exc import ArgumentError from ...util.concurrency import greenlet_spawn if TYPE_CHECKING: @@ -84,17 +85,17 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: kw["_is_async"] = True async_creator = kw.pop("async_creator", None) if async_creator: - - async def wrap_async_creator() -> Any: - return await async_creator() + if kw.get("creator", None): + raise ArgumentError( + "can only specify one of 'async_creator' or" + " 'creator', not both." + ) def creator() -> Any: # note that to send adapted arguments like # prepared_statement_cache_size, user would use # "creator" and emulate this form here - return sync_engine.dialect.dbapi.connect( - creator_fn=wrap_async_creator - ) + return sync_engine.dialect.dbapi.connect(creator_fn=async_creator) kw["creator"] = creator sync_engine = _create_engine(url, **kw)