]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(tests/crdb): generate sync from async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 10 Jun 2024 14:24:21 +0000 (16:24 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 12 Jun 2024 13:26:06 +0000 (15:26 +0200)
tests/acompat.py
tests/crdb/test_connection.py
tests/crdb/test_connection_async.py
tests/crdb/test_copy.py
tests/crdb/test_copy_async.py
tests/crdb/test_cursor.py
tests/crdb/test_cursor_async.py
tools/async_to_sync.py

index 41c2f40b4712ea22dd6208031e97d427fdfef3e7..868b682ec0e19d7e34ee748f46d7a40388840876 100644 (file)
@@ -6,8 +6,11 @@ script async_to_sync.py will replace the async names with the sync names
 when generating the sync version.
 """
 
+from __future__ import annotations
+
 import sys
 import time
+import queue
 import asyncio
 import inspect
 import builtins
@@ -113,3 +116,19 @@ class AEvent(asyncio.Event):
 
     async def wait_timeout(self, timeout):
         await asyncio.wait_for(self.wait(), timeout)
+
+
+class Queue(queue.Queue):  # type: ignore[type-arg]  # can be dropped after Python 3.8
+    """
+    A Queue subclass with an interruptible get() method.
+    """
+
+    def get(self, block: bool = True, timeout: float | None = None) -> Any:
+        # Always specify a timeout to make the wait interruptible.
+        if timeout is None:
+            timeout = 24.0 * 60.0 * 60.0
+        return super().get(block=block, timeout=timeout)
+
+
+class AQueue(asyncio.Queue):  # type: ignore[type-arg]
+    pass
index c45c801e8db1750b2ba71b766ada8abb50b60446..87a0bcf7f412db893fb1a5963e1b6f84adb02db8 100644 (file)
@@ -1,13 +1,17 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_connection_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import time
-import threading
+
+import pytest
 
 import psycopg.crdb
 from psycopg import errors as e
 from psycopg.crdb import CrdbConnection
 
-import pytest
+from ..acompat import sleep, spawn, gather
 
-pytestmark = pytest.mark.crdb
+pytestmark = [pytest.mark.crdb]
 
 
 def test_is_crdb(conn):
@@ -44,7 +48,8 @@ def test_tpc_recover(dsn):
 @pytest.mark.slow
 def test_broken_connection(conn):
     cur = conn.cursor()
-    (session_id,) = cur.execute("select session_id from [show session_id]").fetchone()
+    cur.execute("select session_id from [show session_id]")
+    (session_id,) = cur.fetchone()
     with pytest.raises(psycopg.DatabaseError):
         cur.execute("cancel session %s", [session_id])
     assert conn.closed
@@ -52,7 +57,8 @@ def test_broken_connection(conn):
 
 @pytest.mark.slow
 def test_broken(conn):
-    (session_id,) = conn.execute("show session_id").fetchone()
+    cur = conn.execute("show session_id")
+    (session_id,) = cur.fetchone()
     with pytest.raises(psycopg.OperationalError):
         conn.execute("cancel session %s", [session_id])
 
@@ -68,14 +74,14 @@ def test_broken(conn):
 def test_identify_closure(conn_cls, dsn):
     with conn_cls.connect(dsn, autocommit=True) as conn:
         with conn_cls.connect(dsn, autocommit=True) as conn2:
-            (session_id,) = conn.execute("show session_id").fetchone()
+            cur = conn.execute("show session_id")
+            (session_id,) = cur.fetchone()
 
             def closer():
-                time.sleep(0.2)
+                sleep(0.2)
                 conn2.execute("cancel session %s", [session_id])
 
-            t = threading.Thread(target=closer)
-            t.start()
+            t = spawn(closer)
             t0 = time.time()
             try:
                 with pytest.raises(psycopg.OperationalError):
@@ -84,4 +90,4 @@ def test_identify_closure(conn_cls, dsn):
                 # CRDB seems to take not less than 1s
                 assert 0.2 < dt < 2
             finally:
-                t.join()
+                gather(t)
index c28e73f581b3b9b0e10adf246550fd0367a4eec3..6de10fe52b855c52d6716c81c650ef04fe4b29e6 100644 (file)
@@ -1,13 +1,16 @@
 import time
-import asyncio
+
+import pytest
 
 import psycopg.crdb
 from psycopg import errors as e
 from psycopg.crdb import AsyncCrdbConnection
 
-import pytest
+from ..acompat import asleep, spawn, gather
 
-pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
+pytestmark = [pytest.mark.crdb]
+if True:  # ASYNC
+    pytestmark.append(pytest.mark.anyio)
 
 
 async def test_is_crdb(aconn):
@@ -17,7 +20,11 @@ async def test_is_crdb(aconn):
 
 async def test_connect(dsn):
     async with await AsyncCrdbConnection.connect(dsn) as conn:
-        assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
+        assert isinstance(conn, AsyncCrdbConnection)
+
+    if False:  # ASYNC
+        with psycopg.crdb.connect(dsn) as conn:
+            assert isinstance(conn, AsyncCrdbConnection)
 
 
 async def test_xid(dsn):
@@ -65,21 +72,22 @@ async def test_broken(aconn):
 @pytest.mark.slow
 @pytest.mark.timing
 async def test_identify_closure(aconn_cls, dsn):
-    async with await aconn_cls.connect(dsn) as conn:
-        async with await aconn_cls.connect(dsn) as conn2:
+    async with await aconn_cls.connect(dsn, autocommit=True) as conn:
+        async with await aconn_cls.connect(dsn, autocommit=True) as conn2:
             cur = await conn.execute("show session_id")
             (session_id,) = await cur.fetchone()
 
             async def closer():
-                await asyncio.sleep(0.2)
+                await asleep(0.2)
                 await conn2.execute("cancel session %s", [session_id])
 
-            t = asyncio.create_task(closer())
+            t = spawn(closer)
             t0 = time.time()
             try:
                 with pytest.raises(psycopg.OperationalError):
                     await conn.execute("select pg_sleep(3.0)")
                 dt = time.time() - t0
+                # CRDB seems to take not less than 1s
                 assert 0.2 < dt < 2
             finally:
-                await asyncio.gather(t)
+                await gather(t)
index 4100c33f605a6478d93111c9c822c23e43187849..8c09fea84af0dedbaf31db3ecb1d38a6ee485623 100644 (file)
@@ -1,6 +1,10 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_copy_async.py'
+# DO NOT CHANGE! Change the original file instead.
 import pytest
 import string
 from random import randrange, choice
+from typing import Any  # noqa: ignore
 
 from psycopg import sql, errors as e
 from psycopg.pq import Format
@@ -15,12 +19,11 @@ from .._test_copy import sample_tabledef as sample_tabledef_pg
 # CRDB int/serial are int8
 sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
 
-pytestmark = pytest.mark.crdb
+pytestmark = [pytest.mark.crdb]
 
 
 @pytest.mark.parametrize(
-    "format, buffer",
-    [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
+    "format, buffer", [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")]
 )
 def test_copy_in_buffers(conn, format, buffer):
     cur = conn.cursor()
@@ -28,7 +31,8 @@ def test_copy_in_buffers(conn, format, buffer):
     with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
         copy.write(globals()[buffer])
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -48,7 +52,8 @@ def test_copy_in_str(conn):
     with cur.copy("copy copy_in from stdin") as copy:
         copy.write(sample_text.decode())
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -78,7 +83,7 @@ def test_copy_in_empty(conn, format):
 def test_copy_big_size_record(conn):
     cur = conn.cursor()
     ensure_table(cur, "id serial primary key, data text")
-    data = "".join(chr(randrange(1, 256)) for i in range(10 * 1024 * 1024))
+    data = "".join((chr(randrange(1, 256)) for i in range(10 * 1024 * 1024)))
     with cur.copy("copy copy_in (data) from stdin") as copy:
         copy.write_row([data])
 
@@ -90,7 +95,7 @@ def test_copy_big_size_record(conn):
 def test_copy_big_size_block(conn):
     cur = conn.cursor()
     ensure_table(cur, "id serial primary key, data text")
-    data = "".join(choice(string.ascii_letters) for i in range(10 * 1024 * 1024))
+    data = "".join((choice(string.ascii_letters) for i in range(10 * 1024 * 1024)))
     copy_data = data + "\n"
     with cur.copy("copy copy_in (data) from stdin") as copy:
         copy.write(copy_data)
@@ -116,14 +121,14 @@ def test_copy_in_records(conn, format):
     ensure_table(cur, sample_tabledef)
 
     with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+        row: "tuple[Any, ...]"
         for row in sample_records:
             if format == Format.BINARY:
-                row = tuple(
-                    Int4(i) if isinstance(i, int) else i for i in row
-                )  # type: ignore[assignment]
+                row = tuple((Int4(i) if isinstance(i, int) else i for i in row))
             copy.write_row(row)
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -137,7 +142,8 @@ def test_copy_in_records_set_types(conn, format):
         for row in sample_records:
             copy.write_row(row)
 
-    data = cur.execute("select * from copy_in order by 1").fetchall()
+    cur.execute("select * from copy_in order by 1")
+    data = cur.fetchall()
     assert data == sample_records
 
 
@@ -150,7 +156,8 @@ def test_copy_in_records_binary(conn, format):
         for row in sample_records:
             copy.write_row((None, row[2]))
 
-    data = cur.execute("select col2, data from copy_in order by 2").fetchall()
+    cur.execute("select col2, data from copy_in order by 2")
+    data = cur.fetchall()
     assert data == [(None, "hello"), (None, "world")]
 
 
@@ -176,19 +183,19 @@ def test_copy_in_allchars(conn):
             copy.write_row((i, None, chr(i)))
         copy.write_row((ord(eur), None, eur))
 
-    data = cur.execute(
+    cur.execute(
         """
 select col1 = ascii(data), col2 is null, length(data), count(*)
 from copy_in group by 1, 2, 3
 """
-    ).fetchall()
+    )
+    data = cur.fetchall()
     assert data == [(True, True, 1, 256)]
 
 
 @pytest.mark.slow
 @pytest.mark.parametrize(
-    "fmt, set_types",
-    [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
+    "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)]
 )
 @pytest.mark.crdb_skip("copy array")
 def test_copy_from_leaks(conn_cls, dsn, faker, fmt, set_types, gc):
index 17a37a95fa1517fdb3875df81a997818c0826568..1199e3a7b58a3fa6d779201a54ca7cda6f688372 100644 (file)
@@ -1,18 +1,24 @@
 import pytest
 import string
 from random import randrange, choice
+from typing import Any  # noqa: ignore
 
-from psycopg.pq import Format
 from psycopg import sql, errors as e
+from psycopg.pq import Format
 from psycopg.adapt import PyFormat
 from psycopg.types.numeric import Int4
 
 from ..utils import eur
 from .._test_copy import sample_text, sample_binary  # noqa
 from .._test_copy import ensure_table_async, sample_records
-from .test_copy import sample_tabledef, copyopt
+from .._test_copy import sample_tabledef as sample_tabledef_pg
+
+# CRDB int/serial are int8
+sample_tabledef = sample_tabledef_pg.replace("int", "int4").replace("serial", "int4")
 
-pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
+pytestmark = [pytest.mark.crdb]
+if True:  # ASYNC
+    pytestmark.append(pytest.mark.anyio)
 
 
 @pytest.mark.parametrize(
@@ -115,11 +121,10 @@ async def test_copy_in_records(aconn, format):
     await ensure_table_async(cur, sample_tabledef)
 
     async with cur.copy(f"copy copy_in from stdin {copyopt(format)}") as copy:
+        row: "tuple[Any, ...]"
         for row in sample_records:
             if format == Format.BINARY:
-                row = tuple(
-                    Int4(i) if isinstance(i, int) else i for i in row
-                )  # type: ignore[assignment]
+                row = tuple(Int4(i) if isinstance(i, int) else i for i in row)
             await copy.write_row(row)
 
     await cur.execute("select * from copy_in order by 1")
@@ -232,3 +237,7 @@ async def test_copy_from_leaks(aconn_cls, dsn, faker, fmt, set_types, gc):
         n.append(gc.count())
 
     assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+
+
+def copyopt(format):
+    return "with binary" if format == Format.BINARY else ""
index d3c10e52454e7f25b0ff6808a06dd8c35701e203..3ba0a9f7c142c34e436d255105ed81b48aedd34f 100644 (file)
@@ -1,14 +1,17 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_cursor_async.py'
+# DO NOT CHANGE! Change the original file instead.
+from __future__ import annotations
+
 import json
-import threading
 from uuid import uuid4
-from queue import Queue
-from typing import Any
 
 import pytest
 from psycopg import pq, errors as e
 from psycopg.rows import namedtuple_row
+from ..acompat import Queue, spawn, gather
 
-pytestmark = pytest.mark.crdb
+pytestmark = [pytest.mark.crdb]
 
 
 @pytest.fixture
@@ -23,8 +26,8 @@ def testfeed(svcconn):
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
-    conn.autocommit = True
-    q: "Queue[Any]" = Queue()
+    conn.set_autocommit(True)
+    q = Queue()
 
     def worker():
         try:
@@ -32,26 +35,25 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
                 cur = conn.cursor(binary=fmt_out, row_factory=namedtuple_row)
                 try:
                     for row in cur.stream(f"experimental changefeed for {testfeed}"):
-                        q.put(row)
+                        q.put_nowait(row)
                 except e.QueryCanceled:
                     assert conn.info.transaction_status == conn.TransactionStatus.IDLE
-                    q.put(None)
+                    q.put_nowait(None)
         except Exception as ex:
-            q.put(ex)
+            q.put_nowait(ex)
 
-    t = threading.Thread(target=worker)
-    t.start()
+    t = spawn(worker)
 
     cur = conn.cursor()
     cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
     (key,) = cur.fetchone()
-    row = q.get(timeout=1)
+    row = q.get()
     assert row.table == testfeed
     assert json.loads(row.key) == [key]
     assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
 
     cur.execute(f"delete from {testfeed} where id = %s", [key])
-    row = q.get(timeout=1)
+    row = q.get()
     assert row.table == testfeed
     assert json.loads(row.key) == [key]
     assert json.loads(row.value)["after"] is None
@@ -64,11 +66,11 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
     # We often find the record with {"after": null} at least another time
     # in the queue. Let's tolerate an extra one.
     for i in range(2):
-        row = q.get(timeout=1)
+        row = q.get()
         if row is None:
             break
         assert json.loads(row.value)["after"] is None, json
     else:
         pytest.fail("keep on receiving messages")
 
-    t.join()
+    gather(t)
index c643a40c3c138f0614b3b25e2c46c072fa57b49a..c7848160146f1dd4fe9c2a7ef0b1a4c4c4fea278 100644 (file)
@@ -1,24 +1,32 @@
+from __future__ import annotations
+
 import json
-import asyncio
-from typing import Any
-from asyncio.queues import Queue
+from uuid import uuid4
 
 import pytest
 from psycopg import pq, errors as e
 from psycopg.rows import namedtuple_row
+from ..acompat import AQueue, spawn, gather
 
-from .test_cursor import testfeed
+pytestmark = [pytest.mark.crdb]
+if True:  # ASYNC
+    pytestmark.append(pytest.mark.anyio)
 
-testfeed  # fixture
 
-pytestmark = [pytest.mark.crdb, pytest.mark.anyio]
+@pytest.fixture
+def testfeed(svcconn):
+    name = f"test_feed_{str(uuid4()).replace('-', '')}"
+    svcconn.execute("set cluster setting kv.rangefeed.enabled to true")
+    svcconn.execute(f"create table {name} (id serial primary key, data text)")
+    yield name
+    svcconn.execute(f"drop table {name}")
 
 
 @pytest.mark.slow
 @pytest.mark.parametrize("fmt_out", pq.Format)
 async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
     await aconn.set_autocommit(True)
-    q: "Queue[Any]" = Queue()
+    q = AQueue()
 
     async def worker():
         try:
@@ -35,18 +43,18 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
         except Exception as ex:
             q.put_nowait(ex)
 
-    t = asyncio.create_task(worker())
+    t = spawn(worker)
 
     cur = aconn.cursor()
     await cur.execute(f"insert into {testfeed} (data) values ('hello') returning id")
     (key,) = await cur.fetchone()
-    row = await asyncio.wait_for(q.get(), 1.0)
+    row = await q.get()
     assert row.table == testfeed
     assert json.loads(row.key) == [key]
     assert json.loads(row.value)["after"] == {"id": key, "data": "hello"}
 
     await cur.execute(f"delete from {testfeed} where id = %s", [key])
-    row = await asyncio.wait_for(q.get(), 1.0)
+    row = await q.get()
     assert row.table == testfeed
     assert json.loads(row.key) == [key]
     assert json.loads(row.value)["after"] is None
@@ -59,11 +67,11 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
     # We often find the record with {"after": null} at least another time
     # in the queue. Let's tolerate an extra one.
     for i in range(2):
-        row = await asyncio.wait_for(q.get(), 1.0)
+        row = await q.get()
         if row is None:
             break
         assert json.loads(row.value)["after"] is None, json
     else:
         pytest.fail("keep on receiving messages")
 
-    await asyncio.gather(t)
+    await gather(t)
index ba6fd94be48c16c71a85634d4a65b371fc435464..aef571c7d0cb6d00c4b697ca8ff4a17afdb58082 100755 (executable)
@@ -37,6 +37,9 @@ ALL_INPUTS = """
     psycopg_pool/psycopg_pool/null_pool_async.py
     psycopg_pool/psycopg_pool/pool_async.py
     psycopg_pool/psycopg_pool/sched_async.py
+    tests/crdb/test_connection_async.py
+    tests/crdb/test_copy_async.py
+    tests/crdb/test_cursor_async.py
     tests/pool/test_pool_async.py
     tests/pool/test_pool_common_async.py
     tests/pool/test_pool_null_async.py
@@ -277,6 +280,7 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
         "AsyncConnectionPool": "ConnectionPool",
         "AsyncCopy": "Copy",
         "AsyncCopyWriter": "CopyWriter",
+        "AsyncCrdbConnection": "CrdbConnection",
         "AsyncCursor": "Cursor",
         "AsyncFileWriter": "FileWriter",
         "AsyncGenerator": "Generator",