]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(test): add acompat module
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Oct 2023 23:58:54 +0000 (01:58 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:39 +0000 (23:45 +0200)
19 files changed:
tests/acompat.py [new file with mode: 0644]
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py
tests/pool/test_pool_null.py
tests/pool/test_pool_null_async.py
tests/pool/test_sched.py
tests/pool/test_sched_async.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_copy_async.py
tests/test_cursor_common.py
tests/test_cursor_common_async.py
tests/test_cursor_server_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py
tests/utils.py
tools/async_to_sync.py

diff --git a/tests/acompat.py b/tests/acompat.py
new file mode 100644 (file)
index 0000000..256e29f
--- /dev/null
@@ -0,0 +1,109 @@
+"""
+Utilities to ease the differences between async and sync code.
+
+These object offer a similar interface between sync and async versions; the
+script async_to_sync.py will replace the async names with the sync names
+when generating the sync version.
+"""
+
+import sys
+import time
+import asyncio
+import inspect
+import builtins
+import threading
+import contextlib
+from typing import Any
+
+# Re-exports
+sleep = time.sleep
+Event = threading.Event
+closing = contextlib.closing
+
+
+def is_async(obj):
+    """Return true if obj is an async object (class, instance, module name)"""
+    if isinstance(obj, str):
+        # coming from is_async(__name__)
+        return "async" in obj
+
+    if not isinstance(obj, type):
+        obj = type(obj)
+    return "Async" in obj.__name__
+
+
+if sys.version_info >= (3, 10):
+    anext = builtins.anext
+    aclosing = contextlib.aclosing
+
+else:
+
+    async def anext(it):
+        return await it.__anext__()
+
+    @contextlib.asynccontextmanager
+    async def aclosing(thing):
+        try:
+            yield thing
+        finally:
+            await thing.aclose()
+
+
+async def alist(it):
+    """Consume an async iterator into a list. Async equivalent of list(it)."""
+    return [i async for i in it]
+
+
+def spawn(f, args=None):
+    """
+    Equivalent to asyncio.create_task or creating and running a Thread.
+    """
+    if not args:
+        args = ()
+
+    if inspect.iscoroutinefunction(f):
+        return asyncio.create_task(f(*args))
+    else:
+        t = threading.Thread(target=f, args=args, daemon=True)
+        t.start()
+        return t
+
+
+def gather(*ts, return_exceptions=False, timeout=None):
+    """
+    Equivalent to asyncio.gather or Thread.join()
+    """
+    if ts and inspect.isawaitable(ts[0]):
+        rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions)
+        if timeout is None:
+            rv = asyncio.wait_for(rv, timeout)
+        return rv
+    else:
+        for t in ts:
+            t.join(timeout)
+            assert not t.is_alive()
+
+
+def asleep(s):
+    """
+    Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
+    """
+    return asyncio.sleep(s)
+
+
+def is_alive(t):
+    """
+    Return true if an asyncio.Task or threading.Thread is alive.
+    """
+    return t.is_alive() if isinstance(t, threading.Thread) else not t.done()
+
+
+class AEvent(asyncio.Event):
+    """
+    Subclass of asyncio.Event adding a wait with timeout like threading.Event.
+
+    wait_timeout() is converted to wait() by async_to_sync.
+    """
+
+    async def wait_timeout(self, timeout):
+        await asyncio.wait_for(self.wait(), timeout)
index 10b9616b701944e8a93fd7f7a4dd5fc9ae5ef0ac..7d96d3f3d910999a29654d995f41784d865a80c6 100644 (file)
@@ -13,7 +13,7 @@ from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type, Counter
 
-from ..utils import Event, spawn, gather, sleep, is_async
+from ..acompat import Event, spawn, gather, sleep, is_async
 from .test_pool_common import delay_connection
 
 try:
index 734a0ad1f8f642236e37f8f1579bc68587ede3ca..be018f2b6afa721b8773a10336f65ba655573eb2 100644 (file)
@@ -10,7 +10,7 @@ from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type, Counter
 
-from ..utils import AEvent, spawn, gather, asleep, is_async
+from ..acompat import AEvent, spawn, gather, asleep, is_async
 from .test_pool_common_async import delay_connection
 
 try:
index 1ae013effcb354effa38241c2c6e234c7314293a..e80550c3233d7f19827fe1c206491bed8150de33 100644 (file)
@@ -9,7 +9,7 @@ import pytest
 
 import psycopg
 
-from ..utils import Event, spawn, gather, sleep, is_alive, is_async
+from ..acompat import Event, spawn, gather, sleep, is_alive, is_async
 
 try:
     import psycopg_pool as pool
index 61472e0f0c8a9925b199ea207bcbd15969623e2e..6d1b9fe1b836e41c3965e22d11c90de349711c9f 100644 (file)
@@ -6,7 +6,7 @@ import pytest
 
 import psycopg
 
