From c955059ac8f40d4a2dbad6d376935c2675b74b7f Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 17 Nov 2021 14:15:20 +0100 Subject: [PATCH] Split the internal state of prepared statements into names and counts --- psycopg/psycopg/_preparing.py | 71 +++++++++++++++++++---------------- tests/test_prepared.py | 13 ++++--- tests/test_prepared_async.py | 16 ++++---- 3 files changed, 54 insertions(+), 46 deletions(-) diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py index 5ae79ae0c..b5097c166 100644 --- a/psycopg/psycopg/_preparing.py +++ b/psycopg/psycopg/_preparing.py @@ -5,7 +5,7 @@ Support for prepared statements # Copyright (C) 2020-2021 The Psycopg Team from enum import IntEnum, auto -from typing import Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Optional, Sequence, Tuple, TYPE_CHECKING from collections import OrderedDict from .pq import ExecStatus @@ -22,7 +22,6 @@ class Prepare(IntEnum): Key = Tuple[bytes, Tuple[int, ...]] -Value = Union[int, bytes] class PrepareManager: @@ -33,13 +32,11 @@ class PrepareManager: prepared_max: int = 100 def __init__(self) -> None: - # Number of times each query was seen in order to prepare it. - # Map (query, types) -> name or number of times seen - # - # Note: with this implementation we keep the tally of up to 100 - # queries, but most likely we will prepare way less than that. We might - # change that if we think it would be better. - self._prepared: OrderedDict[Key, Value] = OrderedDict() + # Map (query, types) to the number of times the query was seen. + self._counts: OrderedDict[Key, int] = OrderedDict() + + # Map (query, types) to the name of the statement if prepared. + self._names: OrderedDict[Key, bytes] = OrderedDict() # Counter to generate prepared statements names self._prepared_idx = 0 @@ -59,12 +56,13 @@ class PrepareManager: return Prepare.NO, b"" key = self.key(query) - value: Union[bytes, int] = self._prepared.get(key, 0) - if isinstance(value, bytes): + name = self._names.get(key) + if name: # The query was already prepared in this session - return Prepare.YES, value + return Prepare.YES, name - if value >= self.prepare_threshold or prepare: + count = self._counts.get(key, 0) + if count >= self.prepare_threshold or prepare: # The query has been executed enough times and needs to be prepared name = f"_pg3_{self._prepared_idx}".encode() self._prepared_idx += 1 @@ -80,7 +78,7 @@ class PrepareManager: rollback or on dropping objects, because the same object may get recreated and postgres would fail internal lookups. """ - if self._prepared or prep == Prepare.SHOULD: + if self._names or prep == Prepare.SHOULD: for result in results: if result.status != ExecStatus.COMMAND_OK: continue @@ -111,12 +109,12 @@ class PrepareManager: If it was prepared, deallocate it. Do it only once: if the cache was resized, deallocate gradually. """ - if len(self._prepared) <= self.prepared_max: - return None + if len(self._counts) > self.prepared_max: + self._counts.popitem(last=False) - old_val = self._prepared.popitem(last=False)[1] - if isinstance(old_val, bytes): - return b"DEALLOCATE " + old_val + if len(self._names) > self.prepared_max: + name = self._names.popitem(last=False)[1] + return b"DEALLOCATE " + name else: return None @@ -135,17 +133,24 @@ class PrepareManager: return None key = self.key(query) - if key in self._prepared: - if isinstance(self._prepared[key], int): - if prep is Prepare.SHOULD: - self._prepared[key] = name - else: - self._prepared[key] += 1 # type: ignore[operator] - self._prepared.move_to_end(key) + if key in self._counts: + if prep is Prepare.SHOULD: + del self._counts[key] + self._names[key] = name + else: + self._counts[key] += 1 + self._counts.move_to_end(key) + return None + + elif key in self._names: + self._names.move_to_end(key) return None + else: - value: Value = name if prep is Prepare.SHOULD else 1 - self._prepared[key] = value + if prep is Prepare.SHOULD: + self._names[key] = name + else: + self._counts[key] = 1 return key def validate( @@ -164,15 +169,17 @@ class PrepareManager: cmd = self._should_discard(prep, results) if cmd: return cmd + if not self._check_results(results): - del self._prepared[key] + self._names.pop(key, None) + self._counts.pop(key, None) return None + return self._rotate() def clear(self) -> Optional[bytes]: - if self._prepared_idx: - self._prepared.clear() - self._prepared_idx = 0 + if self._names: + self._names.clear() return b"DEALLOCATE ALL" else: return None diff --git a/tests/test_prepared.py b/tests/test_prepared.py index ef7b7ee4e..fec835547 100644 --- a/tests/test_prepared.py +++ b/tests/test_prepared.py @@ -77,7 +77,8 @@ def test_prepare_disable(conn): res.append(cur.fetchone()[0]) assert res == [0] * 10 - assert not conn._prepared._prepared + assert not conn._prepared._names + assert not conn._prepared._counts def test_no_prepare_multi(conn): @@ -140,10 +141,10 @@ def test_evict_lru(conn): conn.execute("select 'a'") conn.execute(f"select {i}") - assert len(conn._prepared._prepared) == 5 - assert conn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0" + assert len(conn._prepared._names) == 1 + assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0" for i in [9, 8, 7, 6]: - assert conn._prepared._prepared[f"select {i}".encode(), ()] == 1 + assert conn._prepared._counts[f"select {i}".encode(), ()] == 1 cur = conn.execute("select statement from pg_prepared_statements") assert cur.fetchall() == [("select 'a'",)] @@ -156,9 +157,9 @@ def test_evict_lru_deallocate(conn): conn.execute("select 'a'") conn.execute(f"select {i}") - assert len(conn._prepared._prepared) == 5 + assert len(conn._prepared._names) == 5 for j in [9, 8, 7, 6, "'a'"]: - name = conn._prepared._prepared[f"select {j}".encode(), ()] + name = conn._prepared._names[f"select {j}".encode(), ()] assert name.startswith(b"_pg3_") cur = conn.execute( diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py index b3477bc52..d6217b42c 100644 --- a/tests/test_prepared_async.py +++ b/tests/test_prepared_async.py @@ -83,7 +83,8 @@ async def test_prepare_disable(aconn): res.append((await cur.fetchone())[0]) assert res == [0] * 10 - assert not aconn._prepared._prepared + assert not aconn._prepared._names + assert not aconn._prepared._counts async def test_no_prepare_multi(aconn): @@ -148,10 +149,10 @@ async def test_evict_lru(aconn): await aconn.execute("select 'a'") await aconn.execute(f"select {i}") - assert len(aconn._prepared._prepared) == 5 - assert aconn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0" + assert len(aconn._prepared._names) == 1 + assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0" for i in [9, 8, 7, 6]: - assert aconn._prepared._prepared[f"select {i}".encode(), ()] == 1 + assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1 cur = await aconn.execute("select statement from pg_prepared_statements") assert await cur.fetchall() == [("select 'a'",)] @@ -164,11 +165,10 @@ async def test_evict_lru_deallocate(aconn): await aconn.execute("select 'a'") await aconn.execute(f"select {i}") - assert len(aconn._prepared._prepared) == 5 + assert len(aconn._prepared._names) == 5 for j in [9, 8, 7, 6, "'a'"]: - assert aconn._prepared._prepared[ - f"select {j}".encode(), () - ].startswith(b"_pg3_") + name = aconn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") cur = await aconn.execute( "select statement from pg_prepared_statements order by prepare_time", -- 2.47.2