From: Daniele Varrazzo Date: Wed, 23 Dec 2020 02:35:00 +0000 (+0100) Subject: Added first implementation of prepared statements support X-Git-Tag: 3.0.dev0~253^2~4 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=e0e2d80cf25c1bd6be553bf11a3860120141b8e8;p=thirdparty%2Fpsycopg.git Added first implementation of prepared statements support --- diff --git a/psycopg3/psycopg3/_queries.py b/psycopg3/psycopg3/_queries.py index 15085332d..b7067d472 100644 --- a/psycopg3/psycopg3/_queries.py +++ b/psycopg3/psycopg3/_queries.py @@ -5,9 +5,9 @@ Utility module to manipulate queries # Copyright (C) 2020 The Psycopg Team import re -from functools import lru_cache from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional from typing import Sequence, Tuple, Union, TYPE_CHECKING +from functools import lru_cache from . import errors as e from .pq import Format @@ -30,12 +30,13 @@ class PostgresQuery: """ _parts: List[QueryPart] + _query = b"" def __init__(self, transformer: "Transformer"): self._tx = transformer - self.query: bytes = b"" self.params: Optional[List[Optional[bytes]]] = None - self.types: Optional[List[int]] = None + # these are tuples so they can be used as keys e.g. in prepared stmts + self.types: Tuple[int, ...] = () self.formats: Optional[List[Format]] = None self._order: Optional[List[str]] = None @@ -74,7 +75,7 @@ class PostgresQuery: ) assert self.formats is not None ps = self.params = [] - ts = self.types = [] + ts = [] for i in range(len(params)): param = params[i] if param is not None: @@ -84,8 +85,10 @@ class PostgresQuery: else: ps.append(None) ts.append(0) + self.types = tuple(ts) else: - self.params = self.types = None + self.params = None + self.types = () @lru_cache() diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 9af0bc691..60a05a474 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -10,10 +10,11 @@ import logging import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, Type, TYPE_CHECKING, TypeVar +from typing import Optional, Tuple, Type, TYPE_CHECKING, TypeVar, Union from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager +from collections import OrderedDict if sys.version_info >= (3, 7): from contextlib import asynccontextmanager @@ -70,11 +71,10 @@ class Notify(NamedTuple): Notify.__module__ = "psycopg3" +C = TypeVar("C", bound="BaseConnection") NoticeHandler = Callable[[e.Diagnostic], None] NotifyHandler = Callable[[Notify], None] -C = TypeVar("C", bound="BaseConnection") - class BaseConnection: """ @@ -102,6 +102,20 @@ class BaseConnection: cursor_factory: Type["BaseCursor[Any]"] + prepare_threshold: Optional[int] = 5 + """ + Number of times a query is executed before it is prepared. + + `!None` to disable preparing queries automatically. + """ + + prepared_max = 100 + """ + Maximum number of prepared statements on the connection. + + If more are prepared, the least used are deallocated. + """ + def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this self._autocommit = False @@ -115,6 +129,15 @@ class BaseConnection: # only a begin/commit and not a savepoint. self._savepoints: List[str] = [] + # Number of times each query was seen in order to prepare it. + # Map (query, types) -> name or number of times seen + self._prepared_statements: OrderedDict[ + Tuple[bytes, Tuple[int, ...]], Union[int, bytes] + ] = OrderedDict() + + # Counter to generate prepared statements names + self._prepared_idx = 0 + wself = ref(self) pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) @@ -386,11 +409,14 @@ class Connection(BaseConnection): return self.cursor_factory(self, format=format) def execute( - self, query: Query, params: Optional[Params] = None + self, + query: Query, + params: Optional[Params] = None, + prepare: Optional[bool] = None, ) -> "Cursor": """Execute a query and return a cursor to read its results.""" cur = self.cursor() - return cur.execute(query, params) + return cur.execute(query, params, prepare=prepare) def commit(self) -> None: """Commit any pending transaction to the database.""" @@ -511,10 +537,13 @@ class AsyncConnection(BaseConnection): return self.cursor_factory(self, format=format) async def execute( - self, query: Query, params: Optional[Params] = None + self, + query: Query, + params: Optional[Params] = None, + prepare: Optional[bool] = None, ) -> "AsyncCursor": cur = await self.cursor() - return await cur.execute(query, params) + return await cur.execute(query, params, prepare=prepare) async def commit(self) -> None: async with self.lock: diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index af161b4e7..86a368061 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -5,9 +5,10 @@ psycopg3 cursor objects # Copyright (C) 2020 The Psycopg Team import sys +from enum import IntEnum, auto from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List -from typing import Optional, Sequence, Type, TYPE_CHECKING +from typing import Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union from contextlib import contextmanager from . import errors as e @@ -41,6 +42,12 @@ else: execute = generators.execute +class Prepare(IntEnum): + NO = auto() + YES = auto() + SHOULD = auto() + + class BaseCursor(Generic[ConnectionType]): ExecStatus = pq.ExecStatus @@ -157,14 +164,121 @@ class BaseCursor(Generic[ConnectionType]): # def _execute_gen( - self, query: Query, params: Optional[Params] = None + self, + query: Query, + params: Optional[Params] = None, + prepare: Optional[bool] = None, ) -> PQGen[None]: """Generator implementing `Cursor.execute()`.""" yield from self._start_query() - self._execute_send(query, params) + pgq = self._convert_query(query, params) + + # Check if the query is prepared or needs preparing + prep, name = self._get_prepared(pgq, prepare) + if prep is Prepare.YES: + # The query is already prepared + self._send_query_prepared(name, pgq) + + elif prep is Prepare.NO: + # The query must be executed without preparing + self._execute_send(pgq) + + else: + # The query must be prepared and executed + self._send_prepare(name, pgq) + (result,) = yield from execute(self._conn.pgconn) + if result.status == ExecStatus.FATAL_ERROR: + raise e.error_from_result( + result, encoding=self._conn.client_encoding + ) + self._send_query_prepared(name, pgq) + + # run the query results = yield from execute(self._conn.pgconn) + + # Update the prepare state of the query + if prepare is not False: + yield from self._maintain_prepared(pgq, results, prep, name) + self._execute_results(results) + def _get_prepared( + self, query: PostgresQuery, prepare: Optional[bool] = None + ) -> Tuple[Prepare, bytes]: + """ + Check if a query is prepared, tell back whether to prepare it. + """ + conn = self._conn + if prepare is False or conn.prepare_threshold is None: + # The user doesn't want this query to be prepared + return Prepare.NO, b"" + + key = (query.query, query.types) + value: Union[bytes, int] = conn._prepared_statements.get(key, 0) + if isinstance(value, bytes): + # The query was already prepared in this session + return Prepare.YES, value + + if value >= conn.prepare_threshold or prepare: + # The query has been executed enough times and needs to be prepared + name = f"_pg3_{conn._prepared_idx}".encode("utf-8") + conn._prepared_idx += 1 + return Prepare.SHOULD, name + else: + # The query is not to be prepared yet + return Prepare.NO, b"" + + def _maintain_prepared( + self, + query: PostgresQuery, + results: Sequence["PGresult"], + prep: Prepare, + name: bytes, + ) -> PQGen[None]: + """Maintain the cache of he prepared statements.""" + # don't do anything if prepared statements are disabled + if self._conn.prepare_threshold is None: + return + + cache = self._conn._prepared_statements + key = (query.query, query.types) + + # If we know the query already the cache size won't change + # So just update the count and record as last used + if key in cache: + if isinstance(cache[key], int): + if prep is Prepare.SHOULD: + cache[key] = name + else: + cache[key] += 1 # type: ignore # operator + cache.move_to_end(key) + return + + # The query is not in cache. Let's see if we must add it + if len(results) != 1: + # We cannot prepare a multiple statement + return + + result = results[0] + if ( + result.status != ExecStatus.TUPLES_OK + and result.status != ExecStatus.COMMAND_OK + ): + # We don't prepare failed queries or other weird results + return + + # Ok, we got to the conclusion that this query is genuinely to prepare + cache[key] = name if prep is Prepare.SHOULD else 1 + + # Evict an old value from the cache; if it was prepared, deallocate it + # Do it only once: if the cache was resized, deallocate gradually + if len(cache) <= self._conn.prepared_max: + return + + old_val = cache.popitem(last=False)[1] + if isinstance(old_val, bytes): + yield from self._conn._exec_command(b"DEALLOCATE " + old_val) + def _executemany_gen( self, query: Query, params_seq: Sequence[Params] ) -> PQGen[None]: @@ -173,7 +287,8 @@ class BaseCursor(Generic[ConnectionType]): first = True for params in params_seq: if first: - pgq = self._send_prepare(b"", query, params) + pgq = self._convert_query(query, params) + self._send_prepare(b"", pgq) (result,) = yield from execute(self._conn.pgconn) if result.status == ExecStatus.FATAL_ERROR: raise e.error_from_result( @@ -204,40 +319,46 @@ class BaseCursor(Generic[ConnectionType]): def _start_copy_gen(self, statement: Query) -> PQGen[None]: """Generator implementing sending a command for `Cursor.copy().""" yield from self._start_query() + query = self._convert_query(statement) + # Make sure to avoid PQexec to avoid receiving a mix of COPY and # other operations. - self._execute_send(statement, None, no_pqexec=True) + self._execute_send(query, no_pqexec=True) (result,) = yield from execute(self._conn.pgconn) self._check_copy_result(result) self.pgresult = result # will set it on the transformer too def _execute_send( - self, query: Query, params: Optional[Params], no_pqexec: bool = False + self, query: PostgresQuery, no_pqexec: bool = False ) -> None: """ Implement part of execute() before waiting common to sync and async. - This is not a generator, but a normal, non-blocking function. + This is not a generator, but a normal non-blocking function. """ - pgq = PostgresQuery(self._transformer) - pgq.convert(query, params) - - if pgq.params or no_pqexec or self.format == Format.BINARY: - self._query = pgq.query - self._params = pgq.params + if query.params or no_pqexec or self.format == Format.BINARY: + self._query = query.query + self._params = query.params self._conn.pgconn.send_query_params( - pgq.query, - pgq.params, - param_formats=pgq.formats, - param_types=pgq.types, + query.query, + query.params, + param_formats=query.formats, + param_types=query.types, result_format=self.format, ) else: # if we don't have to, let's use exec_ as it can run more than # one query in one go - self._query = pgq.query + self._query = query.query self._params = None - self._conn.pgconn.send_query(pgq.query) + self._conn.pgconn.send_query(query.query) + + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = PostgresQuery(self._transformer) + pgq.convert(query, params) + return pgq _status_ok = { ExecStatus.TUPLES_OK, @@ -254,7 +375,7 @@ class BaseCursor(Generic[ConnectionType]): """ Implement part of execute() after waiting common to sync and async - This is not a generator, but a normal, non-blocking function. + This is not a generator, but a normal non-blocking function. """ if not results: raise e.InternalError("got no result from the query") @@ -287,16 +408,11 @@ class BaseCursor(Generic[ConnectionType]): f" {', '.join(sorted(s.name for s in sorted(badstats)))}" ) - def _send_prepare( - self, name: bytes, query: Query, params: Optional[Params] - ) -> PostgresQuery: - pgq = PostgresQuery(self._transformer) - pgq.convert(query, params) - - self._query = pgq.query - self._conn.pgconn.send_prepare(name, pgq.query, param_types=pgq.types) - - return pgq + def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: + self._query = query.query + self._conn.pgconn.send_prepare( + name, query.query, param_types=query.types + ) def _send_query_prepared(self, name: bytes, pgq: PostgresQuery) -> None: self._params = pgq.params @@ -356,13 +472,16 @@ class Cursor(BaseCursor["Connection"]): self._reset() def execute( - self, query: Query, params: Optional[Params] = None + self, + query: Query, + params: Optional[Params] = None, + prepare: Optional[bool] = None, ) -> "Cursor": """ Execute a query or command to the database. """ with self._conn.lock: - self._conn.wait(self._execute_gen(query, params)) + self._conn.wait(self._execute_gen(query, params, prepare=prepare)) return self def executemany(self, query: Query, params_seq: Sequence[Params]) -> None: @@ -457,10 +576,15 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._reset() async def execute( - self, query: Query, params: Optional[Params] = None + self, + query: Query, + params: Optional[Params] = None, + prepare: Optional[bool] = None, ) -> "AsyncCursor": async with self._conn.lock: - await self._conn.wait(self._execute_gen(query, params)) + await self._conn.wait( + self._execute_gen(query, params, prepare=prepare) + ) return self async def executemany( diff --git a/psycopg3_c/psycopg3_c/pq_cython.pyx b/psycopg3_c/psycopg3_c/pq_cython.pyx index 6dd973ac1..e7eaa806b 100644 --- a/psycopg3_c/psycopg3_c/pq_cython.pyx +++ b/psycopg3_c/psycopg3_c/pq_cython.pyx @@ -522,16 +522,24 @@ cdef PGconn _connect_start(const char *conninfo): cdef (int, Oid *, char * const*, int *, int *) _query_params_args( list param_values: Optional[Sequence[Optional[bytes]]], - list param_types: Optional[Sequence[int]], + param_types: Optional[Sequence[int]], list param_formats: Optional[Sequence[Format]], ) except *: cdef int i + # the PostgresQuery convers the param_types to tuple, so this operation + # is most often no-op + cdef tuple tparam_types + if param_types is not None and not isinstance(param_types, tuple): + tparam_types = tuple(param_types) + else: + tparam_types = param_types + cdef int nparams = len(param_values) if param_values else 0 - if param_types is not None and len(param_types) != nparams: + if tparam_types is not None and len(tparam_types) != nparams: raise ValueError( "got %d param_values but %d param_types" - % (nparams, len(param_types)) + % (nparams, len(tparam_types)) ) if param_formats is not None and len(param_formats) != nparams: raise ValueError( @@ -560,10 +568,10 @@ cdef (int, Oid *, char * const*, int *, int *) _query_params_args( alenghts[i] = length cdef Oid *atypes = NULL - if param_types is not None: + if tparam_types: atypes = PyMem_Malloc(nparams * sizeof(Oid)) for i in range(nparams): - atypes[i] = param_types[i] + atypes[i] = tparam_types[i] cdef int *aformats = NULL if param_formats is not None: diff --git a/tests/test_prepared.py b/tests/test_prepared.py new file mode 100644 index 000000000..96f851491 --- /dev/null +++ b/tests/test_prepared.py @@ -0,0 +1,183 @@ +""" +Prepared statements tests +""" + +import datetime as dt +from decimal import Decimal + +import pytest + + +def test_connection_attributes(conn, monkeypatch): + assert conn.prepare_threshold == 5 + assert conn.prepared_max == 100 + + # They are on the class + monkeypatch.setattr(conn.__class__, "prepare_threshold", 10) + assert conn.prepare_threshold == 10 + + monkeypatch.setattr(conn.__class__, "prepared_max", 200) + assert conn.prepared_max == 200 + + +def test_dont_prepare(conn): + cur = conn.cursor() + for i in range(10): + cur.execute("select %s::int", [i], prepare=False) + + cur.execute("select count(*) from pg_prepared_statements") + assert cur.fetchone() == (0,) + + +def test_do_prepare(conn): + cur = conn.cursor() + cur.execute("select %s::int", [10], prepare=True) + cur.execute("select count(*) from pg_prepared_statements") + assert cur.fetchone() == (1,) + + +def test_auto_prepare(conn): + cur = conn.cursor() + res = [] + for i in range(10): + cur.execute("select count(*) from pg_prepared_statements") + res.append(cur.fetchone()[0]) + + assert res == [0] * 5 + [1] * 5 + + +def test_dont_prepare_conn(conn): + for i in range(10): + conn.execute("select %s::int", [i], prepare=False) + + cur = conn.execute("select count(*) from pg_prepared_statements") + assert cur.fetchone() == (0,) + + +def test_do_prepare_conn(conn): + conn.execute("select %s::int", [10], prepare=True) + cur = conn.execute("select count(*) from pg_prepared_statements") + assert cur.fetchone() == (1,) + + +def test_auto_prepare_conn(conn): + res = [] + for i in range(10): + cur = conn.execute("select count(*) from pg_prepared_statements") + res.append(cur.fetchone()[0]) + + assert res == [0] * 5 + [1] * 5 + + +def test_prepare_disable(conn): + conn.prepare_threshold = None + res = [] + for i in range(10): + cur = conn.execute("select count(*) from pg_prepared_statements") + res.append(cur.fetchone()[0]) + + assert res == [0] * 10 + assert not conn._prepared_statements + + +def test_no_prepare_multi(conn): + res = [] + for i in range(10): + cur = conn.execute( + "select count(*) from pg_prepared_statements; select 1" + ) + res.append(cur.fetchone()[0]) + + assert res == [0] * 10 + + +def test_no_prepare_error(conn): + conn.autocommit = True + for i in range(10): + with pytest.raises(conn.ProgrammingError): + conn.execute("select wat") + + cur = conn.execute("select count(*) from pg_prepared_statements") + assert cur.fetchone() == (0,) + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + "notify foo, 'bar'", + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +def test_misc_statement(conn, query): + conn.execute("create table prepared_test (num int)", prepare=False) + conn.prepare_threshold = 0 + conn.execute(query) + cur = conn.execute( + "select count(*) from pg_prepared_statements", prepare=False + ) + assert cur.fetchone() == (1,) + + +def test_params_types(conn): + conn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + cur = conn.execute("select parameter_types from pg_prepared_statements") + (rec,) = cur.fetchall() + assert rec[0] == ["date", "bigint", "numeric"] + + +def test_evict_lru(conn): + conn.prepared_max = 5 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared_statements) == 5 + assert conn._prepared_statements[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert conn._prepared_statements[f"select {i}".encode("utf8"), ()] == 1 + + cur = conn.execute("select statement from pg_prepared_statements") + assert cur.fetchall() == [("select 'a'",)] + + +def test_evict_lru_deallocate(conn): + conn.prepared_max = 5 + conn.prepare_threshold = 0 + for i in range(10): + conn.execute("select 'a'") + conn.execute(f"select {i}") + + assert len(conn._prepared_statements) == 5 + for i in [9, 8, 7, 6, "'a'"]: + assert conn._prepared_statements[ + f"select {i}".encode("utf8"), () + ].startswith(b"_pg3_") + + cur = conn.execute( + "select statement from pg_prepared_statements order by prepare_time", + prepare=False, + ) + assert cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]] + + +def test_different_types(conn): + conn.prepare_threshold = 0 + conn.execute("select %s", [None]) + conn.execute("select %s", [dt.date(2000, 1, 1)]) + conn.execute("select %s", [42]) + conn.execute("select %s", [41]) + conn.execute("select %s", [dt.date(2000, 1, 2)]) + cur = conn.execute( + "select parameter_types from pg_prepared_statements order by prepare_time", + prepare=False, + ) + assert cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)] diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py new file mode 100644 index 000000000..3a4c573a6 --- /dev/null +++ b/tests/test_prepared_async.py @@ -0,0 +1,195 @@ +""" +Prepared statements tests on async connections +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +pytestmark = pytest.mark.asyncio + + +async def test_connection_attributes(aconn, monkeypatch): + assert aconn.prepare_threshold == 5 + assert aconn.prepared_max == 100 + + # They are on the class + monkeypatch.setattr(aconn.__class__, "prepare_threshold", 10) + assert aconn.prepare_threshold == 10 + + monkeypatch.setattr(aconn.__class__, "prepared_max", 200) + assert aconn.prepared_max == 200 + + +async def test_dont_prepare(aconn): + cur = await aconn.cursor() + for i in range(10): + await cur.execute("select %s::int", [i], prepare=False) + + await cur.execute("select count(*) from pg_prepared_statements") + assert await cur.fetchone() == (0,) + + +async def test_do_prepare(aconn): + cur = await aconn.cursor() + await cur.execute("select %s::int", [10], prepare=True) + await cur.execute("select count(*) from pg_prepared_statements") + assert await cur.fetchone() == (1,) + + +async def test_auto_prepare(aconn): + cur = await aconn.cursor() + res = [] + for i in range(10): + await cur.execute("select count(*) from pg_prepared_statements") + res.append((await cur.fetchone())[0]) + + assert res == [0] * 5 + [1] * 5 + + +async def test_dont_prepare_conn(aconn): + for i in range(10): + await aconn.execute("select %s::int", [i], prepare=False) + + cur = await aconn.execute("select count(*) from pg_prepared_statements") + assert await cur.fetchone() == (0,) + + +async def test_do_prepare_conn(aconn): + await aconn.execute("select %s::int", [10], prepare=True) + cur = await aconn.execute("select count(*) from pg_prepared_statements") + assert await cur.fetchone() == (1,) + + +async def test_auto_prepare_conn(aconn): + res = [] + for i in range(10): + cur = await aconn.execute( + "select count(*) from pg_prepared_statements" + ) + res.append((await cur.fetchone())[0]) + + assert res == [0] * 5 + [1] * 5 + + +async def test_prepare_disable(aconn): + aconn.prepare_threshold = None + res = [] + for i in range(10): + cur = await aconn.execute( + "select count(*) from pg_prepared_statements" + ) + res.append((await cur.fetchone())[0]) + + assert res == [0] * 10 + assert not aconn._prepared_statements + + +async def test_no_prepare_multi(aconn): + res = [] + for i in range(10): + cur = await aconn.execute( + "select count(*) from pg_prepared_statements; select 1" + ) + res.append((await cur.fetchone())[0]) + + assert res == [0] * 10 + + +async def test_no_prepare_error(aconn): + await aconn.set_autocommit(True) + for i in range(10): + with pytest.raises(aconn.ProgrammingError): + await aconn.execute("select wat") + + cur = await aconn.execute("select count(*) from pg_prepared_statements") + assert await cur.fetchone() == (0,) + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + "notify foo, 'bar'", + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +async def test_misc_statement(aconn, query): + await aconn.execute("create table prepared_test (num int)", prepare=False) + aconn.prepare_threshold = 0 + await aconn.execute(query) + cur = await aconn.execute( + "select count(*) from pg_prepared_statements", prepare=False + ) + assert await cur.fetchone() == (1,) + + +async def test_params_types(aconn): + await aconn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + cur = await aconn.execute( + "select parameter_types from pg_prepared_statements" + ) + (rec,) = await cur.fetchall() + assert rec[0] == ["date", "bigint", "numeric"] + + +async def test_evict_lru(aconn): + aconn.prepared_max = 5 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared_statements) == 5 + assert aconn._prepared_statements[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert ( + aconn._prepared_statements[f"select {i}".encode("utf8"), ()] == 1 + ) + + cur = await aconn.execute("select statement from pg_prepared_statements") + assert await cur.fetchall() == [("select 'a'",)] + + +async def test_evict_lru_deallocate(aconn): + aconn.prepared_max = 5 + aconn.prepare_threshold = 0 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared_statements) == 5 + for i in [9, 8, 7, 6, "'a'"]: + assert aconn._prepared_statements[ + f"select {i}".encode("utf8"), () + ].startswith(b"_pg3_") + + cur = await aconn.execute( + "select statement from pg_prepared_statements order by prepare_time", + prepare=False, + ) + assert await cur.fetchall() == [ + (f"select {i}",) for i in ["'a'", 6, 7, 8, 9] + ] + + +async def test_different_types(aconn): + aconn.prepare_threshold = 0 + await aconn.execute("select %s", [None]) + await aconn.execute("select %s", [dt.date(2000, 1, 1)]) + await aconn.execute("select %s", [42]) + await aconn.execute("select %s", [41]) + await aconn.execute("select %s", [dt.date(2000, 1, 2)]) + cur = await aconn.execute( + "select parameter_types from pg_prepared_statements order by prepare_time", + prepare=False, + ) + assert await cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]