-from ..utils import AEvent, spawn, gather, asleep, is_alive, is_async
+from ..acompat import AEvent, spawn, gather, asleep, is_alive, is_async
 
 try:
     import psycopg_pool as pool
index 7fdd055bdb89edde37548be6c50f9bb733c9aa79..3a28beac3064aed5086b3450c3c99631899aebde 100644 (file)
@@ -12,7 +12,7 @@ from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type
 
-from ..utils import Event, sleep, spawn, gather, is_async
+from ..acompat import Event, sleep, spawn, gather, is_async
 from .test_pool_common import delay_connection, ensure_waiting
 
 try:
index 8525c356bbd303514a61387f183a5613a119aaa9..106437eb5a9f8ed5d250a42cfd958a6153d1a316 100644 (file)
@@ -9,7 +9,7 @@ from psycopg.pq import TransactionStatus
 from psycopg.rows import class_row, Row, TupleRow
 from psycopg._compat import assert_type
 
-from ..utils import AEvent, asleep, spawn, gather, is_async
+from ..acompat import AEvent, asleep, spawn, gather, is_async
 from .test_pool_common_async import delay_connection, ensure_waiting
 
 try:
index 2639dc1f8c7825f86d3be329bbd7fc25be50aa2e..8cfff6017b11a59397bc48c51380c15ba8951cdf 100644 (file)
@@ -8,7 +8,7 @@ from contextlib import contextmanager
 
 import pytest
 
-from ..utils import spawn, gather, sleep
+from ..acompat import spawn, gather, sleep
 
 try:
     from psycopg_pool.sched import Scheduler
index 23c7cefbd914a18b3ea305676455238b9feb8cd7..16366735d774504e1d4c95bc46d8b0667165c9e0 100644 (file)
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
 
 import pytest
 
-from ..utils import spawn, gather, asleep
+from ..acompat import spawn, gather, asleep
 
 try:
     from psycopg_pool.sched_async import AsyncScheduler
index f995e5b0befbd3322e98339fa628ae70079b1f4c..dd27e6e396468aa22b59944bbbdd930a18da250c 100644 (file)
@@ -13,7 +13,8 @@ from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
-from .utils import gc_collect, is_async
+from .utils import gc_collect
+from .acompat import is_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
 from ._test_connection import conninfo_params_timeout
index 7830cf9a212dc6fc27085811ec5bb378f9e04caf..b68e36a262111e4be2f6bc4496c11ccafcc0b94c 100644 (file)
@@ -10,7 +10,8 @@ from psycopg import Notify, pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, make_conninfo
 
-from .utils import gc_collect, is_async
+from .utils import gc_collect
+from .acompat import is_async
 from ._test_cursor import my_row_factory
 from ._test_connection import tx_params, tx_params_isolation, tx_values_map
 from ._test_connection import conninfo_params_timeout
index 3435224738ffa63bd610feec526712d94ac4e0c5..c62d9f2d9c763b91692c7d8d251479f435b07c53 100644 (file)
@@ -17,7 +17,8 @@ from psycopg.types import TypeInfo
 from psycopg.types.hstore import register_hstore
 from psycopg.types.numeric import Int4
 
-from .utils import alist, eur, gc_collect, gc_count
+from .utils import eur, gc_collect, gc_count
+from .acompat import alist
 from ._test_copy import sample_text, sample_binary, sample_binary_rows  # noqa
 from ._test_copy import sample_values, sample_records, sample_tabledef
 from ._test_copy import ensure_table_async, py_to_raw, special_chars, AsyncFileWriter
index ccc6513d16ade73dade5d51c35ec72b612f7beae..a5090f92c8e1fa6298415e70166591abaee71c1a 100644 (file)
@@ -16,7 +16,8 @@ from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 
-from .utils import gc_collect, raiseif, closing
+from .utils import gc_collect, raiseif
+from .acompat import closing
 from .fix_crdb import crdb_encoding
 from ._test_cursor import my_row_factory, ph
 from ._test_cursor import execmany, _execmany  # noqa: F401
index ad3673f45a44c24d776233ba867f07930c9b8a06..cad01cfbf7e5adc4d26b6fecfaefcb3eb34aec8b 100644 (file)
@@ -13,7 +13,8 @@ from psycopg import sql, rows
 from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 
-from .utils import gc_collect, raiseif, aclosing, alist, anext
+from .utils import gc_collect, raiseif
+from .acompat import aclosing, alist, anext
 from .fix_crdb import crdb_encoding
 from ._test_cursor import my_row_factory, ph
 from ._test_cursor import execmany, _execmany  # noqa: F401
index 26db29ef36ecf2f8e8b6d37b4a00c7fa0f3676a4..7317c52d8bdc78731fc1ed8478f17848b03f8569 100644 (file)
@@ -4,7 +4,7 @@ import psycopg
 from psycopg import rows, errors as e
 from psycopg.pq import Format
 
