]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
AsyncConnection.cursor() made async
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 8 Nov 2020 23:17:26 +0000 (23:17 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 8 Nov 2020 23:17:26 +0000 (23:17 +0000)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/types/composite.py
tests/test_concurrency_async.py
tests/test_connection_async.py
tests/test_copy_async.py
tests/test_cursor_async.py

index d63cb393bd96e3d50a441c14d5a6925645c5d34b..0cde94d1b9fd112c42e9873a90f079fdf04000e7 100644 (file)
@@ -119,7 +119,7 @@ class BaseConnection:
             )
         self._autocommit = value
 
-    def cursor(
+    def _cursor(
         self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.BaseCursor:
         if name:
@@ -240,7 +240,7 @@ class Connection(BaseConnection):
     def cursor(
         self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.Cursor:
-        cur = super().cursor(name, format=format)
+        cur = self._cursor(name, format=format)
         return cast(cursor.Cursor, cur)
 
     def _start_query(self) -> None:
@@ -352,10 +352,10 @@ class AsyncConnection(BaseConnection):
     async def close(self) -> None:
         self.pgconn.finish()
 
-    def cursor(
+    async def cursor(
         self, name: str = "", format: pq.Format = pq.Format.TEXT
     ) -> cursor.AsyncCursor:
-        cur = super().cursor(name, format=format)
+        cur = self._cursor(name, format=format)
         return cast(cursor.AsyncCursor, cur)
 
     async def _start_query(self) -> None:
index c568438dce0b088f4b6ebe66dad0349d32059c80..f27505db9683097932b1cacdde8014732d436368 100644 (file)
@@ -54,7 +54,7 @@ def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
 async def fetch_info_async(
     conn: "AsyncConnection", name: str
 ) -> Optional[CompositeTypeInfo]:
-    cur = conn.cursor(format=pq.Format.BINARY)
+    cur = await conn.cursor(format=pq.Format.BINARY)
     await cur.execute(_type_info_query, {"name": name})
     rec = await cur.fetchone()
     return CompositeTypeInfo._from_record(rec)
index 5477e858e8720d7f108a4a099b2ea2fe50512d5d..7502574f680b79d273e020d88eb8e7336c8f8dc9 100644 (file)
@@ -27,7 +27,7 @@ async def test_commit_concurrency(aconn):
 
     async def runner():
         nonlocal stop
-        cur = aconn.cursor()
+        cur = await aconn.cursor()
         for i in range(1000):
             await cur.execute("select %s;", (i,))
             await aconn.commit()
@@ -43,7 +43,7 @@ async def test_commit_concurrency(aconn):
 async def test_concurrent_execution(dsn):
     async def worker():
         cnn = await psycopg3.AsyncConnection.connect(dsn)
-        cur = cnn.cursor()
+        cur = await cnn.cursor()
         await cur.execute("select pg_sleep(0.5)")
         await cur.close()
         await cnn.close()
@@ -106,7 +106,7 @@ async def test_cancel(aconn):
             errors.append(exc)
 
     async def worker():
-        cur = aconn.cursor()
+        cur = await aconn.cursor()
         with pytest.raises(psycopg3.DatabaseError):
             await cur.execute("select pg_sleep(2)")
 
@@ -121,6 +121,6 @@ async def test_cancel(aconn):
 
     # still working
     await aconn.rollback()
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select 1")
     assert await cur.fetchone() == (1,)
index 64109e5f9827d7152789227cd9d7d33634edc9a6..14735c78e1968eb175f65636f4bc8c88ba4489d4 100644 (file)
@@ -35,7 +35,7 @@ async def test_close(aconn):
     assert aconn.closed
     assert aconn.status == aconn.ConnStatus.BAD
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
 
     await aconn.close()
     assert aconn.closed
@@ -90,7 +90,7 @@ async def test_auto_transaction(aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
 
     await cur.execute("insert into foo values (1)")
@@ -107,7 +107,7 @@ async def test_auto_transaction_fail(aconn):
     aconn.pgconn.exec_(b"drop table if exists foo")
     aconn.pgconn.exec_(b"create table foo (id int primary key)")
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
 
     await cur.execute("insert into foo values (1)")
@@ -132,7 +132,7 @@ async def test_autocommit(aconn):
 
     await aconn.set_autocommit(True)
     assert aconn.autocommit
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select 1")
     assert await cur.fetchone() == (1,)
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.IDLE
@@ -144,7 +144,7 @@ async def test_autocommit_connect(dsn):
 
 
 async def test_autocommit_intrans(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select 1")
     assert await cur.fetchone() == (1,)
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INTRANS
@@ -154,7 +154,7 @@ async def test_autocommit_intrans(aconn):
 
 
 async def test_autocommit_inerror(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(psycopg3.DatabaseError):
         await cur.execute("meh")
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
@@ -172,7 +172,7 @@ async def test_autocommit_unknown(aconn):
 
 
 async def test_get_encoding(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("show client_encoding")
     (enc,) = await cur.fetchone()
     assert enc == aconn.client_encoding
@@ -186,7 +186,7 @@ async def test_set_encoding(aconn):
     assert aconn.client_encoding != newenc
     await aconn.set_client_encoding(newenc)
     assert aconn.client_encoding == newenc
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("show client_encoding")
     (enc,) = await cur.fetchone()
     assert enc == newenc
@@ -227,8 +227,9 @@ async def test_encoding_env_var(dsn, monkeypatch, enc, out, codec):
 
 async def test_set_encoding_unsupported(aconn):
     await aconn.set_client_encoding("EUC_TW")
+    cur = await aconn.cursor()
     with pytest.raises(psycopg3.NotSupportedError):
-        await aconn.cursor().execute("select 1")
+        await cur.execute("select 1")
 
 
 async def test_set_encoding_bad(aconn):
@@ -278,7 +279,7 @@ async def test_connect_badargs(monkeypatch, pgconn, args, kwargs):
 
 
 async def test_broken_connection(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(psycopg3.DatabaseError):
         await cur.execute("select pg_terminate_backend(pg_backend_pid())")
     assert aconn.closed
@@ -301,7 +302,7 @@ async def test_notice_handlers(aconn, caplog):
     aconn.add_notice_handler(lambda diag: severities.append(diag.severity))
 
     aconn.pgconn.exec_(b"set client_min_messages to notice")
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute(
         "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
     )
@@ -340,7 +341,7 @@ async def test_notify_handlers(aconn):
     aconn.add_notify_handler(lambda n: nots2.append(n))
 
     await aconn.set_autocommit(True)
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("listen foo")
     await cur.execute("notify foo, 'n1'")
 
index a48d82bb5427d8d07f4ac759e422de83a8762c78..19987973efcf859f57095aab54e667d8e0e88f61 100644 (file)
@@ -12,7 +12,7 @@ pytestmark = pytest.mark.asyncio
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 async def test_copy_out_read(aconn, format):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     copy = await cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     )
@@ -32,7 +32,7 @@ async def test_copy_out_read(aconn, format):
 
 @pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
 async def test_copy_out_iter(aconn, format):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     copy = await cur.copy(
         f"copy ({sample_values}) to stdout (format {format.name})"
     )
@@ -51,7 +51,7 @@ async def test_copy_out_iter(aconn, format):
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
 async def test_copy_in_buffers(aconn, format, buffer):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     copy = await cur.copy(f"copy copy_in from stdin (format {format.name})")
     await copy.write(globals()[buffer])
@@ -62,7 +62,7 @@ async def test_copy_in_buffers(aconn, format, buffer):
 
 
 async def test_copy_in_buffers_pg_error(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     copy = await cur.copy("copy copy_in from stdin (format text)")
     await copy.write(sample_text)
@@ -72,10 +72,10 @@ async def test_copy_in_buffers_pg_error(aconn):
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.INERROR
 
 
-async def test_copy_bad_result(conn):
-    conn.autocommit = True
+async def test_copy_bad_result(aconn):
+    await aconn.set_autocommit(True)
 
-    cur = conn.cursor()
+    cur = await aconn.cursor()
 
     with pytest.raises(e.SyntaxError):
         await cur.copy("wat")
@@ -92,7 +92,7 @@ async def test_copy_bad_result(conn):
     [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
 )
 async def test_copy_in_buffers_with(aconn, format, buffer):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     async with (
         await cur.copy(f"copy copy_in from stdin (format {format.name})")
@@ -105,7 +105,7 @@ async def test_copy_in_buffers_with(aconn, format, buffer):
 
 
 async def test_copy_in_str(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     async with (
         await cur.copy("copy copy_in from stdin (format text)")
@@ -118,7 +118,7 @@ async def test_copy_in_str(aconn):
 
 
 async def test_copy_in_str_binary(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled):
         async with (
@@ -130,7 +130,7 @@ async def test_copy_in_str_binary(aconn):
 
 
 async def test_copy_in_buffers_with_pg_error(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.UniqueViolation):
         async with (
@@ -143,7 +143,7 @@ async def test_copy_in_buffers_with_pg_error(aconn):
 
 
 async def test_copy_in_buffers_with_py_error(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
     with pytest.raises(e.QueryCanceled) as exc:
         async with (
@@ -161,7 +161,7 @@ async def test_copy_in_records(aconn, format):
     if format == Format.BINARY:
         pytest.skip("TODO: implement int binary adapter")
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
 
     async with (
@@ -180,7 +180,7 @@ async def test_copy_in_records_binary(aconn, format):
     if format == Format.TEXT:
         pytest.skip("TODO: remove after implementing int binary adapter")
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, "col1 serial primary key, col2 int, data text")
 
     async with (
@@ -197,7 +197,7 @@ async def test_copy_in_records_binary(aconn, format):
 
 
 async def test_copy_in_allchars(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await ensure_table(cur, sample_tabledef)
 
     await aconn.set_client_encoding("utf8")
index 36e8b52678ecc0dd90541934853da1cd218cf027..79ca61b7be27db2ba01f8a57371b0a961bdf5b84 100644 (file)
@@ -11,7 +11,7 @@ pytestmark = pytest.mark.asyncio
 
 
 async def test_close(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     assert not cur.closed
     await cur.close()
     assert cur.closed
@@ -24,14 +24,14 @@ async def test_close(aconn):
 
 
 async def test_context(aconn):
-    async with aconn.cursor() as cur:
+    async with (await aconn.cursor()) as cur:
         assert not cur.closed
 
     assert cur.closed
 
 
 async def test_weakref(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     w = weakref.ref(cur)
     await cur.close()
     del cur
@@ -40,7 +40,7 @@ async def test_weakref(aconn):
 
 
 async def test_status(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     assert cur.status is None
     await cur.execute("reset all")
     assert cur.status == cur.ExecStatus.COMMAND_OK
@@ -51,7 +51,7 @@ async def test_status(aconn):
 
 
 async def test_execute_many_results(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     assert cur.nextset() is None
 
     rv = await cur.execute("select 'foo'; select generate_series(1,3)")
@@ -68,7 +68,7 @@ async def test_execute_many_results(aconn):
 
 
 async def test_execute_sequence(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     rv = await cur.execute(
         "select %s::int, %s::text, %s::text", [1, "foo", None]
     )
@@ -82,7 +82,7 @@ async def test_execute_sequence(aconn):
 
 @pytest.mark.parametrize("query", ["", " ", ";"])
 async def test_execute_empty_query(aconn, query):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute(query)
     assert cur.status == cur.ExecStatus.EMPTY_QUERY
     with pytest.raises(psycopg3.ProgrammingError):
@@ -90,7 +90,7 @@ async def test_execute_empty_query(aconn, query):
 
 
 async def test_fetchone(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
     assert cur.pgresult.fformat(0) == 0
 
@@ -103,7 +103,7 @@ async def test_fetchone(aconn):
 
 
 async def test_execute_binary_result(aconn):
-    cur = aconn.cursor(format=psycopg3.pq.Format.BINARY)
+    cur = await aconn.cursor(format=psycopg3.pq.Format.BINARY)
     await cur.execute("select %s::text, %s::text", ["foo", None])
     assert cur.pgresult.fformat(0) == 1
 
@@ -117,7 +117,7 @@ async def test_execute_binary_result(aconn):
 @pytest.mark.parametrize("encoding", ["utf8", "latin9"])
 async def test_query_encode(aconn, encoding):
     await aconn.set_client_encoding(encoding)
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select '\u20ac'")
     (res,) = await cur.fetchone()
     assert res == "\u20ac"
@@ -125,7 +125,7 @@ async def test_query_encode(aconn, encoding):
 
 async def test_query_badenc(aconn):
     await aconn.set_client_encoding("latin1")
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(UnicodeEncodeError):
         await cur.execute("select '\u20ac'")
 
@@ -142,7 +142,7 @@ async def execmany(svcconn):
 
 
 async def test_executemany(aconn, execmany):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.executemany(
         "insert into execmany(num, data) values (%s, %s)",
         [(10, "hello"), (20, "world")],
@@ -153,7 +153,7 @@ async def test_executemany(aconn, execmany):
 
 
 async def test_executemany_name(aconn, execmany):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.executemany(
         "insert into execmany(num, data) values (%(num)s, %(data)s)",
         [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
@@ -164,7 +164,7 @@ async def test_executemany_name(aconn, execmany):
 
 
 async def test_executemany_rowcount(aconn, execmany):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.executemany(
         "insert into execmany(num, data) values (%s, %s)",
         [(10, "hello"), (20, "world")],
@@ -181,13 +181,13 @@ async def test_executemany_rowcount(aconn, execmany):
     ],
 )
 async def test_executemany_badquery(aconn, query):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(psycopg3.DatabaseError):
         await cur.executemany(query, [(10, "hello"), (20, "world")])
 
 
 async def test_callproc_args(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute(
         """
         create function testfunc(a int, b text) returns text[] language sql as
@@ -199,7 +199,7 @@ async def test_callproc_args(aconn):
 
 
 async def test_callproc_badparam(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(TypeError):
         await cur.callproc("lower", 42)
     with pytest.raises(TypeError):
@@ -209,7 +209,7 @@ async def test_callproc_badparam(aconn):
 async def test_callproc_dict(aconn):
     testfunc = make_testfunc(aconn)
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
 
     await cur.callproc(testfunc.name, [2])
     assert (await cur.fetchone()) == (4,)
@@ -234,13 +234,13 @@ async def test_callproc_dict_bad(aconn, args, exc):
     if "_p" in args:
         args[testfunc.param] = args.pop("_p")
 
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     with pytest.raises(exc):
         await cur.callproc(testfunc.name, args)
 
 
 async def test_rowcount(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
 
     await cur.execute("select 1 from generate_series(1, 42)")
     assert cur.rowcount == 42
@@ -260,7 +260,7 @@ async def test_rowcount(aconn):
 
 
 async def test_iter(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select generate_series(1, 3)")
     res = []
     async for rec in cur:
@@ -269,7 +269,7 @@ async def test_iter(aconn):
 
 
 async def test_iter_stop(aconn):
-    cur = aconn.cursor()
+    cur = await aconn.cursor()
     await cur.execute("select generate_series(1, 3)")
     async for rec in cur:
         assert rec == (1,)