From: Daniele Varrazzo Date: Tue, 14 Apr 2020 07:03:02 +0000 (+1200) Subject: Added async implementation for executemany, fetchmany, fetchall X-Git-Tag: 3.0.dev0~558 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7b3b7f9fd3ab47522833d24568ba97a71a74ef7d;p=thirdparty%2Fpsycopg.git Added async implementation for executemany, fetchmany, fetchall --- diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index e1b3f8777..13856b50a 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -344,12 +344,61 @@ class AsyncCursor(BaseCursor): self._execute_results(results) return self + async def executemany( + self, query: Query, vars_seq: Sequence[Params] + ) -> "AsyncCursor": + async with self.connection.lock: + self._start_query() + for i, vars in enumerate(vars_seq): + if i == 0: + pgq = self._send_prepare(b"", query, vars) + gen = generators.execute(self.connection.pgconn) + (result,) = await self.connection.wait(gen) + if result.status == self.ExecStatus.FATAL_ERROR: + raise e.error_from_result(result) + else: + pgq.dump(vars) + + self._send_query_prepared(b"", pgq) + gen = generators.execute(self.connection.pgconn) + (result,) = await self.connection.wait(gen) + self._execute_results((result,)) + + return self + async def fetchone(self) -> Optional[Sequence[Any]]: rv = self._load_row(self._pos) if rv is not None: self._pos += 1 return rv + async def fetchmany( + self, size: Optional[int] = None + ) -> List[Sequence[Any]]: + if size is None: + size = self.arraysize + + rv: List[Sequence[Any]] = [] + while len(rv) < size: + row = self._load_row(self._pos) + if row is None: + break + self._pos += 1 + rv.append(row) + + return rv + + async def fetchall(self) -> List[Sequence[Any]]: + rv: List[Sequence[Any]] = [] + while 1: + row = self._load_row(self._pos) + if row is None: + break + self._pos += 1 + rv.append(row) + + return rv + class NamedCursorMixin: pass diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py index 1f785bdb8..aa749ee4b 100644 --- a/tests/test_async_cursor.py +++ b/tests/test_async_cursor.py @@ -75,3 +75,109 @@ def test_fetchone(aconn, loop): assert row[2] is None row = loop.run_until_complete(cur.fetchone()) assert row is None + + +def test_execute_binary_result(aconn, loop): + cur = aconn.cursor(binary=True) + loop.run_until_complete(cur.execute("select %s, %s", ["foo", None])) + assert cur.pgresult.fformat(0) == 1 + + row = loop.run_until_complete(cur.fetchone()) + assert row[0] == "foo" + assert row[1] is None + row = loop.run_until_complete(cur.fetchone()) + assert row is None + + +@pytest.mark.parametrize("encoding", ["utf8", "latin9"]) +def test_query_encode(aconn, loop, encoding): + loop.run_until_complete(aconn.set_client_encoding(encoding)) + cur = aconn.cursor() + loop.run_until_complete(cur.execute("select '\u20ac'")) + (res,) = loop.run_until_complete(cur.fetchone()) + assert res == "\u20ac" + + +def test_query_badenc(aconn, loop): + loop.run_until_complete(aconn.set_client_encoding("latin1")) + cur = aconn.cursor() + with pytest.raises(UnicodeEncodeError): + loop.run_until_complete(cur.execute("select '\u20ac'")) + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(aconn, loop, execmany): + cur = aconn.cursor() + loop.run_until_complete( + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + ) + loop.run_until_complete( + cur.execute("select num, data from execmany order by 1") + ) + rv = loop.run_until_complete(cur.fetchall()) + assert rv == [(10, "hello"), (20, "world")] + + +def test_executemany_name(aconn, loop, execmany): + cur = aconn.cursor() + loop.run_until_complete( + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [ + {"num": 11, "data": "hello", "x": 1}, + {"num": 21, "data": "world"}, + ], + ) + ) + loop.run_until_complete( + cur.execute("select num, data from execmany order by 1") + ) + rv = loop.run_until_complete(cur.fetchall()) + assert rv == [(11, "hello"), (21, "world")] + + +@pytest.mark.xfail +def test_executemany_rowcount(aconn, loop, execmany): + cur = aconn.cursor() + loop.run_until_complete( + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + ) + assert cur.rowcount == 2 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(aconn, loop, query): + cur = aconn.cursor() + with pytest.raises(psycopg3.DatabaseError): + loop.run_until_complete( + cur.executemany(query, [(10, "hello"), (20, "world")]) + )