From: Daniele Varrazzo Date: Sat, 11 Apr 2020 15:55:32 +0000 (+1200) Subject: Added first implementation of executemany based on prepared queries X-Git-Tag: 3.0.dev0~566 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d2a8f1fbc75f04406acf5d824fe8e60c456483a1;p=thirdparty%2Fpsycopg.git Added first implementation of executemany based on prepared queries There's a lot of repetition here --- diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 545732c22..6dfad1152 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -191,7 +191,7 @@ class BaseCursor: return generators.execute(self.connection.pgconn) - def _execute_results(self, results: List[pq.PGresult]) -> None: + def _execute_results(self, results: Sequence[pq.PGresult]) -> None: """ Implement part of execute() after waiting common to sync and async """ @@ -202,7 +202,7 @@ class BaseCursor: statuses = {res.status for res in results} badstats = statuses - {S.TUPLES_OK, S.COMMAND_OK, S.EMPTY_QUERY} if not badstats: - self._results = results + self._results = list(results) self.pgresult = results[0] return @@ -219,6 +219,75 @@ class BaseCursor: f" {', '.join(sorted(s.name for s in sorted(badstats)))}" ) + def _send_prepare( + self, name: bytes, query: Query, vars: Optional[Params] + ) -> "PQGen[List[pq.PGresult]]": + """ + Implement part of execute() before waiting common to sync and async + """ + from .adapt import Transformer + + if self.closed: + raise e.OperationalError("the cursor is closed") + + if self.connection.closed: + raise e.OperationalError("the connection is closed") + + if self.connection.status != self.connection.ConnStatus.OK: + raise e.InterfaceError( + f"cannot execute operations: the connection is" + f" in status {self.connection.status}" + ) + + self._reset() + self._transformer = Transformer(self) + + codec = self.connection.codec + + if isinstance(query, str): + query = codec.encode(query)[0] + + # process %% -> % only if there are paramters, even if empty list + if vars is not None: + query, formats, order = query2pg(query, vars, codec) + + if order is not None: + assert isinstance(vars, Mapping) + vars = reorder_params(vars, order) + assert isinstance(vars, Sequence) + params, types = self._transformer.dump_sequence(vars, formats) + self.connection.pgconn.send_prepare( + name, query, param_types=types, + ) + self._order = order + self._formats = formats + return generators.execute(self.connection.pgconn) + + def _send_query_prepared( + self, name: bytes, vars: Optional[Params] + ) -> "PQGen[List[pq.PGresult]]": + if self.connection.closed: + raise e.OperationalError("the connection is closed") + + if self.connection.status != self.connection.ConnStatus.OK: + raise e.InterfaceError( + f"cannot execute operations: the connection is" + f" in status {self.connection.status}" + ) + + if self._order is not None: + assert isinstance(vars, Mapping) + vars = reorder_params(vars, self._order) + assert isinstance(vars, Sequence) + params, types = self._transformer.dump_sequence(vars, self._formats) + self.connection.pgconn.send_query_prepared( + name, + params, + param_formats=self._formats, + result_format=pq.Format(self.binary), + ) + return generators.execute(self.connection.pgconn) + def nextset(self) -> Optional[bool]: self._iresult += 1 if self._iresult < len(self._results): @@ -264,10 +333,17 @@ class Cursor(BaseCursor): self, query: Query, vars_seq: Sequence[Params] ) -> "Cursor": with self.connection.lock: - for vars in vars_seq: - gen = self._execute_send(query, vars) - results = self.connection.wait(gen) - self._execute_results(results) + for i, vars in enumerate(vars_seq): + if i == 0: + gen = self._send_prepare(b"", query, vars) + (result,) = self.connection.wait(gen) + if result.status == self.ExecStatus.FATAL_ERROR: + raise e.error_from_result(result) + + gen = self._send_query_prepared(b"", vars) + (result,) = self.connection.wait(gen) + self._execute_results((result,)) + return self def fetchone(self) -> Optional[Sequence[Any]]: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index e9ca7a359..86b589cbe 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -100,3 +100,64 @@ def test_query_badenc(conn): cur = conn.cursor() with pytest.raises(UnicodeEncodeError): cur.execute("select '\u20ac'") + + +@pytest.fixture(scope="session") +def _execmany(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop table if exists execmany; + create table execmany (id serial primary key, num integer, data text) + """ + ) + + +@pytest.fixture(scope="function") +def execmany(svcconn, _execmany): + cur = svcconn.cursor() + cur.execute("truncate table execmany") + + +def test_executemany(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(10, "hello"), (20, "world")] + + +def test_executemany_name(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%(num)s, %(data)s)", + [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}], + ) + cur.execute("select num, data from execmany order by 1") + assert cur.fetchall() == [(11, "hello"), (21, "world")] + + +@pytest.mark.xfail +def test_executemany_rowcount(conn, execmany): + cur = conn.cursor() + cur.executemany( + "insert into execmany(num, data) values (%s, %s)", + [(10, "hello"), (20, "world")], + ) + assert cur.rowcount == 2 + + +@pytest.mark.parametrize( + "query", + [ + "insert into nosuchtable values (%s, %s)", + "copy (select %s, %s) to stdout", + "wat (%s, %s)", + ], +) +def test_executemany_badquery(conn, query): + cur = conn.cursor() + with pytest.raises(psycopg3.DatabaseError): + cur.executemany(query, [(10, "hello"), (20, "world")])