]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: convert back query2pg from static method to regular function
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Sep 2023 00:36:26 +0000 (01:36 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 11 Sep 2023 01:02:55 +0000 (02:02 +0100)
On Python < 3.10, a static method doesn't seem callable, at least during
class building, so applying lru_cache on it fails.

psycopg/psycopg/_queries.py
psycopg/psycopg/raw_cursor.py
tests/test_cursor.py

index 6da19fa76ae92b002e9af9c61bb332010f7abaf4..376012aec7bd9a4c53e22c193e18aea71d8afedd 100644 (file)
@@ -81,9 +81,9 @@ class PostgresQuery:
                 len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
                 and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
             ):
-                f: PostgresQuery._Query2Pg = PostgresQuery.query2pg
+                f: _Query2Pg = _query2pg
             else:
-                f = PostgresQuery.query2pg_nocache
+                f = _query2pg_nocache
 
             (self.query, self._want_formats, self._order, self._parts) = f(
                 bquery, self._encoding
@@ -165,63 +165,69 @@ class PostgresQuery:
                     f" {', '.join(sorted(i for i in order or () if i not in vars))}"
                 )
 
-    # The type of the query2pg() and query2pg_nocache() methods
-    _Query2Pg: TypeAlias = Callable[
-        [bytes, str],
-        Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]],
-    ]
 
-    @staticmethod
-    def query2pg_nocache(
-        query: bytes, encoding: str
-    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
-        """
-        Convert Python query and params into something Postgres understands.
-
-        - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
-        format (``$1``, ``$2``)
-        - placeholders can be %s, %t, or %b (auto, text or binary)
-        - return ``query`` (bytes), ``formats`` (list of formats) ``order``
-        (sequence of names used in the query, in the position they appear)
-        ``parts`` (splits of queries and placeholders).
-        """
-        parts = _split_query(query, encoding)
-        order: Optional[List[str]] = None
-        chunks: List[bytes] = []
-        formats = []
-
-        if isinstance(parts[0].item, int):
-            for part in parts[:-1]:
-                assert isinstance(part.item, int)
-                chunks.append(part.pre)
-                chunks.append(b"$%d" % (part.item + 1))
+# The type of the _query2pg() and _query2pg_nocache() methods
+_Query2Pg: TypeAlias = Callable[
+    [bytes, str], Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]
+]
+
+
+def _query2pg_nocache(
+    query: bytes, encoding: str
+) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
+    """
+    Convert Python query and params into something Postgres understands.
+
+    - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+      format (``$1``, ``$2``)
+    - placeholders can be %s, %t, or %b (auto, text or binary)
+    - return ``query`` (bytes), ``formats`` (list of formats) ``order``
+      (sequence of names used in the query, in the position they appear)
+      ``parts`` (splits of queries and placeholders).
+    """
+    parts = _split_query(query, encoding)
+    order: Optional[List[str]] = None
+    chunks: List[bytes] = []
+    formats = []
+
+    if isinstance(parts[0].item, int):
+        for part in parts[:-1]:
+            assert isinstance(part.item, int)
+            chunks.append(part.pre)
+            chunks.append(b"$%d" % (part.item + 1))
+            formats.append(part.format)
+
+    elif isinstance(parts[0].item, str):
+        seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+        order = []
+        for part in parts[:-1]:
+            assert isinstance(part.item, str)
+            chunks.append(part.pre)
+            if part.item not in seen:
+                ph = b"$%d" % (len(seen) + 1)
+                seen[part.item] = (ph, part.format)
+                order.append(part.item)
+                chunks.append(ph)
                 formats.append(part.format)
+            else:
+                if seen[part.item][1] != part.format:
+                    raise e.ProgrammingError(
+                        f"placeholder '{part.item}' cannot have different formats"
+                    )
+                chunks.append(seen[part.item][0])
 
