]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add name_func optional attribute for asyncpg adapter 9607/head
authorPavel Sirotkin <pav.pnz@gmail.com>
Thu, 6 Apr 2023 09:30:11 +0000 (11:30 +0200)
committerPavel Sirotkin <pav.pnz@gmail.com>
Thu, 20 Apr 2023 12:07:22 +0000 (14:07 +0200)
Fixes: #9608
lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_async_pg_py3k.py

index 2acc5fea300e0259c92d70e9080c8e70355a9fc2..9b1fdf8b2c08df77e2707921abca09bccbe65f79 100644 (file)
@@ -98,6 +98,44 @@ To disable the prepared statement cache, use a value of zero::
    stale, nor can it retry the statement as the PostgreSQL transaction is
    invalidated when these errors occur.
 
+.. _asyncpg_prepared_statement_name:
+
+Prepared Statement Name
+-----------------------
+
+By default, asyncpg enumerates prepared statements in numeric order, which 
+can lead to errors if a name has already been taken for another prepared 
+statement. This issue can arise if your application uses database proxies 
+such as PgBouncer to handle connections. One possible workaround is to 
+use dynamic prepared statement names, which asyncpg now supports through 
+an optional name value for the statement name. This allows you to 
+generate your own unique names that won't conflict with existing ones. 
+To achieve this, you can provide a function that will be called every time 
+a prepared statement is prepared::
+
+    from uuid import uuid4
+
+    engine = create_async_engine(
+        "postgresql+asyncpg://user:pass@hostname/dbname",     
+        poolclass=NullPool,
+        connect_args={
+            'pepared_statement_name_func': lambda:  f'__asyncpg_{uuid4()}__',
+        },
+    )
+
+.. seealso::
+
+   https://github.com/MagicStack/asyncpg/issues/837
+
+   https://github.com/sqlalchemy/sqlalchemy/issues/6467
+
+.. warning:: To prevent a buildup of useless prepared statements in 
+   your application, it's important to use the NullPool poolclass and 
+   PgBouncer with a configured `DISCARD https://www.postgresql.org/docs/current/sql-discard.html`_ 
+   setup. The DISCARD command is used to release resources held by the db connection, 
+   including prepared statements. Without proper setup, prepared statements can 
+   accumulate quickly and cause performance issues.
+
 Disabling the PostgreSQL JIT to improve ENUM datatype handling
 ---------------------------------------------------------------
 
@@ -642,13 +680,20 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
         "_transaction",
         "_started",
         "_prepared_statement_cache",
+        "_prepared_statement_name_func",
         "_invalidate_schema_cache_asof",
         "_execute_mutex",
     )
 
     await_ = staticmethod(await_only)
 
-    def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
+    def __init__(
+        self,
+        dbapi,
+        connection,
+        prepared_statement_cache_size=100,
+        prepared_statement_name_func=None,
+    ):
         self.dbapi = dbapi
         self._connection = connection
         self.isolation_level = self._isolation_setting = "read_committed"
@@ -666,6 +711,11 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
         else:
             self._prepared_statement_cache = None
 
+        if prepared_statement_name_func:
+            self._prepared_statement_name_func = prepared_statement_name_func
+        else:
+            self._prepared_statement_name_func = self._default_name_func
+
     async def _check_type_cache_invalidation(self, invalidate_timestamp):
         if invalidate_timestamp > self._invalidate_schema_cache_asof:
             await self._connection.reload_schema_state()
@@ -676,7 +726,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
 
         cache = self._prepared_statement_cache
         if cache is None:
-            prepared_stmt = await self._connection.prepare(operation)
+            prepared_stmt = await self._connection.prepare(
+                operation, name=self._prepared_statement_name_func()
+            )
             attributes = prepared_stmt.get_attributes()
             return prepared_stmt, attributes
 
@@ -692,7 +744,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
             if cached_timestamp > invalidate_timestamp:
                 return prepared_stmt, attributes
 
-        prepared_stmt = await self._connection.prepare(operation)
+        prepared_stmt = await self._connection.prepare(
+            operation, name=self._prepared_statement_name_func()
+        )
         attributes = prepared_stmt.get_attributes()
         cache[operation] = (prepared_stmt, attributes, time.time())
 
@@ -792,6 +846,10 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
     def terminate(self):
         self._connection.terminate()
 
+    @staticmethod
+    def _default_name_func():
+        return None
+
 
 class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
     __slots__ = ()
@@ -809,17 +867,23 @@ class AsyncAdapt_asyncpg_dbapi:
         prepared_statement_cache_size = kw.pop(
             "prepared_statement_cache_size", 100
         )
+        prepared_statement_name_func = kw.pop(
+            "prepared_statement_name_func", None
+        )
+
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncpg_connection(
                 self,
                 await_fallback(self.asyncpg.connect(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
+                prepared_statement_name_func=prepared_statement_name_func,
             )
         else:
             return AsyncAdapt_asyncpg_connection(
                 self,
                 await_only(self.asyncpg.connect(*arg, **kw)),
                 prepared_statement_cache_size=prepared_statement_cache_size,
+                prepared_statement_name_func=prepared_statement_name_func,
             )
 
     class Error(Exception):
index d9116a7cecc815d8c9809caff97ad0fb705c085c..49014fcaf94144c1414803693ecefc24bdcf55b2 100644 (file)
@@ -1,4 +1,5 @@
 import random
+import uuid
 
 from sqlalchemy import Column
 from sqlalchemy import exc
@@ -272,3 +273,19 @@ class AsyncPgTest(fixtures.TestBase):
             await conn.close()
 
         eq_(codec_meth.mock_calls, [mock.call(adapted_conn)])
+
+    @async_test
+    async def test_name_connection_func(self, metadata, async_testing_engine):
+        cache = []
+
+        def name_f():
+            name = str(uuid.uuid4())
+            cache.append(name)
+            return name
+
+        engine = async_testing_engine(
+            options={"connect_args": {"prepared_statement_name_func": name_f}},
+        )
+        async with engine.begin() as conn:
+            await conn.execute(select(1))
+            assert len(cache) > 0