From: Daniele Varrazzo Date: Mon, 4 Sep 2023 13:42:00 +0000 (+0100) Subject: fix: avoid caching the parsing of large queries X-Git-Tag: pool-3.2.0~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=082babf150914ab3fa77947d520d11a976d16739;p=thirdparty%2Fpsycopg.git fix: avoid caching the parsing of large queries These queries are typically generated by ORMs and have poor cacheability, but can result in a lot of memory being used. Close #628. See also . --- diff --git a/docs/news.rst b/docs/news.rst index d68780c2e..6b23c6cab 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -31,6 +31,8 @@ Psycopg 3.2 (unreleased) Psycopg 3.1.11 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +- Avoid caching the parsing results of large queries to avoid excessive memory + usage (:ticket:`#628`). - Fix integer overflow in C/binary extension with OID > 2^31 (:ticket:`#630`). - Fix building on Solaris and derivatives (:ticket:`#632`). - Fix possible lack of critical section guard in async diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index caf19a7cc..6da19fa76 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -5,9 +5,10 @@ Utility module to manipulate queries # Copyright (C) 2020 The Psycopg Team import re -from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional +from typing import Any, Callable, Dict, List, Mapping, Match, NamedTuple, Optional from typing import Sequence, Tuple, Union, TYPE_CHECKING from functools import lru_cache +from typing_extensions import TypeAlias from . import pq from . import errors as e @@ -20,6 +21,9 @@ from ._encodings import conn_encoding if TYPE_CHECKING: from .abc import Transformer +MAX_CACHED_STATEMENT_LENGTH = 4096 +MAX_CACHED_STATEMENT_PARAMS = 50 + class QueryPart(NamedTuple): pre: bytes @@ -68,12 +72,22 @@ class PostgresQuery: bquery = query if vars is not None: - ( - self.query, - self._want_formats, - self._order, - self._parts, - ) = self.query2pg(bquery, self._encoding) + # Avoid caching queries extremely long or with a huge number of + # parameters. They are usually generated by ORMs and have poor + # cacheablility (e.g. INSERT ... VALUES (...), (...) with varying + # numbers of tuples. + # see https://github.com/psycopg/psycopg/discussions/628 + if ( + len(bquery) <= MAX_CACHED_STATEMENT_LENGTH + and len(vars) <= MAX_CACHED_STATEMENT_PARAMS + ): + f: PostgresQuery._Query2Pg = PostgresQuery.query2pg + else: + f = PostgresQuery.query2pg_nocache + + (self.query, self._want_formats, self._order, self._parts) = f( + bquery, self._encoding + ) else: self.query = bquery self._want_formats = self._order = None @@ -88,10 +102,7 @@ class PostgresQuery: """ if vars is not None: params = self.validate_and_reorder_params(self._parts, vars, self._order) - num_params = len(params) - if self._want_formats is None: - self._want_formats = [PyFormat.AUTO] * num_params - assert len(self._want_formats) == num_params + assert self._want_formats is not None self.params = self._tx.dump_sequence(params, self._want_formats) self.types = self._tx.types or () self.formats = self._tx.formats @@ -154,9 +165,14 @@ class PostgresQuery: f" {', '.join(sorted(i for i in order or () if i not in vars))}" ) + # The type of the query2pg() and query2pg_nocache() methods + _Query2Pg: TypeAlias = Callable[ + [bytes, str], + Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]], + ] + @staticmethod - @lru_cache() - def query2pg( + def query2pg_nocache( query: bytes, encoding: str ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: """ @@ -205,6 +221,8 @@ class PostgresQuery: return b"".join(chunks), formats, order, parts + query2pg = lru_cache()(query2pg_nocache) + class PostgresClientQuery(PostgresQuery): """ @@ -228,9 +246,20 @@ class PostgresClientQuery(PostgresQuery): bquery = query if vars is not None: - (self.template, _, self._order, self._parts) = PostgresClientQuery.query2pg( - bquery, self._encoding - ) + # Avoid caching queries extremely long or with a huge number of + # parameters. They are usually generated by ORMs and have poor + # cacheablility (e.g. INSERT ... VALUES (...), (...) with varying + # numbers of tuples. + # see https://github.com/psycopg/psycopg/discussions/628 + if ( + len(bquery) <= MAX_CACHED_STATEMENT_LENGTH + and len(vars) <= MAX_CACHED_STATEMENT_PARAMS + ): + f: PostgresQuery._Query2Pg = PostgresClientQuery.query2pg + else: + f = PostgresClientQuery.query2pg_nocache + + (self.template, _, self._order, self._parts) = f(bquery, self._encoding) else: self.query = bquery self._order = None @@ -253,8 +282,7 @@ class PostgresClientQuery(PostgresQuery): self.params = None @staticmethod - @lru_cache() - def query2pg( + def query2pg_nocache( query: bytes, encoding: str ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: """ @@ -290,6 +318,8 @@ class PostgresClientQuery(PostgresQuery): return b"".join(chunks), None, order, parts + query2pg = lru_cache()(query2pg_nocache) + _re_placeholder = re.compile( rb"""(?x) diff --git a/psycopg/psycopg/client_cursor.py b/psycopg/psycopg/client_cursor.py index 6271ec508..77cdd4416 100644 --- a/psycopg/psycopg/client_cursor.py +++ b/psycopg/psycopg/client_cursor.py @@ -28,6 +28,8 @@ BINARY = pq.Format.BINARY class ClientCursorMixin(BaseCursor[ConnectionType, Row]): + _query_cls = PostgresClientQuery + def mogrify(self, query: Query, params: Optional[Params] = None) -> str: """ Return the query and parameters merged. @@ -72,13 +74,6 @@ class ClientCursorMixin(BaseCursor[ConnectionType, Row]): # as it can execute more than one statement in a single query. self._pgconn.send_query(query.query) - def _convert_query( - self, query: Query, params: Optional[Params] = None - ) -> PostgresQuery: - pgq = PostgresClientQuery(self._tx) - pgq.convert(query, params) - return pgq - def _get_prepared( self, pgq: PostgresQuery, prepare: Optional[bool] = None ) -> Tuple[Prepare, bytes]: diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 7b8395c6c..7353f558d 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -59,6 +59,7 @@ class BaseCursor(Generic[ConnectionType, Row]): _tx: "Transformer" _make_row: RowMaker[Row] _pgconn: "PGconn" + _query_cls: Type[PostgresQuery] = PostgresQuery def __init__(self, connection: ConnectionType): self._conn = connection @@ -450,7 +451,7 @@ class BaseCursor(Generic[ConnectionType, Row]): def _convert_query( self, query: Query, params: Optional[Params] = None ) -> PostgresQuery: - pgq = PostgresQuery(self._tx) + pgq = self._query_cls(self._tx) pgq.convert(query, params) return pgq diff --git a/psycopg/psycopg/raw_cursor.py b/psycopg/psycopg/raw_cursor.py index 9bc3164d4..067f66b09 100644 --- a/psycopg/psycopg/raw_cursor.py +++ b/psycopg/psycopg/raw_cursor.py @@ -4,52 +4,61 @@ psycopg raw queries cursors # Copyright (C) 2023 The Psycopg Team -from typing import Any, Optional, Sequence, Tuple, List, TYPE_CHECKING +from typing import Optional, List, Tuple, TYPE_CHECKING from functools import lru_cache -from ._queries import PostgresQuery, QueryPart - from .abc import ConnectionType, Query, Params +from .sql import Composable from .rows import Row -from .cursor import BaseCursor, Cursor from ._enums import PyFormat +from .cursor import BaseCursor, Cursor from .cursor_async import AsyncCursor +from ._queries import PostgresQuery, QueryPart if TYPE_CHECKING: + from typing import Any # noqa: F401 from .connection import Connection # noqa: F401 from .connection_async import AsyncConnection # noqa: F401 -class RawPostgresQuery(PostgresQuery): +class PostgresRawQuery(PostgresQuery): + def convert(self, query: Query, vars: Optional[Params]) -> None: + if isinstance(query, str): + bquery = query.encode(self._encoding) + elif isinstance(query, Composable): + bquery = query.as_bytes(self._tx) + else: + bquery = query + + self.query = bquery + self._want_formats = self._order = None + self.dump(vars) + + def dump(self, vars: Optional[Params]) -> None: + if vars is not None: + if not PostgresQuery.is_params_sequence(vars): + raise TypeError("raw queries require a sequence of parameters") + self._want_formats = [PyFormat.AUTO] * len(vars) + + self.params = self._tx.dump_sequence(vars, self._want_formats) + self.types = self._tx.types or () + self.formats = self._tx.formats + else: + self.params = None + self.types = () + self.formats = None + @staticmethod - @lru_cache() - def query2pg( + def query2pg_nocache( query: bytes, encoding: str ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: - """ - Noop; Python raw query is already in the format Postgres understands. - """ - return query, None, None, [] + raise NotImplementedError() - @staticmethod - def validate_and_reorder_params( - parts: List[QueryPart], vars: Params, order: Optional[List[str]] - ) -> Sequence[Any]: - """ - Verify the compatibility; params must be a sequence for raw query. - """ - if not PostgresQuery.is_params_sequence(vars): - raise TypeError("raw query require a sequence of parameters") - return vars + query2pg = lru_cache()(query2pg_nocache) class RawCursorMixin(BaseCursor[ConnectionType, Row]): - def _convert_query( - self, query: Query, params: Optional[Params] = None - ) -> PostgresQuery: - pgq = RawPostgresQuery(self._tx) - pgq.convert(query, params) - return pgq + _query_cls = PostgresRawQuery class RawCursor(RawCursorMixin["Connection[Any]", Row], Cursor[Row]): diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 3fb15fc25..a54f2f7b4 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -169,6 +169,36 @@ def test_execute_sql(conn): assert cur.fetchone() == ("hello",) +def test_query_parse_cache_size(conn): + cur = conn.cursor() + query_cls = cur._query_cls + + # Warning: testing internal structures. Test have to be refactored. + query_cls.query2pg.cache_clear() + ci = query_cls.query2pg.cache_info() + h0, m0 = ci.hits, ci.misses + tests = [ + (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1), + (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1), + (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1), + (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1), + (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 1, m0 + 2), + (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 2, m0 + 2), + (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2), + (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2), + ] + for i, (query, params, hits, misses) in enumerate(tests): + pq = query_cls(psycopg.adapt.Transformer()) + pq.convert(query, params) + ci = query_cls.query2pg.cache_info() + if not isinstance(cur, psycopg.RawCursor): + assert ci.hits == hits, f"at {i}" + assert ci.misses == misses, f"at {i}" + else: + assert ci.hits == 0, f"at {i}" + assert ci.misses == 0, f"at {i}" + + def test_execute_many_results(conn): cur = conn.cursor() assert cur.nextset() is None