self._execute_results(results)
return self
+ async def executemany(
+ self, query: Query, vars_seq: Sequence[Params]
+ ) -> "AsyncCursor":
+ async with self.connection.lock:
+ self._start_query()
+ for i, vars in enumerate(vars_seq):
+ if i == 0:
+ pgq = self._send_prepare(b"", query, vars)
+ gen = generators.execute(self.connection.pgconn)
+ (result,) = await self.connection.wait(gen)
+ if result.status == self.ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(result)
+ else:
+ pgq.dump(vars)
+
+ self._send_query_prepared(b"", pgq)
+ gen = generators.execute(self.connection.pgconn)
+ (result,) = await self.connection.wait(gen)
+ self._execute_results((result,))
+
+ return self
+
async def fetchone(self) -> Optional[Sequence[Any]]:
rv = self._load_row(self._pos)
if rv is not None:
self._pos += 1
return rv
+ async def fetchmany(
+ self, size: Optional[int] = None
+ ) -> List[Sequence[Any]]:
+ if size is None:
+ size = self.arraysize
+
+ rv: List[Sequence[Any]] = []
+ while len(rv) < size:
+ row = self._load_row(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ rv.append(row)
+
+ return rv
+
+ async def fetchall(self) -> List[Sequence[Any]]:
+ rv: List[Sequence[Any]] = []
+ while 1:
+ row = self._load_row(self._pos)
+ if row is None:
+ break
+ self._pos += 1
+ rv.append(row)
+
+ return rv
+
class NamedCursorMixin:
pass
assert row[2] is None
row = loop.run_until_complete(cur.fetchone())
assert row is None
+
+
+def test_execute_binary_result(aconn, loop):
+ cur = aconn.cursor(binary=True)
+ loop.run_until_complete(cur.execute("select %s, %s", ["foo", None]))
+ assert cur.pgresult.fformat(0) == 1
+
+ row = loop.run_until_complete(cur.fetchone())
+ assert row[0] == "foo"
+ assert row[1] is None
+ row = loop.run_until_complete(cur.fetchone())
+ assert row is None
+
+
+@pytest.mark.parametrize("encoding", ["utf8", "latin9"])
+def test_query_encode(aconn, loop, encoding):
+ loop.run_until_complete(aconn.set_client_encoding(encoding))
+ cur = aconn.cursor()
+ loop.run_until_complete(cur.execute("select '\u20ac'"))
+ (res,) = loop.run_until_complete(cur.fetchone())
+ assert res == "\u20ac"
+
+
+def test_query_badenc(aconn, loop):
+ loop.run_until_complete(aconn.set_client_encoding("latin1"))
+ cur = aconn.cursor()
+ with pytest.raises(UnicodeEncodeError):
+ loop.run_until_complete(cur.execute("select '\u20ac'"))
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(aconn, loop, execmany):
+ cur = aconn.cursor()
+ loop.run_until_complete(
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ )
+ loop.run_until_complete(
+ cur.execute("select num, data from execmany order by 1")
+ )
+ rv = loop.run_until_complete(cur.fetchall())
+ assert rv == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(aconn, loop, execmany):
+ cur = aconn.cursor()
+ loop.run_until_complete(
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [
+ {"num": 11, "data": "hello", "x": 1},
+ {"num": 21, "data": "world"},
+ ],
+ )
+ )
+ loop.run_until_complete(
+ cur.execute("select num, data from execmany order by 1")
+ )
+ rv = loop.run_until_complete(cur.fetchall())
+ assert rv == [(11, "hello"), (21, "world")]
+
+
+@pytest.mark.xfail
+def test_executemany_rowcount(aconn, loop, execmany):
+ cur = aconn.cursor()
+ loop.run_until_complete(
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ )
+ assert cur.rowcount == 2
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(aconn, loop, query):
+ cur = aconn.cursor()
+ with pytest.raises(psycopg3.DatabaseError):
+ loop.run_until_complete(
+ cur.executemany(query, [(10, "hello"), (20, "world")])
+ )