From: Daniele Varrazzo Date: Sat, 2 Sep 2023 21:22:09 +0000 (+0100) Subject: refactor: generate connection.py from connection_async X-Git-Tag: pool-3.2.0~12^2~38 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ddade696899923f229f315a66091b4074a77d6e2;p=thirdparty%2Fpsycopg.git refactor: generate connection.py from connection_async --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index ad8ff9e49..37086cc04 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -1,11 +1,15 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'connection_async.py' +# DO NOT CHANGE! Change the original file instead. """ -psycopg connection objects +psycopg async connection objects """ # Copyright (C) 2020 The Psycopg Team +from __future__ import annotations + import logging -import threading from types import TracebackType from typing import Any, Generator, Iterator, Dict, List, Optional from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING @@ -19,15 +23,17 @@ from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .cursor import Cursor from .conninfo import make_conninfo, conninfo_to_dict from ._pipeline import Pipeline from ._encodings import pgconn_encoding from .generators import notifies from .transaction import Transaction +from .cursor import Cursor from .server_cursor import ServerCursor from ._connection_base import BaseConnection, CursorRow, Notify +from threading import Lock + if TYPE_CHECKING: from .pq.abc import PGconn @@ -37,6 +43,8 @@ BINARY = pq.Format.BINARY IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS +_INTERRUPTED = KeyboardInterrupt + logger = logging.getLogger("psycopg") @@ -60,7 +68,7 @@ class Connection(BaseConnection[Row]): ): super().__init__(pgconn) self.row_factory = row_factory - self.lock = threading.Lock() + self.lock = Lock() self.cursor_factory = Cursor self.server_cursor_factory = ServerCursor @@ -71,12 +79,12 @@ class Connection(BaseConnection[Row]): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory[Row], prepare_threshold: Optional[int] = 5, + row_factory: RowFactory[Row], cursor_factory: Optional[Type[Cursor[Row]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], - ) -> "Connection[Row]": + ) -> Connection[Row]: # TODO: returned type should be _Self. See #308. ... @@ -91,7 +99,7 @@ class Connection(BaseConnection[Row]): cursor_factory: Optional[Type[Cursor[Any]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], - ) -> "Connection[TupleRow]": + ) -> Connection[TupleRow]: ... @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 @@ -101,11 +109,11 @@ class Connection(BaseConnection[Row]): *, autocommit: bool = False, prepare_threshold: Optional[int] = 5, + context: Optional[AdaptContext] = None, row_factory: Optional[RowFactory[Row]] = None, cursor_factory: Optional[Type[Cursor[Row]]] = None, - context: Optional[AdaptContext] = None, **kwargs: Any, - ) -> "Connection[Any]": + ) -> Connection[Any]: """ Connect to a database server and return a new `Connection` instance. """ @@ -345,7 +353,7 @@ class Connection(BaseConnection[Row]): """ try: return waiting.wait(gen, self.pgconn.socket, timeout=timeout) - except KeyboardInterrupt: + except _INTERRUPTED: # On Ctrl-C, try to cancel the query in the server, otherwise # the connection will remain stuck in ACTIVE state. self._try_cancel(self.pgconn) @@ -358,7 +366,7 @@ class Connection(BaseConnection[Row]): @classmethod def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: """Consume a connection generator.""" - return waiting.wait_conn(gen, timeout=timeout) + return waiting.wait_conn(gen, timeout) def _set_autocommit(self, value: bool) -> None: self.set_autocommit(value) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 92709003f..70fa198e8 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -4,8 +4,8 @@ psycopg async connection objects # Copyright (C) 2020 The Psycopg Team -import sys -import asyncio +from __future__ import annotations + import logging from types import TracebackType from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional @@ -20,7 +20,7 @@ from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async +from .conninfo import make_conninfo, conninfo_to_dict from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .generators import notifies @@ -29,6 +29,14 @@ from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor from ._connection_base import BaseConnection, CursorRow, Notify +if True: # ASYNC + import sys + import asyncio + from asyncio import Lock + from .conninfo import resolve_hostaddr_async +else: + from threading import Lock + if TYPE_CHECKING: from .pq.abc import PGconn @@ -38,12 +46,17 @@ BINARY = pq.Format.BINARY IDLE = pq.TransactionStatus.IDLE INTRANS = pq.TransactionStatus.INTRANS +if True: # ASYNC + _INTERRUPTED = (asyncio.CancelledError, KeyboardInterrupt) +else: + _INTERRUPTED = KeyboardInterrupt + logger = logging.getLogger("psycopg") class AsyncConnection(BaseConnection[Row]): """ - Asynchronous wrapper for a connection to the database. + Wrapper for a connection to the database. """ __module__ = "psycopg" @@ -61,7 +74,7 @@ class AsyncConnection(BaseConnection[Row]): ): super().__init__(pgconn) self.row_factory = row_factory - self.lock = asyncio.Lock() + self.lock = Lock() self.cursor_factory = AsyncCursor self.server_cursor_factory = AsyncServerCursor @@ -77,7 +90,7 @@ class AsyncConnection(BaseConnection[Row]): cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], - ) -> "AsyncConnection[Row]": + ) -> AsyncConnection[Row]: # TODO: returned type should be _Self. See #308. ... @@ -92,7 +105,7 @@ class AsyncConnection(BaseConnection[Row]): cursor_factory: Optional[Type[AsyncCursor[Any]]] = None, context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], - ) -> "AsyncConnection[TupleRow]": + ) -> AsyncConnection[TupleRow]: ... @classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004 @@ -106,16 +119,20 @@ class AsyncConnection(BaseConnection[Row]): row_factory: Optional[AsyncRowFactory[Row]] = None, cursor_factory: Optional[Type[AsyncCursor[Row]]] = None, **kwargs: Any, - ) -> "AsyncConnection[Any]": - if sys.platform == "win32": - loop = asyncio.get_running_loop() - if isinstance(loop, asyncio.ProactorEventLoop): - raise e.InterfaceError( - "Psycopg cannot use the 'ProactorEventLoop' to run in async" - " mode. Please use a compatible event loop, for instance by" - " setting 'asyncio.set_event_loop_policy" - "(WindowsSelectorEventLoopPolicy())'" - ) + ) -> AsyncConnection[Any]: + """ + Connect to a database server and return a new `AsyncConnection` instance. + """ + if True: # ASYNC + if sys.platform == "win32": + loop = asyncio.get_running_loop() + if isinstance(loop, asyncio.ProactorEventLoop): + raise e.InterfaceError( + "Psycopg cannot use the 'ProactorEventLoop' to run in async" + " mode. Please use a compatible event loop, for instance by" + " setting 'asyncio.set_event_loop_policy" + "(WindowsSelectorEventLoopPolicy())'" + ) params = await cls._get_connection_params(conninfo, **kwargs) conninfo = make_conninfo(**params) @@ -182,8 +199,9 @@ class AsyncConnection(BaseConnection[Row]): else: params["connect_timeout"] = None - # Resolve host addresses in non-blocking way - params = await resolve_hostaddr_async(params) + if True: # ASYNC + # Resolve host addresses in non-blocking way + params = await resolve_hostaddr_async(params) return params @@ -358,7 +376,7 @@ class AsyncConnection(BaseConnection[Row]): """ try: return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout) - except (asyncio.CancelledError, KeyboardInterrupt): + except _INTERRUPTED: # On Ctrl-C, try to cancel the query in the server, otherwise # the connection will remain stuck in ACTIVE state. self._try_cancel(self.pgconn) @@ -374,7 +392,10 @@ class AsyncConnection(BaseConnection[Row]): return await waiting.wait_conn_async(gen, timeout) def _set_autocommit(self, value: bool) -> None: - self._no_set_async("autocommit") + if True: # ASYNC + self._no_set_async("autocommit") + else: + self.set_autocommit(value) async def set_autocommit(self, value: bool) -> None: """Method version of the `~Connection.autocommit` setter.""" @@ -382,7 +403,10 @@ class AsyncConnection(BaseConnection[Row]): await self.wait(self._set_autocommit_gen(value)) def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None: - self._no_set_async("isolation_level") + if True: # ASYNC + self._no_set_async("isolation_level") + else: + self.set_isolation_level(value) async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None: """Method version of the `~Connection.isolation_level` setter.""" @@ -390,7 +414,10 @@ class AsyncConnection(BaseConnection[Row]): await self.wait(self._set_isolation_level_gen(value)) def _set_read_only(self, value: Optional[bool]) -> None: - self._no_set_async("read_only") + if True: # ASYNC + self._no_set_async("read_only") + else: + self.set_read_only(value) async def set_read_only(self, value: Optional[bool]) -> None: """Method version of the `~Connection.read_only` setter.""" @@ -398,18 +425,23 @@ class AsyncConnection(BaseConnection[Row]): await self.wait(self._set_read_only_gen(value)) def _set_deferrable(self, value: Optional[bool]) -> None: - self._no_set_async("deferrable") + if True: # ASYNC + self._no_set_async("deferrable") + else: + self.set_deferrable(value) async def set_deferrable(self, value: Optional[bool]) -> None: """Method version of the `~Connection.deferrable` setter.""" async with self.lock: await self.wait(self._set_deferrable_gen(value)) - def _no_set_async(self, attribute: str) -> None: - raise AttributeError( - f"'the {attribute!r} property is read-only on async connections:" - f" please use 'await .set_{attribute}()' instead." - ) + if True: # ASYNC + + def _no_set_async(self, attribute: str) -> None: + raise AttributeError( + f"'the {attribute!r} property is read-only on async connections:" + f" please use 'await .set_{attribute}()' instead." + ) async def tpc_begin(self, xid: Union[Xid, str]) -> None: """ diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index ada8b596d..ff70ca3a4 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -6,7 +6,9 @@ from __future__ import annotations import os import sys +from copy import deepcopy from typing import Any +from pathlib import Path from argparse import ArgumentParser, Namespace import ast_comments as ast @@ -27,12 +29,12 @@ ast.Tuple = ast_orig.Tuple def main() -> int: opt = parse_cmdline() - with open(opt.filename) as f: + with opt.filepath.open() as f: source = f.read() - tree = ast.parse(source, filename=opt.filename) - tree = async_to_sync(tree) - output = tree_to_str(tree, opt.filename) + tree = ast.parse(source, filename=str(opt.filepath)) + tree = async_to_sync(tree, filepath=opt.filepath) + output = tree_to_str(tree, opt.filepath) if opt.output: with open(opt.output, "w") as f: @@ -43,23 +45,27 @@ def main() -> int: return 0 -def async_to_sync(tree: ast.AST) -> ast.AST: +def async_to_sync(tree: ast.AST, filepath: Path | None = None) -> ast.AST: tree = BlanksInserter().visit(tree) tree = RenameAsyncToSync().visit(tree) tree = AsyncToSync().visit(tree) return tree -def tree_to_str(tree: ast.AST, filename: str) -> str: +def tree_to_str(tree: ast.AST, filepath: Path) -> str: rv = f"""\ # WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}' -# from the original file '{os.path.basename(filename)}' +# from the original file '{filepath.name}' # DO NOT CHANGE! Change the original file instead. """ rv += unparse(tree) return rv +# Hint: in order to explore the AST of a module you can run: +# python -m ast path/tp/module.py + + class AsyncToSync(ast.NodeTransformer): def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: new_node = ast.FunctionDef( @@ -103,6 +109,13 @@ class AsyncToSync(ast.NodeTransformer): self.visit(child) return node.orelse + # Manage `if True: # ASYNC` + # drop the unneeded branch + if (stmts := self._async_test_statements(node)) is not None: + for child in stmts: + self.visit(child) + return stmts + self.generic_visit(node) return node @@ -115,6 +128,27 @@ class AsyncToSync(ast.NodeTransformer): return False return True + def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None: + if not ( + isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool) + ): + return None + + if not (node.body and isinstance(node.body[0], ast.Comment)): + return None + + comment = node.body[0].value + + if not comment.startswith("# ASYNC"): + return None + + stmts: list[ast.AST] + if node.test.value: + stmts = node.orelse + else: + stmts = node.body[1:] # skip the ASYNC comment + return stmts + class RenameAsyncToSync(ast.NodeTransformer): names_map = { @@ -123,14 +157,16 @@ class RenameAsyncToSync(ast.NodeTransformer): "AsyncCopy": "Copy", "AsyncCopyWriter": "CopyWriter", "AsyncCursor": "Cursor", - "AsyncCursor": "Cursor", "AsyncFileWriter": "FileWriter", + "AsyncGenerator": "Generator", "AsyncIterator": "Iterator", "AsyncLibpqWriter": "LibpqWriter", + "AsyncPipeline": "Pipeline", "AsyncQueuedLibpqWriter": "QueuedLibpqWriter", "AsyncRawCursor": "RawCursor", "AsyncRowFactory": "RowFactory", "AsyncServerCursor": "ServerCursor", + "AsyncTransaction": "Transaction", "AsyncWriter": "Writer", "__aenter__": "__enter__", "__aexit__": "__exit__", @@ -145,23 +181,23 @@ class RenameAsyncToSync(ast.NodeTransformer): "apipeline": "pipeline", "asynccontextmanager": "contextmanager", "connection_async": "connection", + "cursor_async": "cursor", "ensure_table_async": "ensure_table", "find_insert_problem_async": "find_insert_problem", + "wait_async": "wait", + "wait_conn_async": "wait_conn", + } + _skip_imports = { + "utils": {"alist", "anext"}, } def visit_Module(self, node: ast.Module) -> ast.AST: - # Replace the content of the module docstring. - if ( - node.body - and isinstance(node.body[0], ast.Expr) - and isinstance(node.body[0].value, ast.Constant) - ): - node.body[0].value.value = node.body[0].value.value.replace("Async", "") - + self._fix_docstring(node.body) self.generic_visit(node) return node - def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST: + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + self._fix_docstring(node.body) node.name = self.names_map.get(node.name, node.name) for arg in node.args.args: arg.arg = self.names_map.get(arg.arg, arg.arg) @@ -178,7 +214,19 @@ class RenameAsyncToSync(ast.NodeTransformer): self.generic_visit(node) return node - _skip_imports = {"alist", "anext"} + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + self._fix_docstring(node.body) + self.generic_visit(node) + return node + + def _fix_docstring(self, body: list[ast.AST]) -> None: + if ( + body + and isinstance(body[0], ast.Expr) + and isinstance(body[0].value, ast.Constant) + and isinstance(body[0].value.value, str) + ): + body[0].value.value = body[0].value.value.replace("Async", "") def visit_Call(self, node: ast.Call) -> ast.AST: if isinstance(node.func, ast.Name) and node.func.id == "TypeVar": @@ -201,12 +249,13 @@ class RenameAsyncToSync(ast.NodeTransformer): def _visit_type_string(self, source: str) -> str: # Convert the string to tree, visit, and convert it back to string - tree = ast.parse(source) + tree = ast.parse(source, type_comments=False) tree = async_to_sync(tree) rv = unparse(tree) return rv def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST: + self._fix_docstring(node.body) node.name = self.names_map.get(node.name, node.name) node = self._fix_base_params(node) self.generic_visit(node) @@ -230,8 +279,8 @@ class RenameAsyncToSync(ast.NodeTransformer): def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None: # Remove import of async utils eclypsing builtins - if node.module == "utils": - node.names = [n for n in node.names if n.name not in self._skip_imports] + if skips := self._skip_imports.get(node.module): + node.names = [n for n in node.names if n.name not in skips] if not node.names: return None @@ -251,6 +300,27 @@ class RenameAsyncToSync(ast.NodeTransformer): self.generic_visit(node) return node + def visit_Subscript(self, node: ast.Subscript) -> ast.AST: + # Manage AsyncGenerator[X, Y] -> Generator[X, None, Y] + self._manage_async_generator(node) + # # Won't result in a recursion because we change the args number + # self.visit(node) + # return node + + self.generic_visit(node) + return node + + def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None: + if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"): + return None + + if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2): + return None + + node.slice.elts.insert(1, deepcopy(node.slice.elts[1])) + self.generic_visit(node) + return node + class BlanksInserter(ast.NodeTransformer): """ @@ -292,9 +362,45 @@ class BlanksInserter(ast.NodeTransformer): def unparse(tree: ast.AST) -> str: rv: str = Unparser().visit(tree) + rv = _fix_comment_on_decorators(rv) return rv +def _fix_comment_on_decorators(source: str) -> str: + """ + Re-associate comments to decorators. + + In a case like: + + 1 @deco # comment + 2 def func(x): + 3 pass + + it seems that Function lineno is 2 instead of 1 (Python 3.10). Because + the Comment lineno is 1, it ends up printed above the function, instead + of inline. This is a problem for '# type: ignore' comments. + + Maybe the problem could be fixed in the tree, but this solution is a + simpler way to start. + """ + lines = source.splitlines() + + comment_at = None + for i, line in enumerate(lines): + if line.lstrip().startswith("#"): + comment_at = i + elif not line.strip(): + pass + elif line.lstrip().startswith("@classmethod"): + if comment_at is not None: + lines[i] = lines[i] + " " + lines[comment_at].lstrip() + lines[comment_at] = "" + else: + comment_at = None + + return "\n".join(lines) + + class Unparser(ast._Unparser): """ Try to emit long strings as multiline. @@ -313,7 +419,9 @@ class Unparser(ast._Unparser): def parse_cmdline() -> Namespace: parser = ArgumentParser(description=__doc__) - parser.add_argument("filename", metavar="FILE", help="the file to process") + parser.add_argument( + "filepath", metavar="FILE", type=Path, help="the file to process" + ) parser.add_argument( "output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)" ) diff --git a/tools/convert_async_to_sync.sh b/tools/convert_async_to_sync.sh index 75274faa4..21b763526 100755 --- a/tools/convert_async_to_sync.sh +++ b/tools/convert_async_to_sync.sh @@ -18,6 +18,7 @@ fi outputs="" for async in \ + psycopg/psycopg/connection_async.py \ psycopg/psycopg/cursor_async.py \ tests/test_client_cursor_async.py \ tests/test_connection_async.py \