return generators.execute(self.connection.pgconn)
- def _execute_results(self, results: List[pq.PGresult]) -> None:
+ def _execute_results(self, results: Sequence[pq.PGresult]) -> None:
"""
Implement part of execute() after waiting common to sync and async
"""
statuses = {res.status for res in results}
badstats = statuses - {S.TUPLES_OK, S.COMMAND_OK, S.EMPTY_QUERY}
if not badstats:
- self._results = results
+ self._results = list(results)
self.pgresult = results[0]
return
f" {', '.join(sorted(s.name for s in sorted(badstats)))}"
)
+ def _send_prepare(
+ self, name: bytes, query: Query, vars: Optional[Params]
+ ) -> "PQGen[List[pq.PGresult]]":
+ """
+ Implement part of execute() before waiting common to sync and async
+ """
+ from .adapt import Transformer
+
+ if self.closed:
+ raise e.OperationalError("the cursor is closed")
+
+ if self.connection.closed:
+ raise e.OperationalError("the connection is closed")
+
+ if self.connection.status != self.connection.ConnStatus.OK:
+ raise e.InterfaceError(
+ f"cannot execute operations: the connection is"
+ f" in status {self.connection.status}"
+ )
+
+ self._reset()
+ self._transformer = Transformer(self)
+
+ codec = self.connection.codec
+
+ if isinstance(query, str):
+ query = codec.encode(query)[0]
+
+ # process %% -> % only if there are paramters, even if empty list
+ if vars is not None:
+ query, formats, order = query2pg(query, vars, codec)
+
+ if order is not None:
+ assert isinstance(vars, Mapping)
+ vars = reorder_params(vars, order)
+ assert isinstance(vars, Sequence)
+ params, types = self._transformer.dump_sequence(vars, formats)
+ self.connection.pgconn.send_prepare(
+ name, query, param_types=types,
+ )
+ self._order = order
+ self._formats = formats
+ return generators.execute(self.connection.pgconn)
+
+ def _send_query_prepared(
+ self, name: bytes, vars: Optional[Params]
+ ) -> "PQGen[List[pq.PGresult]]":
+ if self.connection.closed:
+ raise e.OperationalError("the connection is closed")
+
+ if self.connection.status != self.connection.ConnStatus.OK:
+ raise e.InterfaceError(
+ f"cannot execute operations: the connection is"
+ f" in status {self.connection.status}"
+ )
+
+ if self._order is not None:
+ assert isinstance(vars, Mapping)
+ vars = reorder_params(vars, self._order)
+ assert isinstance(vars, Sequence)
+ params, types = self._transformer.dump_sequence(vars, self._formats)
+ self.connection.pgconn.send_query_prepared(
+ name,
+ params,
+ param_formats=self._formats,
+ result_format=pq.Format(self.binary),
+ )
+ return generators.execute(self.connection.pgconn)
+
def nextset(self) -> Optional[bool]:
self._iresult += 1
if self._iresult < len(self._results):
self, query: Query, vars_seq: Sequence[Params]
) -> "Cursor":
with self.connection.lock:
- for vars in vars_seq:
- gen = self._execute_send(query, vars)
- results = self.connection.wait(gen)
- self._execute_results(results)
+ for i, vars in enumerate(vars_seq):
+ if i == 0:
+ gen = self._send_prepare(b"", query, vars)
+ (result,) = self.connection.wait(gen)
+ if result.status == self.ExecStatus.FATAL_ERROR:
+ raise e.error_from_result(result)
+
+ gen = self._send_query_prepared(b"", vars)
+ (result,) = self.connection.wait(gen)
+ self._execute_results((result,))
+
return self
def fetchone(self) -> Optional[Sequence[Any]]:
cur = conn.cursor()
with pytest.raises(UnicodeEncodeError):
cur.execute("select '\u20ac'")
+
+
+@pytest.fixture(scope="session")
+def _execmany(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop table if exists execmany;
+ create table execmany (id serial primary key, num integer, data text)
+ """
+ )
+
+
+@pytest.fixture(scope="function")
+def execmany(svcconn, _execmany):
+ cur = svcconn.cursor()
+ cur.execute("truncate table execmany")
+
+
+def test_executemany(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(10, "hello"), (20, "world")]
+
+
+def test_executemany_name(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%(num)s, %(data)s)",
+ [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
+ )
+ cur.execute("select num, data from execmany order by 1")
+ assert cur.fetchall() == [(11, "hello"), (21, "world")]
+
+
+@pytest.mark.xfail
+def test_executemany_rowcount(conn, execmany):
+ cur = conn.cursor()
+ cur.executemany(
+ "insert into execmany(num, data) values (%s, %s)",
+ [(10, "hello"), (20, "world")],
+ )
+ assert cur.rowcount == 2
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "insert into nosuchtable values (%s, %s)",
+ "copy (select %s, %s) to stdout",
+ "wat (%s, %s)",
+ ],
+)
+def test_executemany_badquery(conn, query):
+ cur = conn.cursor()
+ with pytest.raises(psycopg3.DatabaseError):
+ cur.executemany(query, [(10, "hello"), (20, "world")])