]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Cache asyngpg prepared statements
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Jan 2021 18:21:55 +0000 (13:21 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 3 Jan 2021 03:09:39 +0000 (22:09 -0500)
Enhanced the performance of the asyncpg dialect by caching the asyncpg
PreparedStatement objects on a per-connection basis. For a test case that
makes use of the same statement on a set of pooled connections this appears
to grant a 10-20% speed improvement.  The cache size is adjustable and may
also be disabled.

Unfortunately the caching gets more complicated when there are
schema changes present.  An invalidation scheme is now also added
to accommodate for prepared statements as well as asyncpg cached OIDs.
However, the exception raises cannot be prevented if DDL has changed
database structures that were cached for a particular asyncpg
connection.  Logic is added to clear the caches when these errors occur.

Change-Id: Icf02aa4871eb192f245690f28be4e9f9c35656c6

doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_async_pg_py3k.py [new file with mode: 0644]
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst b/doc/build/changelog/unreleased_14/asyncpg_prepared_cache.rst
new file mode 100644 (file)
index 0000000..eee6fb1
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: postgresql, performance
+
+    Enhanced the performance of the asyncpg dialect by caching the asyncpg
+    PreparedStatement objects on a per-connection basis. For a test case that
+    makes use of the same statement on a set of pooled connections this appears
+    to grant a 10-20% speed improvement.  The cache size is adjustable and may
+    also be disabled.
+
+    .. seealso::
+
+        :ref:`asyncpg_prepared_statement_cache`
index 6b7e78266476d085b55f57299ce9d77c061e4034..a79469f2e85759a3c2d04d94704d4a8baebc9d2e 100644 (file)
@@ -45,6 +45,57 @@ in conjunction with :func:`_sa.create_engine`::
     ``json_deserializer`` when creating the engine with
     :func:`create_engine` or :func:`create_async_engine`.
 
+
+.. _asyncpg_prepared_statement_cache:
+
+Prepared Statement Cache
+--------------------------
+
+The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
+for all statements.   The prepared statement objects are cached after
+construction which appears to grant a 10% or more performance improvement for
+statement invocation.   The cache is on a per-DBAPI connection basis, which
+means that the primary storage for prepared statements is within DBAPI
+connections pooled within the connection pool.   The size of this cache
+defaults to 100 statements per DBAPI connection and may be adjusted using the
+``prepared_statement_cache_size`` DBAPI argument (note that while this argument
+is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
+asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
+argument)::
+
+
+    engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
+
+To disable the prepared statement cache, use a value of zero::
+
+    engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
+
+.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
+
+
+.. warning::  The ``asyncpg`` database driver necessarily uses caches for
+   PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
+   such as ``ENUM`` objects are changed via DDL operations.   Additionally,
+   prepared statements themselves which are optionally cached by SQLAlchemy's
+   driver as described above may also become "stale" when DDL has been emitted
+   to the PostgreSQL database which modifies the tables or other objects
+   involved in a particular prepared statement.
+
+   The SQLAlchemy asyncpg dialect will invalidate these caches within its local
+   process when statements that represent DDL are emitted on a local
+   connection, but this is only controllable within a single Python process /
+   database engine.     If DDL changes are made from other database engines
+   and/or processes, a running application may encounter asyncpg exceptions
+   ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
+   failed for type <oid>")`` if it refers to pooled database connections which
+   operated upon the previous structures. The SQLAlchemy asyncpg dialect will
+   recover from these error cases when the driver raises these exceptions by
+   clearing its internal caches as well as those of the asyncpg driver in
+   response to them, but cannot prevent them from being raised in the first
+   place if the cached prepared statement or asyncpg type caches have gone
+   stale, nor can it retry the statement as the PostgreSQL transaction is
+   invalidated when these errors occur.
+
 """  # noqa
 
 import collections
@@ -52,6 +103,7 @@ import decimal
 import itertools
 import json as _py_json
 import re
+import time
 
 from . import json
 from .base import _DECIMAL_TYPES
@@ -235,9 +287,23 @@ class AsyncpgOID(OID):
 
 
 class PGExecutionContext_asyncpg(PGExecutionContext):
+    def handle_dbapi_exception(self, e):
+        if isinstance(
+            e,
+            (
+                self.dialect.dbapi.InvalidCachedStatementError,
+                self.dialect.dbapi.InternalServerError,
+            ),
+        ):
+            self.dialect._invalidate_schema_cache()
+
     def pre_exec(self):
         if self.isddl:
-            self._dbapi_connection.reset_schema_state()
+            self.dialect._invalidate_schema_cache()
+
+        self.cursor._invalidate_schema_cache_asof = (
+            self.dialect._invalidate_schema_cache_asof
+        )
 
         if not self.compiled:
             return
@@ -269,6 +335,7 @@ class AsyncAdapt_asyncpg_cursor:
         "rowcount",
         "_inputsizes",
         "_cursor",
+        "_invalidate_schema_cache_asof",
     )
 
     server_side = False
@@ -282,6 +349,7 @@ class AsyncAdapt_asyncpg_cursor:
         self.arraysize = 1
         self.rowcount = -1
         self._inputsizes = None
+        self._invalidate_schema_cache_asof = 0
 
     def close(self):
         self._rows[:] = []
@@ -302,25 +370,25 @@ class AsyncAdapt_asyncpg_cursor:
             )
 
     async def _prepare_and_execute(self, operation, parameters):
-        # TODO: I guess cache these in an LRU cache, or see if we can
-        # use some asyncpg concept
-
-        # TODO: would be nice to support the dollar numeric thing
-        # directly, this is much easier for now
 
         if not self._adapt_connection._started:
             await self._adapt_connection._start_transaction()
 
         params = self._parameters()
+
+        # TODO: would be nice to support the dollar numeric thing
+        # directly, this is much easier for now
         operation = re.sub(r"\?", lambda m: next(params), operation)
+
         try:
-            prepared_stmt = await self._connection.prepare(operation)
+            prepared_stmt, attributes = await self._adapt_connection._prepare(
+                operation, self._invalidate_schema_cache_asof
+            )
 
-            attributes = prepared_stmt.get_attributes()
             if attributes:
                 self.description = [
                     (attr.name, attr.type.oid, None, None, None, None, None)
-                    for attr in prepared_stmt.get_attributes()
+                    for attr in attributes
                 ]
             else:
                 self.description = None
@@ -350,15 +418,21 @@ class AsyncAdapt_asyncpg_cursor:
             self._handle_exception(error)
 
     def executemany(self, operation, seq_of_parameters):
-        if not self._adapt_connection._started:
-            self._adapt_connection.await_(
-                self._adapt_connection._start_transaction()
+        adapt_connection = self._adapt_connection
+
+        adapt_connection.await_(
+            adapt_connection._check_type_cache_invalidation(
+                self._invalidate_schema_cache_asof
             )
+        )
+
+        if not adapt_connection._started:
+            adapt_connection.await_(adapt_connection._start_transaction())
 
         params = self._parameters()
         operation = re.sub(r"\?", lambda m: next(params), operation)
         try:
-            return self._adapt_connection.await_(
+            return adapt_connection.await_(
                 self._connection.executemany(operation, seq_of_parameters)
             )
         except Exception as error:
@@ -485,11 +559,13 @@ class AsyncAdapt_asyncpg_connection:
         "deferrable",
         "_transaction",
         "_started",
+        "_prepared_statement_cache",
+        "_invalidate_schema_cache_asof",
     )
 
     await_ = staticmethod(await_only)
 
-    def __init__(self, dbapi, connection):
+    def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
         self.dbapi = dbapi
         self._connection = connection
         self.isolation_level = self._isolation_setting = "read_committed"
@@ -497,6 +573,46 @@ class AsyncAdapt_asyncpg_connection:
         self.deferrable = False
         self._transaction = None
         self._started = False
+        self._invalidate_schema_cache_asof = time.time()
+
+        if prepared_statement_cache_size:
+            self._prepared_statement_cache = util.LRUCache(
+                prepared_statement_cache_size
+            )
+        else:
+            self._prepared_statement_cache = None
+
+    async def _check_type_cache_invalidation(self, invalidate_timestamp):
+        if invalidate_timestamp > self._invalidate_schema_cache_asof:
+            await self._connection.reload_schema_state()
+            self._invalidate_schema_cache_asof = invalidate_timestamp
+
+    async def _prepare(self, operation, invalidate_timestamp):
+        await self._check_type_cache_invalidation(invalidate_timestamp)
+
+        cache = self._prepared_statement_cache
+        if cache is None:
+            prepared_stmt = await self._connection.prepare(operation)
+            attributes = prepared_stmt.get_attributes()
+            return prepared_stmt, attributes
+
+        # asyncpg uses a type cache for the "attributes" which seems to go
+        # stale independently of the PreparedStatement itself, so place that
+        # collection in the cache as well.
+        if operation in cache:
+            prepared_stmt, attributes, cached_timestamp = cache[operation]
+
+            # preparedstatements themselves also go stale for certain DDL
+            # changes such as size of a VARCHAR changing, so there is also
+            # a cross-connection invalidation timestamp
+            if cached_timestamp > invalidate_timestamp:
+                return prepared_stmt, attributes
+
+        prepared_stmt = await self._connection.prepare(operation)
+        attributes = prepared_stmt.get_attributes()
+        cache[operation] = (prepared_stmt, attributes, time.time())
+
+        return prepared_stmt, attributes
 
     def _handle_exception(self, error):
         if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
@@ -551,9 +667,6 @@ class AsyncAdapt_asyncpg_connection:
         else:
             return AsyncAdapt_asyncpg_cursor(self)
 
-    def reset_schema_state(self):
-        self.await_(self._connection.reload_schema_state())
-
     def rollback(self):
         if self._started:
             self.await_(self._transaction.rollback())
@@ -586,16 +699,20 @@ class AsyncAdapt_asyncpg_dbapi:
 
     def connect(self, *arg, **kw):
         async_fallback = kw.pop("async_fallback", False)
-
+        prepared_statement_cache_size = kw.pop(
+            "prepared_statement_cache_size", 100
+        )
         if util.asbool(async_fallback):
             return AsyncAdaptFallback_asyncpg_connection(
                 self,
                 await_fallback(self.asyncpg.connect(*arg, **kw)),
+                prepared_statement_cache_size=prepared_statement_cache_size,
             )
         else:
             return AsyncAdapt_asyncpg_connection(
                 self,
                 await_only(self.asyncpg.connect(*arg, **kw)),
+                prepared_statement_cache_size=prepared_statement_cache_size,
             )
 
     class Error(Exception):
@@ -628,15 +745,29 @@ class AsyncAdapt_asyncpg_dbapi:
     class NotSupportedError(DatabaseError):
         pass
 
+    class InternalServerError(InternalError):
+        pass
+
+    class InvalidCachedStatementError(NotSupportedError):
+        def __init__(self, message):
+            super(
+                AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self
+            ).__init__(
+                message + " (SQLAlchemy asyncpg dialect will now invalidate "
+                "all prepared caches in response to this exception)",
+            )
+
     @util.memoized_property
     def _asyncpg_error_translate(self):
         import asyncpg
 
         return {
-            asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError,  # noqa
+            asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError,  # noqa: E501
             asyncpg.exceptions.PostgresError: self.Error,
             asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
             asyncpg.exceptions.InterfaceError: self.InterfaceError,
+            asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError,  # noqa: E501
+            asyncpg.exceptions.InternalServerError: self.InternalServerError,
         }
 
     def Binary(self, value):
@@ -730,6 +861,10 @@ class PGDialect_asyncpg(PGDialect):
         },
     )
     is_async = True
+    _invalidate_schema_cache_asof = 0
+
+    def _invalidate_schema_cache(self):
+        self._invalidate_schema_cache_asof = time.time()
 
     @util.memoized_property
     def _dbapi_version(self):
@@ -787,9 +922,10 @@ class PGDialect_asyncpg(PGDialect):
 
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username="user")
-        if "port" in opts:
-            opts["port"] = int(opts["port"])
+
         opts.update(url.query)
+        util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
+        util.coerce_kw_type(opts, "port", int)
         return ([], opts)
 
     @classmethod
diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py
new file mode 100644 (file)
index 0000000..fadf939
--- /dev/null
@@ -0,0 +1,182 @@
+import random
+
+from sqlalchemy import Column
+from sqlalchemy import exc
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy.dialects.postgresql import ENUM
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import engines
+from sqlalchemy.testing import fixtures
+
+
+class AsyncPgTest(fixtures.TestBase):
+    __requires__ = ("async_dialect",)
+    __only_on__ = "postgresql+asyncpg"
+
+    @testing.fixture
+    def async_engine(self):
+        return create_async_engine(testing.db.url)
+
+    @testing.fixture()
+    def metadata(self):
+        # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
+        # merges
+
+        from sqlalchemy.testing import engines
+        from sqlalchemy.sql import schema
+
+        metadata = schema.MetaData()
+
+        try:
+            yield metadata
+        finally:
+            engines.drop_all_tables(metadata, testing.db)
+
+    @async_test
+    async def test_detect_stale_ddl_cache_raise_recover(
+        self, metadata, async_engine
+    ):
+        async def async_setup(engine, strlen):
+            metadata.clear()
+            t1 = Table(
+                "t1",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("name", String(strlen)),
+            )
+
+            # conn is an instance of AsyncConnection
+            async with engine.begin() as conn:
+                await conn.run_sync(metadata.drop_all)
+                await conn.run_sync(metadata.create_all)
+                await conn.execute(
+                    t1.insert(),
+                    [{"name": "some name %d" % i} for i in range(500)],
+                )
+
+        meta = MetaData()
+
+        t1 = Table(
+            "t1",
+            meta,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+        )
+
+        await async_setup(async_engine, 30)
+
+        second_engine = engines.testing_engine(asyncio=True)
+
+        async with second_engine.connect() as conn:
+            result = await conn.execute(
+                t1.select()
+                .where(t1.c.name.like("some name%"))
+                .where(t1.c.id % 17 == 6)
+            )
+
+            rows = result.fetchall()
+            assert len(rows) >= 29
+
+        await async_setup(async_engine, 20)
+
+        async with second_engine.connect() as conn:
+            with testing.expect_raises_message(
+                exc.NotSupportedError,
+                r"cached statement plan is invalid due to a database schema "
+                r"or configuration change \(SQLAlchemy asyncpg dialect "
+                r"will now invalidate all prepared caches in response "
+                r"to this exception\)",
+            ):
+
+                result = await conn.execute(
+                    t1.select()
+                    .where(t1.c.name.like("some name%"))
+                    .where(t1.c.id % 17 == 6)
+                )
+
+        # works again
+        async with second_engine.connect() as conn:
+            result = await conn.execute(
+                t1.select()
+                .where(t1.c.name.like("some name%"))
+                .where(t1.c.id % 17 == 6)
+            )
+
+            rows = result.fetchall()
+            assert len(rows) >= 29
+
+    @async_test
+    async def test_detect_stale_type_cache_raise_recover(
+        self, metadata, async_engine
+    ):
+        async def async_setup(engine, enums):
+            metadata = MetaData()
+            Table(
+                "t1",
+                metadata,
+                Column("id", Integer, primary_key=True),
+                Column("name", ENUM(*enums, name="my_enum")),
+            )
+
+            # conn is an instance of AsyncConnection
+            async with engine.begin() as conn:
+                await conn.run_sync(metadata.drop_all)
+                await conn.run_sync(metadata.create_all)
+
+        t1 = Table(
+            "t1",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "name",
+                ENUM(
+                    *("beans", "means", "keens", "faux", "beau", "flow"),
+                    name="my_enum"
+                ),
+            ),
+        )
+
+        await async_setup(async_engine, ("beans", "means", "keens"))
+
+        second_engine = engines.testing_engine(
+            asyncio=True,
+            options={"connect_args": {"prepared_statement_cache_size": 0}},
+        )
+
+        async with second_engine.connect() as conn:
+            await conn.execute(
+                t1.insert(),
+                [
+                    {"name": random.choice(("beans", "means", "keens"))}
+                    for i in range(10)
+                ],
+            )
+
+        await async_setup(async_engine, ("faux", "beau", "flow"))
+
+        async with second_engine.connect() as conn:
+            with testing.expect_raises_message(
+                exc.InternalError, "cache lookup failed for type"
+            ):
+                await conn.execute(
+                    t1.insert(),
+                    [
+                        {"name": random.choice(("faux", "beau", "flow"))}
+                        for i in range(10)
+                    ],
+                )
+
+        # works again
+        async with second_engine.connect() as conn:
+            await conn.execute(
+                t1.insert(),
+                [
+                    {"name": random.choice(("faux", "beau", "flow"))}
+                    for i in range(10)
+                ],
+            )
index 55a114409bd146c6bc1b837364f3ad356bfcbe9a..6239d1f18b86e3dd8d03f766db5cdb43e6b8d37d 100644 (file)
@@ -304,6 +304,9 @@ class ExecuteTest(fixtures.TablesTest):
         class NonStandardException(OperationalError):
             pass
 
+        # TODO: this test is assuming too much of arbitrary dialects and would
+        # be better suited tested against a single mock dialect that does not
+        # have any special behaviors
         with patch.object(
             testing.db.dialect, "dbapi", Mock(Error=DBAPIError)
         ), patch.object(
@@ -312,6 +315,10 @@ class ExecuteTest(fixtures.TablesTest):
             testing.db.dialect,
             "do_execute",
             Mock(side_effect=NonStandardException),
+        ), patch.object(
+            testing.db.dialect.execution_ctx_cls,
+            "handle_dbapi_exception",
+            Mock(),
         ):
             with testing.db.connect() as conn:
                 assert_raises(