+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'connection_async.py'
+# DO NOT CHANGE! Change the original file instead.
"""
-psycopg connection objects
+psycopg async connection objects
"""
# Copyright (C) 2020 The Psycopg Team
+from __future__ import annotations
+
import logging
-import threading
from types import TracebackType
from typing import Any, Generator, Iterator, Dict, List, Optional
from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
-from .cursor import Cursor
from .conninfo import make_conninfo, conninfo_to_dict
from ._pipeline import Pipeline
from ._encodings import pgconn_encoding
from .generators import notifies
from .transaction import Transaction
+from .cursor import Cursor
from .server_cursor import ServerCursor
from ._connection_base import BaseConnection, CursorRow, Notify
+from threading import Lock
+
if TYPE_CHECKING:
from .pq.abc import PGconn
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
+_INTERRUPTED = KeyboardInterrupt
+
logger = logging.getLogger("psycopg")
):
super().__init__(pgconn)
self.row_factory = row_factory
- self.lock = threading.Lock()
+ self.lock = Lock()
self.cursor_factory = Cursor
self.server_cursor_factory = ServerCursor
conninfo: str = "",
*,
autocommit: bool = False,
- row_factory: RowFactory[Row],
prepare_threshold: Optional[int] = 5,
+ row_factory: RowFactory[Row],
cursor_factory: Optional[Type[Cursor[Row]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
- ) -> "Connection[Row]":
+ ) -> Connection[Row]:
# TODO: returned type should be _Self. See #308.
...
cursor_factory: Optional[Type[Cursor[Any]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
- ) -> "Connection[TupleRow]":
+ ) -> Connection[TupleRow]:
...
@classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
*,
autocommit: bool = False,
prepare_threshold: Optional[int] = 5,
+ context: Optional[AdaptContext] = None,
row_factory: Optional[RowFactory[Row]] = None,
cursor_factory: Optional[Type[Cursor[Row]]] = None,
- context: Optional[AdaptContext] = None,
**kwargs: Any,
- ) -> "Connection[Any]":
+ ) -> Connection[Any]:
"""
Connect to a database server and return a new `Connection` instance.
"""
"""
try:
return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
- except KeyboardInterrupt:
+ except _INTERRUPTED:
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(self.pgconn)
@classmethod
def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
"""Consume a connection generator."""
- return waiting.wait_conn(gen, timeout=timeout)
+ return waiting.wait_conn(gen, timeout)
def _set_autocommit(self, value: bool) -> None:
self.set_autocommit(value)
# Copyright (C) 2020 The Psycopg Team
-import sys
-import asyncio
+from __future__ import annotations
+
import logging
from types import TracebackType
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
from .adapt import AdaptersMap
from ._enums import IsolationLevel
-from .conninfo import make_conninfo, conninfo_to_dict, resolve_hostaddr_async
+from .conninfo import make_conninfo, conninfo_to_dict
from ._pipeline import AsyncPipeline
from ._encodings import pgconn_encoding
from .generators import notifies
from .server_cursor import AsyncServerCursor
from ._connection_base import BaseConnection, CursorRow, Notify
+if True: # ASYNC
+ import sys
+ import asyncio
+ from asyncio import Lock
+ from .conninfo import resolve_hostaddr_async
+else:
+ from threading import Lock
+
if TYPE_CHECKING:
from .pq.abc import PGconn
IDLE = pq.TransactionStatus.IDLE
INTRANS = pq.TransactionStatus.INTRANS
+if True: # ASYNC
+ _INTERRUPTED = (asyncio.CancelledError, KeyboardInterrupt)
+else:
+ _INTERRUPTED = KeyboardInterrupt
+
logger = logging.getLogger("psycopg")
class AsyncConnection(BaseConnection[Row]):
"""
- Asynchronous wrapper for a connection to the database.
+ Wrapper for a connection to the database.
"""
__module__ = "psycopg"
):
super().__init__(pgconn)
self.row_factory = row_factory
- self.lock = asyncio.Lock()
+ self.lock = Lock()
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
- ) -> "AsyncConnection[Row]":
+ ) -> AsyncConnection[Row]:
# TODO: returned type should be _Self. See #308.
...
cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
context: Optional[AdaptContext] = None,
**kwargs: Union[None, int, str],
- ) -> "AsyncConnection[TupleRow]":
+ ) -> AsyncConnection[TupleRow]:
...
@classmethod # type: ignore[misc] # https://github.com/python/mypy/issues/11004
row_factory: Optional[AsyncRowFactory[Row]] = None,
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
**kwargs: Any,
- ) -> "AsyncConnection[Any]":
- if sys.platform == "win32":
- loop = asyncio.get_running_loop()
- if isinstance(loop, asyncio.ProactorEventLoop):
- raise e.InterfaceError(
- "Psycopg cannot use the 'ProactorEventLoop' to run in async"
- " mode. Please use a compatible event loop, for instance by"
- " setting 'asyncio.set_event_loop_policy"
- "(WindowsSelectorEventLoopPolicy())'"
- )
+ ) -> AsyncConnection[Any]:
+ """
+ Connect to a database server and return a new `AsyncConnection` instance.
+ """
+ if True: # ASYNC
+ if sys.platform == "win32":
+ loop = asyncio.get_running_loop()
+ if isinstance(loop, asyncio.ProactorEventLoop):
+ raise e.InterfaceError(
+ "Psycopg cannot use the 'ProactorEventLoop' to run in async"
+ " mode. Please use a compatible event loop, for instance by"
+ " setting 'asyncio.set_event_loop_policy"
+ "(WindowsSelectorEventLoopPolicy())'"
+ )
params = await cls._get_connection_params(conninfo, **kwargs)
conninfo = make_conninfo(**params)
else:
params["connect_timeout"] = None
- # Resolve host addresses in non-blocking way
- params = await resolve_hostaddr_async(params)
+ if True: # ASYNC
+ # Resolve host addresses in non-blocking way
+ params = await resolve_hostaddr_async(params)
return params
"""
try:
return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
- except (asyncio.CancelledError, KeyboardInterrupt):
+ except _INTERRUPTED:
# On Ctrl-C, try to cancel the query in the server, otherwise
# the connection will remain stuck in ACTIVE state.
self._try_cancel(self.pgconn)
return await waiting.wait_conn_async(gen, timeout)
def _set_autocommit(self, value: bool) -> None:
- self._no_set_async("autocommit")
+ if True: # ASYNC
+ self._no_set_async("autocommit")
+ else:
+ self.set_autocommit(value)
async def set_autocommit(self, value: bool) -> None:
"""Method version of the `~Connection.autocommit` setter."""
await self.wait(self._set_autocommit_gen(value))
def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
- self._no_set_async("isolation_level")
+ if True: # ASYNC
+ self._no_set_async("isolation_level")
+ else:
+ self.set_isolation_level(value)
async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
"""Method version of the `~Connection.isolation_level` setter."""
await self.wait(self._set_isolation_level_gen(value))
def _set_read_only(self, value: Optional[bool]) -> None:
- self._no_set_async("read_only")
+ if True: # ASYNC
+ self._no_set_async("read_only")
+ else:
+ self.set_read_only(value)
async def set_read_only(self, value: Optional[bool]) -> None:
"""Method version of the `~Connection.read_only` setter."""
await self.wait(self._set_read_only_gen(value))
def _set_deferrable(self, value: Optional[bool]) -> None:
- self._no_set_async("deferrable")
+ if True: # ASYNC
+ self._no_set_async("deferrable")
+ else:
+ self.set_deferrable(value)
async def set_deferrable(self, value: Optional[bool]) -> None:
"""Method version of the `~Connection.deferrable` setter."""
async with self.lock:
await self.wait(self._set_deferrable_gen(value))
- def _no_set_async(self, attribute: str) -> None:
- raise AttributeError(
- f"'the {attribute!r} property is read-only on async connections:"
- f" please use 'await .set_{attribute}()' instead."
- )
+ if True: # ASYNC
+
+ def _no_set_async(self, attribute: str) -> None:
+ raise AttributeError(
+ f"'the {attribute!r} property is read-only on async connections:"
+ f" please use 'await .set_{attribute}()' instead."
+ )
async def tpc_begin(self, xid: Union[Xid, str]) -> None:
"""
import os
import sys
+from copy import deepcopy
from typing import Any
+from pathlib import Path
from argparse import ArgumentParser, Namespace
import ast_comments as ast
def main() -> int:
opt = parse_cmdline()
- with open(opt.filename) as f:
+ with opt.filepath.open() as f:
source = f.read()
- tree = ast.parse(source, filename=opt.filename)
- tree = async_to_sync(tree)
- output = tree_to_str(tree, opt.filename)
+ tree = ast.parse(source, filename=str(opt.filepath))
+ tree = async_to_sync(tree, filepath=opt.filepath)
+ output = tree_to_str(tree, opt.filepath)
if opt.output:
with open(opt.output, "w") as f:
return 0
-def async_to_sync(tree: ast.AST) -> ast.AST:
+def async_to_sync(tree: ast.AST, filepath: Path | None = None) -> ast.AST:
tree = BlanksInserter().visit(tree)
tree = RenameAsyncToSync().visit(tree)
tree = AsyncToSync().visit(tree)
return tree
-def tree_to_str(tree: ast.AST, filename: str) -> str:
+def tree_to_str(tree: ast.AST, filepath: Path) -> str:
rv = f"""\
# WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}'
-# from the original file '{os.path.basename(filename)}'
+# from the original file '{filepath.name}'
# DO NOT CHANGE! Change the original file instead.
"""
rv += unparse(tree)
return rv
+# Hint: in order to explore the AST of a module you can run:
+# python -m ast path/tp/module.py
+
+
class AsyncToSync(ast.NodeTransformer):
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
new_node = ast.FunctionDef(
self.visit(child)
return node.orelse
+ # Manage `if True: # ASYNC`
+ # drop the unneeded branch
+ if (stmts := self._async_test_statements(node)) is not None:
+ for child in stmts:
+ self.visit(child)
+ return stmts
+
self.generic_visit(node)
return node
return False
return True
+ def _async_test_statements(self, node: ast.If) -> list[ast.AST] | None:
+ if not (
+ isinstance(node.test, ast.Constant) and isinstance(node.test.value, bool)
+ ):
+ return None
+
+ if not (node.body and isinstance(node.body[0], ast.Comment)):
+ return None
+
+ comment = node.body[0].value
+
+ if not comment.startswith("# ASYNC"):
+ return None
+
+ stmts: list[ast.AST]
+ if node.test.value:
+ stmts = node.orelse
+ else:
+ stmts = node.body[1:] # skip the ASYNC comment
+ return stmts
+
class RenameAsyncToSync(ast.NodeTransformer):
names_map = {
"AsyncCopy": "Copy",
"AsyncCopyWriter": "CopyWriter",
"AsyncCursor": "Cursor",
- "AsyncCursor": "Cursor",
"AsyncFileWriter": "FileWriter",
+ "AsyncGenerator": "Generator",
"AsyncIterator": "Iterator",
"AsyncLibpqWriter": "LibpqWriter",
+ "AsyncPipeline": "Pipeline",
"AsyncQueuedLibpqWriter": "QueuedLibpqWriter",
"AsyncRawCursor": "RawCursor",
"AsyncRowFactory": "RowFactory",
"AsyncServerCursor": "ServerCursor",
+ "AsyncTransaction": "Transaction",
"AsyncWriter": "Writer",
"__aenter__": "__enter__",
"__aexit__": "__exit__",
"apipeline": "pipeline",
"asynccontextmanager": "contextmanager",
"connection_async": "connection",
+ "cursor_async": "cursor",
"ensure_table_async": "ensure_table",
"find_insert_problem_async": "find_insert_problem",
+ "wait_async": "wait",
+ "wait_conn_async": "wait_conn",
+ }
+ _skip_imports = {
+ "utils": {"alist", "anext"},
}
def visit_Module(self, node: ast.Module) -> ast.AST:
- # Replace the content of the module docstring.
- if (
- node.body
- and isinstance(node.body[0], ast.Expr)
- and isinstance(node.body[0].value, ast.Constant)
- ):
- node.body[0].value.value = node.body[0].value.value.replace("Async", "")
-
+ self._fix_docstring(node.body)
self.generic_visit(node)
return node
- def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
+ self._fix_docstring(node.body)
node.name = self.names_map.get(node.name, node.name)
for arg in node.args.args:
arg.arg = self.names_map.get(arg.arg, arg.arg)
self.generic_visit(node)
return node
- _skip_imports = {"alist", "anext"}
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+ self._fix_docstring(node.body)
+ self.generic_visit(node)
+ return node
+
+ def _fix_docstring(self, body: list[ast.AST]) -> None:
+ if (
+ body
+ and isinstance(body[0], ast.Expr)
+ and isinstance(body[0].value, ast.Constant)
+ and isinstance(body[0].value.value, str)
+ ):
+ body[0].value.value = body[0].value.value.replace("Async", "")
def visit_Call(self, node: ast.Call) -> ast.AST:
if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
def _visit_type_string(self, source: str) -> str:
# Convert the string to tree, visit, and convert it back to string
- tree = ast.parse(source)
+ tree = ast.parse(source, type_comments=False)
tree = async_to_sync(tree)
rv = unparse(tree)
return rv
def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST:
+ self._fix_docstring(node.body)
node.name = self.names_map.get(node.name, node.name)
node = self._fix_base_params(node)
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
# Remove import of async utils eclypsing builtins
- if node.module == "utils":
- node.names = [n for n in node.names if n.name not in self._skip_imports]
+ if skips := self._skip_imports.get(node.module):
+ node.names = [n for n in node.names if n.name not in skips]
if not node.names:
return None
self.generic_visit(node)
return node
+ def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
+ # Manage AsyncGenerator[X, Y] -> Generator[X, None, Y]
+ self._manage_async_generator(node)
+ # # Won't result in a recursion because we change the args number
+ # self.visit(node)
+ # return node
+
+ self.generic_visit(node)
+ return node
+
+ def _manage_async_generator(self, node: ast.Subscript) -> ast.AST | None:
+ if not (isinstance(node.value, ast.Name) and node.value.id == "AsyncGenerator"):
+ return None
+
+ if not (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 2):
+ return None
+
+ node.slice.elts.insert(1, deepcopy(node.slice.elts[1]))
+ self.generic_visit(node)
+ return node
+
class BlanksInserter(ast.NodeTransformer):
"""
def unparse(tree: ast.AST) -> str:
rv: str = Unparser().visit(tree)
+ rv = _fix_comment_on_decorators(rv)
return rv
+def _fix_comment_on_decorators(source: str) -> str:
+ """
+ Re-associate comments to decorators.
+
+ In a case like:
+
+ 1 @deco # comment
+ 2 def func(x):
+ 3 pass
+
+ it seems that Function lineno is 2 instead of 1 (Python 3.10). Because
+ the Comment lineno is 1, it ends up printed above the function, instead
+ of inline. This is a problem for '# type: ignore' comments.
+
+ Maybe the problem could be fixed in the tree, but this solution is a
+ simpler way to start.
+ """
+ lines = source.splitlines()
+
+ comment_at = None
+ for i, line in enumerate(lines):
+ if line.lstrip().startswith("#"):
+ comment_at = i
+ elif not line.strip():
+ pass
+ elif line.lstrip().startswith("@classmethod"):
+ if comment_at is not None:
+ lines[i] = lines[i] + " " + lines[comment_at].lstrip()
+ lines[comment_at] = ""
+ else:
+ comment_at = None
+
+ return "\n".join(lines)
+
+
class Unparser(ast._Unparser):
"""
Try to emit long strings as multiline.
def parse_cmdline() -> Namespace:
parser = ArgumentParser(description=__doc__)
- parser.add_argument("filename", metavar="FILE", help="the file to process")
+ parser.add_argument(
+ "filepath", metavar="FILE", type=Path, help="the file to process"
+ )
parser.add_argument(
"output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)"
)
outputs=""
for async in \
+ psycopg/psycopg/connection_async.py \
psycopg/psycopg/cursor_async.py \
tests/test_client_cursor_async.py \
tests/test_connection_async.py \