From: Daniele Varrazzo Date: Tue, 24 May 2022 22:22:46 +0000 (+0200) Subject: test(crdb): make prepared statements tests portable between PG and CRDB X-Git-Tag: 3.1~49^2~49 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fbe291ee07c3c391775d40034f9c6ee3ab81fd55;p=thirdparty%2Fpsycopg.git test(crdb): make prepared statements tests portable between PG and CRDB --- diff --git a/tests/test_prepared.py b/tests/test_prepared.py index 8cacddd46..716496eab 100644 --- a/tests/test_prepared.py +++ b/tests/test_prepared.py @@ -8,6 +8,9 @@ from decimal import Decimal import pytest import psycopg +from psycopg.rows import namedtuple_row + +from .fix_crdb import is_crdb @pytest.mark.parametrize("value", [None, 0, 3]) @@ -21,23 +24,23 @@ def test_dont_prepare(conn): for i in range(10): cur.execute("select %s::int", [i], prepare=False) - cur.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (0,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 def test_do_prepare(conn): cur = conn.cursor() cur.execute("select %s::int", [10], prepare=True) - cur.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (1,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 def test_auto_prepare(conn): - cur = conn.cursor() res = [] for i in range(10): - cur.execute("select count(*) from pg_prepared_statements") - res.append(cur.fetchone()[0]) + conn.execute("select %s::int", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) assert res == [0] * 5 + [1] * 5 @@ -46,21 +49,22 @@ def test_dont_prepare_conn(conn): for i in range(10): conn.execute("select %s::int", [i], prepare=False) - cur = conn.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (0,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 def test_do_prepare_conn(conn): conn.execute("select %s::int", [10], prepare=True) - cur = conn.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (1,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 def test_auto_prepare_conn(conn): res = [] for i in range(10): - cur = conn.execute("select count(*) from pg_prepared_statements") - res.append(cur.fetchone()[0]) + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) assert res == [0] * 5 + [1] * 5 @@ -69,8 +73,9 @@ def test_prepare_disable(conn): conn.prepare_threshold = None res = [] for i in range(10): - cur = conn.execute("select count(*) from pg_prepared_statements") - res.append(cur.fetchone()[0]) + conn.execute("select %s", [0]) + stmts = get_prepared_statements(conn) + res.append(len(stmts)) assert res == [0] * 10 assert not conn._prepared._names @@ -80,8 +85,9 @@ def test_prepare_disable(conn): def test_no_prepare_multi(conn): res = [] for i in range(10): - cur = conn.execute("select count(*) from pg_prepared_statements; select 1") - res.append(cur.fetchone()[0]) + conn.execute("select 1; select 2") + stmts = get_prepared_statements(conn) + res.append(len(stmts)) assert res == [0] * 10 @@ -92,8 +98,8 @@ def test_no_prepare_multi_with_drop(conn): for i in range(10): conn.execute("drop table if exists noprep; create table noprep()") - cur = conn.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (0,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 def test_no_prepare_error(conn): @@ -102,15 +108,17 @@ def test_no_prepare_error(conn): with pytest.raises(conn.ProgrammingError): conn.execute("select wat") - cur = conn.execute("select count(*) from pg_prepared_statements") - assert cur.fetchone() == (0,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 0 @pytest.mark.parametrize( "query", [ "create table test_no_prepare ()", - "notify foo, 'bar'", + pytest.param( + "notify foo, 'bar'", marks=pytest.mark.crdb("skip", reason="notify") + ), "set timezone = utc", "select num from prepared_test", "insert into prepared_test (num) values (1)", @@ -122,8 +130,8 @@ def test_misc_statement(conn, query): conn.execute("create table prepared_test (num int)", prepare=False) conn.prepare_threshold = 0 conn.execute(query) - cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False) - assert cur.fetchone() == (1,) + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 def test_params_types(conn): @@ -132,9 +140,9 @@ def test_params_types(conn): [dt.date(2020, 12, 10), 42, Decimal(42)], prepare=True, ) - cur = conn.execute("select parameter_types from pg_prepared_statements") - (rec,) = cur.fetchall() - assert rec[0] == ["date", "smallint", "numeric"] + stmts = get_prepared_statements(conn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] def test_evict_lru(conn): @@ -148,8 +156,9 @@ def test_evict_lru(conn): for i in [9, 8, 7, 6]: assert conn._prepared._counts[f"select {i}".encode(), ()] == 1 - cur = conn.execute("select statement from pg_prepared_statements") - assert cur.fetchall() == [("select 'a'",)] + stmts = get_prepared_statements(conn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" def test_evict_lru_deallocate(conn): @@ -164,25 +173,26 @@ def test_evict_lru_deallocate(conn): name = conn._prepared._names[f"select {j}".encode(), ()] assert name.startswith(b"_pg3_") - cur = conn.execute( - "select statement from pg_prepared_statements order by prepare_time", - prepare=False, - ) - assert cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]] + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] def test_different_types(conn): conn.prepare_threshold = 0 - conn.execute("select %s", [None]) + # CRDB can't roundtrip None + unk = "foo" if is_crdb(conn) else None + conn.execute("select %s", [unk]) conn.execute("select %s", [dt.date(2000, 1, 1)]) conn.execute("select %s", [42]) conn.execute("select %s", [41]) conn.execute("select %s", [dt.date(2000, 1, 2)]) - cur = conn.execute( - "select parameter_types from pg_prepared_statements order by prepare_time", - prepare=False, - ) - assert cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)] + + stmts = get_prepared_statements(conn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] def test_untyped_json(conn): @@ -192,8 +202,9 @@ def test_untyped_json(conn): for i in range(2): conn.execute("insert into testjson (data) values (%s)", ["{}"]) - cur = conn.execute("select parameter_types from pg_prepared_statements") - assert cur.fetchall() == [(["jsonb"],)] + stmts = get_prepared_statements(conn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] def test_change_type_execute(conn): @@ -219,6 +230,7 @@ def test_change_type_executemany(conn): conn.rollback() +@pytest.mark.crdb("skip", reason="can't re-create a type") def test_change_type(conn): conn.prepare_threshold = 0 conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')") @@ -236,8 +248,8 @@ def test_change_type(conn): {"enum_col": ["foo"]}, ) - cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False) - assert cur.fetchone()[0] == 3 + stmts = get_prepared_statements(conn) + assert len(stmts) == 3 def test_change_type_savepoint(conn): @@ -253,3 +265,20 @@ def test_change_type_savepoint(conn): {"enum_col": ["foo"]}, ) raise ZeroDivisionError() + + +def get_prepared_statements(conn): + cur = conn.cursor(row_factory=namedtuple_row) + cur.execute( + # CRDB has 'PREPARE name AS' in the statement. + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return cur.fetchall() diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py index 330635cb4..2983f2188 100644 --- a/tests/test_prepared_async.py +++ b/tests/test_prepared_async.py @@ -8,6 +8,9 @@ from decimal import Decimal import pytest import psycopg +from psycopg.rows import namedtuple_row + +from .fix_crdb import is_crdb pytestmark = pytest.mark.asyncio @@ -25,23 +28,23 @@ async def test_dont_prepare(aconn): for i in range(10): await cur.execute("select %s::int", [i], prepare=False) - await cur.execute("select count(*) from pg_prepared_statements") - assert await cur.fetchone() == (0,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 async def test_do_prepare(aconn): cur = aconn.cursor() await cur.execute("select %s::int", [10], prepare=True) - await cur.execute("select count(*) from pg_prepared_statements") - assert await cur.fetchone() == (1,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 async def test_auto_prepare(aconn): - cur = aconn.cursor() res = [] for i in range(10): - await cur.execute("select count(*) from pg_prepared_statements") - res.append((await cur.fetchone())[0]) + await aconn.execute("select %s::int", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) assert res == [0] * 5 + [1] * 5 @@ -50,21 +53,22 @@ async def test_dont_prepare_conn(aconn): for i in range(10): await aconn.execute("select %s::int", [i], prepare=False) - cur = await aconn.execute("select count(*) from pg_prepared_statements") - assert await cur.fetchone() == (0,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 async def test_do_prepare_conn(aconn): await aconn.execute("select %s::int", [10], prepare=True) - cur = await aconn.execute("select count(*) from pg_prepared_statements") - assert await cur.fetchone() == (1,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 async def test_auto_prepare_conn(aconn): res = [] for i in range(10): - cur = await aconn.execute("select count(*) from pg_prepared_statements") - res.append((await cur.fetchone())[0]) + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) assert res == [0] * 5 + [1] * 5 @@ -73,8 +77,9 @@ async def test_prepare_disable(aconn): aconn.prepare_threshold = None res = [] for i in range(10): - cur = await aconn.execute("select count(*) from pg_prepared_statements") - res.append((await cur.fetchone())[0]) + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) assert res == [0] * 10 assert not aconn._prepared._names @@ -84,10 +89,9 @@ async def test_prepare_disable(aconn): async def test_no_prepare_multi(aconn): res = [] for i in range(10): - cur = await aconn.execute( - "select count(*) from pg_prepared_statements; select 1" - ) - res.append((await cur.fetchone())[0]) + await aconn.execute("select 1; select 2") + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) assert res == [0] * 10 @@ -98,15 +102,17 @@ async def test_no_prepare_error(aconn): with pytest.raises(aconn.ProgrammingError): await aconn.execute("select wat") - cur = await aconn.execute("select count(*) from pg_prepared_statements") - assert await cur.fetchone() == (0,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 @pytest.mark.parametrize( "query", [ "create table test_no_prepare ()", - "notify foo, 'bar'", + pytest.param( + "notify foo, 'bar'", marks=pytest.mark.crdb("skip", reason="notify") + ), "set timezone = utc", "select num from prepared_test", "insert into prepared_test (num) values (1)", @@ -118,10 +124,8 @@ async def test_misc_statement(aconn, query): await aconn.execute("create table prepared_test (num int)", prepare=False) aconn.prepare_threshold = 0 await aconn.execute(query) - cur = await aconn.execute( - "select count(*) from pg_prepared_statements", prepare=False - ) - assert await cur.fetchone() == (1,) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 async def test_params_types(aconn): @@ -130,9 +134,9 @@ async def test_params_types(aconn): [dt.date(2020, 12, 10), 42, Decimal(42)], prepare=True, ) - cur = await aconn.execute("select parameter_types from pg_prepared_statements") - (rec,) = await cur.fetchall() - assert rec[0] == ["date", "smallint", "numeric"] + stmts = await get_prepared_statements(aconn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] async def test_evict_lru(aconn): @@ -146,8 +150,9 @@ async def test_evict_lru(aconn): for i in [9, 8, 7, 6]: assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1 - cur = await aconn.execute("select statement from pg_prepared_statements") - assert await cur.fetchall() == [("select 'a'",)] + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" async def test_evict_lru_deallocate(aconn): @@ -162,25 +167,26 @@ async def test_evict_lru_deallocate(aconn): name = aconn._prepared._names[f"select {j}".encode(), ()] assert name.startswith(b"_pg3_") - cur = await aconn.execute( - "select statement from pg_prepared_statements order by prepare_time", - prepare=False, - ) - assert await cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]] + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] async def test_different_types(aconn): aconn.prepare_threshold = 0 - await aconn.execute("select %s", [None]) + # CRDB can't roundtrip None + unk = "foo" if is_crdb(aconn) else None + await aconn.execute("select %s", [unk]) await aconn.execute("select %s", [dt.date(2000, 1, 1)]) await aconn.execute("select %s", [42]) await aconn.execute("select %s", [41]) await aconn.execute("select %s", [dt.date(2000, 1, 2)]) - cur = await aconn.execute( - "select parameter_types from pg_prepared_statements order by prepare_time", - prepare=False, - ) - assert await cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)] + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] async def test_untyped_json(aconn): @@ -189,5 +195,22 @@ async def test_untyped_json(aconn): for i in range(2): await aconn.execute("insert into testjson (data) values (%s)", ["{}"]) - cur = await aconn.execute("select parameter_types from pg_prepared_statements") - assert await cur.fetchall() == [(["jsonb"],)] + stmts = await get_prepared_statements(aconn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +async def get_prepared_statements(aconn): + cur = aconn.cursor(row_factory=namedtuple_row) + await cur.execute( + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return await cur.fetchall()