From: Daniele Varrazzo Date: Mon, 21 Aug 2023 17:10:25 +0000 (+0100) Subject: refactor: generate cursor.py from cursor_async X-Git-Tag: pool-3.2.0~12^2~45 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0c4499a4dfaf59b19d9f70d73b96c30002d16f8c;p=thirdparty%2Fpsycopg.git refactor: generate cursor.py from cursor_async --- diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index c26c73abf..c2544e3c8 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -1,18 +1,23 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'cursor_async.py' +# DO NOT CHANGE! Change the original file instead. """ -Psycopg Cursor object +Psycopg Cursor object. """ # Copyright (C) 2020 The Psycopg Team +from __future__ import annotations + from types import TracebackType -from typing import Any, Iterable, Iterator, List, Optional, Type, TypeVar -from typing import overload, TYPE_CHECKING +from typing import Any, Iterator, Iterable, List, Optional, Type, TypeVar +from typing import TYPE_CHECKING, overload from contextlib import contextmanager from . import pq from . import errors as e from .abc import Query, Params -from .copy import Copy, Writer as CopyWriter +from .copy import Copy, Writer from .rows import Row, RowMaker, RowFactory from ._pipeline import Pipeline from ._cursor_base import BaseCursor @@ -29,21 +34,18 @@ class Cursor(BaseCursor["Connection[Any]", Row]): _Self = TypeVar("_Self", bound="Cursor[Any]") @overload - def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): + def __init__(self: Cursor[Row], connection: Connection[Row]): ... @overload def __init__( - self: "Cursor[Row]", - connection: "Connection[Any]", - *, - row_factory: RowFactory[Row], + self: Cursor[Row], connection: Connection[Any], *, row_factory: RowFactory[Row] ): ... def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], *, row_factory: Optional[RowFactory[Row]] = None, ): @@ -102,11 +104,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): return self def executemany( - self, - query: Query, - params_seq: Iterable[Params], - *, - returning: bool = False, + self, query: Query, params_seq: Iterable[Params], *, returning: bool = False ) -> None: """ Execute the same command with a sequence of input data. @@ -157,10 +155,8 @@ class Cursor(BaseCursor["Connection[Any]", Row]): rec: Row = self._tx.load_row(0, self._make_row) # type: ignore yield rec first = False - except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) - finally: if self._pgconn.transaction_status == ACTIVE: # Try to cancel the query, then consume the results @@ -209,9 +205,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, - min(self._pos + size, self.pgresult.ntuples), - self._make_row, + self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row ) self._pos += len(records) return records @@ -263,12 +257,10 @@ class Cursor(BaseCursor["Connection[Any]", Row]): statement: Query, params: Optional[Params] = None, *, - writer: Optional[CopyWriter] = None, + writer: Optional[Writer] = None, ) -> Iterator[Copy]: """ Initiate a :sql:`COPY` operation and return an object to manage it. - - :rtype: Copy """ try: with self._conn.lock: @@ -286,7 +278,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): def _fetch_pipeline(self) -> None: if ( self._execmany_returning is not False - and not self.pgresult + and (not self.pgresult) and self._conn._pipeline ): with self._conn.lock: diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 2589aefca..f02cfc540 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -1,9 +1,11 @@ """ -Psycopg AsyncCursor object +Psycopg AsyncCursor object. """ # Copyright (C) 2020 The Psycopg Team +from __future__ import annotations + from types import TracebackType from typing import Any, AsyncIterator, Iterable, List, Optional, Type, TypeVar from typing import TYPE_CHECKING, overload @@ -12,7 +14,7 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from .abc import Query, Params -from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter +from .copy import AsyncCopy, AsyncWriter from .rows import Row, RowMaker, AsyncRowFactory from ._pipeline import Pipeline from ._cursor_base import BaseCursor @@ -29,13 +31,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): _Self = TypeVar("_Self", bound="AsyncCursor[Any]") @overload - def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): + def __init__(self: AsyncCursor[Row], connection: AsyncConnection[Row]): ... @overload def __init__( - self: "AsyncCursor[Row]", - connection: "AsyncConnection[Any]", + self: AsyncCursor[Row], + connection: AsyncConnection[Any], *, row_factory: AsyncRowFactory[Row], ): @@ -43,7 +45,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): def __init__( self, - connection: "AsyncConnection[Any]", + connection: AsyncConnection[Any], *, row_factory: Optional[AsyncRowFactory[Row]] = None, ): @@ -267,12 +269,10 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): statement: Query, params: Optional[Params] = None, *, - writer: Optional[AsyncCopyWriter] = None, + writer: Optional[AsyncWriter] = None, ) -> AsyncIterator[AsyncCopy]: """ Initiate a :sql:`COPY` operation and return an object to manage it. - - :rtype: AsyncCopy """ try: async with self._conn.lock: diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index f7f42da2d..3375656fc 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -97,6 +97,8 @@ class AsyncToSync(ast.NodeTransformer): def _is_async_call(self, test: ast.AST) -> bool: if not isinstance(test, ast.Call): return False + if not isinstance(test.func, ast.Name): + return False if test.func.id != "is_async": return False return True @@ -107,14 +109,20 @@ class RenameAsyncToSync(ast.NodeTransformer): "AsyncClientCursor": "ClientCursor", "AsyncConnection": "Connection", "AsyncCopy": "Copy", + "AsyncCopyWriter": "CopyWriter", + "AsyncCursor": "Cursor", "AsyncCursor": "Cursor", "AsyncFileWriter": "FileWriter", + "AsyncIterator": "Iterator", "AsyncLibpqWriter": "LibpqWriter", "AsyncQueuedLibpqWriter": "QueuedLibpqWriter", "AsyncRawCursor": "RawCursor", + "AsyncRowFactory": "RowFactory", "AsyncServerCursor": "ServerCursor", + "AsyncWriter": "Writer", "__aenter__": "__enter__", "__aexit__": "__exit__", + "__aiter__": "__iter__", "aclose": "close", "aclosing": "closing", "acommands": "commands", @@ -124,6 +132,8 @@ class RenameAsyncToSync(ast.NodeTransformer): "alist": "list", "anext": "next", "apipeline": "pipeline", + "asynccontextmanager": "contextmanager", + "connection_async": "connection", "ensure_table_async": "ensure_table", "find_insert_problem_async": "find_insert_problem", } @@ -159,6 +169,56 @@ class RenameAsyncToSync(ast.NodeTransformer): _skip_imports = {"alist", "anext"} + def visit_Call(self, node: ast.Call) -> ast.AST: + if isinstance(node.func, ast.Name) and node.func.id == "TypeVar": + node = self._visit_Call_TypeVar(node) + + self.generic_visit(node) + return node + + def _visit_Call_TypeVar(self, node: ast.Call) -> ast.AST: + for kw in node.keywords: + if kw.arg != "bound": + continue + if not isinstance(kw.value, ast.Constant): + continue + if not isinstance(kw.value.value, str): + continue + kw.value.value = self._visit_type_string(kw.value.value) + + return node + + def _visit_type_string(self, source: str) -> str: + # Convert the string to tree, visit, and convert it back to string + tree = ast.parse(source) + tree = async_to_sync(tree) + rv = unparse(tree) + return rv + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST: + node.name = self.names_map.get(node.name, node.name) + node = self._fix_base_params(node) + self.generic_visit(node) + return node + + def _fix_base_params(self, node: ast.ClassDef) -> ast.AST: + # Handle : + # class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): + # the base cannot be a token, even with __future__ annotation. + for base in node.bases: + if not isinstance(base, ast.Subscript): + continue + # if not isinstance(base.slice, ast.Tuple): + # ast.Tuple is typing.Tuple??? + if type(base.slice).__name__ != "Tuple": + continue + for elt in base.slice.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + continue + elt.value = self._visit_type_string(elt.value) + + return node + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None: # Remove import of async utils eclypsing builtins if node.module == "utils": @@ -166,6 +226,7 @@ class RenameAsyncToSync(ast.NodeTransformer): if not node.names: return None + node.module = self.names_map.get(node.module, node.module) for n in node.names: n.name = self.names_map.get(n.name, n.name) return node diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index 76b0e07df..75274faa4 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -18,6 +18,7 @@ fi outputs="" for async in \ + psycopg/psycopg/cursor_async.py \ tests/test_client_cursor_async.py \ tests/test_connection_async.py \ tests/test_copy_async.py \