]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mutex asyncpg / aiomysql connection state changes
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 23 Feb 2021 02:49:09 +0000 (21:49 -0500)
committermike bayer <mike_mp@zzzcomputing.com>
Thu, 25 Feb 2021 20:49:52 +0000 (20:49 +0000)
Added an ``asyncio.Lock()`` within SQLAlchemy's emulated DBAPI cursor,
local to the connection, for the asyncpg dialect, so that the space between
the call to ``prepare()`` and ``fetch()`` is prevented from allowing
concurrent executions on the connection from causing interface error
exceptions, as well as preventing race conditions when starting a new
transaction. Other PostgreSQL DBAPIs are threadsafe at the connection level
so this intends to provide a similar behavior, outside the realm of server
side cursors.

Apply the same idea to the aiomysql dialect which also would
otherwise be subject to corruption if the connection were used
concurrently.

While this is an issue which can also occur with the threaded
connection libraries, we anticipate asyncio users are more likely
to attempt using the same connection in multiple awaitables
at a time, even though this won't achieve concurrency for that
use case, as the asyncio programming style is very encouraging
of this.  As the failure modes are also more complicated under
asyncio, we'd rather not have this being reported.

Fixes: #5967
Change-Id: I3670ba0c8f0b593c587c5aa7c6c61f9e8c5eb93a

doc/build/changelog/unreleased_14/5967.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py

diff --git a/doc/build/changelog/unreleased_14/5967.rst b/doc/build/changelog/unreleased_14/5967.rst
new file mode 100644 (file)
index 0000000..10f4fe8
--- /dev/null
@@ -0,0 +1,29 @@
+.. change::
+    :tags: usecase, postgresql, mysql, asyncio
+    :tickets: 5967
+
+    Added an ``asyncio.Lock()`` within SQLAlchemy's emulated DBAPI cursor,
+    local to the connection, for the asyncpg and aiomysql dialects for the
+    scope of the ``cursor.execute()`` and ``cursor.executemany()`` methods. The
+    rationale is to prevent failures and corruption for the case where the
+    connection is used in multiple awaitables at once.
+
+    While this use case can also occur with threaded code and non-asyncio
+    dialects, we anticipate this kind of use will be more common under asyncio,
+    as the asyncio API is encouraging of such use. It's definitely better to
+    use a distinct connection per concurrent awaitable however as concurrency
+    will not be achieved otherwise.
+
+    For the asyncpg dialect, this is so that the space between
+    the call to ``prepare()`` and ``fetch()`` is prevented from allowing
+    concurrent executions on the connection from causing interface error
+    exceptions, as well as preventing race conditions when starting a new
+    transaction. Other PostgreSQL DBAPIs are threadsafe at the connection level
+    so this intends to provide a similar behavior, outside the realm of server
+    side cursors.
+
+    For the aiomysql dialect, the mutex will provide safety such that
+    the statement execution and the result set fetch, which are two distinct
+    steps at the connection level, won't get corrupted by concurrent
+    executions on the same connection.
+
index 6c968a1e7ce6dbc7e25ff7f574125662f0541de9..cab6df499f980ce4e5faa629ba6a7341c83ff772 100644 (file)
@@ -35,6 +35,7 @@ handling.
 from .pymysql import MySQLDialect_pymysql
 from ... import pool
 from ... import util
+from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
@@ -84,24 +85,32 @@ class AsyncAdapt_aiomysql_cursor:
         self._rows[:] = []
 
     def execute(self, operation, parameters=None):
-        if parameters is None:
-            result = self.await_(self._cursor.execute(operation))
-        else:
-            result = self.await_(self._cursor.execute(operation, parameters))
-
-        if not self.server_side:
-            # aiomysql has a "fake" async result, so we have to pull it out
-            # of that here since our default result is not async.
-            # we could just as easily grab "_rows" here and be done with it
-            # but this is safer.
-            self._rows = list(self.await_(self._cursor.fetchall()))
-        return result
+        return self.await_(self._execute_async(operation, parameters))
 
     def executemany(self, operation, seq_of_parameters):
         return self.await_(
-            self._cursor.executemany(operation, seq_of_parameters)
+            self._executemany_async(operation, seq_of_parameters)
         )
 
+    async def _execute_async(self, operation, parameters):
+        async with self._adapt_connection._execute_mutex:
+            if parameters is None:
+                result = await self._cursor.execute(operation)
+            else:
+                result = await self._cursor.execute(operation, parameters)
+
+            if not self.server_side:
+                # aiomysql has a "fake" async result, so we have to pull it out
+                # of that here since our default result is not async.
+                # we could just as easily grab "_rows" here and be done with it
+                # but this is safer.
+                self._rows = list(await self._cursor.fetchall())
+            return result
+
+    async def _executemany_async(self, operation, seq_of_parameters):
+        async with self._adapt_connection._execute_mutex:
+            return await self._cursor.executemany(operation, seq_of_parameters)
+
     def setinputsizes(self, *inputsizes):
         pass
 
@@ -161,11 +170,12 @@ class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
 
 class AsyncAdapt_aiomysql_connection:
     await_ = staticmethod(await_only)
-    __slots__ = ("dbapi", "_connection")
+    __slots__ = ("dbapi", "_connection", "_execute_mutex")
 
     def __init__(self, dbapi, connection):
         self.dbapi = dbapi
         self._connection = connection
