]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: avoid caching the parsing of large queries
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Sep 2023 13:42:00 +0000 (14:42 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Sep 2023 00:18:10 +0000 (01:18 +0100)
These queries are typically generated by ORMs and have poor
cacheability, but can result in a lot of memory being used.

Close #628.

See also <https://github.com/sqlalchemy/sqlalchemy/discussions/10270>.

docs/news.rst
psycopg/psycopg/_queries.py
psycopg/psycopg/client_cursor.py
psycopg/psycopg/cursor.py
psycopg/psycopg/raw_cursor.py
tests/test_cursor.py

index d68780c2e7b7bfc7d81132961902b6127a0b7932..6b23c6cabdbdd6f3aeb73ee731939d9733e046f1 100644 (file)
@@ -31,6 +31,8 @@ Psycopg 3.2 (unreleased)
 Psycopg 3.1.11 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Avoid caching the parsing results of large queries to avoid excessive memory
+  usage (:ticket:`#628`).
 - Fix integer overflow in C/binary extension with OID > 2^31 (:ticket:`#630`).
 - Fix building on Solaris and derivatives (:ticket:`#632`).
 - Fix possible lack of critical section guard in async
index caf19a7cc3b9a8b87ca209649fe4b256cdcfc9fb..6da19fa76ae92b002e9af9c61bb332010f7abaf4 100644 (file)
@@ -5,9 +5,10 @@ Utility module to manipulate queries
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
+from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional
 from typing import Sequence, Tuple, Union, TYPE_CHECKING
 from functools import lru_cache
+from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
@@ -20,6 +21,9 @@ from ._encodings import conn_encoding
 if TYPE_CHECKING:
     from .abc import Transformer
 
+MAX_CACHED_STATEMENT_LENGTH = 4096
+MAX_CACHED_STATEMENT_PARAMS = 50
+
 
 class QueryPart(NamedTuple):
     pre: bytes
@@ -68,12 +72,22 @@ class PostgresQuery:
             bquery = query
 
         if vars is not None:
-            (
-                self.query,
-                self._want_formats,
-                self._order,
-                self._parts,
-            ) = self.query2pg(bquery, self._encoding)
+            # Avoid caching queries extremely long or with a huge number of
+            # parameters. They are usually generated by ORMs and have poor
+            # cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
+            # numbers of tuples.
+            # see https://github.com/psycopg/psycopg/discussions/628
+            if (
+                len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
+                and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
+            ):
+                f: PostgresQuery._Query2Pg = PostgresQuery.query2pg
+            else:
+                f = PostgresQuery.query2pg_nocache
+
+            (self.query, self._want_formats, self._order, self._parts) = f(
+                bquery, self._encoding
+            )
         else:
             self.query = bquery
             self._want_formats = self._order = None
@@ -88,10 +102,7 @@ class PostgresQuery:
         """
         if vars is not None:
             params = self.validate_and_reorder_params(self._parts, vars, self._order)
-            num_params = len(params)
-            if self._want_formats is None:
-                self._want_formats = [PyFormat.AUTO] * num_params
-            assert len(self._want_formats) == num_params
+            assert self._want_formats is not None
             self.params = self._tx.dump_sequence(params, self._want_formats)
             self.types = self._tx.types or ()
             self.formats = self._tx.formats
@@ -154,9 +165,14 @@ class PostgresQuery:
                     f" {', '.join(sorted(i for i in order or () if i not in vars))}"
                 )
 
+    # The type of the query2pg() and query2pg_nocache() methods
+    _Query2Pg: TypeAlias = Callable[
+        [bytes, str],
+        Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]],
+    ]
+
     @staticmethod
-    @lru_cache()
-    def query2pg(
+    def query2pg_nocache(
         query: bytes, encoding: str
     ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
         """
@@ -205,6 +221,8 @@ class PostgresQuery:
 
         return b"".join(chunks), formats, order, parts
 
+    query2pg = lru_cache()(query2pg_nocache)
+
 
 class PostgresClientQuery(PostgresQuery):
     """
@@ -228,9 +246,20 @@ class PostgresClientQuery(PostgresQuery):
             bquery = query
 
         if vars is not None:
-            (self.template, _, self._order, self._parts) = PostgresClientQuery.query2pg(
-                bquery, self._encoding
-            )
+            # Avoid caching queries extremely long or with a huge number of
+            # parameters. They are usually generated by ORMs and have poor
+            # cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
+            # numbers of tuples.
+            # see https://github.com/psycopg/psycopg/discussions/628
+            if (
+                len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
+                and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
+            ):
+                f: PostgresQuery._Query2Pg = PostgresClientQuery.query2pg
+            else:
+                f = PostgresClientQuery.query2pg_nocache
+
+            (self.template, _, self._order, self._parts) = f(bquery, self._encoding)
         else:
             self.query = bquery
             self._order = None
@@ -253,8 +282,7 @@ class PostgresClientQuery(PostgresQuery):
             self.params = None
 
     @staticmethod
-    @lru_cache()
-    def query2pg(
+    def query2pg_nocache(
         query: bytes, encoding: str
     ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
         """
@@ -290,6 +318,8 @@ class PostgresClientQuery(PostgresQuery):
 
         return b"".join(chunks), None, order, parts
 
+    query2pg = lru_cache()(query2pg_nocache)
+
 
 _re_placeholder = re.compile(
     rb"""(?x)
index 6271ec5086554038cf72566c3b5c6c73a16685da..77cdd4416abcc8ab135a4dff9f1d73e4dc7c4c52 100644 (file)
@@ -28,6 +28,8 @@ BINARY = pq.Format.BINARY
 
 
 class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
+    _query_cls = PostgresClientQuery
+
     def mogrify(self, query: Query, params: Optional[Params] = None) -> str:
         """
         Return the query and parameters merged.
@@ -72,13 +74,6 @@ class ClientCursorMixin(BaseCursor[ConnectionType, Row]):
             # as it can execute more than one statement in a single query.
             self._pgconn.send_query(query.query)
 
-    def _convert_query(
-        self, query: Query, params: Optional[Params] = None
-    ) -> PostgresQuery:
-        pgq = PostgresClientQuery(self._tx)
-        pgq.convert(query, params)
-        return pgq
-
     def _get_prepared(
         self, pgq: PostgresQuery, prepare: Optional[bool] = None
     ) -> Tuple[Prepare, bytes]:
index 7b8395c6cde0a8cb08f9d64501f55c21cc169e05..7353f558d356164b185ce163f36f780c58b9387e 100644 (file)
@@ -59,6 +59,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
     _tx: "Transformer"
     _make_row: RowMaker[Row]
     _pgconn: "PGconn"
+    _query_cls: Type[PostgresQuery] = PostgresQuery
 
     def __init__(self, connection: ConnectionType):
         self._conn = connection
@@ -450,7 +451,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
     def _convert_query(
         self, query: Query, params: Optional[Params] = None
     ) -> PostgresQuery:
-        pgq = PostgresQuery(self._tx)
+        pgq = self._query_cls(self._tx)
         pgq.convert(query, params)
         return pgq
 
index 9bc3164d4b19859baa9fc5c301b64f95ca421168..067f66b09d01712b867c89b6c2c3732b27221d2d 100644 (file)
@@ -4,52 +4,61 @@ psycopg raw queries cursors
 
 # Copyright (C) 2023 The Psycopg Team
 
-from typing import Any, Optional, Sequence, Tuple, List, TYPE_CHECKING
+from typing import Optional, List, Tuple, TYPE_CHECKING
 from functools import lru_cache
 
-from ._queries import PostgresQuery, QueryPart
-
 from .abc import ConnectionType, Query, Params
+from .sql import Composable
 from .rows import Row
-from .cursor import BaseCursor, Cursor
 from ._enums import PyFormat
+from .cursor import BaseCursor, Cursor
 from .cursor_async import AsyncCursor
+from ._queries import PostgresQuery, QueryPart
 
 if TYPE_CHECKING:
+    from typing import Any  # noqa: F401
     from .connection import Connection  # noqa: F401
     from .connection_async import AsyncConnection  # noqa: F401
 
 
-class RawPostgresQuery(PostgresQuery):
+class PostgresRawQuery(PostgresQuery):
+    def convert(self, query: Query, vars: Optional[Params]) -> None:
+        if isinstance(query, str):
+            bquery = query.encode(self._encoding)
+        elif isinstance(query, Composable):
+            bquery = query.as_bytes(self._tx)
+        else:
+            bquery = query
+
+        self.query = bquery
+        self._want_formats = self._order = None
+        self.dump(vars)
+
+    def dump(self, vars: Optional[Params]) -> None:
+        if vars is not None:
+            if not PostgresQuery.is_params_sequence(vars):
+                raise TypeError("raw queries require a sequence of parameters")
+            self._want_formats = [PyFormat.AUTO] * len(vars)
+
+            self.params = self._tx.dump_sequence(vars, self._want_formats)
+            self.types = self._tx.types or ()
+            self.formats = self._tx.formats
+        else:
+            self.params = None
+            self.types = ()
+            self.formats = None
+
     @staticmethod
-    @lru_cache()
-    def query2pg(
+    def query2pg_nocache(
         query: bytes, encoding: str
     ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
-        """
-        Noop; Python raw query is already in the format Postgres understands.
-        """
-        return query, None, None, []
+        raise NotImplementedError()
 
-    @staticmethod
-    def validate_and_reorder_params(
-        parts: List[QueryPart], vars: Params, order: Optional[List[str]]
-    ) -> Sequence[Any]:
-        """
-        Verify the compatibility; params must be a sequence for raw query.
-        """
-        if not PostgresQuery.is_params_sequence(vars):
-            raise TypeError("raw query require a sequence of parameters")
-        return vars
+    query2pg = lru_cache()(query2pg_nocache)
 
 
 class RawCursorMixin(BaseCursor[ConnectionType, Row]):
-    def _convert_query(
-        self, query: Query, params: Optional[Params] = None
-    ) -> PostgresQuery:
-        pgq = RawPostgresQuery(self._tx)
-        pgq.convert(query, params)
-        return pgq
+    _query_cls = PostgresRawQuery
 
 
 class RawCursor(RawCursorMixin["Connection[Any]", Row], Cursor[Row]):
index 3fb15fc250499758772f90fcd7909fc783de37cf..a54f2f7b405c8285c224a2006eab89f98f5d321e 100644 (file)
@@ -169,6 +169,36 @@ def test_execute_sql(conn):
     assert cur.fetchone() == ("hello",)
 
 
+def test_query_parse_cache_size(conn):
+    cur = conn.cursor()
+    query_cls = cur._query_cls
+
+    # Warning: testing internal structures. Test have to be refactored.
+    query_cls.query2pg.cache_clear()
+    ci = query_cls.query2pg.cache_info()
+    h0, m0 = ci.hits, ci.misses
+    tests = [
+        (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
+        (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
+        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 1, m0 + 2),
+        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 2, m0 + 2),
+        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
+        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
+    ]
+    for i, (query, params, hits, misses) in enumerate(tests):
+        pq = query_cls(psycopg.adapt.Transformer())
+        pq.convert(query, params)
+        ci = query_cls.query2pg.cache_info()
+        if not isinstance(cur, psycopg.RawCursor):
+            assert ci.hits == hits, f"at {i}"
+            assert ci.misses == misses, f"at {i}"
+        else:
+            assert ci.hits == 0, f"at {i}"
+            assert ci.misses == 0, f"at {i}"
+
+
 def test_execute_many_results(conn):
     cur = conn.cursor()
     assert cur.nextset() is None