From 3587beb04cb61bb93a78a9afd4b581072c03f2bc Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 10 Jun 2024 16:24:21 +0200 Subject: [PATCH] refactor(tests/crdb): generate sync from async --- tests/acompat.py | 19 +++++++++++++ tests/crdb/test_connection.py | 26 +++++++++++------- tests/crdb/test_connection_async.py | 26 +++++++++++------- tests/crdb/test_copy.py | 41 +++++++++++++++++------------ tests/crdb/test_copy_async.py | 21 ++++++++++----- tests/crdb/test_cursor.py | 32 +++++++++++----------- tests/crdb/test_cursor_async.py | 32 +++++++++++++--------- tools/async_to_sync.py | 4 +++ 8 files changed, 132 insertions(+), 69 deletions(-) diff --git a/tests/acompat.py b/tests/acompat.py index 41c2f40b4..868b682ec 100644 --- a/tests/acompat.py +++ b/tests/acompat.py @@ -6,8 +6,11 @@ script async_to_sync.py will replace the async names with the sync names when generating the sync version. """ +from __future__ import annotations + import sys import time +import queue import asyncio import inspect import builtins @@ -113,3 +116,19 @@ class AEvent(asyncio.Event): async def wait_timeout(self, timeout): await asyncio.wait_for(self.wait(), timeout) + + +class Queue(queue.Queue): # type: ignore[type-arg] # can be dropped after Python 3.8 + """ + A Queue subclass with an interruptible get() method. + """ + + def get(self, block: bool = True, timeout: float | None = None) -> Any: + # Always specify a timeout to make the wait interruptible. + if timeout is None: + timeout = 24.0 * 60.0 * 60.0 + return super().get(block=block, timeout=timeout) + + +class AQueue(asyncio.Queue): # type: ignore[type-arg] + pass diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py index c45c801e8..87a0bcf7f 100644 --- a/tests/crdb/test_connection.py +++ b/tests/crdb/test_connection.py @@ -1,13 +1,17 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_connection_async.py' +# DO NOT CHANGE! Change the original file instead. import time -import threading + +import pytest import psycopg.crdb from psycopg import errors as e from psycopg.crdb import CrdbConnection -import pytest +from ..acompat import sleep, spawn, gather -pytestmark = pytest.mark.crdb +pytestmark = [pytest.mark.crdb] def test_is_crdb(conn): @@ -44,7 +48,8 @@ def test_tpc_recover(dsn): @pytest.mark.slow def test_broken_connection(conn): cur = conn.cursor() - (session_id,) = cur.execute("select session_id from [show session_id]").fetchone() + cur.execute("select session_id from [show session_id]") + (session_id,) = cur.fetchone() with pytest.raises(psycopg.DatabaseError): cur.execute("cancel session %s", [session_id]) assert conn.closed @@ -52,7 +57,8 @@ def test_broken_connection(conn): @pytest.mark.slow def test_broken(conn): - (session_id,) = conn.execute("show session_id").fetchone() + cur = conn.execute("show session_id") + (session_id,) = cur.fetchone() with pytest.raises(psycopg.OperationalError): conn.execute("cancel session %s", [session_id]) @@ -68,14 +74,14 @@ def test_broken(conn): def test_identify_closure(conn_cls, dsn): with conn_cls.connect(dsn, autocommit=True) as conn: with conn_cls.connect(dsn, autocommit=True) as conn2: - (session_id,) = conn.execute("show session_id").fetchone() + cur = conn.execute("show session_id") + (session_id,) = cur.fetchone() def closer(): - time.sleep(0.2) + sleep(0.2) conn2.execute("cancel session %s", [session_id]) - t = threading.Thread(target=closer) - t.start() + t = spawn(closer) t0 = time.time() try: with pytest.raises(psycopg.OperationalError): @@ -84,4 +90,4 @@ def test_identify_closure(conn_cls, dsn): # CRDB seems to take not less than 1s assert 0.2 < dt < 2 finally: - t.join() + gather(t) diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py index c28e73f58..6de10fe52 100644 --- a/tests/crdb/test_connection_async.py +++ b/tests/crdb/test_connection_async.py @@ -1,13 +1,16 @@ import time -import asyncio + +import pytest import psycopg.crdb from psycopg import errors as e from psycopg.crdb import AsyncCrdbConnection -import pytest +from ..acompat import asleep, spawn, gather -pytestmark = [pytest.mark.crdb, pytest.mark.anyio] +pytestmark = [pytest.mark.crdb] +if True: # ASYNC + pytestmark.append(pytest.mark.anyio) async def test_is_crdb(aconn): @@ -17,7 +20,11 @@ async def test_is_crdb(aconn): async def test_connect(dsn): async with await AsyncCrdbConnection.connect(dsn) as conn: - assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) + assert isinstance(conn, AsyncCrdbConnection) + + if False: # ASYNC + with psycopg.crdb.connect(dsn) as conn: + assert isinstance(conn, AsyncCrdbConnection) async def test_xid(dsn): @@ -65,21 +72,22 @@ async def test_broken(aconn): @pytest.mark.slow @pytest.mark.timing async def test_identify_closure(aconn_cls, dsn): - async with await aconn_cls.connect(dsn) as conn: - async with await aconn_cls.connect(dsn) as conn2: + async with await aconn_cls.connect(dsn, autocommit=True) as conn: + async with await aconn_cls.connect(dsn, autocommit=True) as conn2: cur = await conn.execute("show session_id") (session_id,) = await cur.fetchone() async def closer(): - await asyncio.sleep(0.2) + await asleep(0.2) await conn2.execute("cancel session %s", [session_id]) - t = asyncio.create_task(closer()) + t = spawn(closer) t0 = time.time() try: with pytest.raises(psycopg.OperationalError): await conn.execute("select pg_sleep(3.0)") dt = time.time() - t0 + # CRDB seems to take not less than 1s assert 0.2 < dt < 2 finally: - await asyncio.gather(t) + await gather(t) diff --git a/tests/crdb/test_copy.py b/tests/crdb/test_copy.py index 4100c33f6..8c09fea84 100644 --- a/tests/crdb/test_copy.py +++ b/tests/crdb/test_copy.py @@ -1,6 +1,10 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_copy_async.py' +# DO NOT CHANGE! Change the original file instead. import pytest import string from random import randrange, choice +from typing import Any # noqa: ignore from psycopg import sql, errors as e from psycopg.pq import Format @@ -15,12 +19,11 @@ from .._test_copy import sample_tabledef as sample_tabledef_pg # CRDB int/serial are int8 sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") -pytestmark = pytest.mark.crdb +pytestmark = [pytest.mark.crdb] @pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], + "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")] ) def test_copy_in_buffers(conn, format, buffer): cur = conn.cursor() @@ -28,7 +31,8 @@ def test_copy_in_buffers(conn, format, buffer): with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: copy.write(globals()[buffer]) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -48,7 +52,8 @@ def test_copy_in_str(conn): with cur.copy("copy copy_in from stdin") as copy: copy.write(sample_text.decode()) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -78,7 +83,7 @@ def test_copy_in_empty(conn, format): def test_copy_big_size_record(conn): cur = conn.cursor() ensure_table(cur, "id serial primary key, data text") - data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)) + data = "".join((chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))) with cur.copy("copy copy_in (data) from stdin") as copy: copy.write_row([data]) @@ -90,7 +95,7 @@ def test_copy_big_size_record(conn): def test_copy_big_size_block(conn): cur = conn.cursor() ensure_table(cur, "id serial primary key, data text") - data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024)) + data = "".join((choice(string.ascii_letters) for i in range(10 * 1024 * 1024))) copy_data = data + "\n" with cur.copy("copy copy_in (data) from stdin") as copy: copy.write(copy_data) @@ -116,14 +121,14 @@ def test_copy_in_records(conn, format): ensure_table(cur, sample_tabledef) with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + row: "tuple[Any, ...]" for row in sample_records: if format == Format.BINARY: - row = tuple( - Int4(i) if isinstance(i, int) else i for i in row - ) # type: ignore[assignment] + row = tuple((Int4(i) if isinstance(i, int) else i for i in row)) copy.write_row(row) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -137,7 +142,8 @@ def test_copy_in_records_set_types(conn, format): for row in sample_records: copy.write_row(row) - data = cur.execute("select * from copy_in order by 1").fetchall() + cur.execute("select * from copy_in order by 1") + data = cur.fetchall() assert data == sample_records @@ -150,7 +156,8 @@ def test_copy_in_records_binary(conn, format): for row in sample_records: copy.write_row((None, row[2])) - data = cur.execute("select col2, data from copy_in order by 2").fetchall() + cur.execute("select col2, data from copy_in order by 2") + data = cur.fetchall() assert data == [(None, "hello"), (None, "world")] @@ -176,19 +183,19 @@ def test_copy_in_allchars(conn): copy.write_row((i, None, chr(i))) copy.write_row((ord(eur), None, eur)) - data = cur.execute( + cur.execute( """ select col1 = ascii(data), col2 is null, length(data), count(*) from copy_in group by 1, 2, 3 """ - ).fetchall() + ) + data = cur.fetchall() assert data == [(True, True, 1, 256)] @pytest.mark.slow @pytest.mark.parametrize( - "fmt, set_types", - [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], + "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)] ) @pytest.mark.crdb_skip("copy array") def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types, gc): diff --git a/tests/crdb/test_copy_async.py b/tests/crdb/test_copy_async.py index 17a37a95f..1199e3a7b 100644 --- a/tests/crdb/test_copy_async.py +++ b/tests/crdb/test_copy_async.py @@ -1,18 +1,24 @@ import pytest import string from random import randrange, choice +from typing import Any # noqa: ignore -from psycopg.pq import Format from psycopg import sql, errors as e +from psycopg.pq import Format from psycopg.adapt import PyFormat from psycopg.types.numeric import Int4 from ..utils import eur from .._test_copy import sample_text, sample_binary # noqa from .._test_copy import ensure_table_async, sample_records -from .test_copy import sample_tabledef, copyopt +from .._test_copy import sample_tabledef as sample_tabledef_pg + +# CRDB int/serial are int8 +sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4") -pytestmark = [pytest.mark.crdb, pytest.mark.anyio] +pytestmark = [pytest.mark.crdb] +if True: # ASYNC + pytestmark.append(pytest.mark.anyio) @pytest.mark.parametrize( @@ -115,11 +121,10 @@ async def test_copy_in_records(aconn, format): await ensure_table_async(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy: + row: "tuple[Any, ...]" for row in sample_records: if format == Format.BINARY: - row = tuple( - Int4(i) if isinstance(i, int) else i for i in row - ) # type: ignore[assignment] + row = tuple(Int4(i) if isinstance(i, int) else i for i in row) await copy.write_row(row) await cur.execute("select * from copy_in order by 1") @@ -232,3 +237,7 @@ async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types, gc): n.append(gc.count()) assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def copyopt(format): + return "with binary" if format == Format.BINARY else "" diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py index d3c10e524..3ba0a9f7c 100644 --- a/tests/crdb/test_cursor.py +++ b/tests/crdb/test_cursor.py @@ -1,14 +1,17 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_cursor_async.py' +# DO NOT CHANGE! Change the original file instead. +from __future__ import annotations + import json -import threading from uuid import uuid4 -from queue import Queue -from typing import Any import pytest from psycopg import pq, errors as e from psycopg.rows import namedtuple_row +from ..acompat import Queue, spawn, gather -pytestmark = pytest.mark.crdb +pytestmark = [pytest.mark.crdb] @pytest.fixture @@ -23,8 +26,8 @@ def testfeed(svcconn): @pytest.mark.slow @pytest.mark.parametrize("fmt_out", pq.Format) def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): - conn.autocommit = True - q: "Queue[Any]" = Queue() + conn.set_autocommit(True) + q = Queue() def worker(): try: @@ -32,26 +35,25 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row) try: for row in cur.stream(f"experimental changefeed for {testfeed}"): - q.put(row) + q.put_nowait(row) except e.QueryCanceled: assert conn.info.transaction_status == conn.TransactionStatus.IDLE - q.put(None) + q.put_nowait(None) except Exception as ex: - q.put(ex) + q.put_nowait(ex) - t = threading.Thread(target=worker) - t.start() + t = spawn(worker) cur = conn.cursor() cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") (key,) = cur.fetchone() - row = q.get(timeout=1) + row = q.get() assert row.table == testfeed assert json.loads(row.key) == [key] assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} cur.execute(f"delete from {testfeed} where id = %s", [key]) - row = q.get(timeout=1) + row = q.get() assert row.table == testfeed assert json.loads(row.key) == [key] assert json.loads(row.value)["after"] is None @@ -64,11 +66,11 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): # We often find the record with {"after": null} at least another time # in the queue. Let's tolerate an extra one. for i in range(2): - row = q.get(timeout=1) + row = q.get() if row is None: break assert json.loads(row.value)["after"] is None, json else: pytest.fail("keep on receiving messages") - t.join() + gather(t) diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py index c643a40c3..c78481601 100644 --- a/tests/crdb/test_cursor_async.py +++ b/tests/crdb/test_cursor_async.py @@ -1,24 +1,32 @@ +from __future__ import annotations + import json -import asyncio -from typing import Any -from asyncio.queues import Queue +from uuid import uuid4 import pytest from psycopg import pq, errors as e from psycopg.rows import namedtuple_row +from ..acompat import AQueue, spawn, gather -from .test_cursor import testfeed +pytestmark = [pytest.mark.crdb] +if True: # ASYNC + pytestmark.append(pytest.mark.anyio) -testfeed # fixture -pytestmark = [pytest.mark.crdb, pytest.mark.anyio] +@pytest.fixture +def testfeed(svcconn): + name = f"test_feed_{str(uuid4()).replace('-', '')}" + svcconn.execute("set cluster setting kv.rangefeed.enabled to true") + svcconn.execute(f"create table {name} (id serial primary key, data text)") + yield name + svcconn.execute(f"drop table {name}") @pytest.mark.slow @pytest.mark.parametrize("fmt_out", pq.Format) async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): await aconn.set_autocommit(True) - q: "Queue[Any]" = Queue() + q = AQueue() async def worker(): try: @@ -35,18 +43,18 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): except Exception as ex: q.put_nowait(ex) - t = asyncio.create_task(worker()) + t = spawn(worker) cur = aconn.cursor() await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id") (key,) = await cur.fetchone() - row = await asyncio.wait_for(q.get(), 1.0) + row = await q.get() assert row.table == testfeed assert json.loads(row.key) == [key] assert json.loads(row.value)["after"] == {"id": key, "data": "hello"} await cur.execute(f"delete from {testfeed} where id = %s", [key]) - row = await asyncio.wait_for(q.get(), 1.0) + row = await q.get() assert row.table == testfeed assert json.loads(row.key) == [key] assert json.loads(row.value)["after"] is None @@ -59,11 +67,11 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): # We often find the record with {"after": null} at least another time # in the queue. Let's tolerate an extra one. for i in range(2): - row = await asyncio.wait_for(q.get(), 1.0) + row = await q.get() if row is None: break assert json.loads(row.value)["after"] is None, json else: pytest.fail("keep on receiving messages") - await asyncio.gather(t) + await gather(t) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index ba6fd94be..aef571c7d 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -37,6 +37,9 @@ ALL_INPUTS = """ psycopg_pool/psycopg_pool/null_pool_async.py psycopg_pool/psycopg_pool/pool_async.py psycopg_pool/psycopg_pool/sched_async.py + tests/crdb/test_connection_async.py + tests/crdb/test_copy_async.py + tests/crdb/test_cursor_async.py tests/pool/test_pool_async.py tests/pool/test_pool_common_async.py tests/pool/test_pool_null_async.py @@ -277,6 +280,7 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore "AsyncConnectionPool": "ConnectionPool", "AsyncCopy": "Copy", "AsyncCopyWriter": "CopyWriter", + "AsyncCrdbConnection": "CrdbConnection", "AsyncCursor": "Cursor", "AsyncFileWriter": "FileWriter", "AsyncGenerator": "Generator", -- 2.47.2