From dc612d836a387e4c1a8c8733b47db396b51d8c2d Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 8 Nov 2020 23:17:26 +0000 Subject: [PATCH] AsyncConnection.cursor() made async --- psycopg3/psycopg3/connection.py | 8 ++--- psycopg3/psycopg3/types/composite.py | 2 +- tests/test_concurrency_async.py | 8 ++--- tests/test_connection_async.py | 25 ++++++++-------- tests/test_copy_async.py | 30 +++++++++---------- tests/test_cursor_async.py | 44 ++++++++++++++-------------- 6 files changed, 59 insertions(+), 58 deletions(-) diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index d63cb393b..0cde94d1b 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -119,7 +119,7 @@ class BaseConnection: ) self._autocommit = value - def cursor( + def _cursor( self, name: str = "", format: pq.Format = pq.Format.TEXT ) -> cursor.BaseCursor: if name: @@ -240,7 +240,7 @@ class Connection(BaseConnection): def cursor( self, name: str = "", format: pq.Format = pq.Format.TEXT ) -> cursor.Cursor: - cur = super().cursor(name, format=format) + cur = self._cursor(name, format=format) return cast(cursor.Cursor, cur) def _start_query(self) -> None: @@ -352,10 +352,10 @@ class AsyncConnection(BaseConnection): async def close(self) -> None: self.pgconn.finish() - def cursor( + async def cursor( self, name: str = "", format: pq.Format = pq.Format.TEXT ) -> cursor.AsyncCursor: - cur = super().cursor(name, format=format) + cur = self._cursor(name, format=format) return cast(cursor.AsyncCursor, cur) async def _start_query(self) -> None: diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index c568438dc..f27505db9 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -54,7 +54,7 @@ def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]: async def fetch_info_async( conn: "AsyncConnection", name: str ) -> Optional[CompositeTypeInfo]: - cur = conn.cursor(format=pq.Format.BINARY) + cur = await conn.cursor(format=pq.Format.BINARY) await cur.execute(_type_info_query, {"name": name}) rec = await cur.fetchone() return CompositeTypeInfo._from_record(rec) diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 5477e858e..7502574f6 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -27,7 +27,7 @@ async def test_commit_concurrency(aconn): async def runner(): nonlocal stop - cur = aconn.cursor() + cur = await aconn.cursor() for i in range(1000): await cur.execute("select %s;", (i,)) await aconn.commit() @@ -43,7 +43,7 @@ async def test_commit_concurrency(aconn): async def test_concurrent_execution(dsn): async def worker(): cnn = await psycopg3.AsyncConnection.connect(dsn) - cur = cnn.cursor() + cur = await cnn.cursor() await cur.execute("select pg_sleep(0.5)") await cur.close() await cnn.close() @@ -106,7 +106,7 @@ async def test_cancel(aconn): errors.append(exc) async def worker(): - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(psycopg3.DatabaseError): await cur.execute("select pg_sleep(2)") @@ -121,6 +121,6 @@ async def test_cancel(aconn): # still working await aconn.rollback() - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select 1") assert await cur.fetchone() == (1,) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 64109e5f9..14735c78e 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -35,7 +35,7 @@ async def test_close(aconn): assert aconn.closed assert aconn.status == aconn.ConnStatus.BAD - cur = aconn.cursor() + cur = await aconn.cursor() await aconn.close() assert aconn.closed @@ -90,7 +90,7 @@ async def test_auto_transaction(aconn): aconn.pgconn.exec_(b"drop table if exists foo") aconn.pgconn.exec_(b"create table foo (id int primary key)") - cur = aconn.cursor() + cur = await aconn.cursor() assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE await cur.execute("insert into foo values (1)") @@ -107,7 +107,7 @@ async def test_auto_transaction_fail(aconn): aconn.pgconn.exec_(b"drop table if exists foo") aconn.pgconn.exec_(b"create table foo (id int primary key)") - cur = aconn.cursor() + cur = await aconn.cursor() assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE await cur.execute("insert into foo values (1)") @@ -132,7 +132,7 @@ async def test_autocommit(aconn): await aconn.set_autocommit(True) assert aconn.autocommit - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select 1") assert await cur.fetchone() == (1,) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE @@ -144,7 +144,7 @@ async def test_autocommit_connect(dsn): async def test_autocommit_intrans(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select 1") assert await cur.fetchone() == (1,) assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS @@ -154,7 +154,7 @@ async def test_autocommit_intrans(aconn): async def test_autocommit_inerror(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(psycopg3.DatabaseError): await cur.execute("meh") assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR @@ -172,7 +172,7 @@ async def test_autocommit_unknown(aconn): async def test_get_encoding(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("show client_encoding") (enc,) = await cur.fetchone() assert enc == aconn.client_encoding @@ -186,7 +186,7 @@ async def test_set_encoding(aconn): assert aconn.client_encoding != newenc await aconn.set_client_encoding(newenc) assert aconn.client_encoding == newenc - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("show client_encoding") (enc,) = await cur.fetchone() assert enc == newenc @@ -227,8 +227,9 @@ async def test_encoding_env_var(dsn, monkeypatch, enc, out, codec): async def test_set_encoding_unsupported(aconn): await aconn.set_client_encoding("EUC_TW") + cur = await aconn.cursor() with pytest.raises(psycopg3.NotSupportedError): - await aconn.cursor().execute("select 1") + await cur.execute("select 1") async def test_set_encoding_bad(aconn): @@ -278,7 +279,7 @@ async def test_connect_badargs(monkeypatch, pgconn, args, kwargs): async def test_broken_connection(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(psycopg3.DatabaseError): await cur.execute("select pg_terminate_backend(pg_backend_pid())") assert aconn.closed @@ -301,7 +302,7 @@ async def test_notice_handlers(aconn, caplog): aconn.add_notice_handler(lambda diag: severities.append(diag.severity)) aconn.pgconn.exec_(b"set client_min_messages to notice") - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute( "do $$begin raise notice 'hello notice'; end$$ language plpgsql" ) @@ -340,7 +341,7 @@ async def test_notify_handlers(aconn): aconn.add_notify_handler(lambda n: nots2.append(n)) await aconn.set_autocommit(True) - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("listen foo") await cur.execute("notify foo, 'n1'") diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index a48d82bb5..19987973e 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -12,7 +12,7 @@ pytestmark = pytest.mark.asyncio @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) async def test_copy_out_read(aconn, format): - cur = aconn.cursor() + cur = await aconn.cursor() copy = await cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) @@ -32,7 +32,7 @@ async def test_copy_out_read(aconn, format): @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) async def test_copy_out_iter(aconn, format): - cur = aconn.cursor() + cur = await aconn.cursor() copy = await cur.copy( f"copy ({sample_values}) to stdout (format {format.name})" ) @@ -51,7 +51,7 @@ async def test_copy_out_iter(aconn, format): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_copy_in_buffers(aconn, format, buffer): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) copy = await cur.copy(f"copy copy_in from stdin (format {format.name})") await copy.write(globals()[buffer]) @@ -62,7 +62,7 @@ async def test_copy_in_buffers(aconn, format, buffer): async def test_copy_in_buffers_pg_error(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) copy = await cur.copy("copy copy_in from stdin (format text)") await copy.write(sample_text) @@ -72,10 +72,10 @@ async def test_copy_in_buffers_pg_error(aconn): assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR -async def test_copy_bad_result(conn): - conn.autocommit = True +async def test_copy_bad_result(aconn): + await aconn.set_autocommit(True) - cur = conn.cursor() + cur = await aconn.cursor() with pytest.raises(e.SyntaxError): await cur.copy("wat") @@ -92,7 +92,7 @@ async def test_copy_bad_result(conn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_copy_in_buffers_with(aconn, format, buffer): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) async with ( await cur.copy(f"copy copy_in from stdin (format {format.name})") @@ -105,7 +105,7 @@ async def test_copy_in_buffers_with(aconn, format, buffer): async def test_copy_in_str(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) async with ( await cur.copy("copy copy_in from stdin (format text)") @@ -118,7 +118,7 @@ async def test_copy_in_str(aconn): async def test_copy_in_str_binary(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled): async with ( @@ -130,7 +130,7 @@ async def test_copy_in_str_binary(aconn): async def test_copy_in_buffers_with_pg_error(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.UniqueViolation): async with ( @@ -143,7 +143,7 @@ async def test_copy_in_buffers_with_pg_error(aconn): async def test_copy_in_buffers_with_py_error(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) with pytest.raises(e.QueryCanceled) as exc: async with ( @@ -161,7 +161,7 @@ async def test_copy_in_records(aconn, format): if format == Format.BINARY: pytest.skip("TODO: implement int binary adapter") - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) async with ( @@ -180,7 +180,7 @@ async def test_copy_in_records_binary(aconn, format): if format == Format.TEXT: pytest.skip("TODO: remove after implementing int binary adapter") - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, "col1 serial primary key, col2 int, data text") async with ( @@ -197,7 +197,7 @@ async def test_copy_in_records_binary(aconn, format): async def test_copy_in_allchars(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await ensure_table(cur, sample_tabledef) await aconn.set_client_encoding("utf8") diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 36e8b5267..79ca61b7b 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -11,7 +11,7 @@ pytestmark = pytest.mark.asyncio async def test_close(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() assert not cur.closed await cur.close() assert cur.closed @@ -24,14 +24,14 @@ async def test_close(aconn): async def test_context(aconn): - async with aconn.cursor() as cur: + async with (await aconn.cursor()) as cur: assert not cur.closed assert cur.closed async def test_weakref(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() w = weakref.ref(cur) await cur.close() del cur @@ -40,7 +40,7 @@ async def test_weakref(aconn): async def test_status(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() assert cur.status is None await cur.execute("reset all") assert cur.status == cur.ExecStatus.COMMAND_OK @@ -51,7 +51,7 @@ async def test_status(aconn): async def test_execute_many_results(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() assert cur.nextset() is None rv = await cur.execute("select 'foo'; select generate_series(1,3)") @@ -68,7 +68,7 @@ async def test_execute_many_results(aconn): async def test_execute_sequence(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() rv = await cur.execute( "select %s::int, %s::text, %s::text", [1, "foo", None] ) @@ -82,7 +82,7 @@ async def test_execute_sequence(aconn): @pytest.mark.parametrize("query", ["", " ", ";"]) async def test_execute_empty_query(aconn, query): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute(query) assert cur.status == cur.ExecStatus.EMPTY_QUERY with pytest.raises(psycopg3.ProgrammingError): @@ -90,7 +90,7 @@ async def test_execute_empty_query(aconn, query): async def test_fetchone(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None]) assert cur.pgresult.fformat(0) == 0 @@ -103,7 +103,7 @@ async def test_fetchone(aconn): async def test_execute_binary_result(aconn): - cur = aconn.cursor(format=psycopg3.pq.Format.BINARY) + cur = await aconn.cursor(format=psycopg3.pq.Format.BINARY) await cur.execute("select %s::text, %s::text", ["foo", None]) assert cur.pgresult.fformat(0) == 1 @@ -117,7 +117,7 @@ async def test_execute_binary_result(aconn): @pytest.mark.parametrize("encoding", ["utf8", "latin9"]) async def test_query_encode(aconn, encoding): await aconn.set_client_encoding(encoding) - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select '\u20ac'") (res,) = await cur.fetchone() assert res == "\u20ac" @@ -125,7 +125,7 @@ async def test_query_encode(aconn, encoding): async def test_query_badenc(aconn): await aconn.set_client_encoding("latin1") - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(UnicodeEncodeError): await cur.execute("select '\u20ac'") @@ -142,7 +142,7 @@ async def execmany(svcconn): async def test_executemany(aconn, execmany): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.executemany( "insert into execmany(num, data) values (%s, %s)", [(10, "hello"), (20, "world")], @@ -153,7 +153,7 @@ async def test_executemany(aconn, execmany): async def test_executemany_name(aconn, execmany): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.executemany( "insert into execmany(num, data) values (%(num)s, %(data)s)", [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], @@ -164,7 +164,7 @@ async def test_executemany_name(aconn, execmany): async def test_executemany_rowcount(aconn, execmany): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.executemany( "insert into execmany(num, data) values (%s, %s)", [(10, "hello"), (20, "world")], @@ -181,13 +181,13 @@ async def test_executemany_rowcount(aconn, execmany): ], ) async def test_executemany_badquery(aconn, query): - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(psycopg3.DatabaseError): await cur.executemany(query, [(10, "hello"), (20, "world")]) async def test_callproc_args(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute( """ create function testfunc(a int, b text) returns text[] language sql as @@ -199,7 +199,7 @@ async def test_callproc_args(aconn): async def test_callproc_badparam(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(TypeError): await cur.callproc("lower", 42) with pytest.raises(TypeError): @@ -209,7 +209,7 @@ async def test_callproc_badparam(aconn): async def test_callproc_dict(aconn): testfunc = make_testfunc(aconn) - cur = aconn.cursor() + cur = await aconn.cursor() await cur.callproc(testfunc.name, [2]) assert (await cur.fetchone()) == (4,) @@ -234,13 +234,13 @@ async def test_callproc_dict_bad(aconn, args, exc): if "_p" in args: args[testfunc.param] = args.pop("_p") - cur = aconn.cursor() + cur = await aconn.cursor() with pytest.raises(exc): await cur.callproc(testfunc.name, args) async def test_rowcount(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select 1 from generate_series(1, 42)") assert cur.rowcount == 42 @@ -260,7 +260,7 @@ async def test_rowcount(aconn): async def test_iter(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select generate_series(1, 3)") res = [] async for rec in cur: @@ -269,7 +269,7 @@ async def test_iter(aconn): async def test_iter_stop(aconn): - cur = aconn.cursor() + cur = await aconn.cursor() await cur.execute("select generate_series(1, 3)") async for rec in cur: assert rec == (1,) -- 2.47.2