+        self._execute_mutex = asyncio.Lock()
 
     def ping(self, reconnect):
         return self.await_(self._connection.ping(reconnect))
index 7ef5e441cbf0eeb69985de25dff4ec1fc255de7b..4580421f68be36b5f5d5696fdbaae9213f3a918b 100644 (file)
@@ -122,6 +122,7 @@ from ... import pool
 from ... import processors
 from ... import util
 from ...sql import sqltypes
+from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
@@ -369,74 +370,90 @@ class AsyncAdapt_asyncpg_cursor:
             )
 
     async def _prepare_and_execute(self, operation, parameters):
+        adapt_connection = self._adapt_connection
 
-        if not self._adapt_connection._started:
-            await self._adapt_connection._start_transaction()
-
-        if parameters is not None:
-            operation = operation % self._parameter_placeholders(parameters)
-        else:
-            parameters = ()
+        async with adapt_connection._execute_mutex:
 
-        try:
-            prepared_stmt, attributes = await self._adapt_connection._prepare(
-                operation, self._invalidate_schema_cache_asof
-            )
+            if not adapt_connection._started:
+                await adapt_connection._start_transaction()
 
-            if attributes:
-                self.description = [
-                    (attr.name, attr.type.oid, None, None, None, None, None)
-                    for attr in attributes
-                ]
+            if parameters is not None:
+                operation = operation % self._parameter_placeholders(
+                    parameters
+                )
             else:
-                self.description = None
+                parameters = ()
 
-            if self.server_side:
-                self._cursor = await prepared_stmt.cursor(*parameters)
-                self.rowcount = -1
-            else:
-                self._rows = await prepared_stmt.fetch(*parameters)
-                status = prepared_stmt.get_statusmsg()
+            try:
+                prepared_stmt, attributes = await adapt_connection._prepare(
+                    operation, self._invalidate_schema_cache_asof
+                )
 
-                reg = re.match(r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status)
-                if reg:
-                    self.rowcount = int(reg.group(1))
+                if attributes:
+                    self.description = [
+                        (
+                            attr.name,
+                            attr.type.oid,
+                            None,
+                            None,
+                            None,
+                            None,
+                            None,
+                        )
+                        for attr in attributes
+                    ]
                 else:
+                    self.description = None
+
+                if self.server_side:
+                    self._cursor = await prepared_stmt.cursor(*parameters)
                     self.rowcount = -1
+                else:
+                    self._rows = await prepared_stmt.fetch(*parameters)
+                    status = prepared_stmt.get_statusmsg()
 
-        except Exception as error:
-            self._handle_exception(error)
+                    reg = re.match(
+                        r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status
+                    )
+                    if reg:
+                        self.rowcount = int(reg.group(1))
+                    else:
+                        self.rowcount = -1
 
-    def execute(self, operation, parameters=None):
-        try:
-            self._adapt_connection.await_(
-                self._prepare_and_execute(operation, parameters)
-            )
-        except Exception as error:
-            self._handle_exception(error)
+            except Exception as error:
+                self._handle_exception(error)
 
-    def executemany(self, operation, seq_of_parameters):
+    async def _executemany(self, operation, seq_of_parameters):
         adapt_connection = self._adapt_connection
 
-        adapt_connection.await_(
-            adapt_connection._check_type_cache_invalidation(
+        async with adapt_connection._execute_mutex:
+            await adapt_connection._check_type_cache_invalidation(
                 self._invalidate_schema_cache_asof
             )
-        )
 
-        if not adapt_connection._started:
-            adapt_connection.await_(adapt_connection._start_transaction())
+            if not adapt_connection._started:
+                await adapt_connection._start_transaction()
 
-        operation = operation % self._parameter_placeholders(
-            seq_of_parameters[0]
+            operation = operation % self._parameter_placeholders(
+                seq_of_parameters[0]
+            )
+
+            try:
+                return await self._connection.executemany(
+                    operation, seq_of_parameters
+                )
+            except Exception as error:
+                self._handle_exception(error)
+
+    def execute(self, operation, parameters=None):
+        self._adapt_connection.await_(
+            self._prepare_and_execute(operation, parameters)
         )
 
-        try:
-            return adapt_connection.await_(
-                self._connection.executemany(operation, seq_of_parameters)
-            )
-        except Exception as error:
-            self._handle_exception(error)
+    def executemany(self, operation, seq_of_parameters):
+        return self._adapt_connection.await_(
+            self._executemany(operation, seq_of_parameters)
+        )
 
     def setinputsizes(self, *inputsizes):
         self._inputsizes = inputsizes
@@ -561,6 +578,7 @@ class AsyncAdapt_asyncpg_connection:
         "_started",
         "_prepared_statement_cache",
         "_invalidate_schema_cache_asof",
+        "_execute_mutex",
     )
 
     await_ = staticmethod(await_only)
@@ -574,6 +592,7 @@ class AsyncAdapt_asyncpg_connection:
         self._transaction = None
         self._started = False
         self._invalidate_schema_cache_asof = time.time()
+        self._execute_mutex = asyncio.Lock()
 
         if prepared_statement_cache_size:
             self._prepared_statement_cache = util.LRUCache(