]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: generate connection.py from connection_async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Sep 2023 21:22:09 +0000 (22:22 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index ad8ff9e49c14e17d8d466b05659fb216f6ea2d75..37086cc04acfd814c097067022e3054665e89ccd 100644 (file)
@@ -1,11 +1,15 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'connection_async.py'
+# DO NOT CHANGE! Change the original file instead.
 """
-psycopg connection objects
+psycopg async connection objects
 """
 
 # Copyright (C) 2020 The Psycopg Team
 
+from __future__ import annotations
+
 import logging
-import threading
 from types import TracebackType
 from typing import Any, Generator, Iterator, Dict, List, Optional
 from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
@@ -19,15 +23,17 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .cursor import Cursor
 from .conninfo import make_conninfo, conninfo_to_dict
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
 from .transaction import Transaction
+from .cursor import Cursor
 from .server_cursor import ServerCursor
 from ._connection_base import BaseConnection, CursorRow, Notify
 
+from threading import Lock
+
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
@@ -37,6 +43,8 @@ BINARY = pq.Format.BINARY
 IDLE = pq.TransactionStatus.IDLE
 INTRANS = pq.TransactionStatus.INTRANS
 
+_INTERRUPTED = KeyboardInterrupt
+
 logger = logging.getLogger("psycopg")
 
 
@@ -60,7 +68,7 @@ class Connection(BaseConnection[Row]):
     ):
         super().__init__(pgconn)
         self.row_factory = row_factory
-        self.lock = threading.Lock()
+        self.lock = Lock()
         self.cursor_factory = Cursor
         self.server_cursor_factory = ServerCursor
 
@@ -71,12 +79,12 @@ class Connection(BaseConnection[Row]):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory[Row],
         prepare_threshold: Optional[int] = 5,
+        row_factory: RowFactory[Row],
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> "Connection[Row]":
+    ) -> Connection[Row]:
         # TODO: returned type should be _Self. See #308.
         ...
 
@@ -91,7 +99,7 @@ class Connection(BaseConnection[Row]):
         cursor_factory: Optional[Type[Cursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> "Connection[TupleRow]":
+    ) -> Connection[TupleRow]:
         ...
 
     @classmethod  # type: ignore[misc] # https://github.com/python/mypy/issues/11004
@@ -101,11 +109,11 @@ class Connection(BaseConnection[Row]):
         *,
         autocommit: bool = False,
         prepare_threshold: Optional[int] = 5,
+        context: Optional[AdaptContext] = None,
         row_factory: Optional[RowFactory[Row]] = None,
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
-        context: Optional[AdaptContext] = None,
         **kwargs: Any,
-    ) -> "Connection[Any]":
+    ) -> Connection[Any]:
         """
         Connect to a database server and return a new `Connection` instance.
         """
@@ -345,7 +353,7 @@ class Connection(BaseConnection[Row]):
         """
         try:
             return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
-        except KeyboardInterrupt:
+        except _INTERRUPTED:
             # On Ctrl-C, try to cancel the query in the server, otherwise
             # the connection will remain stuck in ACTIVE state.
             self._try_cancel(self.pgconn)
@@ -358,7 +366,7 @@ class Connection(BaseConnection[Row]):
     @classmethod
     def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
         """Consume a connection generator."""
-        return waiting.wait_conn(gen, timeout=timeout)
+        return waiting.wait_conn(gen, timeout)
 
     def _set_autocommit(self, value: bool) -> None:
         self.set_autocommit(value)
index 92709003f802338262134ec691304f94bd26dd26..70fa198e85dd573075e033e657074369a9765ad0 100644 (file)
@@ -4,8 +4,8 @@ psycopg async connection objects
 
 # Copyright (C) 2020 The Psycopg Team
 
-import sys
-import asyncio
+from __future__ import annotations
+
 import logging
 from types import TracebackType
 from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from .conninfo import make_conninfo, conninfo_to_dict
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
 from .generators import notifies
@@ -29,6 +29,14 @@ from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor
 from ._connection_base import BaseConnection, CursorRow, Notify
 
+if True:  # ASYNC
+    import sys
+    import asyncio
+    from asyncio import Lock
+    from .conninfo import resolve_hostaddr_async
+else:
+    from threading import Lock
+
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
@@ -38,12 +46,17 @@ BINARY = pq.Format.BINARY
 IDLE = pq.TransactionStatus.IDLE
 INTRANS = pq.TransactionStatus.INTRANS
 
+if True:  # ASYNC
+    _INTERRUPTED = (asyncio.CancelledError, KeyboardInterrupt)
+else:
+    _INTERRUPTED = KeyboardInterrupt
+
 logger = logging.getLogger("psycopg")
 
 
 class AsyncConnection(BaseConnection[Row]):
     """
-    Asynchronous wrapper for a connection to the database.
+    Wrapper for a connection to the database.
     """
 
     __module__ = "psycopg"
@@ -61,7 +74,7 @@ class AsyncConnection(BaseConnection[Row]):
     ):
         super().__init__(pgconn)
         self.row_factory = row_factory
-        self.lock = asyncio.Lock()
+        self.lock = Lock()
         self.cursor_factory = AsyncCursor
         self.server_cursor_factory = AsyncServerCursor
 
@@ -77,7 +90,7 @@ class AsyncConnection(BaseConnection[Row]):
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> "AsyncConnection[Row]":
+    ) -> AsyncConnection[Row]:
         # TODO: returned type should be _Self. See #308.
         ...
 
@@ -92,7 +105,7 @@ class AsyncConnection(BaseConnection[Row]):
         cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> "AsyncConnection[TupleRow]":
+    ) -> AsyncConnection[TupleRow]:
         ...
 
     @classmethod  # type: ignore[misc] # https://github.com/python/mypy/issues/11004
@@ -106,16 +119,20 @@ class AsyncConnection(BaseConnection[Row]):
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         **kwargs: Any,
-    ) -> "AsyncConnection[Any]":
-        if sys.platform == "win32":
-            loop = asyncio.get_running_loop()
-            if isinstance(loop, asyncio.ProactorEventLoop):
-                raise e.InterfaceError(
-                    "Psycopg cannot use the 'ProactorEventLoop' to run in async"
-                    " mode. Please use a compatible event loop, for instance by"
-                    " setting 'asyncio.set_event_loop_policy"
-                    "(WindowsSelectorEventLoopPolicy())'"
-                )
+    ) -> AsyncConnection[Any]:
+        """
+        Connect to a database server and return a new `AsyncConnection` instance.
+        """
+        if True:  # ASYNC
+            if sys.platform == "win32":
+                loop = asyncio.get_running_loop()
+                if isinstance(loop, asyncio.ProactorEventLoop):
+                    raise e.InterfaceError(
+                        "Psycopg cannot use the 'ProactorEventLoop' to run in async"
+                        " mode. Please use a compatible event loop, for instance by"
+                        " setting 'asyncio.set_event_loop_policy"
+                        "(WindowsSelectorEventLoopPolicy())'"
+                    )
 
         params = await cls._get_connection_params(conninfo, **kwargs)
         conninfo = make_conninfo(**params)
@@ -182,8 +199,9 @@ class AsyncConnection(BaseConnection[Row]):
         else:
             params["connect_timeout"] = None
 
-        # Resolve host addresses in non-blocking way
-        params = await resolve_hostaddr_async(params)
+        if True:  # ASYNC
+            # Resolve host addresses in non-blocking way
+            params = await resolve_hostaddr_async(params)
 
         return params
 
@@ -358,7 +376,7 @@ class AsyncConnection(BaseConnection[Row]):
         """
         try:
             return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
-        except (asyncio.CancelledError, KeyboardInterrupt):
+        except _INTERRUPTED:
             # On Ctrl-C, try to cancel the query in the server, otherwise
             # the connection will remain stuck in ACTIVE state.
             self._try_cancel(self.pgconn)
@@ -374,7 +392,10 @@ class AsyncConnection(BaseConnection[Row]):
         return await waiting.wait_conn_async(gen, timeout)
 
     def _set_autocommit(self, value: bool) -> None:
-        self._no_set_async("autocommit")
+        if True:  # ASYNC
+            self._no_set_async("autocommit")
+        else:
+            self.set_autocommit(value)
 
     async def set_autocommit(self, value: bool) -> None:
         """Method version of the `~Connection.autocommit` setter."""
@@ -382,7 +403,10 @@ class AsyncConnection(BaseConnection[Row]):
             await self.wait(self._set_autocommit_gen(value))
 
     def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
-        self._no_set_async("isolation_level")
+        if True:  # ASYNC
+            self._no_set_async("isolation_level")
+        else:
+            self.set_isolation_level(value)
 
     async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         """Method version of the `~Connection.isolation_level` setter."""
@@ -390,7 +414,10 @@ class AsyncConnection(BaseConnection[Row]):
             await self.wait(self._set_isolation_level_gen(value))
 
     def _set_read_only(self, value: Optional[bool]) -> None:
-        self._no_set_async("read_only")
+        if True:  # ASYNC
+            self._no_set_async("read_only")
+        else:
+            self.set_read_only(value)
 
     async def set_read_only(self, value: Optional[bool]) -> None:
         """Method version of the `~Connection.read_only` setter."""
@@ -398,18 +425,23 @@ class AsyncConnection(BaseConnection[Row]):
             await self.wait(self._set_read_only_gen(value))
 
     def _set_deferrable(self, value: Optional[bool]) -> None:
-        self._no_set_async("deferrable")
+        if True:  # ASYNC
+            self._no_set_async("deferrable")
+        else:
+            self.set_deferrable(value)
 
     async def set_deferrable(self, value: Optional[bool]) -> None:
         """Method version of the `~Connection.deferrable` setter."""
         async with self.lock:
             await self.wait(self._set_deferrable_gen(value))
 
-    def _no_set_async(self, attribute: str) -> None:
-        raise AttributeError(
-            f"'the {attribute!r} property is read-only on async connections:"
-            f" please use 'await .set_{attribute}()' instead."
-        )
+    if True:  # ASYNC
+
+        def _no_set_async(self, attribute: str) -> None:
+            raise AttributeError(
+                f"'the {attribute!r} property is read-only on async connections:"
+                f" please use 'await .set_{attribute}()' instead."
+            )
 
     async def tpc_begin(self, xid: Union[Xid, str]) -> None:
         """
index ada8b596d06a9a7595ff99a5923ce95d79cfd39a..ff70ca3a439cef413b464ccb4cd9e95b729833cd 100755 (executable)
@@ -6,7 +6,9 @@ from __future__ import annotations
 
 import os
 import sys
+from copy import deepcopy
 from typing import Any
+from pathlib import Path
 from argparse import ArgumentParser, Namespace
 
 import ast_comments as ast
@@ -27,12 +29,12 @@ ast.Tuple = ast_orig.Tuple
 
 def main() -> int:
     opt = parse_cmdline()
-    with open(opt.filename) as f:
+    with opt.filepath.open() as f:
         source = f.read()
 
-    tree = ast.parse(source, filename=opt.filename)
-    tree = async_to_sync(tree)
-    output = tree_to_str(tree, opt.filename)
+    tree = ast.parse(source, filename=str(opt.filepath))
+    tree = async_to_sync(tree, filepath=opt.filepath)
+    output = tree_to_str(tree, opt.filepath)
 
     if opt.output:
         with open(opt.output, "w") as f:
@@ -43,23 +45,27 @@ def main() -> int:
     return 0
 
 
-def async_to_sync(tree: ast.AST) -> ast.AST:
+def async_to_sync(tree: ast.AST, filepath: Path | None = None) -> ast.AST:
     tree = BlanksInserter().visit(tree)
     tree = RenameAsyncToSync().visit(tree)
     tree = AsyncToSync().visit(tree)
     return tree
 
 
-def tree_to_str(tree: ast.AST, filename: str) -> str:
+def tree_to_str(tree: ast.AST, filepath: Path) -> str:
     rv = f"""\
 # WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}'
-# from the original file '{os.path.basename(filename)}'
+# from the original file '{filepath.name}'
 # DO NOT CHANGE! Change the original file instead.
 """
     rv += unparse(tree)
     return rv
 
 
+# Hint: in order to explore the AST of a module you can run:
+# python -m ast path/tp/module.py
+
+
 class AsyncToSync(ast.NodeTransformer):
     def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
         new_node = ast.FunctionDef(
@@ -103,6 +109,13 @@ class AsyncToSync(ast.NodeTransformer):
                 self.visit(child)
             return node.orelse
 
+        # Manage `if True:  # ASYNC`
+        # drop the unneeded branch
+        if (stmts := self._async_test_statements(node)) is not None:
+            for child in stmts:
+                self.visit(child)
+            return stmts
+
         self.generic_visit(node)
         return node
 
@@ -115,6 +128,27 @@ class AsyncToSync(ast.NodeTransformer):
             return False
         return True
 
+    def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None:
+        if not (
+            isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool)
+        ):
+            return None
+
+        if not (node.body and isinstance(node.body[0], ast.Comment)):
+            return None
+
+        comment = node.body[0].value
+
+        if not comment.startswith("# ASYNC"):
+            return None
+
+        stmts: list[ast.AST]
+        if node.test.value:
+            stmts = node.orelse
+        else:
+            stmts = node.body[1:]  # skip the ASYNC comment
+        return stmts
+
 
 class RenameAsyncToSync(ast.NodeTransformer):
     names_map = {
@@ -123,14 +157,16 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "AsyncCopy": "Copy",
         "AsyncCopyWriter": "CopyWriter",
         "AsyncCursor": "Cursor",
-        "AsyncCursor": "Cursor",
         "AsyncFileWriter": "FileWriter",
+        "AsyncGenerator": "Generator",
         "AsyncIterator": "Iterator",
         "AsyncLibpqWriter": "LibpqWriter",
+        "AsyncPipeline": "Pipeline",
         "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
         "AsyncRawCursor": "RawCursor",
         "AsyncRowFactory": "RowFactory",
         "AsyncServerCursor": "ServerCursor",
+        "AsyncTransaction": "Transaction",
         "AsyncWriter": "Writer",
         "__aenter__": "__enter__",
         "__aexit__": "__exit__",
@@ -145,23 +181,23 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "apipeline": "pipeline",
         "asynccontextmanager": "contextmanager",
         "connection_async": "connection",
+        "cursor_async": "cursor",
         "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
+        "wait_async": "wait",
+        "wait_conn_async": "wait_conn",
+    }
+    _skip_imports = {
+        "utils": {"alist", "anext"},
     }
 
     def visit_Module(self, node: ast.Module) -> ast.AST:
-        # Replace the content of the module docstring.
-        if (
-            node.body
-            and isinstance(node.body[0], ast.Expr)
-            and isinstance(node.body[0].value, ast.Constant)
-        ):
-            node.body[0].value.value = node.body[0].value.value.replace("Async", "")
-
+        self._fix_docstring(node.body)
         self.generic_visit(node)
         return node
 
-    def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
+        self._fix_docstring(node.body)
         node.name = self.names_map.get(node.name, node.name)
         for arg in node.args.args:
             arg.arg = self.names_map.get(arg.arg, arg.arg)
@@ -178,7 +214,19 @@ class RenameAsyncToSync(ast.NodeTransformer):
         self.generic_visit(node)
         return node
 
-    _skip_imports = {"alist", "anext"}
+    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+        self._fix_docstring(node.body)
+        self.generic_visit(node)
+        return node
+
+    def _fix_docstring(self, body: list[ast.AST]) -> None:
+        if (
+            body
+            and isinstance(body[0], ast.Expr)
+            and isinstance(body[0].value, ast.Constant)
+            and isinstance(body[0].value.value, str)
+        ):
+            body[0].value.value = body[0].value.value.replace("Async", "")
 
     def visit_Call(self, node: ast.Call) -> ast.AST:
         if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
@@ -201,12 +249,13 @@ class RenameAsyncToSync(ast.NodeTransformer):
 
     def _visit_type_string(self, source: str) -> str:
         # Convert the string to tree, visit, and convert it back to string
-        tree = ast.parse(source)
+        tree = ast.parse(source, type_comments=False)
         tree = async_to_sync(tree)
         rv = unparse(tree)
         return rv
 
     def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST:
+        self._fix_docstring(node.body)
         node.name = self.names_map.get(node.name, node.name)
         node = self._fix_base_params(node)
         self.generic_visit(node)
@@ -230,8 +279,8 @@ class RenameAsyncToSync(ast.NodeTransformer):
 
     def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
         # Remove import of async utils eclypsing builtins
-        if node.module == "utils":
-            node.names = [n for n in node.names if n.name not in self._skip_imports]
+        if skips := self._skip_imports.get(node.module):
+            node.names = [n for n in node.names if n.name not in skips]
             if not node.names:
                 return None
 
@@ -251,6 +300,27 @@ class RenameAsyncToSync(ast.NodeTransformer):
         self.generic_visit(node)
         return node
 
+    def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
+        # Manage AsyncGenerator[X, Y] -> Generator[X, None, Y]
+        self._manage_async_generator(node)
+        # # Won't result in a recursion because we change the args number
+        # self.visit(node)
+        # return node
+
+        self.generic_visit(node)
+        return node
+
+    def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None:
+        if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"):
+            return None
+
+        if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2):
+            return None
+
+        node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
+        self.generic_visit(node)
+        return node
+
 
 class BlanksInserter(ast.NodeTransformer):
     """
