From: Daniele Varrazzo Date: Thu, 5 Oct 2023 23:58:54 +0000 (+0200) Subject: refactor(test): add acompat module X-Git-Tag: pool-3.2.0~12^2~14 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f9660d8d182c1b27cdb692be9bcfda69a31a4877;p=thirdparty%2Fpsycopg.git refactor(test): add acompat module --- diff --git a/tests/acompat.py b/tests/acompat.py new file mode 100644 index 000000000..256e29f7c --- /dev/null +++ b/tests/acompat.py @@ -0,0 +1,109 @@ +""" +Utilities to ease the differences between async and sync code. + +These object offer a similar interface between sync and async versions; the +script async_to_sync.py will replace the async names with the sync names +when generating the sync version. +""" + +import sys +import time +import asyncio +import inspect +import builtins +import threading +import contextlib +from typing import Any + +# Re-exports +sleep = time.sleep +Event = threading.Event +closing = contextlib.closing + + +def is_async(obj): + """Return true if obj is an async object (class, instance, module name)""" + if isinstance(obj, str): + # coming from is_async(__name__) + return "async" in obj + + if not isinstance(obj, type): + obj = type(obj) + return "Async" in obj.__name__ + + +if sys.version_info >= (3, 10): + anext = builtins.anext + aclosing = contextlib.aclosing + +else: + + async def anext(it): + return await it.__anext__() + + @contextlib.asynccontextmanager + async def aclosing(thing): + try: + yield thing + finally: + await thing.aclose() + + +async def alist(it): + """Consume an async iterator into a list. Async equivalent of list(it).""" + return [i async for i in it] + + +def spawn(f, args=None): + """ + Equivalent to asyncio.create_task or creating and running a Thread. + """ + if not args: + args = () + + if inspect.iscoroutinefunction(f): + return asyncio.create_task(f(*args)) + else: + t = threading.Thread(target=f, args=args, daemon=True) + t.start() + return t + + +def gather(*ts, return_exceptions=False, timeout=None): + """ + Equivalent to asyncio.gather or Thread.join() + """ + if ts and inspect.isawaitable(ts[0]): + rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions) + if timeout is None: + rv = asyncio.wait_for(rv, timeout) + return rv + else: + for t in ts: + t.join(timeout) + assert not t.is_alive() + + +def asleep(s): + """ + Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync. + """ + return asyncio.sleep(s) + + +def is_alive(t): + """ + Return true if an asyncio.Task or threading.Thread is alive. + """ + return t.is_alive() if isinstance(t, threading.Thread) else not t.done() + + +class AEvent(asyncio.Event): + """ + Subclass of asyncio.Event adding a wait with timeout like threading.Event. + + wait_timeout() is converted to wait() by async_to_sync. + """ + + async def wait_timeout(self, timeout): + await asyncio.wait_for(self.wait(), timeout) diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 10b9616b7..7d96d3f3d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -13,7 +13,7 @@ from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow from psycopg._compat import assert_type, Counter -from ..utils import Event, spawn, gather, sleep, is_async +from ..acompat import Event, spawn, gather, sleep, is_async from .test_pool_common import delay_connection try: diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 734a0ad1f..be018f2b6 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -10,7 +10,7 @@ from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow from psycopg._compat import assert_type, Counter -from ..utils import AEvent, spawn, gather, asleep, is_async +from ..acompat import AEvent, spawn, gather, asleep, is_async from .test_pool_common_async import delay_connection try: diff --git a/tests/pool/test_pool_common.py b/tests/pool/test_pool_common.py index 1ae013eff..e80550c32 100644 --- a/tests/pool/test_pool_common.py +++ b/tests/pool/test_pool_common.py @@ -9,7 +9,7 @@ import pytest import psycopg -from ..utils import Event, spawn, gather, sleep, is_alive, is_async +from ..acompat import Event, spawn, gather, sleep, is_alive, is_async try: import psycopg_pool as pool diff --git a/tests/pool/test_pool_common_async.py b/tests/pool/test_pool_common_async.py index 61472e0f0..6d1b9fe1b 100644 --- a/tests/pool/test_pool_common_async.py +++ b/tests/pool/test_pool_common_async.py @@ -6,7 +6,7 @@ import pytest import psycopg -from ..utils import AEvent, spawn, gather, asleep, is_alive, is_async +from ..acompat import AEvent, spawn, gather, asleep, is_alive, is_async try: import psycopg_pool as pool diff --git a/tests/pool/test_pool_null.py b/tests/pool/test_pool_null.py index 7fdd055bd..3a28beac3 100644 --- a/tests/pool/test_pool_null.py +++ b/tests/pool/test_pool_null.py @@ -12,7 +12,7 @@ from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow from psycopg._compat import assert_type -from ..utils import Event, sleep, spawn, gather, is_async +from ..acompat import Event, sleep, spawn, gather, is_async from .test_pool_common import delay_connection, ensure_waiting try: diff --git a/tests/pool/test_pool_null_async.py b/tests/pool/test_pool_null_async.py index 8525c356b..106437eb5 100644 --- a/tests/pool/test_pool_null_async.py +++ b/tests/pool/test_pool_null_async.py @@ -9,7 +9,7 @@ from psycopg.pq import TransactionStatus from psycopg.rows import class_row, Row, TupleRow from psycopg._compat import assert_type -from ..utils import AEvent, asleep, spawn, gather, is_async +from ..acompat import AEvent, asleep, spawn, gather, is_async from .test_pool_common_async import delay_connection, ensure_waiting try: diff --git a/tests/pool/test_sched.py b/tests/pool/test_sched.py index 2639dc1f8..8cfff6017 100644 --- a/tests/pool/test_sched.py +++ b/tests/pool/test_sched.py @@ -8,7 +8,7 @@ from contextlib import contextmanager import pytest -from ..utils import spawn, gather, sleep +from ..acompat import spawn, gather, sleep try: from psycopg_pool.sched import Scheduler diff --git a/tests/pool/test_sched_async.py b/tests/pool/test_sched_async.py index 23c7cefbd..16366735d 100644 --- a/tests/pool/test_sched_async.py +++ b/tests/pool/test_sched_async.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager import pytest -from ..utils import spawn, gather, asleep +from ..acompat import spawn, gather, asleep try: from psycopg_pool.sched_async import AsyncScheduler diff --git a/tests/test_connection.py b/tests/test_connection.py index f995e5b0b..dd27e6e39 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -13,7 +13,8 @@ from psycopg import Notify, pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, make_conninfo -from .utils import gc_collect, is_async +from .utils import gc_collect +from .acompat import is_async from ._test_cursor import my_row_factory from ._test_connection import tx_params, tx_params_isolation, tx_values_map from ._test_connection import conninfo_params_timeout diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 7830cf9a2..b68e36a26 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -10,7 +10,8 @@ from psycopg import Notify, pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, make_conninfo -from .utils import gc_collect, is_async +from .utils import gc_collect +from .acompat import is_async from ._test_cursor import my_row_factory from ._test_connection import tx_params, tx_params_isolation, tx_values_map from ._test_connection import conninfo_params_timeout diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 343522473..c62d9f2d9 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -17,7 +17,8 @@ from psycopg.types import TypeInfo from psycopg.types.hstore import register_hstore from psycopg.types.numeric import Int4 -from .utils import alist, eur, gc_collect, gc_count +from .utils import eur, gc_collect, gc_count +from .acompat import alist from ._test_copy import sample_text, sample_binary, sample_binary_rows # noqa from ._test_copy import sample_values, sample_records, sample_tabledef from ._test_copy import ensure_table_async, py_to_raw, special_chars, AsyncFileWriter diff --git a/tests/test_cursor_common.py b/tests/test_cursor_common.py index ccc6513d1..a5090f92c 100644 --- a/tests/test_cursor_common.py +++ b/tests/test_cursor_common.py @@ -16,7 +16,8 @@ from psycopg import sql, rows from psycopg.adapt import PyFormat from psycopg.types import TypeInfo -from .utils import gc_collect, raiseif, closing +from .utils import gc_collect, raiseif +from .acompat import closing from .fix_crdb import crdb_encoding from ._test_cursor import my_row_factory, ph from ._test_cursor import execmany, _execmany # noqa: F401 diff --git a/tests/test_cursor_common_async.py b/tests/test_cursor_common_async.py index ad3673f45..cad01cfbf 100644 --- a/tests/test_cursor_common_async.py +++ b/tests/test_cursor_common_async.py @@ -13,7 +13,8 @@ from psycopg import sql, rows from psycopg.adapt import PyFormat from psycopg.types import TypeInfo -from .utils import gc_collect, raiseif, aclosing, alist, anext +from .utils import gc_collect, raiseif +from .acompat import aclosing, alist, anext from .fix_crdb import crdb_encoding from ._test_cursor import my_row_factory, ph from ._test_cursor import execmany, _execmany # noqa: F401 diff --git a/tests/test_cursor_server_async.py b/tests/test_cursor_server_async.py index 26db29ef3..7317c52d8 100644 --- a/tests/test_cursor_server_async.py +++ b/tests/test_cursor_server_async.py @@ -4,7 +4,7 @@ import psycopg from psycopg import rows, errors as e from psycopg.pq import Format -from .utils import alist +from .acompat import alist pytestmark = pytest.mark.crdb_skip("server-side cursor") diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index da46dbc40..a40035a33 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,7 +12,7 @@ import psycopg from psycopg import pq from psycopg import errors as e -from .utils import is_async +from .acompat import is_async pytestmark = [ pytest.mark.pipeline, diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 5dcdc685a..ee0efccd9 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -9,7 +9,7 @@ import psycopg from psycopg import pq from psycopg import errors as e -from .utils import is_async, anext +from .acompat import is_async, anext pytestmark = [ pytest.mark.pipeline, diff --git a/tests/utils.py b/tests/utils.py index 060388672..c1dc86dd1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,17 +1,9 @@ import gc import re import sys -import asyncio -import inspect import operator -from typing import Any, Callable, Optional, Tuple -from threading import Thread -from contextlib import contextmanager, asynccontextmanager - -# Re-exports -from time import sleep as sleep # noqa: F401 -from threading import Event as Event # noqa: F401 -from contextlib import closing as closing # noqa: F401 +from typing import Callable, Optional, Tuple +from contextlib import contextmanager import pytest @@ -152,17 +144,6 @@ def gc_collect(): gc.collect() -def is_async(obj): - """Return true if obj is an async object (class, instance, module name)""" - if isinstance(obj, str): - # coming from is_async(__name__) - return "async" in obj - - if not isinstance(obj, type): - obj = type(obj) - return "Async" in obj.__name__ - - NO_COUNT_TYPES: Tuple[type, ...] = () if sys.version_info[:2] == (3, 10): @@ -195,28 +176,6 @@ def gc_count() -> int: return rv -async def alist(it): - """Consume an async iterator into a list. Async equivalent of list(it).""" - return [i async for i in it] - - -if sys.version_info >= (3, 10): - from builtins import anext as anext - from contextlib import aclosing as aclosing - -else: - - async def anext(it): - return await it.__anext__() - - @asynccontextmanager - async def aclosing(thing): - try: - yield thing - finally: - await thing.aclose() - - @contextmanager def raiseif(cond, *args, **kwargs): """ @@ -233,58 +192,3 @@ def raiseif(cond, *args, **kwargs): with pytest.raises(*args, **kwargs) as ex: yield ex return - - -def spawn(f, args=None): - """ - Equivalent to asyncio.create_task or creating and running a Thread. - """ - if not args: - args = () - - if inspect.iscoroutinefunction(f): - return asyncio.create_task(f(*args)) - else: - t = Thread(target=f, args=args, daemon=True) - t.start() - return t - - -def gather(*ts, return_exceptions=False, timeout=None): - """ - Equivalent to asyncio.gather or Thread.join() - """ - if ts and inspect.isawaitable(ts[0]): - rv: Any = asyncio.gather(*ts, return_exceptions=return_exceptions) - if timeout is None: - rv = asyncio.wait_for(rv, timeout) - return rv - else: - for t in ts: - t.join(timeout) - assert not t.is_alive() - - -def asleep(s): - """ - Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync. - """ - return asyncio.sleep(s) - - -def is_alive(t): - """ - Return true if an asyncio.Task or threading.Thread is alive. - """ - return t.is_alive() if isinstance(t, Thread) else not t.done() - - -class AEvent(asyncio.Event): - """ - Subclass of asyncio.Event adding a wait with timeout like threading.Event. - - wait_timeout() is converted to wait() by async_to_sync. - """ - - async def wait_timeout(self, timeout): - await asyncio.wait_for(self.wait(), timeout) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 300fc0ab0..6b028d53e 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -209,7 +209,7 @@ class RenameAsyncToSync(ast.NodeTransformer): "wait_timeout": "wait", } _skip_imports = { - "utils": {"alist", "anext"}, + "acompat": {"alist", "anext"}, } def visit_Module(self, node: ast.Module) -> ast.AST: