]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: generate cursor.py from cursor_async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 21 Aug 2023 17:10:25 +0000 (18:10 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tools/async_to_sync.py
tools/convert_async_to_sync.sh

index c26c73abfbdf24f68049b4b56631a408b2a4482e..c2544e3c8a5cf6cade5dbe1b8f23b63843c68de4 100644 (file)
@@ -1,18 +1,23 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'cursor_async.py'
+# DO NOT CHANGE! Change the original file instead.
 """
-Psycopg Cursor object
+Psycopg Cursor object.
 """
 
 # Copyright (C) 2020 The Psycopg Team
 
+from __future__ import annotations
+
 from types import TracebackType
-from typing import Any, Iterable, Iterator, List, Optional, Type, TypeVar
-from typing import overload, TYPE_CHECKING
+from typing import Any, Iterator, Iterable, List, Optional, Type, TypeVar
+from typing import TYPE_CHECKING, overload
 from contextlib import contextmanager
 
 from . import pq
 from . import errors as e
 from .abc import Query, Params
-from .copy import Copy, Writer as CopyWriter
+from .copy import Copy, Writer
 from .rows import Row, RowMaker, RowFactory
 from ._pipeline import Pipeline
 from ._cursor_base import BaseCursor
@@ -29,21 +34,18 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     _Self = TypeVar("_Self", bound="Cursor[Any]")
 
     @overload
-    def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
+    def __init__(self: Cursor[Row], connection: Connection[Row]):
         ...
 
     @overload
     def __init__(
-        self: "Cursor[Row]",
-        connection: "Connection[Any]",
-        *,
-        row_factory: RowFactory[Row],
+        self: Cursor[Row], connection: Connection[Any], *, row_factory: RowFactory[Row]
     ):
         ...
 
     def __init__(
         self,
-        connection: "Connection[Any]",
+        connection: Connection[Any],
         *,
         row_factory: Optional[RowFactory[Row]] = None,
     ):
@@ -102,11 +104,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         return self
 
     def executemany(
-        self,
-        query: Query,
-        params_seq: Iterable[Params],
-        *,
-        returning: bool = False,
+        self, query: Query, params_seq: Iterable[Params], *, returning: bool = False
     ) -> None:
         """
         Execute the same command with a sequence of input data.
@@ -157,10 +155,8 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     rec: Row = self._tx.load_row(0, self._make_row)  # type: ignore
                     yield rec
                     first = False
-
             except e._NO_TRACEBACK as ex:
                 raise ex.with_traceback(None)
-
             finally:
                 if self._pgconn.transaction_status == ACTIVE:
                     # Try to cancel the query, then consume the results
@@ -209,9 +205,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos,
-            min(self._pos + size, self.pgresult.ntuples),
-            self._make_row,
+            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -263,12 +257,10 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         statement: Query,
         params: Optional[Params] = None,
         *,
-        writer: Optional[CopyWriter] = None,
+        writer: Optional[Writer] = None,
     ) -> Iterator[Copy]:
         """
         Initiate a :sql:`COPY` operation and return an object to manage it.
-
-        :rtype: Copy
         """
         try:
             with self._conn.lock:
@@ -286,7 +278,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     def _fetch_pipeline(self) -> None:
         if (
             self._execmany_returning is not False
-            and not self.pgresult
+            and (not self.pgresult)
             and self._conn._pipeline
         ):
             with self._conn.lock:
index 2589aefcabb56d044279d9d9d0121006ea030b33..f02cfc54028e4d34bf33089c17b0fbd4c0f03e5c 100644 (file)
@@ -1,9 +1,11 @@
 """
-Psycopg AsyncCursor object
+Psycopg AsyncCursor object.
 """
 
 # Copyright (C) 2020 The Psycopg Team
 
+from __future__ import annotations
+
 from types import TracebackType
 from typing import Any, AsyncIterator, Iterable, List, Optional, Type, TypeVar
 from typing import TYPE_CHECKING, overload
@@ -12,7 +14,7 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from .abc import Query, Params
-from .copy import AsyncCopy, AsyncWriter as AsyncCopyWriter
+from .copy import AsyncCopy, AsyncWriter
 from .rows import Row, RowMaker, AsyncRowFactory
 from ._pipeline import Pipeline
 from ._cursor_base import BaseCursor
@@ -29,13 +31,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     _Self = TypeVar("_Self", bound="AsyncCursor[Any]")
 
     @overload
-    def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
+    def __init__(self: AsyncCursor[Row], connection: AsyncConnection[Row]):
         ...
 
     @overload
     def __init__(
-        self: "AsyncCursor[Row]",
-        connection: "AsyncConnection[Any]",
+        self: AsyncCursor[Row],
+        connection: AsyncConnection[Any],
         *,
         row_factory: AsyncRowFactory[Row],
     ):
@@ -43,7 +45,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
 
     def __init__(
         self,
-        connection: "AsyncConnection[Any]",
+        connection: AsyncConnection[Any],
         *,
         row_factory: Optional[AsyncRowFactory[Row]] = None,
     ):
@@ -267,12 +269,10 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         statement: Query,
         params: Optional[Params] = None,
         *,
-        writer: Optional[AsyncCopyWriter] = None,
+        writer: Optional[AsyncWriter] = None,
     ) -> AsyncIterator[AsyncCopy]:
         """
         Initiate a :sql:`COPY` operation and return an object to manage it.
-
-        :rtype: AsyncCopy
         """
         try:
             async with self._conn.lock:
index f7f42da2d0512d5580336680d4afe3ef65c98c7a..3375656fcbc01cd6c3eae2ef3f0ab5dd0e46b9b5 100755 (executable)
@@ -97,6 +97,8 @@ class AsyncToSync(ast.NodeTransformer):
     def _is_async_call(self, test: ast.AST) -> bool:
         if not isinstance(test, ast.Call):
             return False
+        if not isinstance(test.func, ast.Name):
+            return False
         if test.func.id != "is_async":
             return False
         return True
@@ -107,14 +109,20 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "AsyncClientCursor": "ClientCursor",
         "AsyncConnection": "Connection",
         "AsyncCopy": "Copy",
+        "AsyncCopyWriter": "CopyWriter",
+        "AsyncCursor": "Cursor",
         "AsyncCursor": "Cursor",
         "AsyncFileWriter": "FileWriter",
+        "AsyncIterator": "Iterator",
         "AsyncLibpqWriter": "LibpqWriter",
         "AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
         "AsyncRawCursor": "RawCursor",
+        "AsyncRowFactory": "RowFactory",
         "AsyncServerCursor": "ServerCursor",
+        "AsyncWriter": "Writer",
         "__aenter__": "__enter__",
         "__aexit__": "__exit__",
+        "__aiter__": "__iter__",
         "aclose": "close",
         "aclosing": "closing",
         "acommands": "commands",
@@ -124,6 +132,8 @@ class RenameAsyncToSync(ast.NodeTransformer):
         "alist": "list",
         "anext": "next",
         "apipeline": "pipeline",
+        "asynccontextmanager": "contextmanager",
+        "connection_async": "connection",
         "ensure_table_async": "ensure_table",
         "find_insert_problem_async": "find_insert_problem",
     }
@@ -159,6 +169,56 @@ class RenameAsyncToSync(ast.NodeTransformer):
 
     _skip_imports = {"alist", "anext"}
 
+    def visit_Call(self, node: ast.Call) -> ast.AST:
+        if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
+            node = self._visit_Call_TypeVar(node)
+
+        self.generic_visit(node)
+        return node
+
+    def _visit_Call_TypeVar(self, node: ast.Call) -> ast.AST:
+        for kw in node.keywords:
+            if kw.arg != "bound":
+                continue
+            if not isinstance(kw.value, ast.Constant):
+                continue
+            if not isinstance(kw.value.value, str):
+                continue
+            kw.value.value = self._visit_type_string(kw.value.value)
+
+        return node
+
+    def _visit_type_string(self, source: str) -> str:
+        # Convert the string to tree, visit, and convert it back to string
+        tree = ast.parse(source)
+        tree = async_to_sync(tree)
+        rv = unparse(tree)
+        return rv
+
+    def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST:
+        node.name = self.names_map.get(node.name, node.name)
+        node = self._fix_base_params(node)
+        self.generic_visit(node)
+        return node
+
+    def _fix_base_params(self, node: ast.ClassDef) -> ast.AST:
+        # Handle :
+        #   class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
+        # the base cannot be a token, even with __future__ annotation.
+        for base in node.bases:
+            if not isinstance(base, ast.Subscript):
+                continue
+            # if not isinstance(base.slice, ast.Tuple):
+            # ast.Tuple is typing.Tuple???
+            if type(base.slice).__name__ != "Tuple":
+                continue
+            for elt in base.slice.elts:
+                if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)):
+                    continue
+                elt.value = self._visit_type_string(elt.value)
+
+        return node
+
     def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
         # Remove import of async utils eclypsing builtins
         if node.module == "utils":
@@ -166,6 +226,7 @@ class RenameAsyncToSync(ast.NodeTransformer):
             if not node.names:
                 return None
 
+        node.module = self.names_map.get(node.module, node.module)
         for n in node.names:
             n.name = self.names_map.get(n.name, n.name)
         return node
index 76b0e07dff761f657ee18a1be4bed564d8492264..75274faa42d93d3cda03e713d707c0564500d152 100755 (executable)
@@ -18,6 +18,7 @@ fi
 outputs=""
 
 for async in \
+    psycopg/psycopg/cursor_async.py \
     tests/test_client_cursor_async.py \
     tests/test_connection_async.py \
     tests/test_copy_async.py \