# Copyright (C) 2020-2021 The Psycopg Team
from enum import IntEnum, auto
-from typing import Optional, Sequence, Tuple, TYPE_CHECKING, Union
+from typing import Optional, Sequence, Tuple, TYPE_CHECKING
from collections import OrderedDict
from .pq import ExecStatus
Key = Tuple[bytes, Tuple[int, ...]]
-Value = Union[int, bytes]
class PrepareManager:
prepared_max: int = 100
def __init__(self) -> None:
- # Number of times each query was seen in order to prepare it.
- # Map (query, types) -> name or number of times seen
- #
- # Note: with this implementation we keep the tally of up to 100
- # queries, but most likely we will prepare way less than that. We might
- # change that if we think it would be better.
- self._prepared: OrderedDict[Key, Value] = OrderedDict()
+ # Map (query, types) to the number of times the query was seen.
+ self._counts: OrderedDict[Key, int] = OrderedDict()
+
+ # Map (query, types) to the name of the statement if prepared.
+ self._names: OrderedDict[Key, bytes] = OrderedDict()
# Counter to generate prepared statements names
self._prepared_idx = 0
return Prepare.NO, b""
key = self.key(query)
- value: Union[bytes, int] = self._prepared.get(key, 0)
- if isinstance(value, bytes):
+ name = self._names.get(key)
+ if name:
# The query was already prepared in this session
- return Prepare.YES, value
+ return Prepare.YES, name
- if value >= self.prepare_threshold or prepare:
+ count = self._counts.get(key, 0)
+ if count >= self.prepare_threshold or prepare:
# The query has been executed enough times and needs to be prepared
name = f"_pg3_{self._prepared_idx}".encode()
self._prepared_idx += 1
rollback or on dropping objects, because the same object may get
recreated and postgres would fail internal lookups.
"""
- if self._prepared or prep == Prepare.SHOULD:
+ if self._names or prep == Prepare.SHOULD:
for result in results:
if result.status != ExecStatus.COMMAND_OK:
continue
If it was prepared, deallocate it. Do it only once: if the cache was
resized, deallocate gradually.
"""
- if len(self._prepared) <= self.prepared_max:
- return None
+ if len(self._counts) > self.prepared_max:
+ self._counts.popitem(last=False)
- old_val = self._prepared.popitem(last=False)[1]
- if isinstance(old_val, bytes):
- return b"DEALLOCATE " + old_val
+ if len(self._names) > self.prepared_max:
+ name = self._names.popitem(last=False)[1]
+ return b"DEALLOCATE " + name
else:
return None
return None
key = self.key(query)
- if key in self._prepared:
- if isinstance(self._prepared[key], int):
- if prep is Prepare.SHOULD:
- self._prepared[key] = name
- else:
- self._prepared[key] += 1 # type: ignore[operator]
- self._prepared.move_to_end(key)
+ if key in self._counts:
+ if prep is Prepare.SHOULD:
+ del self._counts[key]
+ self._names[key] = name
+ else:
+ self._counts[key] += 1
+ self._counts.move_to_end(key)
+ return None
+
+ elif key in self._names:
+ self._names.move_to_end(key)
return None
+
else:
- value: Value = name if prep is Prepare.SHOULD else 1
- self._prepared[key] = value
+ if prep is Prepare.SHOULD:
+ self._names[key] = name
+ else:
+ self._counts[key] = 1
return key
def validate(
cmd = self._should_discard(prep, results)
if cmd:
return cmd
+
if not self._check_results(results):
- del self._prepared[key]
+ self._names.pop(key, None)
+ self._counts.pop(key, None)
return None
+
return self._rotate()
def clear(self) -> Optional[bytes]:
- if self._prepared_idx:
- self._prepared.clear()
- self._prepared_idx = 0
+ if self._names:
+ self._names.clear()
return b"DEALLOCATE ALL"
else:
return None
res.append(cur.fetchone()[0])
assert res == [0] * 10
- assert not conn._prepared._prepared
+ assert not conn._prepared._names
+ assert not conn._prepared._counts
def test_no_prepare_multi(conn):
conn.execute("select 'a'")
conn.execute(f"select {i}")
- assert len(conn._prepared._prepared) == 5
- assert conn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0"
+ assert len(conn._prepared._names) == 1
+ assert conn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
for i in [9, 8, 7, 6]:
- assert conn._prepared._prepared[f"select {i}".encode(), ()] == 1
+ assert conn._prepared._counts[f"select {i}".encode(), ()] == 1
cur = conn.execute("select statement from pg_prepared_statements")
assert cur.fetchall() == [("select 'a'",)]
conn.execute("select 'a'")
conn.execute(f"select {i}")
- assert len(conn._prepared._prepared) == 5
+ assert len(conn._prepared._names) == 5
for j in [9, 8, 7, 6, "'a'"]:
- name = conn._prepared._prepared[f"select {j}".encode(), ()]
+ name = conn._prepared._names[f"select {j}".encode(), ()]
assert name.startswith(b"_pg3_")
cur = conn.execute(
res.append((await cur.fetchone())[0])
assert res == [0] * 10
- assert not aconn._prepared._prepared
+ assert not aconn._prepared._names
+ assert not aconn._prepared._counts
async def test_no_prepare_multi(aconn):
await aconn.execute("select 'a'")
await aconn.execute(f"select {i}")
- assert len(aconn._prepared._prepared) == 5
- assert aconn._prepared._prepared[b"select 'a'", ()] == b"_pg3_0"
+ assert len(aconn._prepared._names) == 1
+ assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
for i in [9, 8, 7, 6]:
- assert aconn._prepared._prepared[f"select {i}".encode(), ()] == 1
+ assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
cur = await aconn.execute("select statement from pg_prepared_statements")
assert await cur.fetchall() == [("select 'a'",)]
await aconn.execute("select 'a'")
await aconn.execute(f"select {i}")
- assert len(aconn._prepared._prepared) == 5
+ assert len(aconn._prepared._names) == 5
for j in [9, 8, 7, 6, "'a'"]:
- assert aconn._prepared._prepared[
- f"select {j}".encode(), ()
- ].startswith(b"_pg3_")
+ name = aconn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
cur = await aconn.execute(
"select statement from pg_prepared_statements order by prepare_time",