]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(crdb): make prepared statements tests portable between PG and CRDB
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 May 2022 22:22:46 +0000 (00:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
tests/test_prepared.py
tests/test_prepared_async.py

index 8cacddd46e5561ef527f67ce8aacaaefc1230498..716496eab42583e74292e40735613a1c01ae51fa 100644 (file)
@@ -8,6 +8,9 @@ from decimal import Decimal
 import pytest
 
 import psycopg
+from psycopg.rows import namedtuple_row
+
+from .fix_crdb import is_crdb
 
 
 @pytest.mark.parametrize("value", [None, 0, 3])
@@ -21,23 +24,23 @@ def test_dont_prepare(conn):
     for i in range(10):
         cur.execute("select %s::int", [i], prepare=False)
 
-    cur.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (0,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 0
 
 
 def test_do_prepare(conn):
     cur = conn.cursor()
     cur.execute("select %s::int", [10], prepare=True)
-    cur.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (1,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 1
 
 
 def test_auto_prepare(conn):
-    cur = conn.cursor()
     res = []
     for i in range(10):
-        cur.execute("select count(*) from pg_prepared_statements")
-        res.append(cur.fetchone()[0])
+        conn.execute("select %s::int", [0])
+        stmts = get_prepared_statements(conn)
+        res.append(len(stmts))
 
     assert res == [0] * 5 + [1] * 5
 
@@ -46,21 +49,22 @@ def test_dont_prepare_conn(conn):
     for i in range(10):
         conn.execute("select %s::int", [i], prepare=False)
 
-    cur = conn.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (0,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 0
 
 
 def test_do_prepare_conn(conn):
     conn.execute("select %s::int", [10], prepare=True)
-    cur = conn.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (1,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 1
 
 
 def test_auto_prepare_conn(conn):
     res = []
     for i in range(10):
-        cur = conn.execute("select count(*) from pg_prepared_statements")
-        res.append(cur.fetchone()[0])
+        conn.execute("select %s", [0])
+        stmts = get_prepared_statements(conn)
+        res.append(len(stmts))
 
     assert res == [0] * 5 + [1] * 5
 
@@ -69,8 +73,9 @@ def test_prepare_disable(conn):
     conn.prepare_threshold = None
     res = []
     for i in range(10):
-        cur = conn.execute("select count(*) from pg_prepared_statements")
-        res.append(cur.fetchone()[0])
+        conn.execute("select %s", [0])
+        stmts = get_prepared_statements(conn)
+        res.append(len(stmts))
 
     assert res == [0] * 10
     assert not conn._prepared._names
@@ -80,8 +85,9 @@ def test_prepare_disable(conn):
 def test_no_prepare_multi(conn):
     res = []
     for i in range(10):
-        cur = conn.execute("select count(*) from pg_prepared_statements; select 1")
-        res.append(cur.fetchone()[0])
+        conn.execute("select 1; select 2")
+        stmts = get_prepared_statements(conn)
+        res.append(len(stmts))
 
     assert res == [0] * 10
 
@@ -92,8 +98,8 @@ def test_no_prepare_multi_with_drop(conn):
     for i in range(10):
         conn.execute("drop table if exists noprep; create table noprep()")
 
-    cur = conn.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (0,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 0
 
 
 def test_no_prepare_error(conn):
@@ -102,15 +108,17 @@ def test_no_prepare_error(conn):
         with pytest.raises(conn.ProgrammingError):
             conn.execute("select wat")
 
-    cur = conn.execute("select count(*) from pg_prepared_statements")
-    assert cur.fetchone() == (0,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 0
 
 
 @pytest.mark.parametrize(
     "query",
     [
         "create table test_no_prepare ()",
-        "notify foo, 'bar'",
+        pytest.param(
+            "notify foo, 'bar'", marks=pytest.mark.crdb("skip", reason="notify")
+        ),
         "set timezone = utc",
         "select num from prepared_test",
         "insert into prepared_test (num) values (1)",
@@ -122,8 +130,8 @@ def test_misc_statement(conn, query):
     conn.execute("create table prepared_test (num int)", prepare=False)
     conn.prepare_threshold = 0
     conn.execute(query)
-    cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
-    assert cur.fetchone() == (1,)
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 1
 
 
 def test_params_types(conn):
@@ -132,9 +140,9 @@ def test_params_types(conn):
         [dt.date(2020, 12, 10), 42, Decimal(42)],
         prepare=True,
     )
-    cur = conn.execute("select parameter_types from pg_prepared_statements")
-    (rec,) = cur.fetchall()
-    assert rec[0] == ["date", "smallint", "numeric"]
+    stmts = get_prepared_statements(conn)
+    want = [stmt.parameter_types for stmt in stmts]
+    assert want == [["date", "smallint", "numeric"]]
 
 
 def test_evict_lru(conn):
@@ -148,8 +156,9 @@ def test_evict_lru(conn):
     for i in [9, 8, 7, 6]:
         assert conn._prepared._counts[f"select {i}".encode(), ()] == 1
 
-    cur = conn.execute("select statement from pg_prepared_statements")
-    assert cur.fetchall() == [("select 'a'",)]
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 1
+    assert stmts[0].statement == "select 'a'"
 
 
 def test_evict_lru_deallocate(conn):
@@ -164,25 +173,26 @@ def test_evict_lru_deallocate(conn):
         name = conn._prepared._names[f"select {j}".encode(), ()]
         assert name.startswith(b"_pg3_")
 
-    cur = conn.execute(
-        "select statement from pg_prepared_statements order by prepare_time",
-        prepare=False,
-    )
-    assert cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]]
+    stmts = get_prepared_statements(conn)
+    stmts.sort(key=lambda rec: rec.prepare_time)
+    got = [stmt.statement for stmt in stmts]
+    assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
 
 
 def test_different_types(conn):
     conn.prepare_threshold = 0
-    conn.execute("select %s", [None])
+    # CRDB can't roundtrip None
+    unk = "foo" if is_crdb(conn) else None
+    conn.execute("select %s", [unk])
     conn.execute("select %s", [dt.date(2000, 1, 1)])
     conn.execute("select %s", [42])
     conn.execute("select %s", [41])
     conn.execute("select %s", [dt.date(2000, 1, 2)])
-    cur = conn.execute(
-        "select parameter_types from pg_prepared_statements order by prepare_time",
-        prepare=False,
-    )
-    assert cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)]
+
+    stmts = get_prepared_statements(conn)
+    stmts.sort(key=lambda rec: rec.prepare_time)
+    got = [stmt.parameter_types for stmt in stmts]
+    assert got == [["text"], ["date"], ["smallint"]]
 
 
 def test_untyped_json(conn):
@@ -192,8 +202,9 @@ def test_untyped_json(conn):
     for i in range(2):
         conn.execute("insert into testjson (data) values (%s)", ["{}"])
 
-    cur = conn.execute("select parameter_types from pg_prepared_statements")
-    assert cur.fetchall() == [(["jsonb"],)]
+    stmts = get_prepared_statements(conn)
+    got = [stmt.parameter_types for stmt in stmts]
+    assert got == [["jsonb"]]
 
 
 def test_change_type_execute(conn):
@@ -219,6 +230,7 @@ def test_change_type_executemany(conn):
         conn.rollback()
 
 
+@pytest.mark.crdb("skip", reason="can't re-create a type")
 def test_change_type(conn):
     conn.prepare_threshold = 0
     conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
@@ -236,8 +248,8 @@ def test_change_type(conn):
         {"enum_col": ["foo"]},
     )
 
-    cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
-    assert cur.fetchone()[0] == 3
+    stmts = get_prepared_statements(conn)
+    assert len(stmts) == 3
 
 
 def test_change_type_savepoint(conn):
@@ -253,3 +265,20 @@ def test_change_type_savepoint(conn):
                         {"enum_col": ["foo"]},
                     )
                     raise ZeroDivisionError()
+
+
+def get_prepared_statements(conn):
+    cur = conn.cursor(row_factory=namedtuple_row)
+    cur.execute(
+        # CRDB has 'PREPARE name AS' in the statement.
+        r"""
+select name,
+    regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+    prepare_time,
+    parameter_types
+from pg_prepared_statements
+where name != ''
+        """,
+        prepare=False,
+    )
+    return cur.fetchall()
index 330635cb41b0d3bd765ef845e97e38e42ca5210c..2983f2188179c746dfc4890ded99e0cda48f857b 100644 (file)
@@ -8,6 +8,9 @@ from decimal import Decimal
 import pytest
 
 import psycopg
+from psycopg.rows import namedtuple_row
+
+from .fix_crdb import is_crdb
 
 pytestmark = pytest.mark.asyncio
 
@@ -25,23 +28,23 @@ async def test_dont_prepare(aconn):
     for i in range(10):
         await cur.execute("select %s::int", [i], prepare=False)
 
-    await cur.execute("select count(*) from pg_prepared_statements")
-    assert await cur.fetchone() == (0,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 0
 
 
 async def test_do_prepare(aconn):
     cur = aconn.cursor()
     await cur.execute("select %s::int", [10], prepare=True)
-    await cur.execute("select count(*) from pg_prepared_statements")
-    assert await cur.fetchone() == (1,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 1
 
 
 async def test_auto_prepare(aconn):
-    cur = aconn.cursor()
     res = []
     for i in range(10):
-        await cur.execute("select count(*) from pg_prepared_statements")
-        res.append((await cur.fetchone())[0])
+        await aconn.execute("select %s::int", [0])
+        stmts = await get_prepared_statements(aconn)
+        res.append(len(stmts))
 
     assert res == [0] * 5 + [1] * 5
 
@@ -50,21 +53,22 @@ async def test_dont_prepare_conn(aconn):
     for i in range(10):
         await aconn.execute("select %s::int", [i], prepare=False)
 
-    cur = await aconn.execute("select count(*) from pg_prepared_statements")
-    assert await cur.fetchone() == (0,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 0
 
 
 async def test_do_prepare_conn(aconn):
     await aconn.execute("select %s::int", [10], prepare=True)
-    cur = await aconn.execute("select count(*) from pg_prepared_statements")
-    assert await cur.fetchone() == (1,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 1
 
 
 async def test_auto_prepare_conn(aconn):
     res = []
     for i in range(10):
-        cur = await aconn.execute("select count(*) from pg_prepared_statements")
-        res.append((await cur.fetchone())[0])
+        await aconn.execute("select %s", [0])
+        stmts = await get_prepared_statements(aconn)
+        res.append(len(stmts))
 
     assert res == [0] * 5 + [1] * 5
 
@@ -73,8 +77,9 @@ async def test_prepare_disable(aconn):
     aconn.prepare_threshold = None
     res = []
     for i in range(10):
-        cur = await aconn.execute("select count(*) from pg_prepared_statements")
-        res.append((await cur.fetchone())[0])
+        await aconn.execute("select %s", [0])
+        stmts = await get_prepared_statements(aconn)
+        res.append(len(stmts))
 
     assert res == [0] * 10
     assert not aconn._prepared._names
@@ -84,10 +89,9 @@ async def test_prepare_disable(aconn):
 async def test_no_prepare_multi(aconn):
     res = []
     for i in range(10):
-        cur = await aconn.execute(
-            "select count(*) from pg_prepared_statements; select 1"
-        )
-        res.append((await cur.fetchone())[0])
+        await aconn.execute("select 1; select 2")
+        stmts = await get_prepared_statements(aconn)
+        res.append(len(stmts))
 
     assert res == [0] * 10
 
@@ -98,15 +102,17 @@ async def test_no_prepare_error(aconn):
         with pytest.raises(aconn.ProgrammingError):
             await aconn.execute("select wat")
 
-    cur = await aconn.execute("select count(*) from pg_prepared_statements")
-    assert await cur.fetchone() == (0,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 0
 
 
 @pytest.mark.parametrize(
     "query",
     [
         "create table test_no_prepare ()",
-        "notify foo, 'bar'",
+        pytest.param(
+            "notify foo, 'bar'", marks=pytest.mark.crdb("skip", reason="notify")
+        ),
         "set timezone = utc",
         "select num from prepared_test",
         "insert into prepared_test (num) values (1)",
@@ -118,10 +124,8 @@ async def test_misc_statement(aconn, query):
     await aconn.execute("create table prepared_test (num int)", prepare=False)
     aconn.prepare_threshold = 0
     await aconn.execute(query)
-    cur = await aconn.execute(
-        "select count(*) from pg_prepared_statements", prepare=False
-    )
-    assert await cur.fetchone() == (1,)
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 1
 
 
 async def test_params_types(aconn):
@@ -130,9 +134,9 @@ async def test_params_types(aconn):
         [dt.date(2020, 12, 10), 42, Decimal(42)],
         prepare=True,
     )
-    cur = await aconn.execute("select parameter_types from pg_prepared_statements")
-    (rec,) = await cur.fetchall()
-    assert rec[0] == ["date", "smallint", "numeric"]
+    stmts = await get_prepared_statements(aconn)
+    want = [stmt.parameter_types for stmt in stmts]
+    assert want == [["date", "smallint", "numeric"]]
 
 
 async def test_evict_lru(aconn):
@@ -146,8 +150,9 @@ async def test_evict_lru(aconn):
     for i in [9, 8, 7, 6]:
         assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
 
-    cur = await aconn.execute("select statement from pg_prepared_statements")
-    assert await cur.fetchall() == [("select 'a'",)]
+    stmts = await get_prepared_statements(aconn)
+    assert len(stmts) == 1
+    assert stmts[0].statement == "select 'a'"
 
 
 async def test_evict_lru_deallocate(aconn):
@@ -162,25 +167,26 @@ async def test_evict_lru_deallocate(aconn):
         name = aconn._prepared._names[f"select {j}".encode(), ()]
         assert name.startswith(b"_pg3_")
 
-    cur = await aconn.execute(
-        "select statement from pg_prepared_statements order by prepare_time",
-        prepare=False,
-    )
-    assert await cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]]
+    stmts = await get_prepared_statements(aconn)
+    stmts.sort(key=lambda rec: rec.prepare_time)
+    got = [stmt.statement for stmt in stmts]
+    assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
 
 
 async def test_different_types(aconn):
     aconn.prepare_threshold = 0
-    await aconn.execute("select %s", [None])
+    # CRDB can't roundtrip None
+    unk = "foo" if is_crdb(aconn) else None
+    await aconn.execute("select %s", [unk])
     await aconn.execute("select %s", [dt.date(2000, 1, 1)])
     await aconn.execute("select %s", [42])
     await aconn.execute("select %s", [41])
     await aconn.execute("select %s", [dt.date(2000, 1, 2)])
-    cur = await aconn.execute(
-        "select parameter_types from pg_prepared_statements order by prepare_time",
-        prepare=False,
-    )
-    assert await cur.fetchall() == [(["text"],), (["date"],), (["smallint"],)]
+
+    stmts = await get_prepared_statements(aconn)
+    stmts.sort(key=lambda rec: rec.prepare_time)
+    got = [stmt.parameter_types for stmt in stmts]
+    assert got == [["text"], ["date"], ["smallint"]]
 
 
 async def test_untyped_json(aconn):
@@ -189,5 +195,22 @@ async def test_untyped_json(aconn):
     for i in range(2):
         await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
 
-    cur = await aconn.execute("select parameter_types from pg_prepared_statements")
-    assert await cur.fetchall() == [(["jsonb"],)]
+    stmts = await get_prepared_statements(aconn)
+    got = [stmt.parameter_types for stmt in stmts]
+    assert got == [["jsonb"]]
+
+
+async def get_prepared_statements(aconn):
+    cur = aconn.cursor(row_factory=namedtuple_row)
+    await cur.execute(
+        r"""
+select name,
+    regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+    prepare_time,
+    parameter_types
+from pg_prepared_statements
+where name != ''
+        """,
+        prepare=False,
+    )
+    return await cur.fetchall()