+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_cursor_async.py'
+# DO NOT CHANGE! Change the original file instead.
"""
Tests common to psycopg.Cursor and its subclasses.
"""
@pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor])
-def conn(conn, request):
+def conn(conn, request, anyio_backend):
conn.cursor_factory = request.param
return conn
cache.cache_clear()
ci = cache.cache_info()
- h0, m0 = ci.hits, ci.misses
+ (h0, m0) = (ci.hits, ci.misses)
tests = [
(f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
(f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1),
cur = conn.cursor()
sql = ph(cur, "insert into bug_112 (num) values (%s)")
cur.execute(sql, (1,))
- cur.execute(sql, (100_000,))
+ cur.execute(sql, (100000,))
cur.execute("select num from bug_112 order by num")
- assert cur.fetchall() == [(1,), (100_000,)]
+ assert cur.fetchall() == [(1,), (100000,)]
def test_executemany_type_change(conn):
conn.execute("create table bug_112 (num integer)")
cur = conn.cursor()
sql = ph(cur, "insert into bug_112 (num) values (%s)")
- cur.executemany(sql, [(1,), (100_000,)])
+ cur.executemany(sql, [(1,), (100000,)])
cur.execute("select num from bug_112 order by num")
- assert cur.fetchall() == [(1,), (100_000,)]
+ assert cur.fetchall() == [(1,), (100000,)]
@pytest.mark.parametrize(
def test_query_encode(conn, encoding):
conn.execute(f"set client_encoding to {encoding}")
cur = conn.cursor()
- (res,) = cur.execute("select '\u20ac'").fetchone()
- assert res == "\u20ac"
+ cur.execute("select '€'")
+ (res,) = cur.fetchone()
+ assert res == "€"
@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
conn.execute(f"set client_encoding to {encoding}")
cur = conn.cursor()
with pytest.raises(UnicodeEncodeError):
- cur.execute("select '\u20ac'")
+ cur.execute("select '€'")
def test_executemany(conn, execmany):
[(10, "hello"), (20, "world")],
)
cur.execute("select num, data from execmany order by 1")
- assert cur.fetchall() == [(10, "hello"), (20, "world")]
+ rv = cur.fetchall()
+ assert rv == [(10, "hello"), (20, "world")]
def test_executemany_name(conn, execmany):
[{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
)
cur.execute("select num, data from execmany order by 1")
- assert cur.fetchall() == [(11, "hello"), (21, "world")]
+ rv = cur.fetchall()
+ assert rv == [(11, "hello"), (21, "world")]
def test_executemany_no_data(conn, execmany):
)
with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
cur.executemany(
- f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})",
+ ph(cur, f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})"),
[[1, ""], [3, 4]],
)
def test_rownumber_mixed(conn):
cur = conn.cursor()
- cur.execute(
- """
-select x from generate_series(1, 3) x;
-set timezone to utc;
-select x from generate_series(4, 6) x;
-"""
- )
+ queries = [
+ "select x from generate_series(1, 3) x",
+ "set timezone to utc",
+ "select x from generate_series(4, 6) x",
+ ]
+ cur.execute(";\n".join(queries))
assert cur.rownumber == 0
assert cur.fetchone() == (1,)
assert cur.rownumber == 1
def test_row_factory_none(conn):
cur = conn.cursor(row_factory=None)
assert cur.row_factory is rows.tuple_row
- r = cur.execute("select 1 as a, 2 as b").fetchone()
+ cur.execute("select 1 as a, 2 as b")
+ r = cur.fetchone()
assert type(r) is tuple
assert r == (1, 2)
@pytest.mark.parametrize(
- "query",
- [
- "create table test_stream_badq ()",
- "copy (select 1) to stdout",
- "wat?",
- ],
+ "query", ["create table test_stream_badq ()", "copy (select 1) to stdout", "wat?"]
)
def test_stream_badquery(conn, query):
cur = conn.cursor()
gen = cur.stream("select 1")
for rec in gen:
1 / 0
+
gen.close()
assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
--- /dev/null
+#!/usr/bin/env python
+"""Convert an async module to a sync module.
+"""
+
+from __future__ import annotations
+
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+
+import ast_comments as ast
+
+
+def main() -> int:
+ opt = parse_cmdline()
+ with open(opt.filename) as f:
+ source = f.read()
+
+ tree = ast.parse(source, filename=opt.filename)
+ tree = async_to_sync(tree)
+ output = tree_to_str(tree, opt.filename)
+
+ if opt.output:
+ with open(opt.output, "w") as f:
+ print(output, file=f)
+ else:
+ print(output)
+
+ return 0
+
+
+def async_to_sync(tree: ast.AST) -> ast.AST:
+ tree = BlanksInserter().visit(tree)
+ tree = AsyncToSync().visit(tree)
+ tree = RenameAsyncToSync().visit(tree)
+ tree = FixSetAutocommit().visit(tree)
+ return tree
+
+
+def tree_to_str(tree: ast.AST, filename: str) -> str:
+ rv = f"""\
+# WARNING: this file is auto-generated by '{os.path.basename(sys.argv[0])}'
+# from the original file '{os.path.basename(filename)}'
+# DO NOT CHANGE! Change the original file instead.
+"""
+ rv += ast.unparse(tree)
+ return rv
+
+
+class AsyncToSync(ast.NodeTransformer):
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
+ new_node = ast.FunctionDef(
+ name=node.name,
+ args=node.args,
+ body=node.body,
+ decorator_list=node.decorator_list,
+ returns=node.returns,
+ )
+ ast.copy_location(new_node, node)
+ self.visit(new_node)
+ return new_node
+
+ def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST:
+ new_node = ast.For(
+ target=node.target, iter=node.iter, body=node.body, orelse=node.orelse
+ )
+ ast.copy_location(new_node, node)
+ self.visit(new_node)
+ return new_node
+
+ def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST:
+ new_node = ast.With(items=node.items, body=node.body)
+ ast.copy_location(new_node, node)
+ self.visit(new_node)
+ return new_node
+
+ def visit_Await(self, node: ast.Await) -> ast.AST:
+ new_node = node.value
+ self.visit(new_node)
+ return new_node
+
+
+class RenameAsyncToSync(ast.NodeTransformer):
+ names_map = {
+ "AsyncClientCursor": "ClientCursor",
+ "AsyncCursor": "Cursor",
+ "AsyncRawCursor": "RawCursor",
+ "aclose": "close",
+ "aclosing": "closing",
+ "aconn": "conn",
+ "alist": "list",
+ "anext": "next",
+ }
+
+ def visit_Module(self, node: ast.Module) -> ast.AST:
+ # Replace the content of the module docstring.
+ if (
+ node.body
+ and isinstance(node.body[0], ast.Expr)
+ and isinstance(node.body[0].value, ast.Constant)
+ ):
+ node.body[0].value.value = node.body[0].value.value.replace("Async", "")
+
+ self.generic_visit(node)
+ return node
+
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+ node.name = self.names_map.get(node.name, node.name)
+ for arg in node.args.args:
+ arg.arg = self.names_map.get(arg.arg, arg.arg)
+ self.generic_visit(node)
+ return node
+
+ _skip_imports = {"alist", "anext"}
+
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
+ # Remove import of async utils eclypsing builtings
+ if node.module == "utils":
+ if {n.name for n in node.names} <= self._skip_imports:
+ return None
+
+ for n in node.names:
+ n.name = self.names_map.get(n.name, n.name)
+ return node
+
+ def visit_Name(self, node: ast.Name) -> ast.AST:
+ if node.id in self.names_map:
+ node.id = self.names_map[node.id]
+ return node
+
+ def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
+ if node.attr in self.names_map:
+ node.attr = self.names_map[node.attr]
+ self.generic_visit(node)
+ return node
+
+
+class FixSetAutocommit(ast.NodeTransformer):
+ def visit_Call(self, node: ast.Call) -> ast.AST:
+ new_node = self._fix_autocommit(node)
+ if new_node:
+ return new_node
+
+ self.generic_visit(node)
+ return node
+
+ def _fix_autocommit(self, node: ast.Call) -> ast.AST | None:
+ if not isinstance(node.func, ast.Attribute):
+ return None
+ if node.func.attr != "set_autocommit":
+ return None
+ obj = node.func.value
+ arg = node.args[0]
+ new_node = ast.Assign(
+ targets=[ast.Attribute(value=obj, attr="autocommit")],
+ value=arg,
+ )
+ ast.copy_location(new_node, node)
+ return new_node
+
+
+class BlanksInserter(ast.NodeTransformer):
+ """
+ Restore the missing spaces in the source (or something similar)
+ """
+
+ def generic_visit(self, node: ast.AST) -> ast.AST:
+ if isinstance(getattr(node, "body", None), list):
+ node.body = self._inject_blanks(node.body)
+ super().generic_visit(node)
+ return node
+
+ def _inject_blanks(self, body: list[ast.Node]) -> list[ast.AST]:
+ if not body:
+ return body
+
+ new_body = []
+ before = body[0]
+ new_body.append(before)
+ for i in range(1, len(body)):
+ after = body[i]
+ nblanks = after.lineno - before.end_lineno - 1
+ if nblanks > 0:
+ # Inserting one blank is enough.
+ blank = ast.Comment(
+ value="",
+ inline=False,
+ lineno=before.end_lineno + 1,
+ end_lineno=before.end_lineno + 1,
+ col_offset=0,
+ end_col_offset=0,
+ )
+ new_body.append(blank)
+ new_body.append(after)
+ before = after
+
+ return new_body
+
+
+def parse_cmdline() -> Namespace:
+ parser = ArgumentParser(description=__doc__)
+ parser.add_argument("filename", metavar="FILE", help="the file to process")
+ parser.add_argument(
+ "output", metavar="OUTPUT", nargs="?", help="file where to write (or stdout)"
+ )
+ opt = parser.parse_args()
+
+ return opt
+
+
+if __name__ == "__main__":
+ sys.exit(main())
--- /dev/null
+#!/bin/bash
+
+# Convert all the auto-generated sync files from their async counterparts.
+
+set -euo pipefail
+
+dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "${dir}/.."
+
+python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py
+black -q tests/test_cursor.py