-        elif isinstance(parts[0].item, str):
-            seen: Dict[str, Tuple[bytes, PyFormat]] = {}
-            order = []
-            for part in parts[:-1]:
-                assert isinstance(part.item, str)
-                chunks.append(part.pre)
-                if part.item not in seen:
-                    ph = b"$%d" % (len(seen) + 1)
-                    seen[part.item] = (ph, part.format)
-                    order.append(part.item)
-                    chunks.append(ph)
-                    formats.append(part.format)
-                else:
-                    if seen[part.item][1] != part.format:
-                        raise e.ProgrammingError(
-                            f"placeholder '{part.item}' cannot have different formats"
-                        )
-                    chunks.append(seen[part.item][0])
+    # last part
+    chunks.append(parts[-1].pre)
 
-        # last part
-        chunks.append(parts[-1].pre)
+    return b"".join(chunks), formats, order, parts
 
-        return b"".join(chunks), formats, order, parts
 
-    query2pg = lru_cache()(query2pg_nocache)
+# Note: the cache size is 128 items, but someone has reported throwing ~12k
+# queries (of type `INSERT ... VALUES (...), (...)` with a varying amount of
+# records), and the resulting cache size is >100Mb. So, we will avoid to cache
+# large queries or queries with a large number of params. See
+# https://github.com/sqlalchemy/sqlalchemy/discussions/10270
+_query2pg = lru_cache()(_query2pg_nocache)
 
 
 class PostgresClientQuery(PostgresQuery):
@@ -246,20 +252,15 @@ class PostgresClientQuery(PostgresQuery):
             bquery = query
 
         if vars is not None:
-            # Avoid caching queries extremely long or with a huge number of
-            # parameters. They are usually generated by ORMs and have poor
-            # cacheablility (e.g. INSERT ... VALUES (...), (...) with varying
-            # numbers of tuples.
-            # see https://github.com/psycopg/psycopg/discussions/628
             if (
                 len(bquery) <= MAX_CACHED_STATEMENT_LENGTH
                 and len(vars) <= MAX_CACHED_STATEMENT_PARAMS
             ):
-                f: PostgresQuery._Query2Pg = PostgresClientQuery.query2pg
+                f: _Query2PgClient = _query2pg_client
             else:
-                f = PostgresClientQuery.query2pg_nocache
+                f = _query2pg_client_nocache
 
-            (self.template, _, self._order, self._parts) = f(bquery, self._encoding)
+            (self.template, self._order, self._parts) = f(bquery, self._encoding)
         else:
             self.query = bquery
             self._order = None
@@ -281,44 +282,50 @@ class PostgresClientQuery(PostgresQuery):
         else:
             self.params = None
 