-from .utils import alist
+from .acompat import alist
 
 pytestmark = pytest.mark.crdb_skip("server-side cursor")
 
index da46dbc40e015e670bff2ac204add05ea2db9a3a..a40035a3313e8ed926f9d1f7facdfa0374498100 100644 (file)
@@ -12,7 +12,7 @@ import psycopg
 from psycopg import pq
 from psycopg import errors as e
 
-from .utils import is_async
+from .acompat import is_async
 
 pytestmark = [
     pytest.mark.pipeline,
index 5dcdc685a0563aa86a100a49c021f66e70f613d3..ee0efccd93987b3e2b672239aea9da817c57072d 100644 (file)
@@ -9,7 +9,7 @@ import psycopg
 from psycopg import pq
 from psycopg import errors as e
 
-from .utils import is_async, anext
+from .acompat import is_async, anext
 
 pytestmark = [
     pytest.mark.pipeline,
index 0603886724e8863f516acc313990407faac45842..c1dc86dd14c62d6198aae826bb6fda7d0b338501 100644 (file)
@@ -1,17 +1,9 @@
 import gc
 import re
 import sys
-import asyncio
-import inspect
 import operator
-from typing import Any, Callable, Optional, Tuple
-from threading import Thread
-from contextlib import contextmanager, asynccontextmanager
-
-# Re-exports
-from time import sleep as sleep  # noqa: F401
-from threading import Event as Event  # noqa: F401
-from contextlib import closing as closing  # noqa: F401
+from typing import Callable, Optional, Tuple
+from contextlib import contextmanager
 
 import pytest
 
@@ -152,17 +144,6 @@ def gc_collect():
         gc.collect()
 
 
-def is_async(obj):
-    """Return true if obj is an async object (class, instance, module name)"""
-    if isinstance(obj, str):
-        # coming from is_async(__name__)
-        return "async" in obj
-
-    if not isinstance(obj, type):
-        obj = type(obj)
-    return "Async" in obj.__name__
-
-
 NO_COUNT_TYPES: Tuple[type, ...] = ()
 
 if sys.version_info[:2] == (3, 10):
@@ -195,28 +176,6 @@ def gc_count() -> int:
     return rv
 
 
-async def alist(it):
-    """Consume an async iterator into a list. Async equivalent of list(it)."""
-    return [i async for i in it]
-
-
-if sys.version_info >= (3, 10):
-    from builtins import anext as anext
-    from contextlib import aclosing as aclosing
-
-else:
-
-    async def anext(it):
-        return await it.__anext__()
-
-    @asynccontextmanager
-    async def aclosing(thing):
-        try:
-            yield thing
-        finally:
-            await thing.aclose()
-
-
 @contextmanager
 def raiseif(cond, *args, **kwargs):
     """
@@ -233,58 +192,3 @@ def raiseif(cond, *args, **kwargs):
         with pytest.raises(*args, **kwargs) as ex:
             yield ex
         return
-
-
-def spawn(f, args=None):
-    """
-    Equivalent to asyncio.create_task or creating and running a Thread.
-    """
-    if not args:
-        args = ()
-
-    if inspect.iscoroutinefunction(f):
-        return asyncio.create_task(f(*args))
-    else:
-        t = Thread(target=f, args=args, daemon=True)
-        t.start()
-        return t
-
-
-def gather(*ts, return_exceptions=False, timeout=None):
-    """
-    Equivalent to asyncio.gather or Thread.join()
-    """
-    if ts and inspect.isawaitable(ts[0]):
-        rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions)
-        if timeout is None:
-            rv = asyncio.wait_for(rv, timeout)
-        return rv
-    else:
-        for t in ts:
-            t.join(timeout)
-            assert not t.is_alive()
-
-
-def asleep(s):
-    """
-    Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
-    """
-    return asyncio.sleep(s)
-
-
-def is_alive(t):
-    """
-    Return true if an asyncio.Task or threading.Thread is alive.
-    """
-    return t.is_alive() if isinstance(t, Thread) else not t.done()
-
-
-class AEvent(asyncio.Event):
-    """
-    Subclass of asyncio.Event adding a wait with timeout like threading.Event.
-
-    wait_timeout() is converted to wait() by async_to_sync.
-    """
-
-    async def wait_timeout(self, timeout):
-        await asyncio.wait_for(self.wait(), timeout)
index 300fc0ab07522303b8f42948fc51497308213f5a..6b028d53e5f76ce2497b882f25b96a7248263535 100755 (executable)
@@ -209,7 +209,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "wait_timeout": "wait",
     }
     _skip_imports = {
-        "utils": {"alist", "anext"},
+        "acompat": {"alist", "anext"},
     }
 
     def visit_Module(self, node: ast.Module) -> ast.AST: