]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
chore: update based on feedback (#3)
authorJack Wotherspoon <jackwoth@google.com>
Thu, 1 Jun 2023 14:08:54 +0000 (10:08 -0400)
committerGitHub <noreply@github.com>
Thu, 1 Jun 2023 14:08:54 +0000 (10:08 -0400)
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/ext/asyncio/engine.py

index a4d6ea9c31da4a3aee92f1441dc05731554ec8e1..4c7efae702f95ace37bd771e60ff5e68e00a1a02 100644 (file)
@@ -34,8 +34,6 @@ This dialect should normally be used only with the
 
 
 """  # noqa
-from functools import partial
-
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
@@ -256,7 +254,7 @@ class AsyncAdapt_aiomysql_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-        creator_fn = kw.pop("creator_fn", partial(self.aiomysql.connect))
+        creator_fn = kw.pop("creator_fn", self.aiomysql.connect)
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_aiomysql_connection(
index 6f5b506a71c6afb7e7a8fdbd4229cca7e5f8cdeb..bcec2f41df980d7a31920b603078abb203b3623e 100644 (file)
@@ -30,7 +30,6 @@ This dialect should normally be used only with the
 
 """  # noqa
 from contextlib import asynccontextmanager
-from functools import partial
 
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
@@ -267,7 +266,7 @@ class AsyncAdapt_asyncmy_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-        creator_fn = kw.pop("creator_fn", partial(self.asyncmy.connect))
+        creator_fn = kw.pop("creator_fn", self.asyncmy.connect)
 
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncmy_connection(
index 783552dbf5d980d8cd3e4c3ec3e66658837375b9..a0447144737509c1f07b0c4b072e9e457227d21d 100644 (file)
@@ -161,7 +161,6 @@ from __future__ import annotations
 
 import collections
 import decimal
-from functools import partial
 import json as _py_json
 import re
 import time
@@ -873,7 +872,7 @@ class AsyncAdapt_asyncpg_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-        creator_fn = kw.pop("creator_fn", partial(self.asyncpg.connect))
+        creator_fn = kw.pop("creator_fn", self.asyncpg.connect)
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
         )
index f63bc45f15cc5a82019699b64535d1526edc65ed..132b9f65cb85cdd9b1a4bf5cf3e9af2eaaffc04c 100644 (file)
@@ -298,7 +298,7 @@ class AsyncAdapt_aiosqlite_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-        creator_fn = kw.pop("creator_fn", partial(self.aiosqlite.connect))
+        creator_fn = kw.pop("creator_fn", self.aiosqlite.connect)
         connection = creator_fn(*arg, **kw)
 
         # it's a Thread.   you'll thank us later
index 7cb2e6e69239355d0f73a16de157f7c7e6720c9b..7bdb8e24f8aadae7021a30b94cc2da2343537466 100644 (file)
@@ -39,6 +39,7 @@ from ...engine import create_pool_from_url as _create_pool_from_url
 from ...engine import Engine
 from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
+from ...exc import ArgumentError
 from ...util.concurrency import greenlet_spawn
 
 if TYPE_CHECKING:
@@ -84,17 +85,17 @@ def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
     kw["_is_async"] = True
     async_creator = kw.pop("async_creator", None)
     if async_creator:
-
-        async def wrap_async_creator() -> Any:
-            return await async_creator()
+        if kw.get("creator", None):
+            raise ArgumentError(
+                "can only specify one of 'async_creator' or"
+                " 'creator', not both."
+            )
 
         def creator() -> Any:
             # note that to send adapted arguments like
             # prepared_statement_cache_size, user would use
             # "creator" and emulate this form here
-            return sync_engine.dialect.dbapi.connect(
-                creator_fn=wrap_async_creator
-            )
+            return sync_engine.dialect.dbapi.connect(creator_fn=async_creator)
 
         kw["creator"] = creator
     sync_engine = _create_engine(url, **kw)