-    @staticmethod
-    def query2pg_nocache(
-        query: bytes, encoding: str
-    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
-        """
-        Convert Python query and params into a template to perform client-side binding
-        """
-        parts = _split_query(query, encoding, collapse_double_percent=False)
-        order: Optional[List[str]] = None
-        chunks: List[bytes] = []
-
-        if isinstance(parts[0].item, int):
-            for part in parts[:-1]:
-                assert isinstance(part.item, int)
-                chunks.append(part.pre)
-                chunks.append(b"%s")
-
-        elif isinstance(parts[0].item, str):
-            seen: Dict[str, Tuple[bytes, PyFormat]] = {}
-            order = []
-            for part in parts[:-1]:
-                assert isinstance(part.item, str)
-                chunks.append(part.pre)
-                if part.item not in seen:
-                    ph = b"%s"
-                    seen[part.item] = (ph, part.format)
-                    order.append(part.item)
-                    chunks.append(ph)
-                else:
-                    chunks.append(seen[part.item][0])
-                    order.append(part.item)
 
-        # last part
-        chunks.append(parts[-1].pre)
+_Query2PgClient: TypeAlias = Callable[
+    [bytes, str], Tuple[bytes, Optional[List[str]], List[QueryPart]]
+]
+
+
+def _query2pg_client_nocache(
+    query: bytes, encoding: str
+) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
+    """
+    Convert Python query and params into a template to perform client-side binding
+    """
+    parts = _split_query(query, encoding, collapse_double_percent=False)
+    order: Optional[List[str]] = None
+    chunks: List[bytes] = []
+
+    if isinstance(parts[0].item, int):
+        for part in parts[:-1]:
+            assert isinstance(part.item, int)
+            chunks.append(part.pre)
+            chunks.append(b"%s")
+
+    elif isinstance(parts[0].item, str):
+        seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+        order = []
+        for part in parts[:-1]:
+            assert isinstance(part.item, str)
+            chunks.append(part.pre)
+            if part.item not in seen:
+                ph = b"%s"
+                seen[part.item] = (ph, part.format)
+                order.append(part.item)
+                chunks.append(ph)
+            else:
+                chunks.append(seen[part.item][0])
+                order.append(part.item)
+
+    # last part
+    chunks.append(parts[-1].pre)
+
+    return b"".join(chunks), order, parts
 
-        return b"".join(chunks), None, order, parts
 
-    query2pg = lru_cache()(query2pg_nocache)
+_query2pg_client = lru_cache()(_query2pg_client_nocache)
 
 
 _re_placeholder = re.compile(
index 067f66b09d01712b867c89b6c2c3732b27221d2d..7fc4090bf34685777329ef83cf719690a26121d7 100644 (file)
@@ -4,8 +4,7 @@ psycopg raw queries cursors
 
 # Copyright (C) 2023 The Psycopg Team
 
-from typing import Optional, List, Tuple, TYPE_CHECKING
-from functools import lru_cache
+from typing import Optional, TYPE_CHECKING
 
 from .abc import ConnectionType, Query, Params
 from .sql import Composable
@@ -13,7 +12,7 @@ from .rows import Row
 from ._enums import PyFormat
 from .cursor import BaseCursor, Cursor
 from .cursor_async import AsyncCursor
-from ._queries import PostgresQuery, QueryPart
+from ._queries import PostgresQuery
 
 if TYPE_CHECKING:
     from typing import Any  # noqa: F401
@@ -48,14 +47,6 @@ class PostgresRawQuery(PostgresQuery):
             self.types = ()
             self.formats = None
 
-    @staticmethod
-    def query2pg_nocache(
-        query: bytes, encoding: str
-    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
-        raise NotImplementedError()
-
-    query2pg = lru_cache()(query2pg_nocache)
-
 
 class RawCursorMixin(BaseCursor[ConnectionType, Row]):
     _query_cls = PostgresRawQuery
index a54f2f7b405c8285c224a2006eab89f98f5d321e..2e607ec54e63cd3a8b83912280f796c25c3e6a11 100644 (file)
@@ -171,11 +171,21 @@ def test_execute_sql(conn):
 
 def test_query_parse_cache_size(conn):
     cur = conn.cursor()
-    query_cls = cur._query_cls
+    cls = type(cur)
+
+    # Warning: testing internal structures. Test might need refactoring with the code.
+    cache: Any
+    if cls is psycopg.Cursor:
+        cache = psycopg._queries._query2pg
+    elif cls is psycopg.ClientCursor:
+        cache = psycopg._queries._query2pg_client
+    elif cls is psycopg.RawCursor:
+        pytest.skip("RawCursor has no query parse cache")
+    else:
+        assert False, cls
 
-    # Warning: testing internal structures. Test have to be refactored.
-    query_cls.query2pg.cache_clear()
-    ci = query_cls.query2pg.cache_info()
+    cache.cache_clear()
+    ci = cache.cache_info()
     h0, m0 = ci.hits, ci.misses
     tests = [
         (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
@@ -188,15 +198,11 @@ def test_query_parse_cache_size(conn):
         (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
     ]
     for i, (query, params, hits, misses) in enumerate(tests):
-        pq = query_cls(psycopg.adapt.Transformer())
+        pq = cur._query_cls(psycopg.adapt.Transformer())
         pq.convert(query, params)
-        ci = query_cls.query2pg.cache_info()
-        if not isinstance(cur, psycopg.RawCursor):
-            assert ci.hits == hits, f"at {i}"
-            assert ci.misses == misses, f"at {i}"
-        else:
-            assert ci.hits == 0, f"at {i}"
-            assert ci.misses == 0, f"at {i}"
+        ci = cache.cache_info()
+        assert ci.hits == hits, f"at {i}"
+        assert ci.misses == misses, f"at {i}"
 
 
 def test_execute_many_results(conn):