+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_pipeline_async.py'
+# DO NOT CHANGE! Change the original file instead.
import logging
-import concurrent.futures
from typing import Any
from operator import attrgetter
from itertools import groupby
from psycopg import pq
from psycopg import errors as e
+from .utils import is_async
+
pytestmark = [
pytest.mark.pipeline,
pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
]
-
pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
def test_repr(conn):
with conn.pipeline() as p:
- assert "psycopg.Pipeline" in repr(p)
+ name = "psycopg.AsyncPipeline" if is_async(conn) else "psycopg.Pipeline"
+ assert name in repr(p)
assert "[IDLE, pipeline=ON]" in repr(p)
conn.close()
def test_cursor_stream(conn):
with conn.pipeline(), conn.cursor() as cur:
with pytest.raises(psycopg.ProgrammingError):
- cur.stream("select 1").__next__()
+ next(cur.stream("select 1"))
def test_server_cursor(conn):
with conn.pipeline():
cur = conn.cursor()
with pytest.raises(e.NotSupportedError):
- with cur.copy("copy (select 1) to stdout"):
- pass
+ with cur.copy("copy (select 1) to stdout") as copy:
+ copy.read()
def test_pipeline_processed_at_exit(conn):
conn.autocommit = True
conn.execute("drop table if exists execmanypipeline")
conn.execute(
- "create unlogged table execmanypipeline ("
- " id serial primary key, num integer)"
+ "create unlogged table execmanypipeline (id serial primary key, num integer)"
)
with conn.pipeline(), conn.cursor() as cur:
cur.executemany(
conn.autocommit = True
conn.execute("drop table if exists execmanypipelinenoreturning")
conn.execute(
- "create unlogged table execmanypipelinenoreturning ("
- " id serial primary key, num integer)"
+ """create unlogged table execmanypipelinenoreturning
+ (id serial primary key, num integer)"""
)
with conn.pipeline(), conn.cursor() as cur:
cur.executemany(
items = list(t)
assert items[-1].type == "Terminate"
del items[-1]
- roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))]
assert roundtrips == ["F", "B"]
assert len([i for i in items if i.type == "Sync"]) == 1
items = list(t)
assert items[-1].type == "Terminate"
del items[-1]
- roundtrips = [k for k, g in groupby(items, key=attrgetter("direction"))]
+ roundtrips = [k for (k, g) in groupby(items, key=attrgetter("direction"))]
assert roundtrips == ["F", "B"] * 3
assert items[-2].direction == "F" # last 2 items are F B
assert len([i for i in items if i.type == "Sync"]) == 1
def test_auto_prepare(conn):
- conn.autocommit = True
conn.prepared_threshold = 5
with conn.pipeline():
cursors = [
An invalid prepared statement, in a pipeline, should be discarded at exit
and not reused.
"""
+
conn.autocommit = True
stmt = "INSERT INTO nosuchtable(data) VALUES (%s)"
with pytest.raises(psycopg.errors.UndefinedTable):
conn.execute("drop table if exists accessed")
with conn.transaction():
conn.execute(
- "create unlogged table pipeline_concurrency ("
- " id serial primary key,"
- " value integer"
- ")"
+ """create unlogged table pipeline_concurrency (
+ id serial primary key,
+ value integer)"""
)
conn.execute("create unlogged table accessed as (select now() as value)")
values = range(1, 10)
with conn.pipeline():
- with concurrent.futures.ThreadPoolExecutor() as e:
+ from concurrent.futures import ThreadPoolExecutor
+
+ with ThreadPoolExecutor() as e:
cursors = e.map(update, values, timeout=len(values))
- assert sum(cur.fetchone()[0] for cur in cursors) == sum(values)
- (s,) = conn.execute("select sum(value) from pipeline_concurrency").fetchone()
+ assert sum([cur.fetchone()[0] for cur in cursors]) == sum(values)
+
+ cur = conn.execute("select sum(value) from pipeline_concurrency")
+ (s,) = cur.fetchone()
assert s == sum(values)
(after,) = conn.execute("select value from accessed").fetchone()
assert after > before
-import asyncio
import logging
from typing import Any
from operator import attrgetter
from psycopg import pq
from psycopg import errors as e
-from .test_pipeline import pipeline_aborted
+from .utils import is_async, anext
pytestmark = [
pytest.mark.pipeline,
- pytest.mark.skipif("not psycopg.AsyncPipeline.is_supported()"),
+ pytest.mark.skipif("not psycopg.Pipeline.is_supported()"),
]
+pipeline_aborted = pytest.mark.flakey("the server might get in pipeline aborted")
async def test_repr(aconn):
async with aconn.pipeline() as p:
- assert "psycopg.AsyncPipeline" in repr(p)
+ name = "psycopg.AsyncPipeline" if is_async(aconn) else "psycopg.Pipeline"
+ assert name in repr(p)
assert "[IDLE, pipeline=ON]" in repr(p)
await aconn.close()
async def test_cursor_stream(aconn):
async with aconn.pipeline(), aconn.cursor() as cur:
with pytest.raises(psycopg.ProgrammingError):
- await cur.stream("select 1").__anext__()
+ await anext(cur.stream("select 1"))
async def test_server_cursor(aconn):
await aconn.set_autocommit(True)
await aconn.execute("drop table if exists execmanypipeline")
await aconn.execute(
- "create unlogged table execmanypipeline ("
- " id serial primary key, num integer)"
+ "create unlogged table execmanypipeline (id serial primary key, num integer)"
)
async with aconn.pipeline(), aconn.cursor() as cur:
await cur.executemany(
await aconn.set_autocommit(True)
await aconn.execute("drop table if exists execmanypipelinenoreturning")
await aconn.execute(
- "create unlogged table execmanypipelinenoreturning ("
- " id serial primary key, num integer)"
+ """create unlogged table execmanypipelinenoreturning
+ (id serial primary key, num integer)"""
)
async with aconn.pipeline(), aconn.cursor() as cur:
await cur.executemany(
await aconn.execute("drop table if exists accessed")
async with aconn.transaction():
await aconn.execute(
- "create unlogged table pipeline_concurrency ("
- " id serial primary key,"
- " value integer"
- ")"
+ """create unlogged table pipeline_concurrency (
+ id serial primary key,
+ value integer)"""
)
await aconn.execute("create unlogged table accessed as (select now() as value)")
values = range(1, 10)
async with aconn.pipeline():
- cursors = await asyncio.wait_for(
- asyncio.gather(*[update(value) for value in values]),
- timeout=len(values),
- )
+ if is_async(aconn):
+ import asyncio
+
+ cursors = await asyncio.wait_for(
+ asyncio.gather(*[update(value) for value in values]),
+ timeout=len(values),
+ )
+ else:
+ from concurrent.futures import ThreadPoolExecutor
+
+ with ThreadPoolExecutor() as e:
+ cursors = e.map(update, values, timeout=len(values))
- assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
+ assert sum([(await cur.fetchone())[0] for cur in cursors]) == sum(values)
- (s,) = await (
- await aconn.execute("select sum(value) from pipeline_concurrency")
- ).fetchone()
+ cur = await aconn.execute("select sum(value) from pipeline_concurrency")
+ (s,) = await cur.fetchone()
assert s == sum(values)
(after,) = await (await aconn.execute("select value from accessed")).fetchone()
assert after > before
def async_to_sync(tree: ast.AST) -> ast.AST:
tree = BlanksInserter().visit(tree)
- tree = AsyncToSync().visit(tree)
tree = RenameAsyncToSync().visit(tree)
+ tree = AsyncToSync().visit(tree)
tree = FixAsyncSetters().visit(tree)
return tree
self.visit(new_node)
return new_node
+ def visit_If(self, node: ast.If) -> ast.AST:
+ # Drop `if is_async()` branch.
+ #
+ # Assume that the test guards an async object becoming sync and remove
+ # the async side, because it will likely contain `await` constructs
+ # illegal into a sync function.
+ if self._is_async_call(node.test):
+ for child in node.orelse:
+ self.visit(child)
+ return node.orelse
+
+ self.generic_visit(node)
+ return node
+
+ def _is_async_call(self, test: ast.AST) -> bool:
+ if not isinstance(test, ast.Call):
+ return False
+ if test.func.id != "is_async":
+ return False
+ return True
+
class RenameAsyncToSync(ast.NodeTransformer):
names_map = {
"AsyncClientCursor": "ClientCursor",
+ "AsyncConnection": "Connection",
"AsyncCursor": "Cursor",
"AsyncRawCursor": "RawCursor",
"AsyncServerCursor": "ServerCursor",
self.generic_visit(node)
return node
- def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
+ def visit_AsyncFunctionDef(self, node: ast.FunctionDef) -> ast.AST:
node.name = self.names_map.get(node.name, node.name)
for arg in node.args.args:
arg.arg = self.names_map.get(arg.arg, arg.arg)
+ for arg in node.args.args:
+ ann = arg.annotation
+ if not ann:
+ continue
+ if isinstance(ann, ast.Subscript):
+ # Remove the [] from the type
+ ann = ann.value
+ if isinstance(ann, ast.Attribute):
+ ann.attr = self.names_map.get(ann.attr, ann.attr)
+
self.generic_visit(node)
return node
dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "${dir}/.."
-python "${dir}/async_to_sync.py" tests/test_connection_async.py > tests/test_connection.py
-black -q tests/test_connection.py
-python "${dir}/async_to_sync.py" tests/test_cursor_async.py > tests/test_cursor.py
-black -q tests/test_cursor.py
+for async in \
+ tests/test_connection_async.py \
+ tests/test_cursor_async.py \
+ tests/test_pipeline_async.py
+do
+ sync=${async/_async/}
+ echo "converting '${async}' -> '${sync}'" >&2
+ python "${dir}/async_to_sync.py" ${async} > ${sync}
+ black -q ${sync}
+done