@@ -292,9 +362,45 @@ class BlanksInserter(ast.NodeTransformer):
 
 def unparse(tree: ast.AST) -> str:
     rv: str = Unparser().visit(tree)
+    rv = _fix_comment_on_decorators(rv)
     return rv
 
 
+def _fix_comment_on_decorators(source: str) -> str:
+    """
+    Re-associate comments to decorators.
+
+    In a case like:
+
+        1  @deco  # comment
+        2  def func(x):
+        3     pass
+
+    it seems that Function lineno is 2 instead of 1 (Python 3.10). Because
+    the Comment lineno is 1, it ends up printed above the function, instead
+    of inline. This is a problem for '# type: ignore' comments.
+
+    Maybe the problem could be fixed in the tree, but this solution is a
+    simpler way to start.
+    """
+    lines = source.splitlines()
+
+    comment_at = None
+    for i, line in enumerate(lines):
+        if line.lstrip().startswith("#"):
+            comment_at = i
+        elif not line.strip():
+            pass
+        elif line.lstrip().startswith("@classmethod"):
+            if comment_at is not None:
+                lines[i] = lines[i] + "  " + lines[comment_at].lstrip()
+                lines[comment_at] = ""
+        else:
+            comment_at = None
+
+    return "\n".join(lines)
+
+
 class Unparser(ast._Unparser):
     """
     Try to emit long strings as multiline.
@@ -313,7 +419,9 @@ class Unparser(ast._Unparser):
 
 def parse_cmdline() -> Namespace:
     parser = ArgumentParser(description=__doc__)
-    parser.add_argument("filename", metavar="FILE", help="the file to process")
+    parser.add_argument(
+        "filepath", metavar="FILE", type=Path, help="the file to process"
+    )
     parser.add_argument(
         "output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)"
     )
index 75274faa42d93d3cda03e713d707c0564500d152..21b76352658d806e747d51796aa96d5e58230095 100755 (executable)
@@ -18,6 +18,7 @@ fi
 outputs=""
 
 for async in \
+    psycopg/psycopg/connection_async.py \
     psycopg/psycopg/cursor_async.py \
     tests/test_client_cursor_async.py \
     tests/test_connection_async.py \